Skip to content

Commit 87c6909

Browse files
author
Joan Martinez
committed
test: test batch queue flush all
1 parent 16a464d commit 87c6909

File tree

1 file changed

+76
-7
lines changed

1 file changed

+76
-7
lines changed

tests/unit/serve/dynamic_batching/test_batch_queue.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010

1111
@pytest.mark.asyncio
12-
async def test_batch_queue_timeout():
12+
@pytest.mark.parametrize('flush_all', [False, True])
13+
async def test_batch_queue_timeout(flush_all):
1314
async def foo(docs, **kwargs):
1415
await asyncio.sleep(0.1)
1516
return DocumentArray([Document(text='Done') for _ in docs])
@@ -20,6 +21,7 @@ async def foo(docs, **kwargs):
2021
response_docarray_cls=DocumentArray,
2122
preferred_batch_size=4,
2223
timeout=2000,
24+
flush_all=flush_all,
2325
)
2426

2527
three_data_requests = [DataRequest() for _ in range(3)]
@@ -59,7 +61,8 @@ async def process_request(req):
5961

6062

6163
@pytest.mark.asyncio
62-
async def test_batch_queue_timeout_does_not_wait_previous_batch():
64+
@pytest.mark.parametrize('flush_all', [False, True])
65+
async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all):
6366
batches_lengths_computed = []
6467

6568
async def foo(docs, **kwargs):
@@ -73,6 +76,7 @@ async def foo(docs, **kwargs):
7376
response_docarray_cls=DocumentArray,
7477
preferred_batch_size=5,
7578
timeout=3000,
79+
flush_all=flush_all
7680
)
7781

7882
data_requests = [DataRequest() for _ in range(3)]
@@ -105,7 +109,8 @@ async def process_request(req, sleep=0):
105109

106110

107111
@pytest.mark.asyncio
108-
async def test_batch_queue_req_length_larger_than_preferred():
112+
@pytest.mark.parametrize('flush_all', [False, True])
113+
async def test_batch_queue_req_length_larger_than_preferred(flush_all):
109114
async def foo(docs, **kwargs):
110115
await asyncio.sleep(0.1)
111116
return DocumentArray([Document(text='Done') for _ in docs])
@@ -116,6 +121,7 @@ async def foo(docs, **kwargs):
116121
response_docarray_cls=DocumentArray,
117122
preferred_batch_size=4,
118123
timeout=2000,
124+
flush_all=flush_all,
119125
)
120126

121127
data_requests = [DataRequest() for _ in range(3)]
@@ -142,7 +148,8 @@ async def process_request(req):
142148

143149

144150
@pytest.mark.asyncio
145-
async def test_exception():
151+
@pytest.mark.parametrize('flush_all', [False, True])
152+
async def test_exception(flush_all):
146153
BAD_REQUEST_IDX = [2, 6]
147154

148155
async def foo(docs, **kwargs):
@@ -159,6 +166,7 @@ async def foo(docs, **kwargs):
159166
response_docarray_cls=DocumentArray,
160167
preferred_batch_size=1,
161168
timeout=500,
169+
flush_all=flush_all,
162170
)
163171

164172
data_requests = [DataRequest() for _ in range(35)]
@@ -188,7 +196,8 @@ async def process_request(req):
188196

189197

190198
@pytest.mark.asyncio
191-
async def test_exception_more_complex():
199+
@pytest.mark.parametrize('flush_all', [False, True])
200+
async def test_exception_more_complex(flush_all):
192201
TRIGGER_BAD_REQUEST_IDX = [2, 6]
193202
EXPECTED_BAD_REQUESTS = [2, 3, 6, 7]
194203

@@ -208,6 +217,7 @@ async def foo(docs, **kwargs):
208217
request_docarray_cls=DocumentArray,
209218
response_docarray_cls=DocumentArray,
210219
preferred_batch_size=2,
220+
flush_all=flush_all,
211221
timeout=500,
212222
)
213223

@@ -240,7 +250,8 @@ async def process_request(req):
240250

241251

242252
@pytest.mark.asyncio
243-
async def test_exception_all():
253+
@pytest.mark.parametrize('flush_all', [False, True])
254+
async def test_exception_all(flush_all):
244255
async def foo(docs, **kwargs):
245256
raise AssertionError
246257

@@ -249,6 +260,7 @@ async def foo(docs, **kwargs):
249260
request_docarray_cls=DocumentArray,
250261
response_docarray_cls=DocumentArray,
251262
preferred_batch_size=2,
263+
flush_all=flush_all,
252264
timeout=500,
253265
)
254266

@@ -287,8 +299,9 @@ async def foo(docs, **kwargs):
287299
@pytest.mark.parametrize('num_requests', [61, 127, 100])
288300
@pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100])
289301
@pytest.mark.parametrize('timeout', [0.3, 500])
302+
@pytest.mark.parametrize('flush_all', [False, True])
290303
@pytest.mark.asyncio
291-
async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout):
304+
async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all):
292305
import random
293306

294307
async def foo(docs, **kwargs):
@@ -301,6 +314,7 @@ async def foo(docs, **kwargs):
301314
request_docarray_cls=DocumentArray,
302315
response_docarray_cls=DocumentArray,
303316
preferred_batch_size=preferred_batch_size,
317+
flush_all=flush_all,
304318
timeout=timeout,
305319
)
306320

@@ -331,3 +345,58 @@ async def process_request(req):
331345
assert len(resp.docs) == length
332346
for j, d in enumerate(resp.docs):
333347
assert d.text == f'Text {j} from request {i} with len {length} Processed'
348+
349+
350+
@pytest.mark.parametrize('num_requests', [61, 127, 100])
351+
@pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100])
352+
@pytest.mark.parametrize('timeout', [0.3, 500])
353+
@pytest.mark.parametrize('flush_all', [False, True])
354+
@pytest.mark.asyncio
355+
async def test_length_processed_in_func(num_requests, preferred_batch_size, timeout, flush_all):
356+
import random
357+
358+
async def foo(docs, **kwargs):
359+
if not flush_all:
360+
assert len(docs) <= preferred_batch_size
361+
else:
362+
assert len(docs) >= preferred_batch_size
363+
await asyncio.sleep(0.1)
364+
for doc in docs:
365+
doc.text += ' Processed'
366+
367+
bq: BatchQueue = BatchQueue(
368+
foo,
369+
request_docarray_cls=DocumentArray,
370+
response_docarray_cls=DocumentArray,
371+
preferred_batch_size=preferred_batch_size,
372+
flush_all=flush_all,
373+
timeout=timeout,
374+
)
375+
376+
data_requests = [DataRequest() for _ in range(num_requests)]
377+
len_requests = []
378+
for i, req in enumerate(data_requests):
379+
len_request = random.randint(preferred_batch_size, preferred_batch_size * 10)
380+
len_requests.append(len_request)
381+
req.data.docs = DocumentArray(
382+
[
383+
Document(text=f'Text {j} from request {i} with len {len_request}')
384+
for j in range(len_request)
385+
]
386+
)
387+
388+
async def process_request(req):
389+
q = await bq.push(req)
390+
item = await q.get()
391+
q.task_done()
392+
return item
393+
394+
tasks = [asyncio.create_task(process_request(req)) for req in data_requests]
395+
items = await asyncio.gather(*tasks)
396+
for i, item in enumerate(items):
397+
assert item is None
398+
399+
for i, (resp, length) in enumerate(zip(data_requests, len_requests)):
400+
assert len(resp.docs) == length
401+
for j, d in enumerate(resp.docs):
402+
assert d.text == f'Text {j} from request {i} with len {length} Processed'

0 commit comments

Comments
 (0)