Skip to content

Commit afc2b90

Browse files
committed
arg : allow using -hf offline
1 parent e2e1ddb commit afc2b90

File tree

1 file changed

+80
-47
lines changed

1 file changed

+80
-47
lines changed

common/arg.cpp

Lines changed: 80 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,25 @@ std::initializer_list<enum llama_example> mmproj_examples = {
4343
// TODO: add LLAMA_EXAMPLE_SERVER when it's ready
4444
};
4545

46+
static std::string read_file(const std::string & fname) {
47+
std::ifstream file(fname);
48+
if (!file) {
49+
throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
50+
}
51+
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
52+
file.close();
53+
return content;
54+
}
55+
56+
static void write_file(const std::string & fname, const std::string & content) {
57+
std::ofstream file(fname);
58+
if (!file) {
59+
throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
60+
}
61+
file << content;
62+
file.close();
63+
}
64+
4665
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
4766
this->examples = std::move(examples);
4867
return *this;
@@ -200,9 +219,11 @@ struct curl_slist_ptr {
200219

201220
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
202221
int remaining_attempts = max_attempts;
222+
char * method = nullptr;
223+
curl_easy_getinfo(curl, CURLINFO_EFFECTIVE_METHOD, &method);
203224

204225
while (remaining_attempts > 0) {
205-
LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
226+
LOG_INF("%s: %s %s (attempt %d of %d)...\n", __func__ , method, url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
206227

207228
CURLcode res = curl_easy_perform(curl);
208229
if (res == CURLE_OK) {
@@ -213,6 +234,7 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
213234
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
214235

215236
remaining_attempts--;
237+
if (remaining_attempts == 0) break;
216238
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
217239
}
218240

@@ -231,8 +253,6 @@ static bool common_download_file_single(const std::string & url, const std::stri
231253
return false;
232254
}
233255

234-
bool force_download = false;
235-
236256
// Set the URL, allow to follow http redirection
237257
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
238258
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
@@ -256,7 +276,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
256276

257277
// If the file exists, check its JSON metadata companion file.
258278
std::string metadata_path = path + ".json";
259-
nlohmann::json metadata;
279+
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
260280
std::string etag;
261281
std::string last_modified;
262282

@@ -266,7 +286,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
266286
if (metadata_in.good()) {
267287
try {
268288
metadata_in >> metadata;
269-
LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
289+
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
270290
if (metadata.contains("url") && metadata.at("url").is_string()) {
271291
auto previous_url = metadata.at("url").get<std::string>();
272292
if (previous_url != url) {
@@ -296,7 +316,10 @@ static bool common_download_file_single(const std::string & url, const std::stri
296316
};
297317

298318
common_load_model_from_url_headers headers;
319+
bool head_request_ok = false;
320+
bool should_download = !file_exists; // by default, we should download if the file does not exist
299321

322+
// get ETag to see if the remote file has changed
300323
{
301324
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
302325
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
@@ -325,23 +348,25 @@ static bool common_download_file_single(const std::string & url, const std::stri
325348
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
326349
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
327350

328-
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
351+
// we only allow retrying once for HEAD requests
352+
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0);
329353
if (!was_perform_successful) {
330-
return false;
354+
head_request_ok = false;
331355
}
332356

333357
long http_code = 0;
334358
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
335-
if (http_code != 200) {
336-
// HEAD not supported, we don't know if the file has changed
337-
// force trigger downloading
338-
force_download = true;
339-
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
359+
if (http_code == 200) {
360+
head_request_ok = true;
361+
} else {
362+
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
363+
head_request_ok = false;
340364
}
341365
}
342366

343-
bool should_download = !file_exists || force_download;
344-
if (!should_download) {
367+
if (head_request_ok) {
368+
// check if ETag or Last-Modified headers are different
369+
// if it is, we need to download the file again
345370
if (!etag.empty() && etag != headers.etag) {
346371
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
347372
should_download = true;
@@ -350,6 +375,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
350375
should_download = true;
351376
}
352377
}
378+
353379
if (should_download) {
354380
std::string path_temporary = path + ".downloadInProgress";
355381
if (file_exists) {
@@ -424,13 +450,15 @@ static bool common_download_file_single(const std::string & url, const std::stri
424450
{"etag", headers.etag},
425451
{"lastModified", headers.last_modified}
426452
});
427-
std::ofstream(metadata_path) << metadata.dump(4);
428-
LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
453+
write_file(metadata_path, metadata.dump(4));
454+
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
429455

430456
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
431457
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
432458
return false;
433459
}
460+
} else {
461+
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
434462
}
435463

436464
return true;
@@ -605,16 +633,37 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
605633
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
606634
// User-Agent header is already set in common_remote_get_content, no need to set it here
607635

636+
// we use "=" to avoid clashing with other component, while still being allowed on windows
637+
std::string cached_response_fname = "manifest=" + hf_repo + "=" + tag + ".json";
638+
string_replace_all(cached_response_fname, "/", "_");
639+
std::string cached_response_path = fs_get_cache_file(cached_response_fname);
640+
608641
// make the request
609642
common_remote_params params;
610643
params.headers = headers;
611-
auto res = common_remote_get_content(url, params);
612-
long res_code = res.first;
613-
std::string res_str(res.second.data(), res.second.size());
644+
long res_code = 0;
645+
std::string res_str;
646+
bool use_cache = false;
647+
try {
648+
auto res = common_remote_get_content(url, params);
649+
res_code = res.first;
650+
res_str = std::string(res.second.data(), res.second.size());
651+
} catch (const std::exception & e) {
652+
LOG_WRN("error: failed to get manifest: %s\n", e.what());
653+
LOG_WRN("try reading from cache\n");
654+
// try to read from cache
655+
try {
656+
res_str = read_file(cached_response_path);
657+
res_code = 200;
658+
use_cache = true;
659+
} catch (const std::exception & e) {
660+
throw std::runtime_error("error: failed to get manifest (check your internet connection)");
661+
}
662+
}
614663
std::string ggufFile;
615664
std::string mmprojFile;
616665

617-
if (res_code == 200) {
666+
if (res_code == 200 || res_code == 304) {
618667
// extract ggufFile.rfilename in json, using regex
619668
{
620669
std::regex pattern("\"ggufFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\"");
@@ -631,6 +680,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
631680
mmprojFile = match[1].str();
632681
}
633682
}
683+
if (!use_cache) {
684+
// if not using cached response, update the cache file
685+
write_file(cached_response_path, res_str);
686+
}
634687
} else if (res_code == 401) {
635688
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
636689
} else {
@@ -1142,6 +1195,9 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
11421195
fprintf(stderr, "%s\n", ex.what());
11431196
ctx_arg.params = params_org;
11441197
return false;
1198+
} catch (std::exception & ex) {
1199+
fprintf(stderr, "%s\n", ex.what());
1200+
exit(1); // for other exceptions, we exit with status code 1
11451201
}
11461202

11471203
return true;
@@ -1442,13 +1498,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14421498
{"-f", "--file"}, "FNAME",
14431499
"a file containing the prompt (default: none)",
14441500
[](common_params & params, const std::string & value) {
1445-
std::ifstream file(value);
1446-
if (!file) {
1447-
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
1448-
}
1501+
params.prompt = read_file(value);
14491502
// store the external file name in params
14501503
params.prompt_file = value;
1451-
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
14521504
if (!params.prompt.empty() && params.prompt.back() == '\n') {
14531505
params.prompt.pop_back();
14541506
}
@@ -1458,11 +1510,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14581510
{"-sysf", "--system-prompt-file"}, "FNAME",
14591511
"a file containing the system prompt (default: none)",
14601512
[](common_params & params, const std::string & value) {
1461-
std::ifstream file(value);
1462-
if (!file) {
1463-
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
1464-
}
1465-
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.system_prompt));
1513+
params.system_prompt = read_file(value);
14661514
if (!params.system_prompt.empty() && params.system_prompt.back() == '\n') {
14671515
params.system_prompt.pop_back();
14681516
}
@@ -1887,15 +1935,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
18871935
{"--grammar-file"}, "FNAME",
18881936
"file to read grammar from",
18891937
[](common_params & params, const std::string & value) {
1890-
std::ifstream file(value);
1891-
if (!file) {
1892-
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
1893-
}
1894-
std::copy(
1895-
std::istreambuf_iterator<char>(file),
1896-
std::istreambuf_iterator<char>(),
1897-
std::back_inserter(params.sampling.grammar)
1898-
);
1938+
params.sampling.grammar = read_file(value);
18991939
}
19001940
).set_sparam());
19011941
add_opt(common_arg(
@@ -2815,14 +2855,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
28152855
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
28162856
),
28172857
[](common_params & params, const std::string & value) {
2818-
std::ifstream file(value);
2819-
if (!file) {
2820-
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
2821-
}
2822-
std::copy(
2823-
std::istreambuf_iterator<char>(file),
2824-
std::istreambuf_iterator<char>(),
2825-
std::back_inserter(params.chat_template));
2858+
params.chat_template = read_file(value);
28262859
}
28272860
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
28282861
add_opt(common_arg(

0 commit comments

Comments
 (0)