Skip to content

Commit 0537fec

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Grounding - Released VertexAiSearch and Retrieval to GA
PiperOrigin-RevId: 702481252
1 parent 07d6973 commit 0537fec

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

tests/unit/vertexai/test_generative_models.py

+27-24
Original file line numberDiff line numberDiff line change
@@ -1308,20 +1308,13 @@ def test_conversion_methods(self, generative_models: generative_models):
13081308
assert response.to_dict()["candidates"][0]["finish_reason"] == "STOP"
13091309

13101310
@patch_genai_services
1311-
def test_generate_content_grounding_google_search_retriever_preview(self):
1312-
model = preview_generative_models.GenerativeModel("gemini-pro")
1313-
google_search_retriever_tool = (
1314-
preview_generative_models.Tool.from_google_search_retrieval(
1315-
preview_generative_models.grounding.GoogleSearchRetrieval()
1316-
)
1317-
)
1318-
response = model.generate_content(
1319-
"Why is sky blue?", tools=[google_search_retriever_tool]
1320-
)
1321-
assert response.text
1322-
1323-
@patch_genai_services
1324-
def test_generate_content_grounding_google_search_retriever(self):
1311+
@pytest.mark.parametrize(
1312+
"generative_models",
1313+
[generative_models, preview_generative_models],
1314+
)
1315+
def test_generate_content_grounding_google_search_retriever(
1316+
self, generative_models: generative_models
1317+
):
13251318
model = generative_models.GenerativeModel("gemini-pro")
13261319
google_search_retriever_tool = (
13271320
generative_models.Tool.from_google_search_retrieval(
@@ -1334,11 +1327,17 @@ def test_generate_content_grounding_google_search_retriever(self):
13341327
assert response.text
13351328

13361329
@patch_genai_services
1337-
def test_generate_content_grounding_vertex_ai_search_retriever(self):
1338-
model = preview_generative_models.GenerativeModel("gemini-pro")
1339-
vertex_ai_search_retriever_tool = preview_generative_models.Tool.from_retrieval(
1340-
retrieval=preview_generative_models.grounding.Retrieval(
1341-
source=preview_generative_models.grounding.VertexAISearch(
1330+
@pytest.mark.parametrize(
1331+
"generative_models",
1332+
[generative_models, preview_generative_models],
1333+
)
1334+
def test_generate_content_grounding_vertex_ai_search_retriever(
1335+
self, generative_models: generative_models
1336+
):
1337+
model = generative_models.GenerativeModel("gemini-pro")
1338+
vertex_ai_search_retriever_tool = generative_models.Tool.from_retrieval(
1339+
retrieval=generative_models.grounding.Retrieval(
1340+
source=generative_models.grounding.VertexAISearch(
13421341
datastore=f"projects/{_TEST_PROJECT}/locations/global/collections/default_collection/dataStores/test-datastore",
13431342
)
13441343
)
@@ -1349,13 +1348,17 @@ def test_generate_content_grounding_vertex_ai_search_retriever(self):
13491348
assert response.text
13501349

13511350
@patch_genai_services
1351+
@pytest.mark.parametrize(
1352+
"generative_models",
1353+
[generative_models, preview_generative_models],
1354+
)
13521355
def test_generate_content_grounding_vertex_ai_search_retriever_with_project_and_location(
1353-
self,
1356+
self, generative_models: generative_models
13541357
):
1355-
model = preview_generative_models.GenerativeModel("gemini-pro")
1356-
vertex_ai_search_retriever_tool = preview_generative_models.Tool.from_retrieval(
1357-
retrieval=preview_generative_models.grounding.Retrieval(
1358-
source=preview_generative_models.grounding.VertexAISearch(
1358+
model = generative_models.GenerativeModel("gemini-pro")
1359+
vertex_ai_search_retriever_tool = generative_models.Tool.from_retrieval(
1360+
retrieval=generative_models.grounding.Retrieval(
1361+
source=generative_models.grounding.VertexAISearch(
13591362
datastore="test-datastore",
13601363
project=_TEST_PROJECT,
13611364
location="global",

vertexai/generative_models/_generative_models.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -1931,7 +1931,7 @@ def from_function_declarations(
19311931
@classmethod
19321932
def from_retrieval(
19331933
cls,
1934-
retrieval: Union["preview_grounding.Retrieval"],
1934+
retrieval: Union["grounding.Retrieval"],
19351935
) -> "Tool":
19361936
raw_tool = gapic_tool_types.Tool(retrieval=retrieval._raw_retrieval)
19371937
return cls._from_gapic(raw_tool=raw_tool)
@@ -2767,16 +2767,6 @@ def __init__(
27672767
else None
27682768
)
27692769

2770-
2771-
class preview_grounding: # pylint: disable=invalid-name
2772-
"""Grounding namespace (preview)."""
2773-
2774-
__name__ = "grounding"
2775-
__module__ = "vertexai.preview.generative_models"
2776-
2777-
def __init__(self):
2778-
raise RuntimeError("This class must not be instantiated.")
2779-
27802770
class Retrieval:
27812771
"""Defines a retrieval tool that model can call to access external knowledge."""
27822772

@@ -2838,7 +2828,15 @@ def __init__(
28382828
datastore=datastore,
28392829
)
28402830

2841-
GoogleSearchRetrieval = grounding.GoogleSearchRetrieval
2831+
2832+
class preview_grounding(grounding): # pylint: disable=invalid-name
2833+
"""Grounding namespace (preview)."""
2834+
2835+
__name__ = "grounding"
2836+
__module__ = "vertexai.preview.generative_models"
2837+
2838+
def __init__(self):
2839+
raise RuntimeError("This class must not be instantiated.")
28422840

28432841

28442842
def _to_content(

0 commit comments

Comments
 (0)