32
32
33
33
import asyncio
34
34
import queue
35
+ import time
35
36
import unittest
36
37
from functools import partial
37
- import time
38
38
39
39
import numpy as np
40
40
import test_util as tu
@@ -54,6 +54,7 @@ def callback(user_data, result, error):
54
54
else :
55
55
user_data ._completed_requests .put (result )
56
56
57
+
57
58
class ClientCancellationTest (tu .TestResultCollector ):
58
59
def setUp (self ):
59
60
self .model_name_ = "custom_identity_int32"
@@ -69,13 +70,13 @@ def _record_end_time_ms(self):
69
70
70
71
def _test_runtime_duration (self , upper_limit ):
71
72
self .assertTrue (
72
- (self ._end_time_ms - self ._start_time_ms ) < upper_limit ,
73
- "test runtime expected less than "
74
- + str (upper_limit )
75
- + "ms response time, got "
76
- + str (self ._end_time_ms - self ._start_time_ms )
77
- + " ms" ,
78
- )
73
+ (self ._end_time_ms - self ._start_time_ms ) < upper_limit ,
74
+ "test runtime expected less than "
75
+ + str (upper_limit )
76
+ + "ms response time, got "
77
+ + str (self ._end_time_ms - self ._start_time_ms )
78
+ + " ms" ,
79
+ )
79
80
80
81
def _prepare_request (self ):
81
82
self .inputs_ = []
@@ -85,7 +86,6 @@ def _prepare_request(self):
85
86
86
87
self .inputs_ [0 ].set_data_from_numpy (self .input0_data_ )
87
88
88
-
89
89
def test_grpc_async_infer (self ):
90
90
# Sends a request using async_infer to a
91
91
# model that takes 10s to execute. Issues
@@ -115,13 +115,13 @@ def test_grpc_async_infer(self):
115
115
# Wait until the results is captured via callback
116
116
data_item = user_data ._completed_requests .get ()
117
117
self .assertEqual (type (data_item ), grpcclient .CancelledError )
118
-
118
+
119
119
self ._record_end_time_ms ()
120
120
self ._test_runtime_duration (5000 )
121
121
122
122
def test_grpc_stream_infer (self ):
123
123
# Sends a request using async_stream_infer to a
124
- # model that takes 10s to execute. Issues stream
124
+ # model that takes 10s to execute. Issues stream
125
125
# closure with cancel_requests=True. The client
126
126
# should return with appropriate exception within
127
127
# 5s.
@@ -134,9 +134,7 @@ def test_grpc_stream_infer(self):
134
134
135
135
# The model is configured to take three seconds to send the
136
136
# response. Expect an exception for small timeout values.
137
- triton_client .start_stream (
138
- callback = partial (callback , user_data )
139
- )
137
+ triton_client .start_stream (callback = partial (callback , user_data ))
140
138
self ._record_start_time_ms ()
141
139
for i in range (1 ):
142
140
triton_client .async_stream_infer (
@@ -148,11 +146,10 @@ def test_grpc_stream_infer(self):
148
146
149
147
data_item = user_data ._completed_requests .get ()
150
148
self .assertEqual (type (data_item ), grpcclient .CancelledError )
151
-
149
+
152
150
self ._record_end_time_ms ()
153
151
self ._test_runtime_duration (5000 )
154
152
155
-
156
153
def test_aio_grpc_async_infer (self ):
157
154
# Sends a request using infer of grpc.aio to a
158
155
# model that takes 10s to execute. Issues
@@ -187,7 +184,6 @@ async def test_aio_infer(self):
187
184
self ._record_end_time_ms ()
188
185
self ._test_runtime_duration (5000 )
189
186
190
-
191
187
asyncio .run (test_aio_infer (self ))
192
188
193
189
def test_aio_grpc_stream_infer (self ):
@@ -198,17 +194,23 @@ def test_aio_grpc_stream_infer(self):
198
194
# 5s.
199
195
async def test_aio_streaming_infer (self ):
200
196
async with aiogrpcclient .InferenceServerClient (
201
- url = "localhost:8001" , verbose = True ) as triton_client :
197
+ url = "localhost:8001" , verbose = True
198
+ ) as triton_client :
199
+
202
200
async def async_request_iterator ():
203
201
for i in range (1 ):
204
202
await asyncio .sleep (1 )
205
- yield {"model_name" : self .model_name_ ,
203
+ yield {
204
+ "model_name" : self .model_name_ ,
206
205
"inputs" : self .inputs_ ,
207
- "outputs" : self .outputs_ }
206
+ "outputs" : self .outputs_ ,
207
+ }
208
208
209
209
self ._prepare_request ()
210
210
self ._record_start_time_ms ()
211
- response_iterator = triton_client .stream_infer (inputs_iterator = async_request_iterator (), get_call_obj = True )
211
+ response_iterator = triton_client .stream_infer (
212
+ inputs_iterator = async_request_iterator (), get_call_obj = True
213
+ )
212
214
streaming_call = await response_iterator .__anext__ ()
213
215
214
216
async def cancel_streaming (streaming_call ):
@@ -228,5 +230,6 @@ async def handle_response(response_iterator):
228
230
229
231
asyncio .run (test_aio_streaming_infer (self ))
230
232
233
+
231
234
if __name__ == "__main__" :
232
235
unittest .main ()
0 commit comments