Skip to content

Commit c39334a

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI - Grounding - Added grounding dynamic_retrieval config to Vertex SDK
PiperOrigin-RevId: 696776459
1 parent 44587ec commit c39334a

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

tests/system/vertexai/test_generative_models.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,33 @@ def test_generate_content_from_text_and_remote_audio(
427427
assert api_transport in get_client_api_transport(pro_model._prediction_client)
428428

429429
def test_grounding_google_search_retriever(self, api_endpoint_env_name):
430-
model = preview_generative_models.GenerativeModel(GEMINI_MODEL_NAME)
430+
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
431431
google_search_retriever_tool = (
432-
preview_generative_models.Tool.from_google_search_retrieval(
433-
preview_generative_models.grounding.GoogleSearchRetrieval()
432+
generative_models.Tool.from_google_search_retrieval(
433+
generative_models.grounding.GoogleSearchRetrieval()
434+
)
435+
)
436+
response = model.generate_content(
437+
"Why is sky blue?",
438+
tools=[google_search_retriever_tool],
439+
generation_config=generative_models.GenerationConfig(temperature=0),
440+
)
441+
assert (
442+
response.candidates[0].finish_reason
443+
is generative_models.FinishReason.RECITATION
444+
or response.text
445+
)
446+
447+
def test_grounding_google_search_retriever_with_dynamic_retrieval(
448+
self, api_endpoint_env_name
449+
):
450+
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
451+
google_search_retriever_tool = generative_models.Tool.from_google_search_retrieval(
452+
generative_models.grounding.GoogleSearchRetrieval(
453+
generative_models.grounding.DynamicRetrievalConfig(
454+
mode=generative_models.grounding.DynamicRetrievalConfig.Mode.MODE_DYNAMIC,
455+
dynamic_threshold=0.05,
456+
)
434457
)
435458
)
436459
response = model.generate_content(

vertexai/generative_models/_generative_models.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -2745,14 +2745,41 @@ class grounding: # pylint: disable=invalid-name
27452745
def __init__(self):
27462746
raise RuntimeError("This class must not be instantiated.")
27472747

2748+
class DynamicRetrievalConfig:
2749+
"""Config for dynamic retrieval."""
2750+
2751+
Mode = gapic_tool_types.DynamicRetrievalConfig.Mode
2752+
2753+
def __init__(
2754+
self,
2755+
mode: Mode = Mode.MODE_UNSPECIFIED,
2756+
dynamic_threshold: Optional[float] = None,
2757+
):
2758+
"""Initializes a DynamicRetrievalConfig."""
2759+
self._raw_dynamic_retrieval_config = (
2760+
gapic_tool_types.DynamicRetrievalConfig(
2761+
mode=mode,
2762+
dynamic_threshold=dynamic_threshold,
2763+
)
2764+
)
2765+
27482766
class GoogleSearchRetrieval:
27492767
r"""Tool to retrieve public web data for grounding, powered by
27502768
Google Search.
27512769
"""
27522770

2753-
def __init__(self):
2771+
def __init__(
2772+
self,
2773+
dynamic_retrieval_config: Optional[
2774+
"grounding.DynamicRetrievalConfig"
2775+
] = None,
2776+
):
27542777
"""Initializes a Google Search Retrieval tool."""
2755-
self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval()
2778+
self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval(
2779+
dynamic_retrieval_config=dynamic_retrieval_config._raw_dynamic_retrieval_config
2780+
if dynamic_retrieval_config
2781+
else None
2782+
)
27562783

27572784

27582785
class preview_grounding: # pylint: disable=invalid-name

0 commit comments

Comments
 (0)