@@ -226,9 +226,21 @@ SagemakerAPIServer::Handle(evhtp_request_t* req)
226
226
evhtp_send_reply (req, EVHTP_RES_NOTFOUND); /* 404*/
227
227
return ;
228
228
}
229
- LOG_VERBOSE (1 ) << " SageMaker MME Custom Invoke Model Path"
230
- << std::endl;
231
- SageMakerMMEHandleInfer (req, multi_model_name, model_version_str_);
229
+ LOG_VERBOSE (1 ) << " SageMaker MME Custom Invoke Model Path" ;
230
+
231
+ /* Extract targetModel to log the associated archive */
232
+ const char * target_model =
233
+ evhtp_kv_find (req->headers_in , " X-Amzn-SageMaker-Target-Model" );
234
+
235
+ /* If target_model is not available (e.g., in local testing) use
236
+ * model_name_hash as target_model) */
237
+ if (target_model == nullptr ) {
238
+ target_model = multi_model_name.c_str ();
239
+ }
240
+
241
+ LOG_INFO << " Invoking SageMaker TargetModel: " << target_model;
242
+
243
+ SageMakerMMEHandleInfer (req, target_model, model_version_str_);
232
244
return ;
233
245
}
234
246
if (action.empty ()) {
@@ -330,17 +342,22 @@ SagemakerAPIServer::ParseSageMakerRequest(
330
342
if (action == " load" ) {
331
343
(*parse_map)[" url" ] = url_string.c_str ();
332
344
}
333
- (*parse_map)[" model_name " ] = model_name_string.c_str ();
345
+ (*parse_map)[" model_name_hash " ] = model_name_string.c_str ();
334
346
335
- /* Extract targetModel to log the associated archive */
347
+ /* Extract target_model, specified in header, to log the associated archive */
348
+ const char * target_model =
349
+ evhtp_kv_find (req->headers_in , " X-Amzn-SageMaker-Target-Model" );
336
350
337
- /* Read headers*/
338
- (*parse_map)[" TargetModel" ] = " targetModel.tar.gz" ;
339
351
340
- const char * targetModel =
341
- evhtp_kv_find (req->headers_in , " X-Amzn-SageMaker-Target-Model" );
352
+ /* If target_model is not available (e.g., in local testing) use
353
+ * model_name_hash as target_model) */
354
+ if (target_model != nullptr ) {
355
+ (*parse_map)[" target_model" ] = target_model;
356
+ } else {
357
+ (*parse_map)[" target_model" ] = model_name_string.c_str ();
358
+ }
342
359
343
- LOG_INFO << " Loading SageMaker TargetModel: " << targetModel << std::endl ;
360
+ LOG_INFO << " Loading SageMaker TargetModel: " << target_model ;
344
361
345
362
return ;
346
363
}
@@ -443,11 +460,6 @@ SagemakerAPIServer::SageMakerMMEHandleInfer(
443
460
return ;
444
461
}
445
462
446
- /* Extract targetModel to log the associated archive */
447
- const char * targetModel =
448
- evhtp_kv_find (req->headers_in , " X-Amzn-SageMaker-Target-Model" );
449
- LOG_INFO << " Invoking SageMaker TargetModel: " << targetModel << std::endl;
450
-
451
463
bool connection_paused = false ;
452
464
453
465
int64_t requested_model_version;
@@ -687,7 +699,7 @@ SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
687
699
688
700
LOG_VERBOSE (1 ) << " Discovered model: " << name
689
701
<< " , version: " << version << " in state: " << state
690
- << " for the reason: " << reason;
702
+ << " for the reason: " << reason;
691
703
692
704
break ;
693
705
}
@@ -700,35 +712,43 @@ SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
700
712
701
713
void
702
714
SagemakerAPIServer::SageMakerMMEUnloadModel (
703
- evhtp_request_t * req, const char * model_name )
715
+ evhtp_request_t * req, const char * model_name_hash )
704
716
{
705
- if (sagemaker_models_list_.find (model_name) == sagemaker_models_list_.end ()) {
706
- LOG_VERBOSE (1 ) << " Model " << model_name << " is not loaded." << std::endl;
717
+ /* Extract targetModel to log the associated archive */
718
+ const char * target_model =
719
+ evhtp_kv_find (req->headers_in , " X-Amzn-SageMaker-Target-Model" );
720
+
721
+ /* If target_model is not available (e.g., in local testing) use
722
+ * model_name_hash as target_model) */
723
+ if (target_model == nullptr ) {
724
+ target_model = model_name_hash;
725
+ }
726
+
727
+ if (sagemaker_models_list_.find (model_name_hash) ==
728
+ sagemaker_models_list_.end ()) {
729
+ LOG_VERBOSE (1 ) << " Model " << target_model << " with model hash "
730
+ << model_name_hash << " is not loaded." << std::endl;
707
731
evhtp_send_reply (req, EVHTP_RES_NOTFOUND); /* 404*/
708
732
return ;
709
733
}
710
734
711
- /* Extract targetModel to log the associated archive */
712
- const char * targetModel =
713
- evhtp_kv_find (req->headers_in , " X-Amzn-SageMaker-Target-Model" );
714
-
715
- LOG_INFO << " Unloading SageMaker TargetModel: " << targetModel << std::endl;
735
+ LOG_INFO << " Unloading SageMaker TargetModel: " << target_model << std::endl;
716
736
717
737
auto start_time = std::chrono::high_resolution_clock::now ();
718
738
719
739
/* Always unload dependents as well - this is required to unload dependents in
720
740
* ensemble */
721
741
TRITONSERVER_Error* unload_err = nullptr ;
722
742
unload_err =
723
- TRITONSERVER_ServerUnloadModelAndDependents (server_.get (), model_name );
743
+ TRITONSERVER_ServerUnloadModelAndDependents (server_.get (), target_model );
724
744
725
745
if (unload_err != nullptr ) {
726
746
EVBufferAddErrorJson (req->buffer_out , unload_err);
727
747
evhtp_send_reply (req, EVHTP_RES_BADREQ);
728
748
729
749
LOG_ERROR
730
750
<< " Error when unloading SageMaker Model with dependents for model: "
731
- << model_name << std::endl;
751
+ << target_model << std::endl;
732
752
733
753
TRITONSERVER_ErrorDelete (unload_err);
734
754
return ;
@@ -745,13 +765,13 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(
745
765
succeeded.*/
746
766
if (unload_err == nullptr ) {
747
767
LOG_VERBOSE (1 ) << " Using Model Repository Index during UNLOAD to check for "
748
- " status of model: "
749
- << model_name ;
768
+ " status of model hash : "
769
+ << model_name_hash << " for model: " << target_model ;
750
770
while (is_model_unavailable == false &&
751
771
unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) {
752
772
LOG_VERBOSE (1 ) << " In the loop to wait for model to be unavailable" ;
753
773
unload_err = SageMakerMMECheckUnloadedModelIsUnavailable (
754
- model_name , &is_model_unavailable);
774
+ target_model , &is_model_unavailable);
755
775
if (unload_err != nullptr ) {
756
776
LOG_ERROR << " Error: Received non-zero exit code on checking for "
757
777
" model unavailability. "
@@ -767,7 +787,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(
767
787
end_time - start_time)
768
788
.count ();
769
789
}
770
- LOG_INFO << " UNLOAD for model " << model_name << " completed in "
790
+ LOG_INFO << " UNLOAD for model " << target_model << " completed in "
771
791
<< unload_time_in_secs << " seconds." ;
772
792
TRITONSERVER_ErrorDelete (unload_err);
773
793
}
@@ -780,7 +800,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(
780
800
" result in SageMaker UNLOAD timeout." ;
781
801
}
782
802
783
- std::string repo_parent_path = sagemaker_models_list_.at (model_name );
803
+ std::string repo_parent_path = sagemaker_models_list_.at (model_name_hash );
784
804
785
805
TRITONSERVER_Error* unregister_err = nullptr ;
786
806
@@ -799,7 +819,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(
799
819
TRITONSERVER_ErrorDelete (unregister_err);
800
820
801
821
std::lock_guard<std::mutex> lock (models_list_mutex_);
802
- sagemaker_models_list_.erase (model_name );
822
+ sagemaker_models_list_.erase (model_name_hash );
803
823
}
804
824
805
825
void
@@ -946,7 +966,8 @@ SagemakerAPIServer::SageMakerMMELoadModel(
946
966
const std::unordered_map<std::string, std::string> parse_map)
947
967
{
948
968
std::string repo_path = parse_map.at (" url" );
949
- std::string model_name = parse_map.at (" model_name" );
969
+ std::string model_name_hash = parse_map.at (" model_name_hash" );
970
+ std::string target_model = parse_map.at (" target_model" );
950
971
951
972
/* Check subdirs for models and find ensemble model within the repo_path
952
973
* If only 1 model, that will be selected as model_subdir
@@ -1043,7 +1064,8 @@ SagemakerAPIServer::SageMakerMMELoadModel(
1043
1064
}
1044
1065
1045
1066
auto param = TRITONSERVER_ParameterNew (
1046
- model_subdir.c_str (), TRITONSERVER_PARAMETER_STRING, model_name.c_str ());
1067
+ model_subdir.c_str (), TRITONSERVER_PARAMETER_STRING,
1068
+ target_model.c_str ());
1047
1069
1048
1070
if (param != nullptr ) {
1049
1071
subdir_modelname_map.emplace_back (param);
@@ -1076,7 +1098,7 @@ SagemakerAPIServer::SageMakerMMELoadModel(
1076
1098
return ;
1077
1099
}
1078
1100
1079
- err = TRITONSERVER_ServerLoadModel (server_.get (), model_name .c_str ());
1101
+ err = TRITONSERVER_ServerLoadModel (server_.get (), target_model .c_str ());
1080
1102
1081
1103
/* Unlikely after duplicate repo check, but in case Load Model also returns
1082
1104
* ALREADY_EXISTS error */
@@ -1091,7 +1113,8 @@ SagemakerAPIServer::SageMakerMMELoadModel(
1091
1113
} else {
1092
1114
std::lock_guard<std::mutex> lock (models_list_mutex_);
1093
1115
1094
- sagemaker_models_list_.emplace (model_name, repo_parent_path);
1116
+ /* Use model name hash as expected in SageMaker MME contract */
1117
+ sagemaker_models_list_.emplace (model_name_hash, repo_parent_path);
1095
1118
evhtp_send_reply (req, EVHTP_RES_OK);
1096
1119
}
1097
1120
@@ -1101,7 +1124,7 @@ SagemakerAPIServer::SageMakerMMELoadModel(
1101
1124
server_.get (), repo_parent_path.c_str ());
1102
1125
LOG_VERBOSE (1 )
1103
1126
<< " Unregistered model repository due to load failure for model: "
1104
- << model_name << std::endl;
1127
+ << target_model << std::endl;
1105
1128
}
1106
1129
1107
1130
if (err != nullptr ) {
0 commit comments