35
35
from tritonclient .utils import InferenceServerException
36
36
from models .model_init_del .util import (get_count , reset_count , set_delay ,
37
37
update_instance_group ,
38
+ update_sequence_batching ,
38
39
update_model_file , enable_batching ,
39
40
disable_batching )
40
41
@@ -44,9 +45,21 @@ class TestInstanceUpdate(unittest.TestCase):
44
45
__model_name = "model_init_del"
45
46
46
47
def setUp (self ):
47
- # Initialize client
48
+ self . __reset_model ()
48
49
self .__triton = grpcclient .InferenceServerClient ("localhost:8001" )
49
50
51
+ def __reset_model (self ):
52
+ # Reset counters
53
+ reset_count ("initialize" )
54
+ reset_count ("finalize" )
55
+ # Reset batching
56
+ disable_batching ()
57
+ # Reset delays
58
+ set_delay ("initialize" , 0 )
59
+ set_delay ("infer" , 0 )
60
+ # Reset sequence batching
61
+ update_sequence_batching ("" )
62
+
50
63
def __get_inputs (self , batching = False ):
51
64
self .assertIsInstance (batching , bool )
52
65
if batching :
@@ -90,14 +103,8 @@ def __check_count(self, kind, expected_count, poll=False):
90
103
self .assertEqual (get_count (kind ), expected_count )
91
104
92
105
def __load_model (self , instance_count , instance_config = "" , batching = False ):
93
- # Reset counters
94
- reset_count ("initialize" )
95
- reset_count ("finalize" )
96
106
# Set batching
97
107
enable_batching () if batching else disable_batching ()
98
- # Reset delays
99
- set_delay ("initialize" , 0 )
100
- set_delay ("infer" , 0 )
101
108
# Load model
102
109
self .__update_instance_count (instance_count ,
103
110
0 ,
@@ -148,6 +155,7 @@ def test_add_rm_add_instance(self):
148
155
self .__update_instance_count (1 , 0 , batching = batching ) # add
149
156
stop ()
150
157
self .__unload_model (batching = batching )
158
+ self .__reset_model () # for next iteration
151
159
152
160
# Test remove -> add -> remove an instance
153
161
def test_rm_add_rm_instance (self ):
@@ -159,6 +167,7 @@ def test_rm_add_rm_instance(self):
159
167
self .__update_instance_count (0 , 1 , batching = batching ) # remove
160
168
stop ()
161
169
self .__unload_model (batching = batching )
170
+ self .__reset_model () # for next iteration
162
171
163
172
# Test reduce instance count to zero
164
173
def test_rm_instance_to_zero (self ):
@@ -457,15 +466,89 @@ def test_instance_resource_decrease(self):
457
466
# explicit limit of 10 is set.
458
467
self .assertNotIn ("Resource: R1\t Count: 3" , f .read ())
459
468
460
- # Test for instance update on direct sequence scheduling
461
- @unittest .skip ("Sequence will not continue after update [FIXME: DLIS-4820]" )
462
- def test_instance_update_on_direct_sequence_scheduling (self ):
463
- pass
464
-
465
- # Test for instance update on oldest sequence scheduling
466
- @unittest .skip ("Sequence will not continue after update [FIXME: DLIS-4820]" )
467
- def test_instance_update_on_oldest_sequence_scheduling (self ):
468
- pass
469
+ # Test wait for in-flight sequence completion and block new sequence
470
+ def test_sequence_instance_update (self ):
471
+ for sequence_batching_type in [
472
+ "direct { }\n max_sequence_idle_microseconds: 10000000" ,
473
+ "oldest { max_candidate_sequences: 4 }\n max_sequence_idle_microseconds: 10000000"
474
+ ]:
475
+ # Load model
476
+ update_instance_group ("{\n count: 2\n kind: KIND_CPU\n }" )
477
+ update_sequence_batching (sequence_batching_type )
478
+ self .__triton .load_model (self .__model_name )
479
+ self .__check_count ("initialize" , 2 )
480
+ self .__check_count ("finalize" , 0 )
481
+ # Basic sequence inference
482
+ self .__triton .infer (self .__model_name ,
483
+ self .__get_inputs (),
484
+ sequence_id = 1 ,
485
+ sequence_start = True )
486
+ self .__triton .infer (self .__model_name ,
487
+ self .__get_inputs (),
488
+ sequence_id = 1 )
489
+ self .__triton .infer (self .__model_name ,
490
+ self .__get_inputs (),
491
+ sequence_id = 1 ,
492
+ sequence_end = True )
493
+ # Update instance
494
+ update_instance_group ("{\n count: 4\n kind: KIND_CPU\n }" )
495
+ self .__triton .load_model (self .__model_name )
496
+ self .__check_count ("initialize" , 4 )
497
+ self .__check_count ("finalize" , 0 )
498
+ # Start an in-flight sequence
499
+ self .__triton .infer (self .__model_name ,
500
+ self .__get_inputs (),
501
+ sequence_id = 1 ,
502
+ sequence_start = True )
503
+ # Check update instance will wait for in-flight sequence completion
504
+ # and block new sequence from starting.
505
+ update_instance_group ("{\n count: 3\n kind: KIND_CPU\n }" )
506
+ update_complete = [False ]
507
+ def update ():
508
+ self .__triton .load_model (self .__model_name )
509
+ update_complete [0 ] = True
510
+ self .__check_count ("initialize" , 4 )
511
+ self .__check_count ("finalize" , 1 )
512
+ infer_complete = [False ]
513
+ def infer ():
514
+ self .__triton .infer (self .__model_name ,
515
+ self .__get_inputs (),
516
+ sequence_id = 2 ,
517
+ sequence_start = True )
518
+ infer_complete [0 ] = True
519
+ with concurrent .futures .ThreadPoolExecutor () as pool :
520
+ # Update should wait until sequence 1 end
521
+ update_thread = pool .submit (update )
522
+ time .sleep (2 ) # make sure update has started
523
+ self .assertFalse (update_complete [0 ],
524
+ "Unexpected update completion" )
525
+ # New sequence should wait until update complete
526
+ infer_thread = pool .submit (infer )
527
+ time .sleep (2 ) # make sure infer has started
528
+ self .assertFalse (infer_complete [0 ],
529
+ "Unexpected infer completion" )
530
+ # End sequence 1 should unblock update
531
+ self .__triton .infer (self .__model_name ,
532
+ self .__get_inputs (),
533
+ sequence_id = 1 ,
534
+ sequence_end = True )
535
+ time .sleep (2 ) # make sure update has returned
536
+ self .assertTrue (update_complete [0 ], "Update possibly stuck" )
537
+ update_thread .result ()
538
+ # Update completion should unblock new sequence
539
+ time .sleep (2 ) # make sure infer has returned
540
+ self .assertTrue (infer_complete [0 ], "Infer possibly stuck" )
541
+ infer_thread .result ()
542
+ # End sequence 2
543
+ self .__triton .infer (self .__model_name ,
544
+ self .__get_inputs (),
545
+ sequence_id = 2 ,
546
+ sequence_end = True )
547
+ # Unload model
548
+ self .__triton .unload_model (self .__model_name )
549
+ self .__check_count ("initialize" , 4 )
550
+ self .__check_count ("finalize" , 4 , True )
551
+ self .__reset_model ()
469
552
470
553
471
554
if __name__ == "__main__" :
0 commit comments