Skip to content

Commit f334321

Browse files
holtskinnercopybara-github
authored andcommitted
feat: Grounding - Allow initialization of grounding.VertexAISearch with full resource name or data store ID, project ID, and location.
PiperOrigin-RevId: 667990245
1 parent fef5e4d commit f334321

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

tests/unit/vertexai/test_generative_models.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -1023,15 +1023,38 @@ def test_generate_content_grounding_google_search_retriever(self):
10231023
)
10241024
def test_generate_content_grounding_vertex_ai_search_retriever(self):
10251025
model = preview_generative_models.GenerativeModel("gemini-pro")
1026-
google_search_retriever_tool = preview_generative_models.Tool.from_retrieval(
1026+
vertex_ai_search_retriever_tool = preview_generative_models.Tool.from_retrieval(
10271027
retrieval=preview_generative_models.grounding.Retrieval(
10281028
source=preview_generative_models.grounding.VertexAISearch(
10291029
datastore=f"projects/{_TEST_PROJECT}/locations/global/collections/default_collection/dataStores/test-datastore",
10301030
)
10311031
)
10321032
)
10331033
response = model.generate_content(
1034-
"Why is sky blue?", tools=[google_search_retriever_tool]
1034+
"Why is sky blue?", tools=[vertex_ai_search_retriever_tool]
1035+
)
1036+
assert response.text
1037+
1038+
@mock.patch.object(
1039+
target=prediction_service.PredictionServiceClient,
1040+
attribute="generate_content",
1041+
new=mock_generate_content,
1042+
)
1043+
def test_generate_content_grounding_vertex_ai_search_retriever_with_project_and_location(
1044+
self,
1045+
):
1046+
model = preview_generative_models.GenerativeModel("gemini-pro")
1047+
vertex_ai_search_retriever_tool = preview_generative_models.Tool.from_retrieval(
1048+
retrieval=preview_generative_models.grounding.Retrieval(
1049+
source=preview_generative_models.grounding.VertexAISearch(
1050+
datastore="test-datastore",
1051+
project=_TEST_PROJECT,
1052+
location="global",
1053+
)
1054+
)
1055+
)
1056+
response = model.generate_content(
1057+
"Why is sky blue?", tools=[vertex_ai_search_retriever_tool]
10351058
)
10361059
assert response.text
10371060

vertexai/generative_models/_generative_models.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import io
2121
import json
2222
import pathlib
23+
import re
2324
from typing import (
2425
Any,
2526
AsyncIterable,
@@ -2273,7 +2274,7 @@ def __init__(
22732274
source (VertexAISearch):
22742275
Set to use data source powered by Vertex AI Search.
22752276
disable_attribution (bool):
2276-
Optional. Disable using the result from this
2277+
Deprecated. Disable using the result from this
22772278
tool in detecting grounding attribution. This
22782279
does not affect how the result is given to the
22792280
model for generation.
@@ -2284,22 +2285,38 @@ def __init__(
22842285
)
22852286

22862287
class VertexAISearch:
2287-
r"""Retrieve from Vertex AI Search datastore for grounding.
2288-
See https://cloud.google.com/vertex-ai-search-and-conversation
2288+
r"""Retrieve from Vertex AI Search data store for grounding.
2289+
See https://cloud.google.com/products/agent-builder
22892290
"""
22902291

22912292
def __init__(
22922293
self,
22932294
datastore: str,
2295+
*,
2296+
project: Optional[str] = None,
2297+
location: Optional[str] = None,
22942298
):
22952299
"""Initializes a Vertex AI Search tool.
22962300
22972301
Args:
22982302
datastore (str):
2299-
Required. Fully-qualified Vertex AI Search's
2300-
datastore resource ID.
2301-
projects/<>/locations/<>/collections/<>/dataStores/<>
2303+
Required. Vertex AI Search data store resource name. Format:
2304+
``projects/{project}/locations/{location}/collections/default_collection/dataStores/{data_store}``
2305+
or ``{data_store}``.
2306+
project (str):
2307+
Optional. Project ID of the data store. Must provide either the full data store resource name or data store id, project ID, and location.
2308+
location (str):
2309+
Optional. Location of the data store. Must provide either the full data store resource name or data store id, project ID, and location.
23022310
"""
2311+
if not re.fullmatch(
2312+
r"^projects/[a-z0-9-]*/locations/[a-z0-9][a-z0-9-]*/collections/[a-z0-9][a-z0-9-_]*/dataStores/[a-z0-9][a-z0-9-_]*$",
2313+
datastore,
2314+
):
2315+
if not project or not location:
2316+
raise ValueError(
2317+
"Must provide either the full data store resource name or data store id, project ID, and location."
2318+
)
2319+
datastore = f"projects/{project}/locations/{location}/collections/default_collection/dataStores/{datastore}"
23032320
self._raw_vertex_ai_search = gapic_tool_types.VertexAISearch(
23042321
datastore=datastore,
23052322
)

0 commit comments

Comments
 (0)