Skip to content

Commit 7624490

Browse files
authored
Use targetmodel from header as model name in SageMaker (#6147)
* Use targetmodel from header as model name in SageMaker * Update naming for model hash
1 parent 68e116a commit 7624490

File tree

1 file changed

+60
-37
lines changed

1 file changed

+60
-37
lines changed

src/sagemaker_server.cc

+60-37
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,21 @@ SagemakerAPIServer::Handle(evhtp_request_t* req)
226226
evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/
227227
return;
228228
}
229-
LOG_VERBOSE(1) << "SageMaker MME Custom Invoke Model Path"
230-
<< std::endl;
231-
SageMakerMMEHandleInfer(req, multi_model_name, model_version_str_);
229+
LOG_VERBOSE(1) << "SageMaker MME Custom Invoke Model Path";
230+
231+
/* Extract targetModel to log the associated archive */
232+
const char* target_model =
233+
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");
234+
235+
/* If target_model is not available (e.g., in local testing) use
236+
* model_name_hash as target_model) */
237+
if (target_model == nullptr) {
238+
target_model = multi_model_name.c_str();
239+
}
240+
241+
LOG_INFO << "Invoking SageMaker TargetModel: " << target_model;
242+
243+
SageMakerMMEHandleInfer(req, target_model, model_version_str_);
232244
return;
233245
}
234246
if (action.empty()) {
@@ -330,17 +342,22 @@ SagemakerAPIServer::ParseSageMakerRequest(
330342
if (action == "load") {
331343
(*parse_map)["url"] = url_string.c_str();
332344
}
333-
(*parse_map)["model_name"] = model_name_string.c_str();
345+
(*parse_map)["model_name_hash"] = model_name_string.c_str();
334346

335-
/* Extract targetModel to log the associated archive */
347+
/* Extract target_model, specified in header, to log the associated archive */
348+
const char* target_model =
349+
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");
336350

337-
/* Read headers*/
338-
(*parse_map)["TargetModel"] = "targetModel.tar.gz";
339351

340-
const char* targetModel =
341-
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");
352+
/* If target_model is not available (e.g., in local testing) use
353+
* model_name_hash as target_model) */
354+
if (target_model != nullptr) {
355+
(*parse_map)["target_model"] = target_model;
356+
} else {
357+
(*parse_map)["target_model"] = model_name_string.c_str();
358+
}
342359

343-
LOG_INFO << "Loading SageMaker TargetModel: " << targetModel << std::endl;
360+
LOG_INFO << "Loading SageMaker TargetModel: " << target_model;
344361

345362
return;
346363
}
@@ -443,11 +460,6 @@ SagemakerAPIServer::SageMakerMMEHandleInfer(
443460
return;
444461
}
445462

446-
/* Extract targetModel to log the associated archive */
447-
const char* targetModel =
448-
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");
449-
LOG_INFO << "Invoking SageMaker TargetModel: " << targetModel << std::endl;
450-
451463
bool connection_paused = false;
452464

453465
int64_t requested_model_version;
@@ -687,7 +699,7 @@ SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
687699

688700
LOG_VERBOSE(1) << "Discovered model: " << name
689701
<< ", version: " << version << " in state: " << state
690-
<< "for the reason: " << reason;
702+
<< " for the reason: " << reason;
691703

692704
break;
693705
}
@@ -700,35 +712,43 @@ SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
700712

701713
void
702714
SagemakerAPIServer::SageMakerMMEUnloadModel(
703-
evhtp_request_t* req, const char* model_name)
715+
evhtp_request_t* req, const char* model_name_hash)
704716
{
705-
if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) {
706-
LOG_VERBOSE(1) << "Model " << model_name << " is not loaded." << std::endl;
717+
/* Extract targetModel to log the associated archive */
718+
const char* target_model =
719+
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");
720+
721+
/* If target_model is not available (e.g., in local testing) use
722+
* model_name_hash as target_model) */
723+
if (target_model == nullptr) {
724+
target_model = model_name_hash;
725+
}
726+
727+
if (sagemaker_models_list_.find(model_name_hash) ==
728+
sagemaker_models_list_.end()) {
729+
LOG_VERBOSE(1) << "Model " << target_model << " with model hash "
730+
<< model_name_hash << " is not loaded." << std::endl;
707731
evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/
708732
return;
709733
}
710734

711-
/* Extract targetModel to log the associated archive */
712-
const char* targetModel =
713-
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");
714-
715-
LOG_INFO << "Unloading SageMaker TargetModel: " << targetModel << std::endl;
735+
LOG_INFO << "Unloading SageMaker TargetModel: " << target_model << std::endl;
716736

717737
auto start_time = std::chrono::high_resolution_clock::now();
718738

719739
/* Always unload dependents as well - this is required to unload dependents in
720740
* ensemble */
721741
TRITONSERVER_Error* unload_err = nullptr;
722742
unload_err =
723-
TRITONSERVER_ServerUnloadModelAndDependents(server_.get(), model_name);
743+
TRITONSERVER_ServerUnloadModelAndDependents(server_.get(), target_model);
724744

725745
if (unload_err != nullptr) {
726746
EVBufferAddErrorJson(req->buffer_out, unload_err);
727747
evhtp_send_reply(req, EVHTP_RES_BADREQ);
728748

729749
LOG_ERROR
730750
<< "Error when unloading SageMaker Model with dependents for model: "
731-
<< model_name << std::endl;
751+
<< target_model << std::endl;
732752

733753
TRITONSERVER_ErrorDelete(unload_err);
734754
return;
@@ -745,13 +765,13 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(
745765
succeeded.*/
746766
if (unload_err == nullptr) {
747767
LOG_VERBOSE(1) << "Using Model Repository Index during UNLOAD to check for "
748-
"status of model: "
749-
<< model_name;
768+
"status of model hash: "
769+
<< model_name_hash << " for model: " << target_model;
750770
while (is_model_unavailable == false &&
751771
unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) {
752772
LOG_VERBOSE(1) << "In the loop to wait for model to be unavailable";
753773
unload_err = SageMakerMMECheckUnloadedModelIsUnavailable(
754-
model_name, &is_model_unavailable);
774+
target_model, &is_model_unavailable);
755775
if (unload_err != nullptr) {
756776
LOG_ERROR << "Error: Received non-zero exit code on checking for "
757777
"model unavailability. "
@@ -767,7 +787,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(
767787
end_time - start_time)
768788
.count();
769789
}
770-
LOG_INFO << "UNLOAD for model " << model_name << " completed in "
790+
LOG_INFO << "UNLOAD for model " << target_model << " completed in "
771791
<< unload_time_in_secs << " seconds.";
772792
TRITONSERVER_ErrorDelete(unload_err);
773793
}
@@ -780,7 +800,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(
780800
"result in SageMaker UNLOAD timeout.";
781801
}
782802

783-
std::string repo_parent_path = sagemaker_models_list_.at(model_name);
803+
std::string repo_parent_path = sagemaker_models_list_.at(model_name_hash);
784804

785805
TRITONSERVER_Error* unregister_err = nullptr;
786806

@@ -799,7 +819,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(
799819
TRITONSERVER_ErrorDelete(unregister_err);
800820

801821
std::lock_guard<std::mutex> lock(models_list_mutex_);
802-
sagemaker_models_list_.erase(model_name);
822+
sagemaker_models_list_.erase(model_name_hash);
803823
}
804824

805825
void
@@ -946,7 +966,8 @@ SagemakerAPIServer::SageMakerMMELoadModel(
946966
const std::unordered_map<std::string, std::string> parse_map)
947967
{
948968
std::string repo_path = parse_map.at("url");
949-
std::string model_name = parse_map.at("model_name");
969+
std::string model_name_hash = parse_map.at("model_name_hash");
970+
std::string target_model = parse_map.at("target_model");
950971

951972
/* Check subdirs for models and find ensemble model within the repo_path
952973
* If only 1 model, that will be selected as model_subdir
@@ -1043,7 +1064,8 @@ SagemakerAPIServer::SageMakerMMELoadModel(
10431064
}
10441065

10451066
auto param = TRITONSERVER_ParameterNew(
1046-
model_subdir.c_str(), TRITONSERVER_PARAMETER_STRING, model_name.c_str());
1067+
model_subdir.c_str(), TRITONSERVER_PARAMETER_STRING,
1068+
target_model.c_str());
10471069

10481070
if (param != nullptr) {
10491071
subdir_modelname_map.emplace_back(param);
@@ -1076,7 +1098,7 @@ SagemakerAPIServer::SageMakerMMELoadModel(
10761098
return;
10771099
}
10781100

1079-
err = TRITONSERVER_ServerLoadModel(server_.get(), model_name.c_str());
1101+
err = TRITONSERVER_ServerLoadModel(server_.get(), target_model.c_str());
10801102

10811103
/* Unlikely after duplicate repo check, but in case Load Model also returns
10821104
* ALREADY_EXISTS error */
@@ -1091,7 +1113,8 @@ SagemakerAPIServer::SageMakerMMELoadModel(
10911113
} else {
10921114
std::lock_guard<std::mutex> lock(models_list_mutex_);
10931115

1094-
sagemaker_models_list_.emplace(model_name, repo_parent_path);
1116+
/* Use model name hash as expected in SageMaker MME contract */
1117+
sagemaker_models_list_.emplace(model_name_hash, repo_parent_path);
10951118
evhtp_send_reply(req, EVHTP_RES_OK);
10961119
}
10971120

@@ -1101,7 +1124,7 @@ SagemakerAPIServer::SageMakerMMELoadModel(
11011124
server_.get(), repo_parent_path.c_str());
11021125
LOG_VERBOSE(1)
11031126
<< "Unregistered model repository due to load failure for model: "
1104-
<< model_name << std::endl;
1127+
<< target_model << std::endl;
11051128
}
11061129

11071130
if (err != nullptr) {

0 commit comments

Comments
 (0)