Skip to content

Commit 4278a26

Browse files
committed
Add test for sequence model instance update
1 parent 4f487a0 commit 4278a26

File tree

3 files changed

+123
-19
lines changed

3 files changed

+123
-19
lines changed

qa/L0_model_update/instance_update_test.py

+99-16
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tritonclient.utils import InferenceServerException
3535
from models.model_init_del.util import (get_count, reset_count, set_delay,
3636
update_instance_group,
37+
update_sequence_batching,
3738
update_model_file, enable_batching,
3839
disable_batching)
3940

@@ -43,9 +44,21 @@ class TestInstanceUpdate(unittest.TestCase):
4344
__model_name = "model_init_del"
4445

4546
def setUp(self):
46-
# Initialize client
47+
self.__reset_model()
4748
self.__triton = grpcclient.InferenceServerClient("localhost:8001")
4849

50+
def __reset_model(self):
51+
# Reset counters
52+
reset_count("initialize")
53+
reset_count("finalize")
54+
# Reset batching
55+
disable_batching()
56+
# Reset delays
57+
set_delay("initialize", 0)
58+
set_delay("infer", 0)
59+
# Reset sequence batching
60+
update_sequence_batching("")
61+
4962
def __get_inputs(self, batching=False):
5063
self.assertIsInstance(batching, bool)
5164
if batching:
@@ -85,14 +98,8 @@ def __check_count(self, kind, expected_count, poll=False):
8598
self.assertEqual(get_count(kind), expected_count)
8699

87100
def __load_model(self, instance_count, instance_config="", batching=False):
88-
# Reset counters
89-
reset_count("initialize")
90-
reset_count("finalize")
91101
# Set batching
92102
enable_batching() if batching else disable_batching()
93-
# Reset delays
94-
set_delay("initialize", 0)
95-
set_delay("infer", 0)
96103
# Load model
97104
self.__update_instance_count(instance_count,
98105
0,
@@ -143,6 +150,7 @@ def test_add_rm_add_instance(self):
143150
self.__update_instance_count(1, 0, batching=batching) # add
144151
stop()
145152
self.__unload_model(batching=batching)
153+
self.__reset_model() # for next iteration
146154

147155
# Test remove -> add -> remove an instance
148156
def test_rm_add_rm_instance(self):
@@ -154,6 +162,7 @@ def test_rm_add_rm_instance(self):
154162
self.__update_instance_count(0, 1, batching=batching) # remove
155163
stop()
156164
self.__unload_model(batching=batching)
165+
self.__reset_model() # for next iteration
157166

158167
# Test reduce instance count to zero
159168
def test_rm_instance_to_zero(self):
@@ -341,15 +350,89 @@ def infer():
341350
# Unload model
342351
self.__unload_model()
343352

344-
# Test for instance update on direct sequence scheduling
345-
@unittest.skip("Sequence will not continue after update [FIXME: DLIS-4820]")
346-
def test_instance_update_on_direct_sequence_scheduling(self):
347-
pass
348-
349-
# Test for instance update on oldest sequence scheduling
350-
@unittest.skip("Sequence will not continue after update [FIXME: DLIS-4820]")
351-
def test_instance_update_on_oldest_sequence_scheduling(self):
352-
pass
353+
# Test wait for in-flight sequence completion and block new sequence
354+
def test_sequence_instance_update(self):
355+
for sequence_batching_type in [
356+
"direct { }\nmax_sequence_idle_microseconds: 10000000",
357+
"oldest { max_candidate_sequences: 4 }\nmax_sequence_idle_microseconds: 10000000"
358+
]:
359+
# Load model
360+
update_instance_group("{\ncount: 2\nkind: KIND_CPU\n}")
361+
update_sequence_batching(sequence_batching_type)
362+
self.__triton.load_model(self.__model_name)
363+
self.__check_count("initialize", 2)
364+
self.__check_count("finalize", 0)
365+
# Basic sequence inference
366+
self.__triton.infer(self.__model_name,
367+
self.__get_inputs(),
368+
sequence_id=1,
369+
sequence_start=True)
370+
self.__triton.infer(self.__model_name,
371+
self.__get_inputs(),
372+
sequence_id=1)
373+
self.__triton.infer(self.__model_name,
374+
self.__get_inputs(),
375+
sequence_id=1,
376+
sequence_end=True)
377+
# Update instance
378+
update_instance_group("{\ncount: 4\nkind: KIND_CPU\n}")
379+
self.__triton.load_model(self.__model_name)
380+
self.__check_count("initialize", 4)
381+
self.__check_count("finalize", 0)
382+
# Start an in-flight sequence
383+
self.__triton.infer(self.__model_name,
384+
self.__get_inputs(),
385+
sequence_id=1,
386+
sequence_start=True)
387+
# Check update instance will wait for in-flight sequence completion
388+
# and block new sequence from starting.
389+
update_instance_group("{\ncount: 3\nkind: KIND_CPU\n}")
390+
update_complete = [False]
391+
def update():
392+
self.__triton.load_model(self.__model_name)
393+
update_complete[0] = True
394+
self.__check_count("initialize", 4)
395+
self.__check_count("finalize", 1)
396+
infer_complete = [False]
397+
def infer():
398+
self.__triton.infer(self.__model_name,
399+
self.__get_inputs(),
400+
sequence_id=2,
401+
sequence_start=True)
402+
infer_complete[0] = True
403+
with concurrent.futures.ThreadPoolExecutor() as pool:
404+
# Update should wait until sequence 1 end
405+
update_thread = pool.submit(update)
406+
time.sleep(2) # make sure update has started
407+
self.assertFalse(update_complete[0],
408+
"Unexpected update completion")
409+
# New sequence should wait until update complete
410+
infer_thread = pool.submit(infer)
411+
time.sleep(2) # make sure infer has started
412+
self.assertFalse(infer_complete[0],
413+
"Unexpected infer completion")
414+
# End sequence 1 should unblock update
415+
self.__triton.infer(self.__model_name,
416+
self.__get_inputs(),
417+
sequence_id=1,
418+
sequence_end=True)
419+
time.sleep(2) # make sure update has returned
420+
self.assertTrue(update_complete[0], "Update possibly stuck")
421+
update_thread.result()
422+
# Update completion should unblock new sequence
423+
time.sleep(2) # make sure infer has returned
424+
self.assertTrue(infer_complete[0], "Infer possibly stuck")
425+
infer_thread.result()
426+
# End sequence 2
427+
self.__triton.infer(self.__model_name,
428+
self.__get_inputs(),
429+
sequence_id=2,
430+
sequence_end=True)
431+
# Unload model
432+
self.__triton.unload_model(self.__model_name)
433+
self.__check_count("initialize", 4)
434+
self.__check_count("finalize", 4, True)
435+
self.__reset_model()
353436

354437

355438
if __name__ == "__main__":

qa/python_models/model_init_del/config.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ instance_group [
4949
count: 1
5050
kind: KIND_CPU
5151
}
52-
]
52+
] # end instance_group

qa/python_models/model_init_del/util.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,31 @@ def update_instance_group(instance_group_str):
113113
full_path = os.path.join(os.path.dirname(__file__), "config.pbtxt")
114114
with open(full_path, mode="r+", encoding="utf-8", errors="strict") as f:
115115
txt = f.read()
116-
txt = txt.split("instance_group [")[0]
116+
txt, post_match = txt.split("instance_group [")
117117
txt += "instance_group [\n"
118118
txt += instance_group_str
119-
txt += "\n]\n"
119+
txt += "\n] # end instance_group\n"
120+
txt += post_match.split("\n] # end instance_group\n")[1]
121+
f.truncate(0)
122+
f.seek(0)
123+
f.write(txt)
124+
return txt
125+
126+
def update_sequence_batching(sequence_batching_str):
127+
full_path = os.path.join(os.path.dirname(__file__), "config.pbtxt")
128+
with open(full_path, mode="r+", encoding="utf-8", errors="strict") as f:
129+
txt = f.read()
130+
if "sequence_batching {" in txt:
131+
txt, post_match = txt.split("sequence_batching {")
132+
if sequence_batching_str != "":
133+
txt += "sequence_batching {\n"
134+
txt += sequence_batching_str
135+
txt += "\n} # end sequence_batching\n"
136+
txt += post_match.split("\n} # end sequence_batching\n")[1]
137+
elif sequence_batching_str != "":
138+
txt += "\nsequence_batching {\n"
139+
txt += sequence_batching_str
140+
txt += "\n} # end sequence_batching\n"
120141
f.truncate(0)
121142
f.seek(0)
122143
f.write(txt)

0 commit comments

Comments
 (0)