Skip to content

Commit 80d0d6b

Browse files
authored
common : add -hfd option for the draft model (ggml-org#11318)
* common : add -hfd option for the draft model * cont : fix env var * cont : more fixes
1 parent aea8ddd commit 80d0d6b

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

common/arg.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ static void common_params_handle_model_default(
133133
const std::string & model_url,
134134
std::string & hf_repo,
135135
std::string & hf_file,
136-
const std::string & hf_token) {
136+
const std::string & hf_token,
137+
const std::string & model_default) {
137138
if (!hf_repo.empty()) {
138139
// short-hand to avoid specifying --hf-file -> default it to --model
139140
if (hf_file.empty()) {
@@ -163,7 +164,7 @@ static void common_params_handle_model_default(
163164
model = fs_get_cache_file(string_split<std::string>(f, '/').back());
164165
}
165166
} else if (model.empty()) {
166-
model = DEFAULT_MODEL_PATH;
167+
model = model_default;
167168
}
168169
}
169170

@@ -299,8 +300,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
299300
}
300301

301302
// TODO: refactor model params in a common struct
302-
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token);
303-
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
303+
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH);
304+
common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, "");
305+
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, "");
304306

305307
if (params.escape) {
306308
string_process_escapes(params.prompt);
@@ -1629,6 +1631,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
16291631
params.hf_repo = value;
16301632
}
16311633
).set_env("LLAMA_ARG_HF_REPO"));
1634+
add_opt(common_arg(
1635+
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
1636+
"Same as --hf-repo, but for the draft model (default: unused)",
1637+
[](common_params & params, const std::string & value) {
1638+
params.speculative.hf_repo = value;
1639+
}
1640+
).set_env("LLAMA_ARG_HFD_REPO"));
16321641
add_opt(common_arg(
16331642
{"-hff", "--hf-file"}, "FILE",
16341643
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",

common/common.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,11 @@ struct common_params_speculative {
175175
struct cpu_params cpuparams;
176176
struct cpu_params cpuparams_batch;
177177

178-
std::string model = ""; // draft model for speculative decoding // NOLINT
178+
std::string hf_repo = ""; // HF repo // NOLINT
179+
std::string hf_file = ""; // HF file // NOLINT
180+
181+
std::string model = ""; // draft model for speculative decoding // NOLINT
182+
std::string model_url = ""; // model url to download // NOLINT
179183
};
180184

181185
struct common_params_vocoder {
@@ -508,12 +512,14 @@ struct llama_model * common_load_model_from_url(
508512
const std::string & local_path,
509513
const std::string & hf_token,
510514
const struct llama_model_params & params);
515+
511516
struct llama_model * common_load_model_from_hf(
512517
const std::string & repo,
513518
const std::string & remote_path,
514519
const std::string & local_path,
515520
const std::string & hf_token,
516521
const struct llama_model_params & params);
522+
517523
std::pair<std::string, std::string> common_get_hf_file(
518524
const std::string & hf_repo_with_tag,
519525
const std::string & hf_token);

examples/server/server.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1728,13 +1728,16 @@ struct server_context {
17281728
add_bos_token = llama_vocab_get_add_bos(vocab);
17291729
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
17301730

1731-
if (!params_base.speculative.model.empty()) {
1731+
if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) {
17321732
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
17331733

17341734
auto params_dft = params_base;
17351735

17361736
params_dft.devices = params_base.speculative.devices;
1737+
params_dft.hf_file = params_base.speculative.hf_file;
1738+
params_dft.hf_repo = params_base.speculative.hf_repo;
17371739
params_dft.model = params_base.speculative.model;
1740+
params_dft.model_url = params_base.speculative.model_url;
17381741
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
17391742
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
17401743
params_dft.n_parallel = 1;

0 commit comments

Comments
 (0)