34
34
from tritonclient .utils import InferenceServerException
35
35
from models .model_init_del .util import (get_count , reset_count , set_delay ,
36
36
update_instance_group ,
37
+ update_sequence_batching ,
37
38
update_model_file , enable_batching ,
38
39
disable_batching )
39
40
@@ -43,9 +44,21 @@ class TestInstanceUpdate(unittest.TestCase):
43
44
__model_name = "model_init_del"
44
45
45
46
def setUp (self ):
46
- # Initialize client
47
+ self . __reset_model ()
47
48
self .__triton = grpcclient .InferenceServerClient ("localhost:8001" )
48
49
50
+ def __reset_model (self ):
51
+ # Reset counters
52
+ reset_count ("initialize" )
53
+ reset_count ("finalize" )
54
+ # Reset batching
55
+ disable_batching ()
56
+ # Reset delays
57
+ set_delay ("initialize" , 0 )
58
+ set_delay ("infer" , 0 )
59
+ # Reset sequence batching
60
+ update_sequence_batching ("" )
61
+
49
62
def __get_inputs (self , batching = False ):
50
63
self .assertIsInstance (batching , bool )
51
64
if batching :
@@ -85,14 +98,8 @@ def __check_count(self, kind, expected_count, poll=False):
85
98
self .assertEqual (get_count (kind ), expected_count )
86
99
87
100
def __load_model (self , instance_count , instance_config = "" , batching = False ):
88
- # Reset counters
89
- reset_count ("initialize" )
90
- reset_count ("finalize" )
91
101
# Set batching
92
102
enable_batching () if batching else disable_batching ()
93
- # Reset delays
94
- set_delay ("initialize" , 0 )
95
- set_delay ("infer" , 0 )
96
103
# Load model
97
104
self .__update_instance_count (instance_count ,
98
105
0 ,
@@ -143,6 +150,7 @@ def test_add_rm_add_instance(self):
143
150
self .__update_instance_count (1 , 0 , batching = batching ) # add
144
151
stop ()
145
152
self .__unload_model (batching = batching )
153
+ self .__reset_model () # for next iteration
146
154
147
155
# Test remove -> add -> remove an instance
148
156
def test_rm_add_rm_instance (self ):
@@ -154,6 +162,7 @@ def test_rm_add_rm_instance(self):
154
162
self .__update_instance_count (0 , 1 , batching = batching ) # remove
155
163
stop ()
156
164
self .__unload_model (batching = batching )
165
+ self .__reset_model () # for next iteration
157
166
158
167
# Test reduce instance count to zero
159
168
def test_rm_instance_to_zero (self ):
@@ -341,15 +350,89 @@ def infer():
341
350
# Unload model
342
351
self .__unload_model ()
343
352
344
- # Test for instance update on direct sequence scheduling
345
- @unittest .skip ("Sequence will not continue after update [FIXME: DLIS-4820]" )
346
- def test_instance_update_on_direct_sequence_scheduling (self ):
347
- pass
348
-
349
- # Test for instance update on oldest sequence scheduling
350
- @unittest .skip ("Sequence will not continue after update [FIXME: DLIS-4820]" )
351
- def test_instance_update_on_oldest_sequence_scheduling (self ):
352
- pass
353
+ # Test wait for in-flight sequence completion and block new sequence
354
+ def test_sequence_instance_update (self ):
355
+ for sequence_batching_type in [
356
+ "direct { }\n max_sequence_idle_microseconds: 10000000" ,
357
+ "oldest { max_candidate_sequences: 4 }\n max_sequence_idle_microseconds: 10000000"
358
+ ]:
359
+ # Load model
360
+ update_instance_group ("{\n count: 2\n kind: KIND_CPU\n }" )
361
+ update_sequence_batching (sequence_batching_type )
362
+ self .__triton .load_model (self .__model_name )
363
+ self .__check_count ("initialize" , 2 )
364
+ self .__check_count ("finalize" , 0 )
365
+ # Basic sequence inference
366
+ self .__triton .infer (self .__model_name ,
367
+ self .__get_inputs (),
368
+ sequence_id = 1 ,
369
+ sequence_start = True )
370
+ self .__triton .infer (self .__model_name ,
371
+ self .__get_inputs (),
372
+ sequence_id = 1 )
373
+ self .__triton .infer (self .__model_name ,
374
+ self .__get_inputs (),
375
+ sequence_id = 1 ,
376
+ sequence_end = True )
377
+ # Update instance
378
+ update_instance_group ("{\n count: 4\n kind: KIND_CPU\n }" )
379
+ self .__triton .load_model (self .__model_name )
380
+ self .__check_count ("initialize" , 4 )
381
+ self .__check_count ("finalize" , 0 )
382
+ # Start an in-flight sequence
383
+ self .__triton .infer (self .__model_name ,
384
+ self .__get_inputs (),
385
+ sequence_id = 1 ,
386
+ sequence_start = True )
387
+ # Check update instance will wait for in-flight sequence completion
388
+ # and block new sequence from starting.
389
+ update_instance_group ("{\n count: 3\n kind: KIND_CPU\n }" )
390
+ update_complete = [False ]
391
+ def update ():
392
+ self .__triton .load_model (self .__model_name )
393
+ update_complete [0 ] = True
394
+ self .__check_count ("initialize" , 4 )
395
+ self .__check_count ("finalize" , 1 )
396
+ infer_complete = [False ]
397
+ def infer ():
398
+ self .__triton .infer (self .__model_name ,
399
+ self .__get_inputs (),
400
+ sequence_id = 2 ,
401
+ sequence_start = True )
402
+ infer_complete [0 ] = True
403
+ with concurrent .futures .ThreadPoolExecutor () as pool :
404
+ # Update should wait until sequence 1 end
405
+ update_thread = pool .submit (update )
406
+ time .sleep (2 ) # make sure update has started
407
+ self .assertFalse (update_complete [0 ],
408
+ "Unexpected update completion" )
409
+ # New sequence should wait until update complete
410
+ infer_thread = pool .submit (infer )
411
+ time .sleep (2 ) # make sure infer has started
412
+ self .assertFalse (infer_complete [0 ],
413
+ "Unexpected infer completion" )
414
+ # End sequence 1 should unblock update
415
+ self .__triton .infer (self .__model_name ,
416
+ self .__get_inputs (),
417
+ sequence_id = 1 ,
418
+ sequence_end = True )
419
+ time .sleep (2 ) # make sure update has returned
420
+ self .assertTrue (update_complete [0 ], "Update possibly stuck" )
421
+ update_thread .result ()
422
+ # Update completion should unblock new sequence
423
+ time .sleep (2 ) # make sure infer has returned
424
+ self .assertTrue (infer_complete [0 ], "Infer possibly stuck" )
425
+ infer_thread .result ()
426
+ # End sequence 2
427
+ self .__triton .infer (self .__model_name ,
428
+ self .__get_inputs (),
429
+ sequence_id = 2 ,
430
+ sequence_end = True )
431
+ # Unload model
432
+ self .__triton .unload_model (self .__model_name )
433
+ self .__check_count ("initialize" , 4 )
434
+ self .__check_count ("finalize" , 4 , True )
435
+ self .__reset_model ()
353
436
354
437
355
438
if __name__ == "__main__" :
0 commit comments