Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 3a63826

Browse files
authored
Hostfix: remove not needed params from load_model (#2209)
* refactor: remove --pooling flag from model loading The --pooling flag was removed as the mean pooling functionality not needed in chat models. This fixes the regression * feat(local-engine): add ctx_len parameter support Adds support for the ctx_len parameter by appending --ctx-size with its value. Removed outdated parameter mappings from the kParamsMap to reflect current implementation details and ensure consistency. * feat: add conditional model parameters based on path When the model path contains both "jan" and "nano" (case-insensitive), automatically add speculative decoding parameters to adjust generation behavior. This improves flexibility by enabling environment-specific configurations without manual parameter tuning. Also includes necessary headers for string manipulation and fixes whitespace in ctx_len handling. * chore: remove redundant comment The comment was redundant as the code's purpose is clear without it, improving readability.
1 parent a90a5e8 commit 3a63826

File tree

2 files changed

+40
-17
lines changed

2 files changed

+40
-17
lines changed

engine/extensions/local-engine/local_engine.cc

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include "local_engine.h"
2+
#include <algorithm>
23
#include <random>
4+
#include <string>
35
#include <thread>
6+
#include <string.h>
47
#include <unordered_set>
58
#include "utils/curl_utils.h"
69
#include "utils/json_helper.h"
@@ -20,6 +23,7 @@ const std::unordered_set<std::string> kIgnoredParams = {
2023
"user_prompt", "min_keep", "mirostat", "mirostat_eta",
2124
"mirostat_tau", "text_model", "version", "n_probs",
2225
"object", "penalize_nl", "precision", "size",
26+
"flash_attn",
2327
"stop", "tfs_z", "typ_p", "caching_enabled"};
2428

2529
const std::unordered_map<std::string, std::string> kParamsMap = {
@@ -42,18 +46,24 @@ int GenerateRandomInteger(int min, int max) {
4246
std::uniform_int_distribution<> dis(
4347
min, max); // Distribution for the desired range
4448

45-
return dis(gen); // Generate and return a random integer within the range
49+
return dis(gen);
4650
}
4751

4852
std::vector<std::string> ConvertJsonToParamsVector(const Json::Value& root) {
4953
std::vector<std::string> res;
50-
std::string errors;
5154

5255
for (const auto& member : root.getMemberNames()) {
5356
if (member == "model_path" || member == "llama_model_path") {
5457
if (!root[member].isNull()) {
58+
const std::string path = root[member].asString();
5559
res.push_back("--model");
56-
res.push_back(root[member].asString());
60+
res.push_back(path);
61+
62+
// If path contains both "Jan" and "nano", case-insensitive, add special params
63+
std::string lowered = path;
64+
std::transform(lowered.begin(), lowered.end(), lowered.begin(), [](unsigned char c) {
65+
return std::tolower(c);
66+
});
5767
}
5868
continue;
5969
} else if (kIgnoredParams.find(member) != kIgnoredParams.end()) {
@@ -85,8 +95,15 @@ std::vector<std::string> ConvertJsonToParamsVector(const Json::Value& root) {
8595
res.push_back("--ignore_eos");
8696
}
8797
continue;
98+
} else if (member == "ctx_len") {
99+
if (!root[member].isNull()) {
100+
res.push_back("--ctx-size");
101+
res.push_back(root[member].asString());
102+
}
103+
continue;
88104
}
89105

106+
// Generic handling for other members
90107
res.push_back("--" + member);
91108
if (root[member].isString()) {
92109
res.push_back(root[member].asString());
@@ -105,14 +122,15 @@ std::vector<std::string> ConvertJsonToParamsVector(const Json::Value& root) {
105122
ss << "\"" << value.asString() << "\"";
106123
first = false;
107124
}
108-
ss << "] ";
125+
ss << "]";
109126
res.push_back(ss.str());
110127
}
111128
}
112129

113130
return res;
114131
}
115132

133+
116134
constexpr const auto kMinDataChunkSize = 6u;
117135

118136
struct OaiInfo {
@@ -561,8 +579,6 @@ void LocalEngine::LoadModel(std::shared_ptr<Json::Value> json_body,
561579
params.push_back("--port");
562580
params.push_back(std::to_string(s.port));
563581

564-
params.push_back("--pooling");
565-
params.push_back("mean");
566582

567583
params.push_back("--jinja");
568584

engine/services/model_service.cc

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ ModelService::ModelService(std::shared_ptr<DatabaseService> db_service,
165165
download_service_{download_service},
166166
inference_svc_(inference_service),
167167
engine_svc_(engine_svc),
168-
task_queue_(task_queue) {
169-
// ProcessBgrTasks();
168+
task_queue_(task_queue){
169+
// ProcessBgrTasks();
170170
};
171171

172172
void ModelService::ForceIndexingModelList() {
@@ -557,6 +557,8 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
557557
if (auto& o = params_override["ctx_len"]; !o.isNull()) {
558558
ctx_len = o.asInt();
559559
}
560+
Json::Value model_load_params;
561+
json_helper::MergeJson(model_load_params, params_override);
560562

561563
try {
562564
constexpr const int kDefautlContextLength = 8192;
@@ -630,6 +632,8 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
630632
#else
631633
json_data["model_path"] =
632634
fmu::ToAbsoluteCortexDataPath(fs::path(mc.files[0])).string();
635+
model_load_params["model_path"] =
636+
fmu::ToAbsoluteCortexDataPath(fs::path(mc.files[0])).string();
633637
#endif
634638
} else {
635639
LOG_WARN << "model_path is empty";
@@ -642,6 +646,8 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
642646
#else
643647
json_data["mmproj"] =
644648
fmu::ToAbsoluteCortexDataPath(fs::path(mc.mmproj)).string();
649+
model_load_params["model_path"] =
650+
fmu::ToAbsoluteCortexDataPath(fs::path(mc.mmproj)).string();
645651
#endif
646652
}
647653
json_data["system_prompt"] = mc.system_template;
@@ -655,15 +661,14 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
655661
}
656662

657663
json_data["model"] = model_handle;
664+
model_load_params["model"] = model_handle;
658665
if (auto& cpt = custom_prompt_template; !cpt.value_or("").empty()) {
659666
auto parse_prompt_result = string_utils::ParsePrompt(cpt.value());
660667
json_data["system_prompt"] = parse_prompt_result.system_prompt;
661668
json_data["user_prompt"] = parse_prompt_result.user_prompt;
662669
json_data["ai_prompt"] = parse_prompt_result.ai_prompt;
663670
}
664671

665-
json_helper::MergeJson(json_data, params_override);
666-
667672
// Set default cpu_threads if it is not configured
668673
if (!json_data.isMember("cpu_threads")) {
669674
json_data["cpu_threads"] = GetCpuThreads();
@@ -686,12 +691,12 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
686691

687692
assert(!!inference_svc_);
688693

689-
auto ir =
690-
inference_svc_->LoadModel(std::make_shared<Json::Value>(json_data));
694+
auto ir = inference_svc_->LoadModel(
695+
std::make_shared<Json::Value>(model_load_params));
691696
auto status = std::get<0>(ir)["status_code"].asInt();
692697
auto data = std::get<1>(ir);
693698

694-
if (status == drogon::k200OK) {
699+
if (status == drogon::k200OK) {
695700
return StartModelResult{/* .success = */ true,
696701
/* .warning = */ may_fallback_res.value()};
697702
} else if (status == drogon::k409Conflict) {
@@ -1031,13 +1036,15 @@ ModelService::MayFallbackToCpu(const std::string& model_path, int ngl,
10311036
auto es = hardware::EstimateLLaMACppRun(model_path, rc);
10321037

10331038
if (!!es && (*es).gpu_mode.vram_MiB > free_vram_MiB && is_cuda) {
1034-
CTL_WRN("Not enough VRAM - " << "required: " << (*es).gpu_mode.vram_MiB
1035-
<< ", available: " << free_vram_MiB);
1039+
CTL_WRN("Not enough VRAM - "
1040+
<< "required: " << (*es).gpu_mode.vram_MiB
1041+
<< ", available: " << free_vram_MiB);
10361042
}
10371043

10381044
if (!!es && (*es).cpu_mode.ram_MiB > free_ram_MiB) {
1039-
CTL_WRN("Not enough RAM - " << "required: " << (*es).cpu_mode.ram_MiB
1040-
<< ", available: " << free_ram_MiB);
1045+
CTL_WRN("Not enough RAM - "
1046+
<< "required: " << (*es).cpu_mode.ram_MiB
1047+
<< ", available: " << free_ram_MiB);
10411048
}
10421049

10431050
return warning;

0 commit comments

Comments
 (0)