Skip to content

Commit bd13207

Browse files
authored
Improve code quality and suppress HF token warnings (#252)
1 parent d2b79cc commit bd13207

13 files changed

+342
-231
lines changed

README.md

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -208,33 +208,17 @@ TabPFN automatically downloads model weights when first used. For offline usage:
208208
- macOS: `~/Library/Caches/tabpfn/`
209209
- Linux: `~/.cache/tabpfn/`
210210

211-
**Quick Download Script**
211+
**Using the Provided Download Script**
212212

213-
```python
214-
import requests
215-
from tabpfn.utils import _user_cache_dir
216-
import sys
217-
218-
# Get default cache directory using TabPFN's internal function
219-
cache_dir = _user_cache_dir(platform=sys.platform)
220-
cache_dir.mkdir(parents=True, exist_ok=True)
221-
222-
# Define models to download
223-
models = {
224-
"tabpfn-v2-classifier.ckpt": "https://huggingface.co/Prior-Labs/TabPFN-v2-clf/resolve/main/tabpfn-v2-classifier.ckpt",
225-
"tabpfn-v2-regressor.ckpt": "https://huggingface.co/Prior-Labs/TabPFN-v2-reg/resolve/main/tabpfn-v2-regressor.ckpt",
226-
}
227-
228-
# Download each model
229-
for name, url in models.items():
230-
path = cache_dir / name
231-
print(f"Downloading {name} to {path}")
232-
with open(path, "wb") as f:
233-
f.write(requests.get(url).content)
213+
If you have the TabPFN repository, you can use the included script to download all models (including ensemble variants):
234214

235-
print(f"Models downloaded to {cache_dir}")
215+
```bash
216+
# After installing TabPFN
217+
python scripts/download_all_models.py
236218
```
237219

220+
This script will download the main classifier and regressor models, as well as all ensemble variant models to your system's default cache directory.
221+
238222
**Q: I'm getting a `pickle` error when loading the model. What should I do?**
239223
A: Try the following:
240224
- Download the newest version of tabpfn `pip install tabpfn --upgrade`

examples/tabpfn_for_binary_classification.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
# Copyright (c) Prior Labs GmbH 2025.
2+
"""Example of using TabPFN for binary classification.
3+
4+
This example demonstrates how to use TabPFNClassifier on a binary classification task
5+
using the breast cancer dataset from scikit-learn.
6+
"""
27

38
from sklearn.datasets import load_breast_cancer
49
from sklearn.metrics import accuracy_score, roc_auc_score

examples/tabpfn_for_multiclass_classification.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
# Copyright (c) Prior Labs GmbH 2025.
2+
"""Example of using TabPFN for multiclass classification.
3+
4+
This example demonstrates how to use TabPFNClassifier on a multiclass classification task
5+
using the Iris dataset from scikit-learn.
6+
"""
27

38
from sklearn.datasets import load_iris
49
from sklearn.metrics import accuracy_score, roc_auc_score

examples/tabpfn_for_regression.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
# Copyright (c) Prior Labs GmbH 2025.
2+
"""Example of using TabPFN for regression.
3+
4+
This example demonstrates how to use TabPFNRegressor on a regression task
5+
using the diabetes dataset from scikit-learn.
6+
"""
27

38
from sklearn.datasets import load_diabetes
49
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

scripts/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Scripts package for TabPFN."""
2+
3+
from __future__ import annotations

scripts/download_all_models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Download all TabPFN model files for offline use."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
import sys
7+
8+
from tabpfn.model.loading import _user_cache_dir, download_all_models
9+
10+
11+
def main() -> None:
12+
"""Download all TabPFN models and save to cache directory."""
13+
# Configure logging
14+
logging.basicConfig(level=logging.INFO, format="%(message)s")
15+
logger = logging.getLogger(__name__)
16+
17+
# Get default cache directory using TabPFN's internal function
18+
cache_dir = _user_cache_dir(platform=sys.platform, appname="tabpfn")
19+
cache_dir.mkdir(parents=True, exist_ok=True)
20+
21+
logger.info(f"Downloading all models to {cache_dir}")
22+
download_all_models(cache_dir)
23+
logger.info(f"All models downloaded to {cache_dir}")
24+
25+
26+
if __name__ == "__main__":
27+
main()

scripts/get_max_dependencies.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
1+
"""Generate requirements.txt with maximum allowed dependency versions."""
2+
3+
from __future__ import annotations
4+
15
import re
6+
from pathlib import Path
7+
28

39
def main() -> None:
4-
with open('pyproject.toml', 'r') as f:
10+
"""Extract maximum dependency versions and write to requirements.txt."""
11+
with Path("pyproject.toml").open() as f:
512
content = f.read()
613

714
# Find dependencies section using regex
8-
deps_match = re.search(r'dependencies\s*=\s*\[(.*?)\]', content, re.DOTALL)
15+
deps_match = re.search(r"dependencies\s*=\s*\[(.*?)\]", content, re.DOTALL)
916
if deps_match:
10-
deps = [d.strip(' "\'') for d in deps_match.group(1).strip().split('\n') if d.strip()]
17+
deps = [
18+
d.strip(' "\'')
19+
for d in deps_match.group(1).strip().split("\n")
20+
if d.strip()
21+
]
1122
max_reqs = []
1223
for dep in deps:
1324
# Check for maximum version constraint
@@ -18,11 +29,12 @@ def main() -> None:
1829
max_reqs.append(f"{package}<{max_ver}")
1930
else:
2031
# If no max version, just use the package name
21-
package = re.match(r'([^>=<\s]+)', dep).group(1)
32+
package = re.match(r"([^>=<\s]+)", dep).group(1)
2233
max_reqs.append(package)
2334

24-
with open('requirements.txt', 'w') as f:
25-
f.write('\n'.join(max_reqs))
35+
with Path("requirements.txt").open("w") as f:
36+
f.write("\n".join(max_reqs))
37+
2638

27-
if __name__ == '__main__':
39+
if __name__ == "__main__":
2840
main()

scripts/get_min_dependencies.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
1+
"""Generate requirements.txt with minimum dependency versions."""
2+
3+
from __future__ import annotations
4+
15
import re
6+
from pathlib import Path
7+
28

39
def main() -> None:
4-
with open('pyproject.toml', 'r') as f:
10+
"""Extract minimum dependency versions and write to requirements.txt."""
11+
with Path("pyproject.toml").open() as f:
512
content = f.read()
613

714
# Find dependencies section using regex
8-
deps_match = re.search(r'dependencies\s*=\s*\[(.*?)\]', content, re.DOTALL)
15+
deps_match = re.search(r"dependencies\s*=\s*\[(.*?)\]", content, re.DOTALL)
916
if deps_match:
10-
deps = [d.strip(' "\'') for d in deps_match.group(1).strip().split('\n') if d.strip()]
17+
deps = [
18+
d.strip(' "\'')
19+
for d in deps_match.group(1).strip().split("\n")
20+
if d.strip()
21+
]
1122
min_reqs = []
1223
for dep in deps:
1324
match = re.match(r'([^>=<\s]+)\s*>=\s*([^,\s"\']+)', dep)
1425
if match:
1526
package, min_ver = match.groups()
1627
min_reqs.append(f"{package}=={min_ver}")
1728

18-
with open('requirements.txt', 'w') as f:
19-
f.write('\n'.join(min_reqs))
29+
with Path("requirements.txt").open("w") as f:
30+
f.write("\n".join(min_reqs))
31+
2032

21-
if __name__ == '__main__':
33+
if __name__ == "__main__":
2234
main()

src/tabpfn/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@
2525
InferenceEngineCachePreprocessing,
2626
InferenceEngineOnDemand,
2727
)
28-
from tabpfn.utils import (
29-
infer_fp16_inference_mode,
30-
load_model_criterion_config,
31-
)
28+
from tabpfn.model.loading import load_model_criterion_config
29+
from tabpfn.utils import infer_fp16_inference_mode
3230

3331
if TYPE_CHECKING:
3432
import numpy as np

0 commit comments

Comments
 (0)