Skip to content

Commit cf7284f

Browse files
author
Joan Fontanals
authored
feat: add flush all option to dynamic batching configuration (#6179)
1 parent e0f620d commit cf7284f

File tree

6 files changed

+114
-29
lines changed

6 files changed

+114
-29
lines changed

jina/serve/executors/decorators.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def dynamic_batching(
416416
*,
417417
preferred_batch_size: Optional[int] = None,
418418
timeout: Optional[float] = 10_000,
419+
flush_all: bool = False
419420
):
420421
"""
421422
`@dynamic_batching` defines the dynamic batching behavior of an Executor.
@@ -426,11 +427,13 @@ def dynamic_batching(
426427
427428
:param func: the method to decorate
428429
:param preferred_batch_size: target number of Documents in a batch. The batcher will collect requests until `preferred_batch_size` is reached,
429-
or until `timeout` is reached. Therefore, the actual batch size can be smaller or larger than `preferred_batch_size`.
430+
or until `timeout` is reached. Therefore, the actual batch size can be smaller or equal to `preferred_batch_size`, except if `flush_all` is set to True
430431
:param timeout: maximum time in milliseconds to wait for a request to be assigned to a batch.
431432
If the oldest request in the queue reaches a waiting time of `timeout`, the batch will be passed to the Executor,
432433
even if it contains fewer than `preferred_batch_size` Documents.
433434
Default is 10_000ms (10 seconds).
435+
:param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not.
436+
If this is true, `preferred_batch_size` is used as a trigger mechanism.
434437
:return: decorated function
435438
"""
436439

@@ -476,6 +479,7 @@ def _inject_owner_attrs(self, owner, name):
476479
'preferred_batch_size'
477480
] = preferred_batch_size
478481
owner.dynamic_batching[fn_name]['timeout'] = timeout
482+
owner.dynamic_batching[fn_name]['flush_all'] = flush_all
479483
setattr(owner, name, self.fn)
480484

481485
def __set_name__(self, owner, name):

jina/serve/runtimes/worker/batch_queue.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
response_docarray_cls,
2424
output_array_type: Optional[str] = None,
2525
params: Optional[Dict] = None,
26+
flush_all: bool = False,
2627
preferred_batch_size: int = 4,
2728
timeout: int = 10_000,
2829
) -> None:
@@ -35,6 +36,7 @@ def __init__(
3536
self.params = params
3637
self._request_docarray_cls = request_docarray_cls
3738
self._response_docarray_cls = response_docarray_cls
39+
self._flush_all = flush_all
3840
self._preferred_batch_size: int = preferred_batch_size
3941
self._timeout: int = timeout
4042
self._reset()
@@ -205,7 +207,10 @@ async def _assign_results(
205207

206208
return num_assigned_docs
207209

208-
def batch(iterable_1, iterable_2, n=1):
210+
def batch(iterable_1, iterable_2, n:Optional[int] = 1):
211+
if n is None:
212+
yield iterable_1, iterable_2
213+
return
209214
items = len(iterable_1)
210215
for ndx in range(0, items, n):
211216
yield iterable_1[ndx : min(ndx + n, items)], iterable_2[
@@ -229,7 +234,7 @@ def batch(iterable_1, iterable_2, n=1):
229234
non_assigned_to_response_request_idxs = []
230235
sum_from_previous_first_req_idx = 0
231236
for docs_inner_batch, req_idxs in batch(
232-
self._big_doc, self._request_idxs, self._preferred_batch_size
237+
self._big_doc, self._request_idxs, self._preferred_batch_size if not self._flush_all else None
233238
):
234239
involved_requests_min_indx = req_idxs[0]
235240
involved_requests_max_indx = req_idxs[-1]

jina/serve/runtimes/worker/http_fastapi_app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from jina.serve.networking.sse import EventSourceResponse
88
from jina.types.request.data import DataRequest
99

10+
1011
if TYPE_CHECKING:
1112
from jina.logging.logger import JinaLogger
1213

@@ -88,7 +89,6 @@ def add_post_route(
8889

8990
@app.api_route(**app_kwargs)
9091
async def post(body: input_model, response: Response):
91-
9292
req = DataRequest()
9393
if body.header is not None:
9494
req.header.request_id = body.header.request_id
@@ -122,7 +122,9 @@ async def post(body: input_model, response: Response):
122122
docs_response = resp.docs.to_dict()
123123
else:
124124
docs_response = resp.docs
125+
125126
ret = output_model(data=docs_response, parameters=resp.parameters)
127+
126128
return ret
127129

128130
def add_streaming_routes(

tests/integration/dynamic_batching/test_dynamic_batching.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,11 +629,17 @@ def test_failure_propagation():
629629
)
630630

631631

632-
@pytest.mark.repeat(10)
633-
def test_exception_handling_in_dynamic_batch():
632+
@pytest.mark.parametrize(
633+
'flush_all',
634+
[
635+
False,
636+
True
637+
],
638+
)
639+
def test_exception_handling_in_dynamic_batch(flush_all):
634640
class SlowExecutorWithException(Executor):
635641

636-
@dynamic_batching(preferred_batch_size=3, timeout=1000)
642+
@dynamic_batching(preferred_batch_size=3, timeout=5000, flush_all=flush_all)
637643
@requests(on='/foo')
638644
def foo(self, docs, **kwargs):
639645
for doc in docs:
@@ -659,4 +665,50 @@ def foo(self, docs, **kwargs):
659665
if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR:
660666
num_failed_requests += 1
661667

662-
assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing
668+
if not flush_all:
669+
assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing
670+
else:
671+
assert 1 <= num_failed_requests <= len(da) # 3 requests in the dynamic batch failing
672+
673+
@pytest.mark.asyncio
674+
@pytest.mark.parametrize(
675+
'flush_all',
676+
[
677+
False,
678+
True
679+
],
680+
)
681+
async def test_num_docs_processed_in_exec(flush_all):
682+
class DynBatchProcessor(Executor):
683+
684+
@dynamic_batching(preferred_batch_size=5, timeout=5000, flush_all=flush_all)
685+
@requests(on='/foo')
686+
def foo(self, docs, **kwargs):
687+
for doc in docs:
688+
doc.text = f"{len(docs)}"
689+
690+
depl = Deployment(uses=DynBatchProcessor, protocol='http')
691+
692+
with depl:
693+
da = DocumentArray([Document(text='good') for _ in range(50)])
694+
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
695+
res = []
696+
async for r in cl.post(
697+
on='/foo',
698+
inputs=da,
699+
request_size=7,
700+
continue_on_error=True,
701+
results_in_order=True,
702+
):
703+
res.extend(r)
704+
assert len(res) == 50 # 1 request per input
705+
if not flush_all:
706+
for d in res:
707+
assert int(d.text) <= 5
708+
else:
709+
larger_than_5 = 0
710+
for d in res:
711+
if int(d.text) > 5:
712+
larger_than_5 += 1
713+
assert int(d.text) >= 5
714+
assert larger_than_5 > 0

tests/unit/serve/dynamic_batching/test_batch_queue.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010

1111
@pytest.mark.asyncio
12-
async def test_batch_queue_timeout():
12+
@pytest.mark.parametrize('flush_all', [False, True])
13+
async def test_batch_queue_timeout(flush_all):
1314
async def foo(docs, **kwargs):
1415
await asyncio.sleep(0.1)
1516
return DocumentArray([Document(text='Done') for _ in docs])
@@ -20,6 +21,7 @@ async def foo(docs, **kwargs):
2021
response_docarray_cls=DocumentArray,
2122
preferred_batch_size=4,
2223
timeout=2000,
24+
flush_all=flush_all,
2325
)
2426

2527
three_data_requests = [DataRequest() for _ in range(3)]
@@ -59,7 +61,8 @@ async def process_request(req):
5961

6062

6163
@pytest.mark.asyncio
62-
async def test_batch_queue_timeout_does_not_wait_previous_batch():
64+
@pytest.mark.parametrize('flush_all', [False, True])
65+
async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all):
6366
batches_lengths_computed = []
6467

6568
async def foo(docs, **kwargs):
@@ -73,6 +76,7 @@ async def foo(docs, **kwargs):
7376
response_docarray_cls=DocumentArray,
7477
preferred_batch_size=5,
7578
timeout=3000,
79+
flush_all=flush_all
7680
)
7781

7882
data_requests = [DataRequest() for _ in range(3)]
@@ -93,19 +97,28 @@ async def process_request(req, sleep=0):
9397
init_time = time.time()
9498
tasks = [asyncio.create_task(process_request(req)) for req in data_requests]
9599
tasks.append(asyncio.create_task(process_request(extra_data_request, sleep=2)))
96-
responses = await asyncio.gather(*tasks)
100+
_ = await asyncio.gather(*tasks)
97101
time_spent = (time.time() - init_time) * 1000
98-
# TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately
99-
# BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late)
100-
assert time_spent >= 12000
101-
assert time_spent <= 12500
102-
assert batches_lengths_computed == [5, 1, 2]
102+
103+
if flush_all is False:
104+
# TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately
105+
# BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late)
106+
assert time_spent >= 12000
107+
assert time_spent <= 12500
108+
else:
109+
assert time_spent >= 8000
110+
assert time_spent <= 8500
111+
if flush_all is False:
112+
assert batches_lengths_computed == [5, 1, 2]
113+
else:
114+
assert batches_lengths_computed == [6, 2]
103115

104116
await bq.close()
105117

106118

107119
@pytest.mark.asyncio
108-
async def test_batch_queue_req_length_larger_than_preferred():
120+
@pytest.mark.parametrize('flush_all', [False, True])
121+
async def test_batch_queue_req_length_larger_than_preferred(flush_all):
109122
async def foo(docs, **kwargs):
110123
await asyncio.sleep(0.1)
111124
return DocumentArray([Document(text='Done') for _ in docs])
@@ -116,6 +129,7 @@ async def foo(docs, **kwargs):
116129
response_docarray_cls=DocumentArray,
117130
preferred_batch_size=4,
118131
timeout=2000,
132+
flush_all=flush_all,
119133
)
120134

121135
data_requests = [DataRequest() for _ in range(3)]
@@ -240,7 +254,8 @@ async def process_request(req):
240254

241255

242256
@pytest.mark.asyncio
243-
async def test_exception_all():
257+
@pytest.mark.parametrize('flush_all', [False, True])
258+
async def test_exception_all(flush_all):
244259
async def foo(docs, **kwargs):
245260
raise AssertionError
246261

@@ -249,6 +264,7 @@ async def foo(docs, **kwargs):
249264
request_docarray_cls=DocumentArray,
250265
response_docarray_cls=DocumentArray,
251266
preferred_batch_size=2,
267+
flush_all=flush_all,
252268
timeout=500,
253269
)
254270

@@ -284,14 +300,19 @@ async def foo(docs, **kwargs):
284300
assert repr(bq) == str(bq)
285301

286302

287-
@pytest.mark.parametrize('num_requests', [61, 127, 100])
288-
@pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100])
303+
@pytest.mark.parametrize('num_requests', [33, 127, 100])
304+
@pytest.mark.parametrize('preferred_batch_size', [7, 61, 100])
289305
@pytest.mark.parametrize('timeout', [0.3, 500])
306+
@pytest.mark.parametrize('flush_all', [False, True])
290307
@pytest.mark.asyncio
291-
async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout):
308+
async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all):
292309
import random
293310

294311
async def foo(docs, **kwargs):
312+
if not flush_all:
313+
assert len(docs) <= preferred_batch_size
314+
else:
315+
assert len(docs) >= preferred_batch_size
295316
await asyncio.sleep(0.1)
296317
for doc in docs:
297318
doc.text += ' Processed'
@@ -301,6 +322,7 @@ async def foo(docs, **kwargs):
301322
request_docarray_cls=DocumentArray,
302323
response_docarray_cls=DocumentArray,
303324
preferred_batch_size=preferred_batch_size,
325+
flush_all=flush_all,
304326
timeout=timeout,
305327
)
306328

tests/unit/serve/executors/test_executor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -614,15 +614,15 @@ class C(B):
614614
[
615615
(
616616
dict(preferred_batch_size=4, timeout=5_000),
617-
dict(preferred_batch_size=4, timeout=5_000),
617+
dict(preferred_batch_size=4, timeout=5_000, flush_all=False),
618618
),
619619
(
620-
dict(preferred_batch_size=4, timeout=5_000),
621-
dict(preferred_batch_size=4, timeout=5_000),
620+
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
621+
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
622622
),
623623
(
624624
dict(preferred_batch_size=4),
625-
dict(preferred_batch_size=4, timeout=10_000),
625+
dict(preferred_batch_size=4, timeout=10_000, flush_all=False),
626626
),
627627
],
628628
)
@@ -641,15 +641,15 @@ def foo(self, docs, **kwargs):
641641
[
642642
(
643643
dict(preferred_batch_size=4, timeout=5_000),
644-
dict(preferred_batch_size=4, timeout=5_000),
644+
dict(preferred_batch_size=4, timeout=5_000, flush_all=False),
645645
),
646646
(
647-
dict(preferred_batch_size=4, timeout=5_000),
648-
dict(preferred_batch_size=4, timeout=5_000),
647+
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
648+
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
649649
),
650650
(
651651
dict(preferred_batch_size=4),
652-
dict(preferred_batch_size=4, timeout=10_000),
652+
dict(preferred_batch_size=4, timeout=10_000, flush_all=False),
653653
),
654654
],
655655
)

0 commit comments

Comments
 (0)