Skip to content

Commit 2c4d8f8

Browse files
committed
Add test for sequence model instance update
1 parent 565f306 commit 2c4d8f8

File tree

3 files changed

+123
-19
lines changed

3 files changed

+123
-19
lines changed

qa/L0_model_update/instance_update_test.py

+99-16
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from tritonclient.utils import InferenceServerException
3636
from models.model_init_del.util import (get_count, reset_count, set_delay,
3737
update_instance_group,
38+
update_sequence_batching,
3839
update_model_file, enable_batching,
3940
disable_batching)
4041

@@ -44,9 +45,21 @@ class TestInstanceUpdate(unittest.TestCase):
4445
__model_name = "model_init_del"
4546

4647
def setUp(self):
47-
# Initialize client
48+
self.__reset_model()
4849
self.__triton = grpcclient.InferenceServerClient("localhost:8001")
4950

51+
def __reset_model(self):
52+
# Reset counters
53+
reset_count("initialize")
54+
reset_count("finalize")
55+
# Reset batching
56+
disable_batching()
57+
# Reset delays
58+
set_delay("initialize", 0)
59+
set_delay("infer", 0)
60+
# Reset sequence batching
61+
update_sequence_batching("")
62+
5063
def __get_inputs(self, batching=False):
5164
self.assertIsInstance(batching, bool)
5265
if batching:
@@ -90,14 +103,8 @@ def __check_count(self, kind, expected_count, poll=False):
90103
self.assertEqual(get_count(kind), expected_count)
91104

92105
def __load_model(self, instance_count, instance_config="", batching=False):
93-
# Reset counters
94-
reset_count("initialize")
95-
reset_count("finalize")
96106
# Set batching
97107
enable_batching() if batching else disable_batching()
98-
# Reset delays
99-
set_delay("initialize", 0)
100-
set_delay("infer", 0)
101108
# Load model
102109
self.__update_instance_count(instance_count,
103110
0,
@@ -148,6 +155,7 @@ def test_add_rm_add_instance(self):
148155
self.__update_instance_count(1, 0, batching=batching) # add
149156
stop()
150157
self.__unload_model(batching=batching)
158+
self.__reset_model() # for next iteration
151159

152160
# Test remove -> add -> remove an instance
153161
def test_rm_add_rm_instance(self):
@@ -159,6 +167,7 @@ def test_rm_add_rm_instance(self):
159167
self.__update_instance_count(0, 1, batching=batching) # remove
160168
stop()
161169
self.__unload_model(batching=batching)
170+
self.__reset_model() # for next iteration
162171

163172
# Test reduce instance count to zero
164173
def test_rm_instance_to_zero(self):
@@ -457,15 +466,89 @@ def test_instance_resource_decrease(self):
457466
# explicit limit of 10 is set.
458467
self.assertNotIn("Resource: R1\t Count: 3", f.read())
459468

460-
# Test for instance update on direct sequence scheduling
461-
@unittest.skip("Sequence will not continue after update [FIXME: DLIS-4820]")
462-
def test_instance_update_on_direct_sequence_scheduling(self):
463-
pass
464-
465-
# Test for instance update on oldest sequence scheduling
466-
@unittest.skip("Sequence will not continue after update [FIXME: DLIS-4820]")
467-
def test_instance_update_on_oldest_sequence_scheduling(self):
468-
pass
469+
# Test wait for in-flight sequence completion and block new sequence
470+
def test_sequence_instance_update(self):
471+
for sequence_batching_type in [
472+
"direct { }\nmax_sequence_idle_microseconds: 10000000",
473+
"oldest { max_candidate_sequences: 4 }\nmax_sequence_idle_microseconds: 10000000"
474+
]:
475+
# Load model
476+
update_instance_group("{\ncount: 2\nkind: KIND_CPU\n}")
477+
update_sequence_batching(sequence_batching_type)
478+
self.__triton.load_model(self.__model_name)
479+
self.__check_count("initialize", 2)
480+
self.__check_count("finalize", 0)
481+
# Basic sequence inference
482+
self.__triton.infer(self.__model_name,
483+
self.__get_inputs(),
484+
sequence_id=1,
485+
sequence_start=True)
486+
self.__triton.infer(self.__model_name,
487+
self.__get_inputs(),
488+
sequence_id=1)
489+
self.__triton.infer(self.__model_name,
490+
self.__get_inputs(),
491+
sequence_id=1,
492+
sequence_end=True)
493+
# Update instance
494+
update_instance_group("{\ncount: 4\nkind: KIND_CPU\n}")
495+
self.__triton.load_model(self.__model_name)
496+
self.__check_count("initialize", 4)
497+
self.__check_count("finalize", 0)
498+
# Start an in-flight sequence
499+
self.__triton.infer(self.__model_name,
500+
self.__get_inputs(),
501+
sequence_id=1,
502+
sequence_start=True)
503+
# Check update instance will wait for in-flight sequence completion
504+
# and block new sequence from starting.
505+
update_instance_group("{\ncount: 3\nkind: KIND_CPU\n}")
506+
update_complete = [False]
507+
def update():
508+
self.__triton.load_model(self.__model_name)
509+
update_complete[0] = True
510+
self.__check_count("initialize", 4)
511+
self.__check_count("finalize", 1)
512+
infer_complete = [False]
513+
def infer():
514+
self.__triton.infer(self.__model_name,
515+
self.__get_inputs(),
516+
sequence_id=2,
517+
sequence_start=True)
518+
infer_complete[0] = True
519+
with concurrent.futures.ThreadPoolExecutor() as pool:
520+
# Update should wait until sequence 1 end
521+
update_thread = pool.submit(update)
522+
time.sleep(2) # make sure update has started
523+
self.assertFalse(update_complete[0],
524+
"Unexpected update completion")
525+
# New sequence should wait until update complete
526+
infer_thread = pool.submit(infer)
527+
time.sleep(2) # make sure infer has started
528+
self.assertFalse(infer_complete[0],
529+
"Unexpected infer completion")
530+
# End sequence 1 should unblock update
531+
self.__triton.infer(self.__model_name,
532+
self.__get_inputs(),
533+
sequence_id=1,
534+
sequence_end=True)
535+
time.sleep(2) # make sure update has returned
536+
self.assertTrue(update_complete[0], "Update possibly stuck")
537+
update_thread.result()
538+
# Update completion should unblock new sequence
539+
time.sleep(2) # make sure infer has returned
540+
self.assertTrue(infer_complete[0], "Infer possibly stuck")
541+
infer_thread.result()
542+
# End sequence 2
543+
self.__triton.infer(self.__model_name,
544+
self.__get_inputs(),
545+
sequence_id=2,
546+
sequence_end=True)
547+
# Unload model
548+
self.__triton.unload_model(self.__model_name)
549+
self.__check_count("initialize", 4)
550+
self.__check_count("finalize", 4, True)
551+
self.__reset_model()
469552

470553

471554
if __name__ == "__main__":

qa/python_models/model_init_del/config.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ instance_group [
4949
count: 1
5050
kind: KIND_CPU
5151
}
52-
]
52+
] # end instance_group

qa/python_models/model_init_del/util.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,31 @@ def update_instance_group(instance_group_str):
127127
full_path = os.path.join(os.path.dirname(__file__), "config.pbtxt")
128128
with open(full_path, mode="r+", encoding="utf-8", errors="strict") as f:
129129
txt = f.read()
130-
txt = txt.split("instance_group [")[0]
130+
txt, post_match = txt.split("instance_group [")
131131
txt += "instance_group [\n"
132132
txt += instance_group_str
133-
txt += "\n]\n"
133+
txt += "\n] # end instance_group\n"
134+
txt += post_match.split("\n] # end instance_group\n")[1]
135+
f.truncate(0)
136+
f.seek(0)
137+
f.write(txt)
138+
return txt
139+
140+
def update_sequence_batching(sequence_batching_str):
141+
full_path = os.path.join(os.path.dirname(__file__), "config.pbtxt")
142+
with open(full_path, mode="r+", encoding="utf-8", errors="strict") as f:
143+
txt = f.read()
144+
if "sequence_batching {" in txt:
145+
txt, post_match = txt.split("sequence_batching {")
146+
if sequence_batching_str != "":
147+
txt += "sequence_batching {\n"
148+
txt += sequence_batching_str
149+
txt += "\n} # end sequence_batching\n"
150+
txt += post_match.split("\n} # end sequence_batching\n")[1]
151+
elif sequence_batching_str != "":
152+
txt += "\nsequence_batching {\n"
153+
txt += sequence_batching_str
154+
txt += "\n} # end sequence_batching\n"
134155
f.truncate(0)
135156
f.seek(0)
136157
f.write(txt)

0 commit comments

Comments
 (0)