Skip to content

Commit f1f2ea5

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
feat: Add the initial version of the LlamaIndex agent prebuilt template.
PiperOrigin-RevId: 738627111
1 parent 5da362f commit f1f2ea5

File tree

7 files changed

+867
-0
lines changed

7 files changed

+867
-0
lines changed

noxfile.py

+31
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
UNIT_TEST_PYTHON_VERSIONS = ["3.8", "3.9", "3.10", "3.11", "3.12"]
5555
UNIT_TEST_LANGCHAIN_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12"]
5656
UNIT_TEST_AG2_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12"]
57+
UNIT_TEST_LLAMA_INDEX_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12"]
5758
UNIT_TEST_STANDARD_DEPENDENCIES = [
5859
"mock",
5960
"asyncmock",
@@ -93,6 +94,7 @@
9394
"unit_ray",
9495
"unit_langchain",
9596
"unit_ag2",
97+
"unit_llama_index",
9698
"system",
9799
"cover",
98100
"lint",
@@ -208,6 +210,7 @@ def default(session):
208210
"--ignore=tests/unit/vertex_ray",
209211
"--ignore=tests/unit/vertex_langchain",
210212
"--ignore=tests/unit/vertex_ag2",
213+
"--ignore=tests/unit/vertex_llama_index",
211214
"--ignore=tests/unit/architecture",
212215
os.path.join("tests", "unit"),
213216
*session.posargs,
@@ -331,6 +334,34 @@ def unit_ag2(session):
331334
)
332335

333336

337+
@nox.session(python=UNIT_TEST_LLAMA_INDEX_PYTHON_VERSIONS)
338+
def unit_llama_index(session):
339+
# Install all test dependencies, then install this package in-place.
340+
341+
constraints_path = str(
342+
CURRENT_DIRECTORY / "testing" / "constraints-llama-index.txt"
343+
)
344+
standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES
345+
session.install(*standard_deps, "-c", constraints_path)
346+
347+
# Install llama_index extras
348+
session.install("-e", ".[llama_index_testing]", "-c", constraints_path)
349+
350+
# Run py.test against the unit tests.
351+
session.run(
352+
"py.test",
353+
"--quiet",
354+
"--junitxml=unit_llama_index_sponge_log.xml",
355+
"--cov=google",
356+
"--cov-append",
357+
"--cov-config=.coveragerc",
358+
"--cov-report=",
359+
"--cov-fail-under=0",
360+
os.path.join("tests", "unit", "vertex_llama_index"),
361+
*session.posargs,
362+
)
363+
364+
334365
def install_systemtest_dependencies(session, *constraints):
335366
# Use pre-release gRPC for system tests.
336367
# Exclude version 1.52.0rc1 which has a known issue.

setup.py

+16
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,20 @@
193193
)
194194
)
195195

196+
llama_index_extra_require = [
197+
"llama-index",
198+
"llama-index-llms-google-genai",
199+
"openinference-instrumentation-llama-index >= 3.0, < 4.0",
200+
]
201+
202+
llama_index_testing_extra_require = list(
203+
set(
204+
llama_index_extra_require
205+
+ reasoning_engine_extra_require
206+
+ ["absl-py", "pytest-xdist"]
207+
)
208+
)
209+
196210
tokenization_extra_require = ["sentencepiece >= 0.2.0"]
197211
tokenization_testing_extra_require = tokenization_extra_require + ["nltk"]
198212

@@ -309,6 +323,8 @@
309323
"tokenization": tokenization_extra_require,
310324
"ag2": ag2_extra_require,
311325
"ag2_testing": ag2_testing_extra_require,
326+
"llama_index": llama_index_extra_require,
327+
"llama_index_testing": llama_index_testing_extra_require,
312328
},
313329
python_requires=">=3.8",
314330
classifiers=[

testing/constraints-llama-index.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pydantic<2.10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
import importlib
16+
from unittest import mock
17+
import json
18+
19+
from google import auth
20+
import vertexai
21+
from google.cloud.aiplatform import initializer
22+
from vertexai.preview.reasoning_engines.templates import llama_index
23+
from vertexai.reasoning_engines import _utils
24+
import pytest
25+
26+
from llama_index.core import prompts
27+
from llama_index.core.base.llms import types
28+
29+
_TEST_LOCATION = "us-central1"
30+
_TEST_PROJECT = "test-project"
31+
_TEST_MODEL = "gemini-1.0-pro"
32+
_TEST_SYSTEM_INSTRUCTION = "You are a helpful bot."
33+
34+
35+
@pytest.fixture(scope="module")
36+
def google_auth_mock():
37+
with mock.patch.object(auth, "default") as google_auth_mock:
38+
credentials_mock = mock.Mock()
39+
credentials_mock.with_quota_project.return_value = None
40+
google_auth_mock.return_value = (
41+
credentials_mock,
42+
_TEST_PROJECT,
43+
)
44+
yield google_auth_mock
45+
46+
47+
@pytest.fixture
48+
def vertexai_init_mock():
49+
with mock.patch.object(vertexai, "init") as vertexai_init_mock:
50+
yield vertexai_init_mock
51+
52+
53+
@pytest.fixture
54+
def json_loads_mock():
55+
with mock.patch.object(json, "loads") as json_loads_mock:
56+
yield json_loads_mock
57+
58+
59+
@pytest.fixture
60+
def model_builder_mock():
61+
with mock.patch.object(
62+
llama_index,
63+
"_default_model_builder",
64+
) as model_builder_mock:
65+
yield model_builder_mock
66+
67+
68+
@pytest.fixture
69+
def cloud_trace_exporter_mock():
70+
with mock.patch.object(
71+
_utils,
72+
"_import_cloud_trace_exporter_or_warn",
73+
) as cloud_trace_exporter_mock:
74+
yield cloud_trace_exporter_mock
75+
76+
77+
@pytest.fixture
78+
def tracer_provider_mock():
79+
with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock:
80+
yield tracer_provider_mock
81+
82+
83+
@pytest.fixture
84+
def simple_span_processor_mock():
85+
with mock.patch(
86+
"opentelemetry.sdk.trace.export.SimpleSpanProcessor"
87+
) as simple_span_processor_mock:
88+
yield simple_span_processor_mock
89+
90+
91+
@pytest.fixture
92+
def llama_index_instrumentor_mock():
93+
with mock.patch.object(
94+
_utils,
95+
"_import_openinference_llama_index_or_warn",
96+
) as llama_index_instrumentor_mock:
97+
yield llama_index_instrumentor_mock
98+
99+
100+
@pytest.fixture
101+
def llama_index_instrumentor_none_mock():
102+
with mock.patch.object(
103+
_utils,
104+
"_import_openinference_llama_index_or_warn",
105+
) as llama_index_instrumentor_mock:
106+
llama_index_instrumentor_mock.return_value = None
107+
yield llama_index_instrumentor_mock
108+
109+
110+
@pytest.mark.usefixtures("google_auth_mock")
111+
class TestLlamaIndexQueryPipelineAgent:
112+
def setup_method(self):
113+
importlib.reload(initializer)
114+
importlib.reload(vertexai)
115+
vertexai.init(
116+
project=_TEST_PROJECT,
117+
location=_TEST_LOCATION,
118+
)
119+
self.prompt = prompts.ChatPromptTemplate(
120+
message_templates=[
121+
types.ChatMessage(
122+
role=types.MessageRole.SYSTEM,
123+
content=_TEST_SYSTEM_INSTRUCTION,
124+
),
125+
types.ChatMessage(
126+
role=types.MessageRole.USER,
127+
content="{input}",
128+
),
129+
],
130+
)
131+
132+
def teardown_method(self):
133+
initializer.global_pool.shutdown(wait=True)
134+
135+
def test_initialization(self):
136+
agent = llama_index.LlamaIndexQueryPipelineAgent(model=_TEST_MODEL)
137+
assert agent._model_name == _TEST_MODEL
138+
assert agent._project == _TEST_PROJECT
139+
assert agent._location == _TEST_LOCATION
140+
assert agent._runnable is None
141+
142+
def test_set_up(self):
143+
agent = llama_index.LlamaIndexQueryPipelineAgent(
144+
model=_TEST_MODEL,
145+
prompt=self.prompt,
146+
model_builder=lambda **kwargs: kwargs,
147+
runnable_builder=lambda **kwargs: kwargs,
148+
)
149+
assert agent._runnable is None
150+
agent.set_up()
151+
assert agent._runnable is not None
152+
153+
def test_clone(self):
154+
agent = llama_index.LlamaIndexQueryPipelineAgent(
155+
model=_TEST_MODEL,
156+
prompt=self.prompt,
157+
model_builder=lambda **kwargs: kwargs,
158+
runnable_builder=lambda **kwargs: kwargs,
159+
)
160+
agent.set_up()
161+
assert agent._runnable is not None
162+
agent_clone = agent.clone()
163+
assert agent._runnable is not None
164+
assert agent_clone._runnable is None
165+
agent_clone.set_up()
166+
assert agent_clone._runnable is not None
167+
168+
def test_query(self, json_loads_mock):
169+
agent = llama_index.LlamaIndexQueryPipelineAgent(
170+
model=_TEST_MODEL,
171+
prompt=self.prompt,
172+
)
173+
agent._runnable = mock.Mock()
174+
mocks = mock.Mock()
175+
mocks.attach_mock(mock=agent._runnable, attribute="run")
176+
agent.query(input="test query")
177+
mocks.assert_has_calls([mock.call.run.run(input="test query")])
178+
179+
def test_query_with_kwargs(self, json_loads_mock):
180+
agent = llama_index.LlamaIndexQueryPipelineAgent(
181+
model=_TEST_MODEL,
182+
prompt=self.prompt,
183+
)
184+
agent._runnable = mock.Mock()
185+
mocks = mock.Mock()
186+
mocks.attach_mock(mock=agent._runnable, attribute="run")
187+
agent.query(input="test query", test_arg=123)
188+
mocks.assert_has_calls([mock.call.run.run(input="test query", test_arg=123)])
189+
190+
def test_query_with_kwargs_and_input_dict(self, json_loads_mock):
191+
agent = llama_index.LlamaIndexQueryPipelineAgent(
192+
model=_TEST_MODEL,
193+
prompt=self.prompt,
194+
)
195+
agent._runnable = mock.Mock()
196+
mocks = mock.Mock()
197+
mocks.attach_mock(mock=agent._runnable, attribute="run")
198+
agent.query(input={"input": "test query"})
199+
mocks.assert_has_calls([mock.call.run.run(input="test query")])
200+
201+
@pytest.mark.usefixtures("caplog")
202+
def test_enable_tracing(
203+
self,
204+
caplog,
205+
cloud_trace_exporter_mock,
206+
tracer_provider_mock,
207+
simple_span_processor_mock,
208+
llama_index_instrumentor_mock,
209+
):
210+
agent = llama_index.LlamaIndexQueryPipelineAgent(
211+
model=_TEST_MODEL,
212+
prompt=self.prompt,
213+
enable_tracing=True,
214+
)
215+
assert agent._instrumentor is None
216+
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
217+
# agent.set_up()
218+
# assert agent._instrumentor is not None
219+
# assert (
220+
# "enable_tracing=True but proceeding with tracing disabled"
221+
# not in caplog.text
222+
# )
223+
224+
@pytest.mark.usefixtures("caplog")
225+
def test_enable_tracing_warning(self, caplog, llama_index_instrumentor_none_mock):
226+
agent = llama_index.LlamaIndexQueryPipelineAgent(
227+
model=_TEST_MODEL,
228+
prompt=self.prompt,
229+
enable_tracing=True,
230+
)
231+
assert agent._instrumentor is None
232+
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
233+
# agent.set_up()
234+
# assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text

vertexai/preview/reasoning_engines/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@
2929
from vertexai.preview.reasoning_engines.templates.langgraph import (
3030
LanggraphAgent,
3131
)
32+
from vertexai.preview.reasoning_engines.templates.llama_index import (
33+
LlamaIndexQueryPipelineAgent,
34+
)
3235

3336
__all__ = (
3437
"AG2Agent",
3538
"LangchainAgent",
3639
"LanggraphAgent",
40+
"LlamaIndexQueryPipelineAgent",
3741
"Queryable",
3842
"ReasoningEngine",
3943
)

0 commit comments

Comments
 (0)