Skip to content

Commit 0478f10

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add new template for AgentEngine
PiperOrigin-RevId: 745364132
1 parent 126d10c commit 0478f10

15 files changed

+1083
-41
lines changed

noxfile.py

+1
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def default(session):
212212
"--cov-report=",
213213
"--cov-fail-under=0",
214214
"--ignore=tests/unit/vertex_ray",
215+
"--ignore=tests/unit/vertex_adk",
215216
"--ignore=tests/unit/vertex_langchain",
216217
"--ignore=tests/unit/vertex_ag2",
217218
"--ignore=tests/unit/vertex_llama_index",

setup.py

+5
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@
142142
"xgboost_ray",
143143
]
144144

145+
adk_extra_require = [
146+
"google-adk >= 0.0.2",
147+
]
148+
145149
reasoning_engine_extra_require = [
146150
"cloudpickle >= 3.0, < 4.0",
147151
"google-cloud-trace < 2",
@@ -320,6 +324,7 @@
320324
"preview": preview_extra_require,
321325
"ray": ray_extra_require,
322326
"ray_testing": ray_testing_extra_require,
327+
"adk": adk_extra_require,
323328
"reasoningengine": reasoning_engine_extra_require,
324329
"agent_engines": agent_engines_extra_require,
325330
"evaluation": evaluation_extra_require,

testing/constraints-3.10.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ ray==2.4.0 # Pinned until 2.9.3 is verified for Ray tests
1515
ipython==8.22.2 # Pinned to unbreak TypeAliasType import error
1616
scikit-learn!=1.4.1.post1 # Pin to unbreak test_sklearn (b/332610038)
1717
requests==2.31.0 # Pinned to unbreak http+docker error (b/342669351)
18-
google-vizier==0.1.21
18+
google-vizier==0.1.21
19+
google-adk==0.0.2

testing/constraints-3.11.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ packaging==24.1 # Increased to unbreak canonicalize_version error (b/377774673)
1111
pytest-xdist==3.3.1 # Pinned to unbreak unit tests
1212
ray==2.5.0 # Pinned until 2.9.3 is verified for Ray tests
1313
ipython==8.22.2 # Pinned to unbreak TypeAliasType import error
14-
google-vizier==0.1.21
14+
google-vizier==0.1.21
15+
google-adk==0.0.2

testing/constraints-3.12.txt

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ packaging==24.1 # Increased to unbreak canonicalize_version error (b/377774673)
1111
pytest-xdist==3.3.1 # Pinned to unbreak unit tests
1212
ray==2.5.0 # Pinned until 2.9.3 is verified for Ray tests
1313
ipython==8.22.2 # Pinned to unbreak TypeAliasType import error
14+
google-adk==0.0.2

testing/constraints-3.8.txt

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ grpcio-testing==1.34.0
1313
pytest-xdist==3.3.1 # Pinned to unbreak unit tests
1414
ray==2.4.0 # Pinned until 2.9.3 is verified for Ray tests
1515
google-vizier==0.1.21
16+
google-adk==0.0.2

testing/constraints-3.9.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ packaging==24.1 # Increased to unbreak canonicalize_version error (b/377774673)
1111
grpcio-testing==1.34.0
1212
pytest-xdist==3.3.1 # Pinned to unbreak unit tests
1313
ray==2.4.0 # Pinned until 2.9.3 is verified for Ray tests
14-
google-vizier==0.1.21
14+
google-vizier==0.1.21
15+
google-adk==0.0.2

tests/unit/architecture/test_vertexai_import.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_vertexai_import():
8484
assert sorted(new_modules_after_vertexai) == [vertexai_module_name]
8585

8686
assert vertexai_import_timedelta.total_seconds() < 0.005
87-
assert aip_import_timedelta.total_seconds() < 23
87+
assert aip_import_timedelta.total_seconds() < 40
8888

8989
# Testing that external modules are not loaded.
9090
new_modules = modules_after_vertexai - modules_before_aip
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
import importlib
16+
import json
17+
from unittest import mock
18+
19+
from google import auth
20+
import vertexai
21+
from google.cloud.aiplatform import initializer
22+
from vertexai.preview import reasoning_engines
23+
from vertexai.agent_engines import _utils
24+
import pytest
25+
26+
27+
try:
28+
from google.adk.agents import llm_agent
29+
30+
Agent = llm_agent.Agent
31+
except ImportError:
32+
33+
class Agent:
34+
def __init__(self, name: str, model: str):
35+
self.name = name
36+
self.model = model
37+
38+
39+
_TEST_LOCATION = "us-central1"
40+
_TEST_PROJECT = "test-project"
41+
_TEST_MODEL = "gemini-1.0-pro"
42+
43+
44+
@pytest.fixture(scope="module")
45+
def google_auth_mock():
46+
with mock.patch.object(auth, "default") as google_auth_mock:
47+
credentials_mock = mock.Mock()
48+
credentials_mock.with_quota_project.return_value = None
49+
google_auth_mock.return_value = (
50+
credentials_mock,
51+
_TEST_PROJECT,
52+
)
53+
yield google_auth_mock
54+
55+
56+
@pytest.fixture
57+
def vertexai_init_mock():
58+
with mock.patch.object(vertexai, "init") as vertexai_init_mock:
59+
yield vertexai_init_mock
60+
61+
62+
@pytest.fixture
63+
def cloud_trace_exporter_mock():
64+
with mock.patch.object(
65+
_utils,
66+
"_import_cloud_trace_exporter_or_warn",
67+
) as cloud_trace_exporter_mock:
68+
yield cloud_trace_exporter_mock
69+
70+
71+
@pytest.fixture
72+
def tracer_provider_mock():
73+
with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock:
74+
yield tracer_provider_mock
75+
76+
77+
@pytest.fixture
78+
def simple_span_processor_mock():
79+
with mock.patch(
80+
"opentelemetry.sdk.trace.export.SimpleSpanProcessor"
81+
) as simple_span_processor_mock:
82+
yield simple_span_processor_mock
83+
84+
85+
class _MockRunner:
86+
def run(self, *args, **kwargs):
87+
from google.adk.events import event
88+
89+
yield event.Event(
90+
**{
91+
"author": "currency_exchange_agent",
92+
"content": {
93+
"parts": [
94+
{
95+
"function_call": {
96+
"args": {
97+
"currency_date": "2025-04-03",
98+
"currency_from": "USD",
99+
"currency_to": "SEK",
100+
},
101+
"id": "af-c5a57692-9177-4091-a3df-098f834ee849",
102+
"name": "get_exchange_rate",
103+
}
104+
}
105+
],
106+
"role": "model",
107+
},
108+
"id": "9aaItGK9",
109+
"invocation_id": "e-6543c213-6417-484b-9551-b67915d1d5f7",
110+
}
111+
)
112+
113+
114+
@pytest.mark.usefixtures("google_auth_mock")
115+
class TestAdkApp:
116+
def setup_method(self):
117+
importlib.reload(initializer)
118+
importlib.reload(vertexai)
119+
vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
120+
121+
def teardown_method(self):
122+
initializer.global_pool.shutdown(wait=True)
123+
124+
def test_initialization(self):
125+
app = reasoning_engines.AdkApp(
126+
agent=Agent(name="test_agent", model=_TEST_MODEL),
127+
)
128+
assert app._tmpl_attrs.get("project") == _TEST_PROJECT
129+
assert app._tmpl_attrs.get("location") == _TEST_LOCATION
130+
assert app._tmpl_attrs.get("runner") is None
131+
132+
def test_set_up(self):
133+
app = reasoning_engines.AdkApp(
134+
agent=Agent(name="test_agent", model=_TEST_MODEL),
135+
)
136+
assert app._tmpl_attrs.get("runner") is None
137+
app.set_up()
138+
assert app._tmpl_attrs.get("runner") is not None
139+
140+
def test_clone(self):
141+
app = reasoning_engines.AdkApp(
142+
agent=Agent(name="test_agent", model=_TEST_MODEL),
143+
)
144+
app.set_up()
145+
assert app._tmpl_attrs.get("runner") is not None
146+
app_clone = app.clone()
147+
assert app._tmpl_attrs.get("runner") is not None
148+
assert app_clone._tmpl_attrs.get("runner") is None
149+
app_clone.set_up()
150+
assert app_clone._tmpl_attrs.get("runner") is not None
151+
152+
def test_register_operations(self):
153+
app = reasoning_engines.AdkApp(
154+
agent=Agent(name="test_agent", model=_TEST_MODEL),
155+
)
156+
for operations in app.register_operations().values():
157+
for operation in operations:
158+
assert operation in dir(app)
159+
160+
def test_stream_query(self):
161+
app = reasoning_engines.AdkApp(
162+
agent=Agent(name="test_agent", model=_TEST_MODEL)
163+
)
164+
assert app._tmpl_attrs.get("runner") is None
165+
app.set_up()
166+
app._tmpl_attrs["runner"] = _MockRunner()
167+
events = list(
168+
app.stream_query(
169+
user_id="test_user_id",
170+
message="test message",
171+
)
172+
)
173+
assert len(events) == 1
174+
175+
def test_streaming_agent_run_with_events(self):
176+
app = reasoning_engines.AdkApp(
177+
agent=Agent(name="test_agent", model=_TEST_MODEL)
178+
)
179+
app.set_up()
180+
app._tmpl_attrs["in_memory_runner"] = _MockRunner()
181+
request_json = json.dumps(
182+
{
183+
"artifacts": [
184+
{
185+
"file_name": "test_file_name",
186+
"versions": [{"version": "v1", "data": "v1data"}],
187+
}
188+
],
189+
"authorizations": {
190+
"test_user_id1": {"access_token": "test_access_token"},
191+
"test_user_id2": {"accessToken": "test-access-token"},
192+
},
193+
"user_id": "test_user_id",
194+
"message": {
195+
"parts": [{"text": "What is the exchange rate from USD to SEK?"}],
196+
"role": "user",
197+
},
198+
}
199+
)
200+
events = list(app.streaming_agent_run_with_events(request_json=request_json))
201+
assert len(events) == 1
202+
203+
def test_create_session(self):
204+
app = reasoning_engines.AdkApp(
205+
agent=Agent(name="test_agent", model=_TEST_MODEL)
206+
)
207+
session1 = app.create_session(user_id="test_user_id")
208+
assert session1.user_id == "test_user_id"
209+
session2 = app.create_session(
210+
user_id="test_user_id", session_id="test_session_id"
211+
)
212+
assert session2.user_id == "test_user_id"
213+
assert session2.id == "test_session_id"
214+
215+
def test_get_session(self):
216+
app = reasoning_engines.AdkApp(
217+
agent=Agent(name="test_agent", model=_TEST_MODEL)
218+
)
219+
session1 = app.create_session(user_id="test_user_id")
220+
session2 = app.get_session(
221+
user_id="test_user_id",
222+
session_id=session1.id,
223+
)
224+
assert session2.user_id == "test_user_id"
225+
assert session1.id == session2.id
226+
227+
def test_list_sessions(self):
228+
app = reasoning_engines.AdkApp(
229+
agent=Agent(name="test_agent", model=_TEST_MODEL)
230+
)
231+
response0 = app.list_sessions(user_id="test_user_id")
232+
assert not response0.sessions
233+
session = app.create_session(user_id="test_user_id")
234+
response1 = app.list_sessions(user_id="test_user_id")
235+
assert len(response1.sessions) == 1
236+
assert response1.sessions[0].id == session.id
237+
session2 = app.create_session(user_id="test_user_id")
238+
response2 = app.list_sessions(user_id="test_user_id")
239+
assert len(response2.sessions) == 2
240+
assert response2.sessions[0].id == session.id
241+
assert response2.sessions[1].id == session2.id
242+
243+
def test_delete_session(self):
244+
app = reasoning_engines.AdkApp(
245+
agent=Agent(name="test_agent", model=_TEST_MODEL)
246+
)
247+
response = app.delete_session(user_id="test_user_id", session_id="")
248+
assert not response
249+
session = app.create_session(user_id="test_user_id")
250+
response1 = app.list_sessions(user_id="test_user_id")
251+
assert len(response1.sessions) == 1
252+
app.delete_session(user_id="test_user_id", session_id=session.id)
253+
response0 = app.list_sessions(user_id="test_user_id")
254+
assert not response0.sessions
255+
256+
@pytest.mark.usefixtures("caplog")
257+
def test_enable_tracing(
258+
self,
259+
caplog,
260+
cloud_trace_exporter_mock,
261+
tracer_provider_mock,
262+
simple_span_processor_mock,
263+
):
264+
app = reasoning_engines.AdkApp(
265+
agent=Agent(name="test_agent", model=_TEST_MODEL),
266+
enable_tracing=True,
267+
)
268+
assert app._tmpl_attrs.get("instrumentor") is None
269+
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
270+
# agent.set_up()
271+
# assert agent._tmpl_attrs.get("instrumentor") is not None
272+
# assert (
273+
# "enable_tracing=True but proceeding with tracing disabled"
274+
# not in caplog.text
275+
# )
276+
277+
@pytest.mark.usefixtures("caplog")
278+
def test_enable_tracing_warning(self, caplog):
279+
app = reasoning_engines.AdkApp(
280+
agent=Agent(name="test_agent", model=_TEST_MODEL),
281+
enable_tracing=True,
282+
)
283+
assert app._tmpl_attrs.get("instrumentor") is None
284+
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
285+
# app.set_up()
286+
# assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text
287+
288+
289+
class TestAdkAppErrors:
290+
def test_raise_get_session_not_found_error(self):
291+
with pytest.raises(
292+
RuntimeError,
293+
match=r"Session not found. Please create it using .create_session()",
294+
):
295+
app = reasoning_engines.AdkApp(
296+
agent=Agent(name="test_agent", model=_TEST_MODEL),
297+
)
298+
app.get_session(
299+
user_id="non_existent_user",
300+
session_id="test_session_id",
301+
)

0 commit comments

Comments
 (0)