File tree Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Original file line number Diff line number Diff line change 2
2
3
3
from __future__ import annotations
4
4
5
+ import argparse
5
6
import logging
6
7
import sys
8
+ from pathlib import Path
7
9
8
10
from tabpfn .model .loading import _user_cache_dir , download_all_models
9
11
10
12
11
13
def main () -> None :
12
14
"""Download all TabPFN models and save to cache directory."""
15
+ # Parse command-line arguments
16
+ parser = argparse .ArgumentParser (description = "Download all TabPFN models for offline use." )
17
+ parser .add_argument (
18
+ "--cache-dir" ,
19
+ type = Path ,
20
+ default = None ,
21
+ help = "Optional path to override the default cache directory." ,
22
+ )
23
+ args = parser .parse_args ()
24
+
13
25
# Configure logging
14
26
logging .basicConfig (level = logging .INFO , format = "%(message)s" )
15
27
logger = logging .getLogger (__name__ )
16
28
17
- # Get default cache directory using TabPFN's internal function
18
- cache_dir = _user_cache_dir (platform = sys .platform , appname = "tabpfn" )
29
+ # Determine cache directory
30
+ cache_dir = args . cache_dir or _user_cache_dir (platform = sys .platform , appname = "tabpfn" )
19
31
cache_dir .mkdir (parents = True , exist_ok = True )
20
32
21
33
logger .info (f"Downloading all models to { cache_dir } " )
You can’t perform that action at this time.
0 commit comments