Skip to content

Commit 788b19d

Browse files
tanmayv25rmccorm4
andauthored
Add tests for gRPC client-side cancellation (#6278)
* Add tests for gRPC client-side cancellation * Fix CodeQL issues * Formatting * Update qa/L0_client_cancellation/client_cancellation_test.py Co-authored-by: Ryan McCormick <[email protected]> * Move to L0_request_cancellation * Address review comments * Removing request cancellation support from asyncio version * Format * Update copyright * Remove tests * Handle cancellation notification in gRPC server (#6298) * Handle cancellation notification in gRPC server * Fix the request ptr initialization * Update src/grpc/infer_handler.h Co-authored-by: Ryan McCormick <[email protected]> * Address review comment * Fix logs * Fix request complete callback by removing reference to state * Improve documentation --------- Co-authored-by: Ryan McCormick <[email protected]> --------- Co-authored-by: Ryan McCormick <[email protected]>
1 parent 6748dc4 commit 788b19d

File tree

6 files changed

+623
-17
lines changed

6 files changed

+623
-17
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions
7+
# are met:
8+
# * Redistributions of source code must retain the above copyright
9+
# notice, this list of conditions and the following disclaimer.
10+
# * Redistributions in binary form must reproduce the above copyright
11+
# notice, this list of conditions and the following disclaimer in the
12+
# documentation and/or other materials provided with the distribution.
13+
# * Neither the name of NVIDIA CORPORATION nor the names of its
14+
# contributors may be used to endorse or promote products derived
15+
# from this software without specific prior written permission.
16+
#
17+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
18+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
20+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
21+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
25+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import sys
30+
31+
sys.path.append("../common")
32+
33+
import asyncio
34+
import queue
35+
import time
36+
import unittest
37+
from functools import partial
38+
39+
import numpy as np
40+
import test_util as tu
41+
import tritonclient.grpc as grpcclient
42+
import tritonclient.grpc.aio as aiogrpcclient
43+
from tritonclient.utils import InferenceServerException
44+
45+
46+
class UserData:
47+
def __init__(self):
48+
self._completed_requests = queue.Queue()
49+
50+
51+
def callback(user_data, result, error):
52+
if error:
53+
user_data._completed_requests.put(error)
54+
else:
55+
user_data._completed_requests.put(result)
56+
57+
58+
class ClientCancellationTest(tu.TestResultCollector):
59+
def setUp(self):
60+
self.model_name_ = "custom_identity_int32"
61+
self.input0_data_ = np.array([[10]], dtype=np.int32)
62+
self._start_time_ms = 0
63+
self._end_time_ms = 0
64+
65+
def _record_start_time_ms(self):
66+
self._start_time_ms = int(round(time.time() * 1000))
67+
68+
def _record_end_time_ms(self):
69+
self._end_time_ms = int(round(time.time() * 1000))
70+
71+
def _test_runtime_duration(self, upper_limit):
72+
self.assertTrue(
73+
(self._end_time_ms - self._start_time_ms) < upper_limit,
74+
"test runtime expected less than "
75+
+ str(upper_limit)
76+
+ "ms response time, got "
77+
+ str(self._end_time_ms - self._start_time_ms)
78+
+ " ms",
79+
)
80+
81+
def _prepare_request(self):
82+
self.inputs_ = []
83+
self.inputs_.append(grpcclient.InferInput("INPUT0", [1, 1], "INT32"))
84+
self.outputs_ = []
85+
self.outputs_.append(grpcclient.InferRequestedOutput("OUTPUT0"))
86+
87+
self.inputs_[0].set_data_from_numpy(self.input0_data_)
88+
89+
def test_grpc_async_infer(self):
90+
# Sends a request using async_infer to a
91+
# model that takes 10s to execute. Issues
92+
# a cancellation request after 2s. The client
93+
# should return with appropriate exception within
94+
# 5s.
95+
triton_client = grpcclient.InferenceServerClient(
96+
url="localhost:8001", verbose=True
97+
)
98+
self._prepare_request()
99+
100+
user_data = UserData()
101+
102+
self._record_start_time_ms()
103+
104+
with self.assertRaises(InferenceServerException) as cm:
105+
future = triton_client.async_infer(
106+
model_name=self.model_name_,
107+
inputs=self.inputs_,
108+
callback=partial(callback, user_data),
109+
outputs=self.outputs_,
110+
)
111+
time.sleep(2)
112+
future.cancel()
113+
114+
data_item = user_data._completed_requests.get()
115+
if type(data_item) == InferenceServerException:
116+
raise data_item
117+
self.assertIn("Locally cancelled by application!", str(cm.exception))
118+
119+
self._record_end_time_ms()
120+
self._test_runtime_duration(5000)
121+
122+
def test_grpc_stream_infer(self):
123+
# Sends a request using async_stream_infer to a
124+
# model that takes 10s to execute. Issues stream
125+
# closure with cancel_requests=True. The client
126+
# should return with appropriate exception within
127+
# 5s.
128+
triton_client = grpcclient.InferenceServerClient(
129+
url="localhost:8001", verbose=True
130+
)
131+
132+
self._prepare_request()
133+
user_data = UserData()
134+
135+
triton_client.start_stream(callback=partial(callback, user_data))
136+
self._record_start_time_ms()
137+
138+
with self.assertRaises(InferenceServerException) as cm:
139+
for i in range(1):
140+
triton_client.async_stream_infer(
141+
model_name=self.model_name_,
142+
inputs=self.inputs_,
143+
outputs=self.outputs_,
144+
)
145+
time.sleep(2)
146+
triton_client.stop_stream(cancel_requests=True)
147+
data_item = user_data._completed_requests.get()
148+
if type(data_item) == InferenceServerException:
149+
raise data_item
150+
self.assertIn("Locally cancelled by application!", str(cm.exception))
151+
152+
self._record_end_time_ms()
153+
self._test_runtime_duration(5000)
154+
155+
156+
# Disabling AsyncIO cancellation testing. Enable once
157+
# DLIS-5476 is implemented.
158+
# def test_aio_grpc_async_infer(self):
159+
# # Sends a request using infer of grpc.aio to a
160+
# # model that takes 10s to execute. Issues
161+
# # a cancellation request after 2s. The client
162+
# # should return with appropriate exception within
163+
# # 5s.
164+
# async def cancel_request(call):
165+
# await asyncio.sleep(2)
166+
# self.assertTrue(call.cancel())
167+
#
168+
# async def handle_response(generator):
169+
# with self.assertRaises(asyncio.exceptions.CancelledError) as cm:
170+
# _ = await anext(generator)
171+
#
172+
# async def test_aio_infer(self):
173+
# triton_client = aiogrpcclient.InferenceServerClient(
174+
# url="localhost:8001", verbose=True
175+
# )
176+
# self._prepare_request()
177+
# self._record_start_time_ms()
178+
#
179+
# generator = triton_client.infer(
180+
# model_name=self.model_name_,
181+
# inputs=self.inputs_,
182+
# outputs=self.outputs_,
183+
# get_call_obj=True,
184+
# )
185+
# grpc_call = await anext(generator)
186+
#
187+
# tasks = []
188+
# tasks.append(asyncio.create_task(handle_response(generator)))
189+
# tasks.append(asyncio.create_task(cancel_request(grpc_call)))
190+
#
191+
# for task in tasks:
192+
# await task
193+
#
194+
# self._record_end_time_ms()
195+
# self._test_runtime_duration(5000)
196+
#
197+
# asyncio.run(test_aio_infer(self))
198+
#
199+
# def test_aio_grpc_stream_infer(self):
200+
# # Sends a request using stream_infer of grpc.aio
201+
# # library model that takes 10s to execute. Issues
202+
# # stream closure with cancel_requests=True. The client
203+
# # should return with appropriate exception within
204+
# # 5s.
205+
# async def test_aio_streaming_infer(self):
206+
# async with aiogrpcclient.InferenceServerClient(
207+
# url="localhost:8001", verbose=True
208+
# ) as triton_client:
209+
#
210+
# async def async_request_iterator():
211+
# for i in range(1):
212+
# await asyncio.sleep(1)
213+
# yield {
214+
# "model_name": self.model_name_,
215+
# "inputs": self.inputs_,
216+
# "outputs": self.outputs_,
217+
# }
218+
#
219+
# self._prepare_request()
220+
# self._record_start_time_ms()
221+
# response_iterator = triton_client.stream_infer(
222+
# inputs_iterator=async_request_iterator(), get_call_obj=True
223+
# )
224+
# streaming_call = await anext(response_iterator)
225+
#
226+
# async def cancel_streaming(streaming_call):
227+
# await asyncio.sleep(2)
228+
# streaming_call.cancel()
229+
#
230+
# async def handle_response(response_iterator):
231+
# with self.assertRaises(asyncio.exceptions.CancelledError) as cm:
232+
# async for response in response_iterator:
233+
# self.assertTrue(False, "Received an unexpected response!")
234+
#
235+
# tasks = []
236+
# tasks.append(asyncio.create_task(handle_response(response_iterator)))
237+
# tasks.append(asyncio.create_task(cancel_streaming(streaming_call)))
238+
#
239+
# for task in tasks:
240+
# await task
241+
#
242+
# self._record_end_time_ms()
243+
# self._test_runtime_duration(5000)
244+
#
245+
# asyncio.run(test_aio_streaming_infer(self))
246+
247+
248+
if __name__ == "__main__":
249+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
name: "custom_identity_int32"
28+
backend: "identity"
29+
max_batch_size: 1024
30+
version_policy: { latest { num_versions: 1 }}
31+
instance_group [ { kind: KIND_CPU } ]
32+
33+
input [
34+
{
35+
name: "INPUT0"
36+
data_type: TYPE_INT32
37+
dims: [ -1 ]
38+
39+
}
40+
]
41+
output [
42+
{
43+
name: "OUTPUT0"
44+
data_type: TYPE_INT32
45+
dims: [ -1 ]
46+
}
47+
]
48+
49+
parameters [
50+
{
51+
key: "execute_delay_ms"
52+
value: { string_value: "10000" }
53+
}
54+
]

qa/L0_request_cancellation/test.sh

+50
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ DATADIR=${DATADIR:="/data/inferenceserver/${REPO_VERSION}"}
4242
RET=0
4343

4444
mkdir -p models/model/1
45+
mkdir -p $DATADIR/custom_identity_int32/1
46+
47+
export CUDA_VISIBLE_DEVICES=0
48+
49+
RET=0
50+
51+
CLIENT_CANCELLATION_TEST=client_cancellation_test.py
52+
TEST_RESULT_FILE='test_results.txt'
53+
54+
rm -f *.log
55+
rm -f *.log.*
56+
57+
CLIENT_LOG=`pwd`/client.log
58+
DATADIR=`pwd`/models
59+
SERVER=/opt/tritonserver/bin/tritonserver
60+
SERVER_ARGS="--model-repository=$DATADIR --log-verbose=1"
61+
source ../common/util.sh
4562

4663
SERVER_LOG=server.log
4764
LD_LIBRARY_PATH=/opt/tritonserver/lib:$LD_LIBRARY_PATH ./request_cancellation_test > $SERVER_LOG
@@ -50,6 +67,39 @@ if [ $? -ne 0 ]; then
5067
RET=1
5168
fi
5269

70+
# gRPC client-side cancellation tests...
71+
for i in test_grpc_async_infer \
72+
test_grpc_stream_infer \
73+
; do
74+
75+
SERVER_LOG=${i}.server.log
76+
run_server
77+
if [ "$SERVER_PID" == "0" ]; then
78+
echo -e "\n***\n*** Failed to start $SERVER\n***"
79+
cat $SERVER_LOG
80+
exit 1
81+
fi
82+
83+
set +e
84+
python $CLIENT_CANCELLATION_TEST ClientCancellationTest.$i >>$CLIENT_LOG 2>&1
85+
if [ $? -ne 0 ]; then
86+
echo -e "\n***\n*** Test $i Failed\n***" >>$CLIENT_LOG
87+
echo -e "\n***\n*** Test $i Failed\n***"
88+
RET=1
89+
else
90+
check_test_results $TEST_RESULT_FILE 1
91+
if [ $? -ne 0 ]; then
92+
cat $CLIENT_LOG
93+
echo -e "\n***\n*** Test Result Verification Failed\n***"
94+
RET=1
95+
fi
96+
fi
97+
98+
set -e
99+
kill $SERVER_PID
100+
wait $SERVER_PID
101+
done
102+
53103
if [ $RET -eq 0 ]; then
54104
echo -e "\n***\n*** Test Passed\n***"
55105
else

0 commit comments

Comments
 (0)