@@ -142,13 +142,8 @@ def check_status(self, model_name, batch_exec, request_cnt, infer_cnt):
142
142
stats = self .triton_client_ .get_inference_statistics (model_name , "1" )
143
143
self .assertEqual (len (stats .model_stats ), 1 , "expect 1 model stats" )
144
144
actual_exec_cnt = stats .model_stats [0 ].execution_count
145
- if actual_exec_cnt == exec_cnt :
145
+ if stats . model_stats [ 0 ]. execution_count > 0 :
146
146
break
147
- print (
148
- "WARNING: expect {} executions, got {} (attempt {})" .format (
149
- exec_cnt , actual_exec_cnt , i
150
- )
151
- )
152
147
time .sleep (1 )
153
148
154
149
self .assertEqual (
@@ -411,6 +406,40 @@ def test_ensemble_optional_pipeline(self):
411
406
except Exception as ex :
412
407
self .assertTrue (False , "unexpected error {}" .format (ex ))
413
408
409
+ def test_ensemble_optional_connecting_tensor (self ):
410
+ # The ensemble is a special case of pipelining models with optional
411
+ # inputs, where the request will only produce a subset of inputs
412
+ # for the second model while the ensemble graph connects all inputs of
413
+ # the second model (which is valid because the not-provided inputs
414
+ # are marked optional). See 'config.pbtxt' for detail.
415
+ self .model_name_ = "optional_connecting_tensor"
416
+
417
+ # Provide all inputs, send requests that don't form preferred batch
418
+ # so all requests should be returned after the queue delay
419
+ try :
420
+ provided_inputs = ("INPUT0" ,)
421
+ inputs = []
422
+ outputs = []
423
+ for provided_input in provided_inputs :
424
+ inputs .append (self .inputs_ [provided_input ])
425
+ outputs .append (self .outputs_ [provided_input ])
426
+
427
+ triton_client = grpcclient .InferenceServerClient ("localhost:8001" )
428
+ results = triton_client .infer (
429
+ model_name = self .model_name_ , inputs = inputs , outputs = outputs
430
+ )
431
+
432
+ expected = self .input_data_ ["INPUT0" ]
433
+ output_data = results .as_numpy ("OUTPUT0" )
434
+ self .assertTrue (
435
+ np .array_equal (output_data , expected ),
436
+ "{}, {}, expected: {}, got {}" .format (
437
+ self .model_name_ , "OUTPUT0" , expected , output_data
438
+ ),
439
+ )
440
+ except Exception as ex :
441
+ self .assertTrue (False , "unexpected error {}" .format (ex ))
442
+
414
443
415
444
if __name__ == "__main__" :
416
445
unittest .main ()
0 commit comments