Skip to content

Commit 8a0e9ad

Browse files
committed
Slightly improved download script
1 parent bd13207 commit 8a0e9ad

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

scripts/download_all_models.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,32 @@
22

33
from __future__ import annotations
44

5+
import argparse
56
import logging
67
import sys
8+
from pathlib import Path
79

810
from tabpfn.model.loading import _user_cache_dir, download_all_models
911

1012

1113
def main() -> None:
1214
"""Download all TabPFN models and save to cache directory."""
15+
# Parse command-line arguments
16+
parser = argparse.ArgumentParser(description="Download all TabPFN models for offline use.")
17+
parser.add_argument(
18+
"--cache-dir",
19+
type=Path,
20+
default=None,
21+
help="Optional path to override the default cache directory.",
22+
)
23+
args = parser.parse_args()
24+
1325
# Configure logging
1426
logging.basicConfig(level=logging.INFO, format="%(message)s")
1527
logger = logging.getLogger(__name__)
1628

17-
# Get default cache directory using TabPFN's internal function
18-
cache_dir = _user_cache_dir(platform=sys.platform, appname="tabpfn")
29+
# Determine cache directory
30+
cache_dir = args.cache_dir or _user_cache_dir(platform=sys.platform, appname="tabpfn")
1931
cache_dir.mkdir(parents=True, exist_ok=True)
2032

2133
logger.info(f"Downloading all models to {cache_dir}")

0 commit comments

Comments
 (0)