Skip to content

Commit 16a464d

Browse files
author
Joan Martinez
committed
feat: add flush all option
1 parent e0f620d commit 16a464d

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

jina/serve/runtimes/worker/batch_queue.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
response_docarray_cls,
2424
output_array_type: Optional[str] = None,
2525
params: Optional[Dict] = None,
26+
flush_all: bool = False,
2627
preferred_batch_size: int = 4,
2728
timeout: int = 10_000,
2829
) -> None:
@@ -35,6 +36,7 @@ def __init__(
3536
self.params = params
3637
self._request_docarray_cls = request_docarray_cls
3738
self._response_docarray_cls = response_docarray_cls
39+
self._flush_all = flush_all
3840
self._preferred_batch_size: int = preferred_batch_size
3941
self._timeout: int = timeout
4042
self._reset()
@@ -205,7 +207,10 @@ async def _assign_results(
205207

206208
return num_assigned_docs
207209

208-
def batch(iterable_1, iterable_2, n=1):
210+
def batch(iterable_1, iterable_2, n:Optional[int] = 1):
211+
if n is None:
212+
yield iterable_1, iterable_2
213+
return
209214
items = len(iterable_1)
210215
for ndx in range(0, items, n):
211216
yield iterable_1[ndx : min(ndx + n, items)], iterable_2[
@@ -229,7 +234,7 @@ def batch(iterable_1, iterable_2, n=1):
229234
non_assigned_to_response_request_idxs = []
230235
sum_from_previous_first_req_idx = 0
231236
for docs_inner_batch, req_idxs in batch(
232-
self._big_doc, self._request_idxs, self._preferred_batch_size
237+
self._big_doc, self._request_idxs, self._preferred_batch_size if not self._flush_all else None
233238
):
234239
involved_requests_min_indx = req_idxs[0]
235240
involved_requests_max_indx = req_idxs[-1]

jina/serve/runtimes/worker/http_fastapi_app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from jina.serve.networking.sse import EventSourceResponse
88
from jina.types.request.data import DataRequest
99

10+
1011
if TYPE_CHECKING:
1112
from jina.logging.logger import JinaLogger
1213

@@ -88,7 +89,6 @@ def add_post_route(
8889

8990
@app.api_route(**app_kwargs)
9091
async def post(body: input_model, response: Response):
91-
9292
req = DataRequest()
9393
if body.header is not None:
9494
req.header.request_id = body.header.request_id
@@ -122,7 +122,9 @@ async def post(body: input_model, response: Response):
122122
docs_response = resp.docs.to_dict()
123123
else:
124124
docs_response = resp.docs
125+
125126
ret = output_model(data=docs_response, parameters=resp.parameters)
127+
126128
return ret
127129

128130
def add_streaming_routes(

0 commit comments

Comments
 (0)