Skip to content

Commit 310ee49

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - GAPIC - Added support for Grounding
feat: Add Retrieval feat: Add GoogleSearchRetrieval feat: Add VertexAiSearch feat: Add Tool.retrieval feat: Add Tool.google_search_retrieval feat: Add Candidate.grounding_metadata PiperOrigin-RevId: 606432042
1 parent f821e45 commit 310ee49

File tree

5 files changed

+241
-3
lines changed

5 files changed

+241
-3
lines changed

google/cloud/aiplatform_v1beta1/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,12 @@
9292
from .types.content import Content
9393
from .types.content import FileData
9494
from .types.content import GenerationConfig
95+
from .types.content import GroundingAttribution
96+
from .types.content import GroundingMetadata
9597
from .types.content import Part
9698
from .types.content import SafetyRating
9799
from .types.content import SafetySetting
100+
from .types.content import Segment
98101
from .types.content import VideoMetadata
99102
from .types.content import HarmCategory
100103
from .types.context import Context
@@ -695,7 +698,10 @@
695698
from .types.tool import FunctionCall
696699
from .types.tool import FunctionDeclaration
697700
from .types.tool import FunctionResponse
701+
from .types.tool import GoogleSearchRetrieval
702+
from .types.tool import Retrieval
698703
from .types.tool import Tool
704+
from .types.tool import VertexAISearch
699705
from .types.training_pipeline import FilterSplit
700706
from .types.training_pipeline import FractionSplit
701707
from .types.training_pipeline import InputDataConfig
@@ -1064,6 +1070,9 @@
10641070
"GetTensorboardTimeSeriesRequest",
10651071
"GetTrainingPipelineRequest",
10661072
"GetTrialRequest",
1073+
"GoogleSearchRetrieval",
1074+
"GroundingAttribution",
1075+
"GroundingMetadata",
10671076
"HarmCategory",
10681077
"HyperparameterTuningJob",
10691078
"IdMatcher",
@@ -1294,6 +1303,7 @@
12941303
"RestoreDatasetVersionRequest",
12951304
"ResumeModelDeploymentMonitoringJobRequest",
12961305
"ResumeScheduleRequest",
1306+
"Retrieval",
12971307
"SafetyRating",
12981308
"SafetySetting",
12991309
"SampleConfig",
@@ -1315,6 +1325,7 @@
13151325
"SearchModelDeploymentMonitoringStatsAnomaliesResponse",
13161326
"SearchNearestEntitiesRequest",
13171327
"SearchNearestEntitiesResponse",
1328+
"Segment",
13181329
"ServiceAccountSpec",
13191330
"SmoothGradConfig",
13201331
"SpecialistPool",
@@ -1410,6 +1421,7 @@
14101421
"UpsertDatapointsResponse",
14111422
"UserActionReference",
14121423
"Value",
1424+
"VertexAISearch",
14131425
"VideoMetadata",
14141426
"VizierServiceClient",
14151427
"WorkerPoolSpec",

google/cloud/aiplatform_v1beta1/types/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@
3636
Content,
3737
FileData,
3838
GenerationConfig,
39+
GroundingAttribution,
40+
GroundingMetadata,
3941
Part,
4042
SafetyRating,
4143
SafetySetting,
44+
Segment,
4245
VideoMetadata,
4346
HarmCategory,
4447
)
@@ -784,7 +787,10 @@
784787
FunctionCall,
785788
FunctionDeclaration,
786789
FunctionResponse,
790+
GoogleSearchRetrieval,
791+
Retrieval,
787792
Tool,
793+
VertexAISearch,
788794
)
789795
from .training_pipeline import (
790796
FilterSplit,
@@ -850,9 +856,12 @@
850856
"Content",
851857
"FileData",
852858
"GenerationConfig",
859+
"GroundingAttribution",
860+
"GroundingMetadata",
853861
"Part",
854862
"SafetyRating",
855863
"SafetySetting",
864+
"Segment",
856865
"VideoMetadata",
857866
"HarmCategory",
858867
"Context",
@@ -1437,7 +1446,10 @@
14371446
"FunctionCall",
14381447
"FunctionDeclaration",
14391448
"FunctionResponse",
1449+
"GoogleSearchRetrieval",
1450+
"Retrieval",
14401451
"Tool",
1452+
"VertexAISearch",
14411453
"FilterSplit",
14421454
"FractionSplit",
14431455
"InputDataConfig",

google/cloud/aiplatform_v1beta1/types/content.py

+125
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
"CitationMetadata",
4040
"Citation",
4141
"Candidate",
42+
"Segment",
43+
"GroundingAttribution",
44+
"GroundingMetadata",
4245
},
4346
)
4447

@@ -503,6 +506,9 @@ class Candidate(proto.Message):
503506
citation_metadata (google.cloud.aiplatform_v1beta1.types.CitationMetadata):
504507
Output only. Source attribution of the
505508
generated content.
509+
grounding_metadata (google.cloud.aiplatform_v1beta1.types.GroundingMetadata):
510+
Output only. Metadata specifies sources used
511+
to ground generated content.
506512
"""
507513

508514
class FinishReason(proto.Enum):
@@ -566,6 +572,125 @@ class FinishReason(proto.Enum):
566572
number=6,
567573
message="CitationMetadata",
568574
)
575+
grounding_metadata: "GroundingMetadata" = proto.Field(
576+
proto.MESSAGE,
577+
number=7,
578+
message="GroundingMetadata",
579+
)
580+
581+
582+
class Segment(proto.Message):
583+
r"""Segment of the content.
584+
585+
Attributes:
586+
part_index (int):
587+
Output only. The index of a Part object
588+
within its parent Content object.
589+
start_index (int):
590+
Output only. Start index in the given Part,
591+
measured in bytes. Offset from the start of the
592+
Part, inclusive, starting at zero.
593+
end_index (int):
594+
Output only. End index in the given Part,
595+
measured in bytes. Offset from the start of the
596+
Part, exclusive, starting at zero.
597+
"""
598+
599+
part_index: int = proto.Field(
600+
proto.INT32,
601+
number=1,
602+
)
603+
start_index: int = proto.Field(
604+
proto.INT32,
605+
number=2,
606+
)
607+
end_index: int = proto.Field(
608+
proto.INT32,
609+
number=3,
610+
)
611+
612+
613+
class GroundingAttribution(proto.Message):
614+
r"""Grounding attribution.
615+
616+
.. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
617+
618+
Attributes:
619+
web (google.cloud.aiplatform_v1beta1.types.GroundingAttribution.Web):
620+
Optional. Attribution from the web.
621+
622+
This field is a member of `oneof`_ ``reference``.
623+
segment (google.cloud.aiplatform_v1beta1.types.Segment):
624+
Output only. Segment of the content this
625+
attribution belongs to.
626+
confidence_score (float):
627+
Optional. Output only. Confidence score of
628+
the attribution. Ranges from 0 to 1. 1 is the
629+
most confident.
630+
631+
This field is a member of `oneof`_ ``_confidence_score``.
632+
"""
633+
634+
class Web(proto.Message):
635+
r"""Attribution from the web.
636+
637+
Attributes:
638+
uri (str):
639+
Output only. URI reference of the
640+
attribution.
641+
title (str):
642+
Output only. Title of the attribution.
643+
"""
644+
645+
uri: str = proto.Field(
646+
proto.STRING,
647+
number=1,
648+
)
649+
title: str = proto.Field(
650+
proto.STRING,
651+
number=2,
652+
)
653+
654+
web: Web = proto.Field(
655+
proto.MESSAGE,
656+
number=3,
657+
oneof="reference",
658+
message=Web,
659+
)
660+
segment: "Segment" = proto.Field(
661+
proto.MESSAGE,
662+
number=1,
663+
message="Segment",
664+
)
665+
confidence_score: float = proto.Field(
666+
proto.FLOAT,
667+
number=2,
668+
optional=True,
669+
)
670+
671+
672+
class GroundingMetadata(proto.Message):
673+
r"""Metadata returned to client when grounding is enabled.
674+
675+
Attributes:
676+
web_search_queries (MutableSequence[str]):
677+
Optional. Web search queries for the
678+
following-up web search.
679+
grounding_attributions (MutableSequence[google.cloud.aiplatform_v1beta1.types.GroundingAttribution]):
680+
Optional. List of grounding attributions.
681+
"""
682+
683+
web_search_queries: MutableSequence[str] = proto.RepeatedField(
684+
proto.STRING,
685+
number=1,
686+
)
687+
grounding_attributions: MutableSequence[
688+
"GroundingAttribution"
689+
] = proto.RepeatedField(
690+
proto.MESSAGE,
691+
number=2,
692+
message="GroundingAttribution",
693+
)
569694

570695

571696
__all__ = tuple(sorted(__protobuf__.manifest))

google/cloud/aiplatform_v1beta1/types/prediction_service.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,7 @@ class GenerateContentRequest(proto.Message):
784784
785785
A ``Tool`` is a piece of code that enables the system to
786786
interact with external systems to perform an action, or set
787-
of actions, outside of knowledge and scope of the model. The
788-
only supported tool is currently ``Function``
787+
of actions, outside of knowledge and scope of the model.
789788
safety_settings (MutableSequence[google.cloud.aiplatform_v1beta1.types.SafetySetting]):
790789
Optional. Per request settings for blocking
791790
unsafe content. Enforced on

google/cloud/aiplatform_v1beta1/types/tool.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
"FunctionDeclaration",
3131
"FunctionCall",
3232
"FunctionResponse",
33+
"Retrieval",
34+
"VertexAISearch",
35+
"GoogleSearchRetrieval",
3336
},
3437
)
3538

@@ -39,7 +42,8 @@ class Tool(proto.Message):
3942
4043
A ``Tool`` is a piece of code that enables the system to interact
4144
with external systems to perform an action, or set of actions,
42-
outside of knowledge and scope of the model.
45+
outside of knowledge and scope of the model. A Tool object should
46+
contain exactly one type of Tool.
4347
4448
Attributes:
4549
function_declarations (MutableSequence[google.cloud.aiplatform_v1beta1.types.FunctionDeclaration]):
@@ -52,13 +56,32 @@ class Tool(proto.Message):
5256
function call in the next turn. Based on the function
5357
responses, Model will generate the final response back to
5458
the user. Maximum 64 function declarations can be provided.
59+
retrieval (google.cloud.aiplatform_v1beta1.types.Retrieval):
60+
Optional. System will always execute the
61+
provided retrieval tool(s) to get external
62+
knowledge to answer the prompt. Retrieval
63+
results are presented to the model for
64+
generation.
65+
google_search_retrieval (google.cloud.aiplatform_v1beta1.types.GoogleSearchRetrieval):
66+
Optional. Specialized retrieval tool that is
67+
powered by Google search.
5568
"""
5669

5770
function_declarations: MutableSequence["FunctionDeclaration"] = proto.RepeatedField(
5871
proto.MESSAGE,
5972
number=1,
6073
message="FunctionDeclaration",
6174
)
75+
retrieval: "Retrieval" = proto.Field(
76+
proto.MESSAGE,
77+
number=2,
78+
message="Retrieval",
79+
)
80+
google_search_retrieval: "GoogleSearchRetrieval" = proto.Field(
81+
proto.MESSAGE,
82+
number=3,
83+
message="GoogleSearchRetrieval",
84+
)
6285

6386

6487
class FunctionDeclaration(proto.Message):
@@ -169,4 +192,71 @@ class FunctionResponse(proto.Message):
169192
)
170193

171194

195+
class Retrieval(proto.Message):
196+
r"""Defines a retrieval tool that model can call to access
197+
external knowledge.
198+
199+
200+
.. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
201+
202+
Attributes:
203+
vertex_ai_search (google.cloud.aiplatform_v1beta1.types.VertexAISearch):
204+
Set to use data source powered by Vertex AI
205+
Search.
206+
207+
This field is a member of `oneof`_ ``source``.
208+
disable_attribution (bool):
209+
Optional. Disable using the result from this
210+
tool in detecting grounding attribution. This
211+
does not affect how the result is given to the
212+
model for generation.
213+
"""
214+
215+
vertex_ai_search: "VertexAISearch" = proto.Field(
216+
proto.MESSAGE,
217+
number=2,
218+
oneof="source",
219+
message="VertexAISearch",
220+
)
221+
disable_attribution: bool = proto.Field(
222+
proto.BOOL,
223+
number=3,
224+
)
225+
226+
227+
class VertexAISearch(proto.Message):
228+
r"""Retrieve from Vertex AI Search datastore for grounding.
229+
See https://cloud.google.com/vertex-ai-search-and-conversation
230+
231+
Attributes:
232+
datastore (str):
233+
Required. Fully-qualified Vertex AI Search's
234+
datastore resource ID.
235+
projects/<>/locations/<>/collections/<>/dataStores/<>
236+
"""
237+
238+
datastore: str = proto.Field(
239+
proto.STRING,
240+
number=1,
241+
)
242+
243+
244+
class GoogleSearchRetrieval(proto.Message):
245+
r"""Tool to retrieve public web data for grounding, powered by
246+
Google.
247+
248+
Attributes:
249+
disable_attribution (bool):
250+
Optional. Disable using the result from this
251+
tool in detecting grounding attribution. This
252+
does not affect how the result is given to the
253+
model for generation.
254+
"""
255+
256+
disable_attribution: bool = proto.Field(
257+
proto.BOOL,
258+
number=1,
259+
)
260+
261+
172262
__all__ = tuple(sorted(__protobuf__.manifest))

0 commit comments

Comments
 (0)