@@ -292,15 +292,18 @@ def _init_batchqueue_dict(self):
292
292
func .fn .__name__ : [] for func in self ._executor .requests .values ()
293
293
}
294
294
for endpoint , func in self ._executor .requests .items ():
295
- func_endpoints [func .fn .__name__ ].append (endpoint )
295
+ if func .fn .__name__ in func_endpoints :
296
+ # For SageMaker, not all endpoints are there
297
+ func_endpoints [func .fn .__name__ ].append (endpoint )
296
298
for func_name , dbatch_config in dbatch_functions :
297
- for endpoint in func_endpoints [func_name ]:
298
- if endpoint not in self ._batchqueue_config :
299
- self ._batchqueue_config [endpoint ] = dbatch_config
300
- else :
301
- # we need to eventually copy the `custom_metric`
302
- if dbatch_config .get ('custom_metric' , None ) is not None :
303
- self ._batchqueue_config [endpoint ]['custom_metric' ] = dbatch_config .get ('custom_metric' )
299
+ if func_name in func_endpoints : # For SageMaker, not all endpoints are there
300
+ for endpoint in func_endpoints [func_name ]:
301
+ if endpoint not in self ._batchqueue_config :
302
+ self ._batchqueue_config [endpoint ] = dbatch_config
303
+ else :
304
+ # we need to eventually copy the `custom_metric`
305
+ if dbatch_config .get ('custom_metric' , None ) is not None :
306
+ self ._batchqueue_config [endpoint ]['custom_metric' ] = dbatch_config .get ('custom_metric' )
304
307
305
308
keys_to_remove = []
306
309
for k , batch_config in self ._batchqueue_config .items ():
0 commit comments