@@ -22,22 +22,65 @@ macro_rules! api_dir_list {
22
22
. collect:: <Vec <String >>( )
23
23
. into_iter( )
24
24
} else {
25
- $api. info( )
26
- . map( |repo| {
27
- repo. siblings
28
- . iter( )
29
- . map( |x| x. rfilename. clone( ) )
30
- . collect:: <Vec <String >>( )
31
- } )
32
- . unwrap_or_else( |e| {
33
- if $should_panic {
34
- panic!( "Could not get directory listing from API: {:?}" , e)
35
- } else {
36
- tracing:: warn!( "Could not get directory listing from API: {:?}" , e) ;
37
- Vec :: <String >:: new( )
38
- }
39
- } )
40
- . into_iter( )
25
+ let sanitized_id = std:: path:: Path :: new( $model_id)
26
+ . display( )
27
+ . to_string( )
28
+ . replace( "/" , "-" ) ;
29
+
30
+ let home_folder = if dirs:: home_dir( ) . is_some( ) {
31
+ let mut path = dirs:: home_dir( ) . unwrap( ) ;
32
+ path. push( ".cache/huggingface/hub/" ) ;
33
+ if !path. exists( ) {
34
+ let _ = std:: fs:: create_dir_all( & path) ;
35
+ }
36
+ path
37
+ } else {
38
+ "./" . into( )
39
+ } ;
40
+
41
+ let cache_dir: std:: path:: PathBuf = std:: env:: var( "HF_HUB_CACHE" )
42
+ . map( std:: path:: PathBuf :: from)
43
+ . unwrap_or( home_folder. into( ) ) ;
44
+ let cache_file = cache_dir. join( format!( "{sanitized_id}_repo_list.json" ) ) ;
45
+ if std:: path:: Path :: new( & cache_file) . exists( ) {
46
+ use std:: io:: Read ;
47
+ // Read from cache
48
+ let mut file = std:: fs:: File :: open( & cache_file) . expect( "Could not open cache file" ) ;
49
+ let mut contents = String :: new( ) ;
50
+ file. read_to_string( & mut contents)
51
+ . expect( "Could not read cache file" ) ;
52
+ let cache: $crate:: pipeline:: FileListCache =
53
+ serde_json:: from_str( & contents) . expect( "Could not parse cache JSON" ) ;
54
+ tracing:: info!( "Read from cache file {:?}" , cache_file) ;
55
+ cache. files. into_iter( )
56
+ } else {
57
+ $api. info( )
58
+ . map( |repo| {
59
+ let files: Vec <String > = repo
60
+ . siblings
61
+ . iter( )
62
+ . map( |x| x. rfilename. clone( ) )
63
+ . collect:: <Vec <String >>( ) ;
64
+ // Save to cache
65
+ let cache = $crate:: pipeline:: FileListCache {
66
+ files: files. clone( ) ,
67
+ } ;
68
+ let json = serde_json:: to_string_pretty( & cache)
69
+ . expect( "Could not serialize cache" ) ;
70
+ let ret = std:: fs:: write( & cache_file, json) ;
71
+ tracing:: info!( "Write to cache file {:?}, {:?}" , cache_file, ret) ;
72
+ files
73
+ } )
74
+ . unwrap_or_else( |e| {
75
+ if $should_panic {
76
+ panic!( "Could not get directory listing from API: {:?}" , e)
77
+ } else {
78
+ tracing:: warn!( "Could not get directory listing from API: {:?}" , e) ;
79
+ Vec :: <String >:: new( )
80
+ }
81
+ } )
82
+ . into_iter( )
83
+ }
41
84
}
42
85
} ;
43
86
}
@@ -117,10 +160,9 @@ macro_rules! get_paths {
117
160
revision. clone( ) ,
118
161
$this. xlora_order. as_ref( ) ,
119
162
) ?;
120
- let gen_conf = if $crate:: api_dir_list!( api, model_id, false )
121
- . collect:: <Vec <_>>( )
122
- . contains( & "generation_config.json" . to_string( ) )
123
- {
163
+ let dir_list = $crate:: api_dir_list!( api, model_id, false ) . collect:: <Vec <_>>( ) ;
164
+
165
+ let gen_conf = if dir_list. contains( & "generation_config.json" . to_string( ) ) {
124
166
info!( "Loading `generation_config.json` at `{}`" , $this. model_id) ;
125
167
Some ( $crate:: api_get_file!(
126
168
api,
@@ -130,10 +172,7 @@ macro_rules! get_paths {
130
172
} else {
131
173
None
132
174
} ;
133
- let preprocessor_config = if $crate:: api_dir_list!( api, model_id, false )
134
- . collect:: <Vec <_>>( )
135
- . contains( & "preprocessor_config.json" . to_string( ) )
136
- {
175
+ let preprocessor_config = if dir_list. contains( & "preprocessor_config.json" . to_string( ) ) {
137
176
info!( "Loading `preprocessor_config.json` at `{}`" , $this. model_id) ;
138
177
Some ( $crate:: api_get_file!(
139
178
api,
@@ -143,10 +182,7 @@ macro_rules! get_paths {
143
182
} else {
144
183
None
145
184
} ;
146
- let processor_config = if $crate:: api_dir_list!( api, model_id, false )
147
- . collect:: <Vec <_>>( )
148
- . contains( & "processor_config.json" . to_string( ) )
149
- {
185
+ let processor_config = if dir_list. contains( & "processor_config.json" . to_string( ) ) {
150
186
info!( "Loading `processor_config.json` at `{}`" , $this. model_id) ;
151
187
Some ( $crate:: api_get_file!(
152
188
api,
@@ -167,10 +203,7 @@ macro_rules! get_paths {
167
203
model_id
168
204
) )
169
205
} ;
170
- let chat_template_json_filename = if $crate:: api_dir_list!( api, model_id, false )
171
- . collect:: <Vec <_>>( )
172
- . contains( & "chat_template.json" . to_string( ) )
173
- {
206
+ let chat_template_json_filename = if dir_list. contains( & "chat_template.json" . to_string( ) ) {
174
207
info!( "Loading `chat_template.json` at `{}`" , $this. model_id) ;
175
208
Some ( $crate:: api_get_file!( api, "chat_template.json" , model_id) )
176
209
} else {
0 commit comments