Skip to content

Commit abf08da

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
feat: Add the initial version of the AG2 agent prebuilt template.
PiperOrigin-RevId: 731075667
1 parent 4998c1a commit abf08da

File tree

8 files changed

+825
-0
lines changed

8 files changed

+825
-0
lines changed

noxfile.py

+29
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
UNIT_TEST_PYTHON_VERSIONS = ["3.8", "3.9", "3.10", "3.11", "3.12"]
5555
UNIT_TEST_LANGCHAIN_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12"]
56+
UNIT_TEST_AG2_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12"]
5657
UNIT_TEST_STANDARD_DEPENDENCIES = [
5758
"mock",
5859
"asyncmock",
@@ -91,6 +92,7 @@
9192
"unit",
9293
"unit_ray",
9394
"unit_langchain",
95+
"unit_ag2",
9496
"system",
9597
"cover",
9698
"lint",
@@ -205,6 +207,7 @@ def default(session):
205207
"--cov-fail-under=0",
206208
"--ignore=tests/unit/vertex_ray",
207209
"--ignore=tests/unit/vertex_langchain",
210+
"--ignore=tests/unit/vertex_ag2",
208211
"--ignore=tests/unit/architecture",
209212
os.path.join("tests", "unit"),
210213
*session.posargs,
@@ -302,6 +305,32 @@ def unit_langchain(session):
302305
)
303306

304307

308+
@nox.session(python=UNIT_TEST_AG2_PYTHON_VERSIONS)
309+
def unit_ag2(session):
310+
# Install all test dependencies, then install this package in-place.
311+
312+
constraints_path = str(CURRENT_DIRECTORY / "testing" / "constraints-ag2.txt")
313+
standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES
314+
session.install(*standard_deps, "-c", constraints_path)
315+
316+
# Install ag2 extras
317+
session.install("-e", ".[ag2_testing]", "-c", constraints_path)
318+
319+
# Run py.test against the unit tests.
320+
session.run(
321+
"py.test",
322+
"--quiet",
323+
"--junitxml=unit_ag2_sponge_log.xml",
324+
"--cov=google",
325+
"--cov-append",
326+
"--cov-config=.coveragerc",
327+
"--cov-report=",
328+
"--cov-fail-under=0",
329+
os.path.join("tests", "unit", "vertex_ag2"),
330+
*session.posargs,
331+
)
332+
333+
305334
def install_systemtest_dependencies(session, *constraints):
306335
# Use pre-release gRPC for system tests.
307336
# Exclude version 1.52.0rc1 which has a known issue.

setup.py

+12
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,16 @@
171171
)
172172
)
173173

174+
ag2_extra_require = [
175+
"ag2[gemini]",
176+
]
177+
178+
ag2_testing_extra_require = list(
179+
set(
180+
ag2_extra_require + reasoning_engine_extra_require + ["absl-py", "pytest-xdist"]
181+
)
182+
)
183+
174184
tokenization_extra_require = ["sentencepiece >= 0.2.0"]
175185
tokenization_testing_extra_require = tokenization_extra_require + ["nltk"]
176186

@@ -284,6 +294,8 @@
284294
"langchain": langchain_extra_require,
285295
"langchain_testing": langchain_testing_extra_require,
286296
"tokenization": tokenization_extra_require,
297+
"ag2": ag2_extra_require,
298+
"ag2_testing": ag2_testing_extra_require,
287299
},
288300
python_requires=">=3.8",
289301
classifiers=[

testing/constraints-ag2.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pydantic<2.10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
import dataclasses
16+
import importlib
17+
from typing import Optional
18+
from unittest import mock
19+
20+
from google import auth
21+
import vertexai
22+
from google.cloud.aiplatform import initializer
23+
from vertexai.preview import reasoning_engines
24+
from vertexai.reasoning_engines import _utils
25+
import pytest
26+
27+
28+
_DEFAULT_PLACE_TOOL_ACTIVITY = "museums"
29+
_DEFAULT_PLACE_TOOL_PAGE_SIZE = 3
30+
_DEFAULT_PLACE_PHOTO_MAXWIDTH = 400
31+
_TEST_LOCATION = "us-central1"
32+
_TEST_PROJECT = "test-project"
33+
_TEST_MODEL = "gemini-1.0-pro"
34+
_TEST_SYSTEM_INSTRUCTION = "You are a helpful bot."
35+
36+
37+
def place_tool_query(
38+
city: str,
39+
activity: str = _DEFAULT_PLACE_TOOL_ACTIVITY,
40+
page_size: int = _DEFAULT_PLACE_TOOL_PAGE_SIZE,
41+
):
42+
"""Searches the city for recommendations on the activity."""
43+
return {"city": city, "activity": activity, "page_size": page_size}
44+
45+
46+
def place_photo_query(
47+
photo_reference: str,
48+
maxwidth: int = _DEFAULT_PLACE_PHOTO_MAXWIDTH,
49+
maxheight: Optional[int] = None,
50+
):
51+
"""Returns the photo for a given reference."""
52+
result = {"photo_reference": photo_reference, "maxwidth": maxwidth}
53+
if maxheight:
54+
result["maxheight"] = maxheight
55+
return result
56+
57+
58+
@pytest.fixture(scope="module")
59+
def google_auth_mock():
60+
with mock.patch.object(auth, "default") as google_auth_mock:
61+
credentials_mock = mock.Mock()
62+
credentials_mock.with_quota_project.return_value = None
63+
google_auth_mock.return_value = (
64+
credentials_mock,
65+
_TEST_PROJECT,
66+
)
67+
yield google_auth_mock
68+
69+
70+
@pytest.fixture
71+
def vertexai_init_mock():
72+
with mock.patch.object(vertexai, "init") as vertexai_init_mock:
73+
yield vertexai_init_mock
74+
75+
76+
@pytest.fixture
77+
def dataclasses_asdict_mock():
78+
with mock.patch.object(dataclasses, "asdict") as dataclasses_asdict_mock:
79+
dataclasses_asdict_mock.return_value = {}
80+
yield dataclasses_asdict_mock
81+
82+
83+
@pytest.fixture
84+
def cloud_trace_exporter_mock():
85+
with mock.patch.object(
86+
_utils,
87+
"_import_cloud_trace_exporter_or_warn",
88+
) as cloud_trace_exporter_mock:
89+
yield cloud_trace_exporter_mock
90+
91+
92+
@pytest.fixture
93+
def tracer_provider_mock():
94+
with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock:
95+
yield tracer_provider_mock
96+
97+
98+
@pytest.fixture
99+
def simple_span_processor_mock():
100+
with mock.patch(
101+
"opentelemetry.sdk.trace.export.SimpleSpanProcessor"
102+
) as simple_span_processor_mock:
103+
yield simple_span_processor_mock
104+
105+
106+
@pytest.fixture
107+
def autogen_tools_mock():
108+
with mock.patch.object(
109+
_utils,
110+
"_import_autogen_tools_or_warn",
111+
) as autogen_tools_mock:
112+
autogen_tools_mock.return_value = mock.MagicMock()
113+
yield autogen_tools_mock
114+
115+
116+
@pytest.mark.usefixtures("google_auth_mock")
117+
class TestAG2Agent:
118+
def setup_method(self):
119+
importlib.reload(initializer)
120+
importlib.reload(vertexai)
121+
vertexai.init(
122+
project=_TEST_PROJECT,
123+
location=_TEST_LOCATION,
124+
)
125+
126+
def teardown_method(self):
127+
initializer.global_pool.shutdown(wait=True)
128+
129+
def test_initialization(self):
130+
agent = reasoning_engines.AG2Agent(model=_TEST_MODEL)
131+
assert agent._model_name == _TEST_MODEL
132+
assert agent._project == _TEST_PROJECT
133+
assert agent._location == _TEST_LOCATION
134+
assert agent._runnable is None
135+
136+
def test_initialization_with_tools(self, autogen_tools_mock):
137+
tools = [
138+
place_tool_query,
139+
place_photo_query,
140+
]
141+
agent = reasoning_engines.AG2Agent(
142+
model=_TEST_MODEL,
143+
system_instruction=_TEST_SYSTEM_INSTRUCTION,
144+
tools=tools,
145+
runnable_builder=lambda **kwargs: kwargs,
146+
)
147+
assert agent._runnable is None
148+
assert agent._tools
149+
assert not agent._ag2_tool_objects
150+
agent.set_up()
151+
assert agent._runnable is not None
152+
assert agent._ag2_tool_objects
153+
154+
def test_set_up(self):
155+
agent = reasoning_engines.AG2Agent(
156+
model=_TEST_MODEL,
157+
runnable_builder=lambda **kwargs: kwargs,
158+
)
159+
assert agent._runnable is None
160+
agent.set_up()
161+
assert agent._runnable is not None
162+
163+
def test_clone(self):
164+
agent = reasoning_engines.AG2Agent(
165+
model=_TEST_MODEL,
166+
runnable_builder=lambda **kwargs: kwargs,
167+
)
168+
agent.set_up()
169+
assert agent._runnable is not None
170+
agent_clone = agent.clone()
171+
assert agent._runnable is not None
172+
assert agent_clone._runnable is None
173+
agent_clone.set_up()
174+
assert agent_clone._runnable is not None
175+
176+
def test_query(self, dataclasses_asdict_mock):
177+
agent = reasoning_engines.AG2Agent(
178+
model=_TEST_MODEL,
179+
)
180+
agent._runnable = mock.Mock()
181+
mocks = mock.Mock()
182+
mocks.attach_mock(mock=agent._runnable, attribute="run")
183+
agent.query(input="test query")
184+
mocks.assert_has_calls(
185+
[
186+
mock.call.run.run(
187+
{"content": "test query"},
188+
user_input=False,
189+
tools=[],
190+
max_turns=None,
191+
)
192+
]
193+
)
194+
195+
@pytest.mark.usefixtures("caplog")
196+
def test_enable_tracing(
197+
self,
198+
caplog,
199+
cloud_trace_exporter_mock,
200+
tracer_provider_mock,
201+
simple_span_processor_mock,
202+
):
203+
agent = reasoning_engines.AG2Agent(
204+
model=_TEST_MODEL,
205+
enable_tracing=True,
206+
)
207+
assert agent._enable_tracing is True
208+
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
209+
# agent.set_up()
210+
# assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text
211+
212+
213+
def _return_input_no_typing(input_):
214+
"""Returns input back to user."""
215+
return input_
216+
217+
218+
class TestConvertToolsOrRaiseErrors:
219+
def test_raise_untyped_input_args(self, vertexai_init_mock):
220+
with pytest.raises(TypeError, match=r"has untyped input_arg"):
221+
reasoning_engines.AG2Agent(
222+
model=_TEST_MODEL,
223+
tools=[_return_input_no_typing],
224+
)

0 commit comments

Comments
 (0)