Skip to content

Commit 14622d6

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
chore: Support the batch mode in Llama Index Query Pipeline.
PiperOrigin-RevId: 742454899
1 parent 898109d commit 14622d6

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

tests/unit/vertex_llama_index/test_reasoning_engine_templates_llama_index.py

+23
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ def llama_index_instrumentor_none_mock():
108108
yield llama_index_instrumentor_mock
109109

110110

111+
@pytest.fixture
112+
def nest_asyncio_apply_mock():
113+
with mock.patch.object(
114+
_utils,
115+
"_import_nest_asyncio_or_warn",
116+
) as nest_asyncio_apply_mock:
117+
yield nest_asyncio_apply_mock
118+
119+
111120
@pytest.mark.usefixtures("google_auth_mock")
112121
class TestLlamaIndexQueryPipelineAgent:
113122
def setup_method(self):
@@ -199,6 +208,20 @@ def test_query_with_kwargs_and_input_dict(self, json_loads_mock):
199208
agent.query(input={"input": "test query"})
200209
mocks.assert_has_calls([mock.call.run.run(input="test query")])
201210

211+
def test_query_with_batch_input(self, json_loads_mock, nest_asyncio_apply_mock):
212+
agent = llama_index.LlamaIndexQueryPipelineAgent(
213+
model=_TEST_MODEL,
214+
prompt=self.prompt,
215+
)
216+
agent._runnable = mock.Mock()
217+
mocks = mock.Mock()
218+
mocks.attach_mock(mock=agent._runnable, attribute="run")
219+
agent.query(input={"input": ["test query 1", "test query 2"]}, batch=True)
220+
mocks.assert_has_calls(
221+
[mock.call.run.run(input=["test query 1", "test query 2"], batch=True)]
222+
)
223+
nest_asyncio_apply_mock.assert_called_once()
224+
202225
@pytest.mark.usefixtures("caplog")
203226
def test_enable_tracing(
204227
self,

vertexai/preview/reasoning_engines/templates/llama_index.py

+4
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,10 @@ def query(
544544
if not self._runnable:
545545
self.set_up()
546546

547+
if kwargs.get("batch"):
548+
nest_asyncio = _utils._import_nest_asyncio_or_warn()
549+
nest_asyncio.apply()
550+
547551
return _utils.to_json_serializable_llama_index_object(
548552
self._runnable.run(**input, **kwargs)
549553
)

vertexai/reasoning_engines/_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -472,3 +472,16 @@ def _import_autogen_tools_or_warn() -> Optional[types.ModuleType]:
472472
"autogen.tools is not installed. Please call: `pip install ag2[tools]`"
473473
)
474474
return None
475+
476+
477+
def _import_nest_asyncio_or_warn() -> Optional[types.ModuleType]:
478+
"""Tries to import the nest_asyncio module."""
479+
try:
480+
import nest_asyncio
481+
482+
return nest_asyncio
483+
except ImportError:
484+
_LOGGER.warning(
485+
"nest_asyncio is not installed. Please call: `pip install nest-asyncio`"
486+
)
487+
return None

0 commit comments

Comments
 (0)