Skip to content

Commit 37a55f9

Browse files
authored
Remove duplicate calls for api_dir_list (#1474)
* Remove duplicate calls for api_dir_list * Support local cache for api_dir_list * Fix home folder for metal * Capitalized
1 parent f3b1afa commit 37a55f9

File tree

3 files changed

+79
-36
lines changed

3 files changed

+79
-36
lines changed

mistralrs-core/src/pipeline/macros.rs

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,65 @@ macro_rules! api_dir_list {
2222
.collect::<Vec<String>>()
2323
.into_iter()
2424
} else {
25-
$api.info()
26-
.map(|repo| {
27-
repo.siblings
28-
.iter()
29-
.map(|x| x.rfilename.clone())
30-
.collect::<Vec<String>>()
31-
})
32-
.unwrap_or_else(|e| {
33-
if $should_panic {
34-
panic!("Could not get directory listing from API: {:?}", e)
35-
} else {
36-
tracing::warn!("Could not get directory listing from API: {:?}", e);
37-
Vec::<String>::new()
38-
}
39-
})
40-
.into_iter()
25+
let sanitized_id = std::path::Path::new($model_id)
26+
.display()
27+
.to_string()
28+
.replace("/", "-");
29+
30+
let home_folder = if dirs::home_dir().is_some() {
31+
let mut path = dirs::home_dir().unwrap();
32+
path.push(".cache/huggingface/hub/");
33+
if !path.exists() {
34+
let _ = std::fs::create_dir_all(&path);
35+
}
36+
path
37+
} else {
38+
"./".into()
39+
};
40+
41+
let cache_dir: std::path::PathBuf = std::env::var("HF_HUB_CACHE")
42+
.map(std::path::PathBuf::from)
43+
.unwrap_or(home_folder.into());
44+
let cache_file = cache_dir.join(format!("{sanitized_id}_repo_list.json"));
45+
if std::path::Path::new(&cache_file).exists() {
46+
use std::io::Read;
47+
// Read from cache
48+
let mut file = std::fs::File::open(&cache_file).expect("Could not open cache file");
49+
let mut contents = String::new();
50+
file.read_to_string(&mut contents)
51+
.expect("Could not read cache file");
52+
let cache: $crate::pipeline::FileListCache =
53+
serde_json::from_str(&contents).expect("Could not parse cache JSON");
54+
tracing::info!("Read from cache file {:?}", cache_file);
55+
cache.files.into_iter()
56+
} else {
57+
$api.info()
58+
.map(|repo| {
59+
let files: Vec<String> = repo
60+
.siblings
61+
.iter()
62+
.map(|x| x.rfilename.clone())
63+
.collect::<Vec<String>>();
64+
// Save to cache
65+
let cache = $crate::pipeline::FileListCache {
66+
files: files.clone(),
67+
};
68+
let json = serde_json::to_string_pretty(&cache)
69+
.expect("Could not serialize cache");
70+
let ret = std::fs::write(&cache_file, json);
71+
tracing::info!("Write to cache file {:?}, {:?}", cache_file, ret);
72+
files
73+
})
74+
.unwrap_or_else(|e| {
75+
if $should_panic {
76+
panic!("Could not get directory listing from API: {:?}", e)
77+
} else {
78+
tracing::warn!("Could not get directory listing from API: {:?}", e);
79+
Vec::<String>::new()
80+
}
81+
})
82+
.into_iter()
83+
}
4184
}
4285
};
4386
}
@@ -117,10 +160,9 @@ macro_rules! get_paths {
117160
revision.clone(),
118161
$this.xlora_order.as_ref(),
119162
)?;
120-
let gen_conf = if $crate::api_dir_list!(api, model_id, false)
121-
.collect::<Vec<_>>()
122-
.contains(&"generation_config.json".to_string())
123-
{
163+
let dir_list = $crate::api_dir_list!(api, model_id, false).collect::<Vec<_>>();
164+
165+
let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
124166
info!("Loading `generation_config.json` at `{}`", $this.model_id);
125167
Some($crate::api_get_file!(
126168
api,
@@ -130,10 +172,7 @@ macro_rules! get_paths {
130172
} else {
131173
None
132174
};
133-
let preprocessor_config = if $crate::api_dir_list!(api, model_id, false)
134-
.collect::<Vec<_>>()
135-
.contains(&"preprocessor_config.json".to_string())
136-
{
175+
let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string()) {
137176
info!("Loading `preprocessor_config.json` at `{}`", $this.model_id);
138177
Some($crate::api_get_file!(
139178
api,
@@ -143,10 +182,7 @@ macro_rules! get_paths {
143182
} else {
144183
None
145184
};
146-
let processor_config = if $crate::api_dir_list!(api, model_id, false)
147-
.collect::<Vec<_>>()
148-
.contains(&"processor_config.json".to_string())
149-
{
185+
let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
150186
info!("Loading `processor_config.json` at `{}`", $this.model_id);
151187
Some($crate::api_get_file!(
152188
api,
@@ -167,10 +203,7 @@ macro_rules! get_paths {
167203
model_id
168204
))
169205
};
170-
let chat_template_json_filename = if $crate::api_dir_list!(api, model_id, false)
171-
.collect::<Vec<_>>()
172-
.contains(&"chat_template.json".to_string())
173-
{
206+
let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
174207
info!("Loading `chat_template.json` at `{}`", $this.model_id);
175208
Some($crate::api_get_file!(api, "chat_template.json", model_id))
176209
} else {

mistralrs-core/src/pipeline/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,11 @@ impl ForwardInputsResult {
360360
}
361361
}
362362

363+
#[derive(serde::Serialize, serde::Deserialize)]
364+
pub(crate) struct FileListCache {
365+
files: Vec<String>,
366+
}
367+
363368
#[async_trait::async_trait]
364369
pub trait Pipeline:
365370
Send

mistralrs-core/src/pipeline/paths.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,11 @@ pub fn get_xlora_paths(
7878
revision,
7979
));
8080
let model_id = Path::new(&xlora_id);
81-
81+
let dir_list = api_dir_list!(api, model_id, true).collect::<Vec<_>>();
8282
// Get the path for the xlora classifier
83-
let xlora_classifier = &api_dir_list!(api, model_id, true)
83+
let xlora_classifier = &dir_list
84+
.clone()
85+
.into_iter()
8486
.filter(|x| x.contains("xlora_classifier.safetensors"))
8587
.collect::<Vec<_>>();
8688
if xlora_classifier.len() > 1 {
@@ -94,7 +96,9 @@ pub fn get_xlora_paths(
9496

9597
// Get the path for the xlora config by checking all for valid versions.
9698
// NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable
97-
let xlora_configs = &api_dir_list!(api, model_id, true)
99+
let xlora_configs = &dir_list
100+
.clone()
101+
.into_iter()
98102
.filter(|x| x.contains("xlora_config.json"))
99103
.collect::<Vec<_>>();
100104
if xlora_configs.len() > 1 {
@@ -135,7 +139,8 @@ pub fn get_xlora_paths(
135139
});
136140

137141
// If there are adapters in the ordering file, get their names and remote paths
138-
let adapter_files = api_dir_list!(api, model_id, true)
142+
let adapter_files = dir_list
143+
.into_iter()
139144
.filter_map(|name| {
140145
if let Some(ref adapters) = xlora_order.adapters {
141146
for adapter_name in adapters {

0 commit comments

Comments
 (0)