Skip to content

Commit 0c3e294

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Added support for Grounding
PiperOrigin-RevId: 607165697
1 parent dd80b69 commit 0c3e294

File tree

4 files changed

+208
-5
lines changed

4 files changed

+208
-5
lines changed

tests/system/vertexai/test_generative_models.py

+15
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from google.cloud import aiplatform
2525
from tests.system.aiplatform import e2e_base
2626
from vertexai import generative_models
27+
from vertexai.preview import generative_models as preview_generative_models
2728

2829

2930
class TestGenerativeModels(e2e_base.TestEndToEnd):
@@ -134,6 +135,20 @@ def test_generate_content_from_text_and_remote_video(self):
134135
assert response.text
135136
assert "Zootopia" in response.text
136137

138+
def test_grounding_google_search_retriever(self):
139+
model = preview_generative_models.GenerativeModel("gemini-pro")
140+
google_search_retriever_tool = (
141+
preview_generative_models.Tool.from_google_search_retrieval(
142+
preview_generative_models.grounding.GoogleSearchRetrieval(
143+
disable_attribution=False
144+
)
145+
)
146+
)
147+
response = model.generate_content(
148+
"Why is sky blue?", tools=[google_search_retriever_tool]
149+
)
150+
assert response.text
151+
137152
# Chat
138153

139154
def test_send_message_from_text(self):

tests/unit/vertexai/test_generative_models.py

+76-5
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,27 @@ def mock_generate_content(
121121
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
122122
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
123123
is_continued_chat = len(request.contents) > 1
124-
has_tools = bool(request.tools)
124+
has_retrieval = any(
125+
tool.retrieval or tool.google_search_retrieval for tool in request.tools
126+
)
127+
has_function_declarations = any(
128+
tool.function_declarations for tool in request.tools
129+
)
130+
has_function_request = any(
131+
content.parts[0].function_call for content in request.contents
132+
)
133+
has_function_response = any(
134+
content.parts[0].function_response for content in request.contents
135+
)
125136

126-
if has_tools:
127-
has_function_response = any(
128-
"function_response" in content.parts[0] for content in request.contents
129-
)
137+
if has_function_request:
138+
assert has_function_response
139+
140+
if has_function_response:
141+
assert has_function_request
142+
assert has_function_declarations
143+
144+
if has_function_declarations:
130145
needs_function_call = not has_function_response
131146
if needs_function_call:
132147
response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT
@@ -158,6 +173,24 @@ def mock_generate_content(
158173
gapic_content_types.Citation(_RESPONSE_CITATION_STRUCT),
159174
]
160175
),
176+
grounding_metadata=gapic_content_types.GroundingMetadata(
177+
web_search_queries=[request.contents[0].parts[0].text],
178+
grounding_attributions=[
179+
gapic_content_types.GroundingAttribution(
180+
segment=gapic_content_types.Segment(
181+
start_index=0,
182+
end_index=67,
183+
),
184+
confidence_score=0.69857746,
185+
web=gapic_content_types.GroundingAttribution.Web(
186+
uri="https://math.ucr.edu/home/baez/physics/General/BlueSky/blue_sky.html",
187+
title="Why is the sky blue? - UCR Math",
188+
),
189+
),
190+
],
191+
)
192+
if has_retrieval and request.contents[0].parts[0].text
193+
else None,
161194
),
162195
],
163196
)
@@ -288,3 +321,41 @@ def test_chat_function_calling(self, generative_models: generative_models):
288321
),
289322
)
290323
assert response2.text == "The weather in Boston is super nice!"
324+
325+
@mock.patch.object(
326+
target=prediction_service.PredictionServiceClient,
327+
attribute="generate_content",
328+
new=mock_generate_content,
329+
)
330+
def test_generate_content_grounding_google_search_retriever(self):
331+
model = preview_generative_models.GenerativeModel("gemini-pro")
332+
google_search_retriever_tool = (
333+
preview_generative_models.Tool.from_google_search_retrieval(
334+
preview_generative_models.grounding.GoogleSearchRetrieval(
335+
disable_attribution=False
336+
)
337+
)
338+
)
339+
response = model.generate_content(
340+
"Why is sky blue?", tools=[google_search_retriever_tool]
341+
)
342+
assert response.text
343+
344+
@mock.patch.object(
345+
target=prediction_service.PredictionServiceClient,
346+
attribute="generate_content",
347+
new=mock_generate_content,
348+
)
349+
def test_generate_content_grounding_vertex_ai_search_retriever(self):
350+
model = preview_generative_models.GenerativeModel("gemini-pro")
351+
google_search_retriever_tool = preview_generative_models.Tool.from_retrieval(
352+
retrieval=preview_generative_models.grounding.Retrieval(
353+
source=preview_generative_models.grounding.VertexAISearch(
354+
datastore=f"projects/{_TEST_PROJECT}/locations/global/collections/default_collection/dataStores/test-datastore",
355+
)
356+
)
357+
)
358+
response = model.generate_content(
359+
"Why is sky blue?", tools=[google_search_retriever_tool]
360+
)
361+
assert response.text

vertexai/generative_models/_generative_models.py

+115
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,40 @@ def __init__(
11321132
function_declarations=gapic_function_declarations
11331133
)
11341134

1135+
@classmethod
1136+
def from_function_declarations(
1137+
cls,
1138+
function_declarations: List["FunctionDeclaration"],
1139+
):
1140+
gapic_function_declarations = [
1141+
function_declaration._raw_function_declaration
1142+
for function_declaration in function_declarations
1143+
]
1144+
raw_tool = gapic_tool_types.Tool(
1145+
function_declarations=gapic_function_declarations
1146+
)
1147+
return cls._from_gapic(raw_tool=raw_tool)
1148+
1149+
@classmethod
1150+
def from_retrieval(
1151+
cls,
1152+
retrieval: "Retrieval",
1153+
):
1154+
raw_tool = gapic_tool_types.Tool(
1155+
retrieval=retrieval._raw_retrieval
1156+
)
1157+
return cls._from_gapic(raw_tool=raw_tool)
1158+
1159+
@classmethod
1160+
def from_google_search_retrieval(
1161+
cls,
1162+
google_search_retrieval: "GoogleSearchRetrieval",
1163+
):
1164+
raw_tool = gapic_tool_types.Tool(
1165+
google_search_retrieval=google_search_retrieval._raw_google_search_retrieval
1166+
)
1167+
return cls._from_gapic(raw_tool=raw_tool)
1168+
11351169
@classmethod
11361170
def _from_gapic(
11371171
cls,
@@ -1520,6 +1554,87 @@ def _image(self) -> "Image":
15201554
return Image.from_bytes(data=self._raw_part.inline_data.data)
15211555

15221556

1557+
class grounding: # pylint: disable=invalid-name
1558+
"""Grounding namespace."""
1559+
1560+
def __init__(self):
1561+
raise RuntimeError("This class must not be instantiated.")
1562+
1563+
class Retrieval:
1564+
"""Defines a retrieval tool that model can call to access external knowledge."""
1565+
1566+
def __init__(
1567+
self,
1568+
source: Union["grounding.VertexAISearch"],
1569+
disable_attribution: Optional[bool] = None,
1570+
):
1571+
"""Initializes a Retrieval tool.
1572+
1573+
Args:
1574+
source (VertexAISearch):
1575+
Set to use data source powered by Vertex AI Search.
1576+
disable_attribution (bool):
1577+
Optional. Disable using the result from this
1578+
tool in detecting grounding attribution. This
1579+
does not affect how the result is given to the
1580+
model for generation.
1581+
"""
1582+
self._raw_retrieval = gapic_tool_types.Retrieval(
1583+
vertex_ai_search=source._raw_vertex_ai_search,
1584+
disable_attribution=disable_attribution,
1585+
)
1586+
1587+
class VertexAISearch:
1588+
r"""Retrieve from Vertex AI Search datastore for grounding.
1589+
See https://cloud.google.com/vertex-ai-search-and-conversation
1590+
"""
1591+
1592+
def __init__(
1593+
self,
1594+
datastore: str,
1595+
):
1596+
"""Initializes a Vertex AI Search tool.
1597+
1598+
Args:
1599+
datastore (str):
1600+
Required. Fully-qualified Vertex AI Search's
1601+
datastore resource ID.
1602+
projects/<>/locations/<>/collections/<>/dataStores/<>
1603+
"""
1604+
self._raw_vertex_ai_search = gapic_tool_types.VertexAISearch(
1605+
datastore=datastore,
1606+
)
1607+
1608+
class GoogleSearchRetrieval:
1609+
r"""Tool to retrieve public web data for grounding, powered by
1610+
Google Search.
1611+
1612+
Attributes:
1613+
disable_attribution (bool):
1614+
Optional. Disable using the result from this
1615+
tool in detecting grounding attribution. This
1616+
does not affect how the result is given to the
1617+
model for generation.
1618+
"""
1619+
1620+
def __init__(
1621+
self,
1622+
disable_attribution: Optional[bool] = None,
1623+
):
1624+
"""Initializes a Google Search Retrieval tool.
1625+
1626+
Args:
1627+
disable_attribution (bool):
1628+
Optional. Disable using the result from this
1629+
tool in detecting grounding attribution. This
1630+
does not affect how the result is given to the
1631+
model for generation.
1632+
"""
1633+
self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval(
1634+
disable_attribution=disable_attribution,
1635+
)
1636+
1637+
15231638
def _to_content(
15241639
value: Union[
15251640
gapic_content_types.Content,

vertexai/preview/generative_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# We just want to re-export certain classes
1818
# pylint: disable=g-multiple-import,g-importing-member
1919
from vertexai.generative_models._generative_models import (
20+
grounding,
2021
_PreviewGenerativeModel,
2122
GenerationConfig,
2223
GenerationResponse,
@@ -39,6 +40,7 @@ class GenerativeModel(_PreviewGenerativeModel):
3940

4041

4142
__all__ = [
43+
"grounding",
4244
"GenerationConfig",
4345
"GenerativeModel",
4446
"GenerationResponse",

0 commit comments

Comments
 (0)