9
9
10
10
11
11
@pytest .mark .asyncio
12
- async def test_batch_queue_timeout ():
12
+ @pytest .mark .parametrize ('flush_all' , [False , True ])
13
+ async def test_batch_queue_timeout (flush_all ):
13
14
async def foo (docs , ** kwargs ):
14
15
await asyncio .sleep (0.1 )
15
16
return DocumentArray ([Document (text = 'Done' ) for _ in docs ])
@@ -20,6 +21,7 @@ async def foo(docs, **kwargs):
20
21
response_docarray_cls = DocumentArray ,
21
22
preferred_batch_size = 4 ,
22
23
timeout = 2000 ,
24
+ flush_all = flush_all ,
23
25
)
24
26
25
27
three_data_requests = [DataRequest () for _ in range (3 )]
@@ -59,7 +61,8 @@ async def process_request(req):
59
61
60
62
61
63
@pytest .mark .asyncio
62
- async def test_batch_queue_timeout_does_not_wait_previous_batch ():
64
+ @pytest .mark .parametrize ('flush_all' , [False , True ])
65
+ async def test_batch_queue_timeout_does_not_wait_previous_batch (flush_all ):
63
66
batches_lengths_computed = []
64
67
65
68
async def foo (docs , ** kwargs ):
@@ -73,6 +76,7 @@ async def foo(docs, **kwargs):
73
76
response_docarray_cls = DocumentArray ,
74
77
preferred_batch_size = 5 ,
75
78
timeout = 3000 ,
79
+ flush_all = flush_all
76
80
)
77
81
78
82
data_requests = [DataRequest () for _ in range (3 )]
@@ -105,7 +109,8 @@ async def process_request(req, sleep=0):
105
109
106
110
107
111
@pytest .mark .asyncio
108
- async def test_batch_queue_req_length_larger_than_preferred ():
112
+ @pytest .mark .parametrize ('flush_all' , [False , True ])
113
+ async def test_batch_queue_req_length_larger_than_preferred (flush_all ):
109
114
async def foo (docs , ** kwargs ):
110
115
await asyncio .sleep (0.1 )
111
116
return DocumentArray ([Document (text = 'Done' ) for _ in docs ])
@@ -116,6 +121,7 @@ async def foo(docs, **kwargs):
116
121
response_docarray_cls = DocumentArray ,
117
122
preferred_batch_size = 4 ,
118
123
timeout = 2000 ,
124
+ flush_all = flush_all ,
119
125
)
120
126
121
127
data_requests = [DataRequest () for _ in range (3 )]
@@ -142,7 +148,8 @@ async def process_request(req):
142
148
143
149
144
150
@pytest .mark .asyncio
145
- async def test_exception ():
151
+ @pytest .mark .parametrize ('flush_all' , [False , True ])
152
+ async def test_exception (flush_all ):
146
153
BAD_REQUEST_IDX = [2 , 6 ]
147
154
148
155
async def foo (docs , ** kwargs ):
@@ -159,6 +166,7 @@ async def foo(docs, **kwargs):
159
166
response_docarray_cls = DocumentArray ,
160
167
preferred_batch_size = 1 ,
161
168
timeout = 500 ,
169
+ flush_all = flush_all ,
162
170
)
163
171
164
172
data_requests = [DataRequest () for _ in range (35 )]
@@ -188,7 +196,8 @@ async def process_request(req):
188
196
189
197
190
198
@pytest .mark .asyncio
191
- async def test_exception_more_complex ():
199
+ @pytest .mark .parametrize ('flush_all' , [False , True ])
200
+ async def test_exception_more_complex (flush_all ):
192
201
TRIGGER_BAD_REQUEST_IDX = [2 , 6 ]
193
202
EXPECTED_BAD_REQUESTS = [2 , 3 , 6 , 7 ]
194
203
@@ -208,6 +217,7 @@ async def foo(docs, **kwargs):
208
217
request_docarray_cls = DocumentArray ,
209
218
response_docarray_cls = DocumentArray ,
210
219
preferred_batch_size = 2 ,
220
+ flush_all = flush_all ,
211
221
timeout = 500 ,
212
222
)
213
223
@@ -240,7 +250,8 @@ async def process_request(req):
240
250
241
251
242
252
@pytest .mark .asyncio
243
- async def test_exception_all ():
253
+ @pytest .mark .parametrize ('flush_all' , [False , True ])
254
+ async def test_exception_all (flush_all ):
244
255
async def foo (docs , ** kwargs ):
245
256
raise AssertionError
246
257
@@ -249,6 +260,7 @@ async def foo(docs, **kwargs):
249
260
request_docarray_cls = DocumentArray ,
250
261
response_docarray_cls = DocumentArray ,
251
262
preferred_batch_size = 2 ,
263
+ flush_all = flush_all ,
252
264
timeout = 500 ,
253
265
)
254
266
@@ -287,8 +299,9 @@ async def foo(docs, **kwargs):
287
299
@pytest .mark .parametrize ('num_requests' , [61 , 127 , 100 ])
288
300
@pytest .mark .parametrize ('preferred_batch_size' , [7 , 27 , 61 , 73 , 100 ])
289
301
@pytest .mark .parametrize ('timeout' , [0.3 , 500 ])
302
+ @pytest .mark .parametrize ('flush_all' , [False , True ])
290
303
@pytest .mark .asyncio
291
- async def test_return_proper_assignment (num_requests , preferred_batch_size , timeout ):
304
+ async def test_return_proper_assignment (num_requests , preferred_batch_size , timeout , flush_all ):
292
305
import random
293
306
294
307
async def foo (docs , ** kwargs ):
@@ -301,6 +314,7 @@ async def foo(docs, **kwargs):
301
314
request_docarray_cls = DocumentArray ,
302
315
response_docarray_cls = DocumentArray ,
303
316
preferred_batch_size = preferred_batch_size ,
317
+ flush_all = flush_all ,
304
318
timeout = timeout ,
305
319
)
306
320
@@ -331,3 +345,58 @@ async def process_request(req):
331
345
assert len (resp .docs ) == length
332
346
for j , d in enumerate (resp .docs ):
333
347
assert d .text == f'Text { j } from request { i } with len { length } Processed'
348
+
349
+
350
+ @pytest .mark .parametrize ('num_requests' , [61 , 127 , 100 ])
351
+ @pytest .mark .parametrize ('preferred_batch_size' , [7 , 27 , 61 , 73 , 100 ])
352
+ @pytest .mark .parametrize ('timeout' , [0.3 , 500 ])
353
+ @pytest .mark .parametrize ('flush_all' , [False , True ])
354
+ @pytest .mark .asyncio
355
+ async def test_length_processed_in_func (num_requests , preferred_batch_size , timeout , flush_all ):
356
+ import random
357
+
358
+ async def foo (docs , ** kwargs ):
359
+ if not flush_all :
360
+ assert len (docs ) <= preferred_batch_size
361
+ else :
362
+ assert len (docs ) >= preferred_batch_size
363
+ await asyncio .sleep (0.1 )
364
+ for doc in docs :
365
+ doc .text += ' Processed'
366
+
367
+ bq : BatchQueue = BatchQueue (
368
+ foo ,
369
+ request_docarray_cls = DocumentArray ,
370
+ response_docarray_cls = DocumentArray ,
371
+ preferred_batch_size = preferred_batch_size ,
372
+ flush_all = flush_all ,
373
+ timeout = timeout ,
374
+ )
375
+
376
+ data_requests = [DataRequest () for _ in range (num_requests )]
377
+ len_requests = []
378
+ for i , req in enumerate (data_requests ):
379
+ len_request = random .randint (preferred_batch_size , preferred_batch_size * 10 )
380
+ len_requests .append (len_request )
381
+ req .data .docs = DocumentArray (
382
+ [
383
+ Document (text = f'Text { j } from request { i } with len { len_request } ' )
384
+ for j in range (len_request )
385
+ ]
386
+ )
387
+
388
+ async def process_request (req ):
389
+ q = await bq .push (req )
390
+ item = await q .get ()
391
+ q .task_done ()
392
+ return item
393
+
394
+ tasks = [asyncio .create_task (process_request (req )) for req in data_requests ]
395
+ items = await asyncio .gather (* tasks )
396
+ for i , item in enumerate (items ):
397
+ assert item is None
398
+
399
+ for i , (resp , length ) in enumerate (zip (data_requests , len_requests )):
400
+ assert len (resp .docs ) == length
401
+ for j , d in enumerate (resp .docs ):
402
+ assert d .text == f'Text { j } from request { i } with len { length } Processed'
0 commit comments