Skip to content

Commit b5e2c02

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI - Added the response_schema parameter to the GenerationConfig class
PiperOrigin-RevId: 637930285
1 parent ac17d87 commit b5e2c02

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

tests/system/vertexai/test_generative_models.py

+11
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ def get_current_weather(location: str, unit: str = "centigrade"):
7070
"required": ["location"],
7171
}
7272

73+
_RESPONSE_SCHEMA_STRUCT = {
74+
"type": "object",
75+
"properties": {
76+
"location": {
77+
"type": "string",
78+
},
79+
},
80+
"required": ["location"],
81+
}
82+
7383

7484
class TestGenerativeModels(e2e_base.TestEndToEnd):
7585
"""System tests for generative models."""
@@ -174,6 +184,7 @@ def test_generate_content_with_gemini_15_parameters(self):
174184
presence_penalty=0.0,
175185
frequency_penalty=0.0,
176186
response_mime_type="application/json",
187+
response_schema=_RESPONSE_SCHEMA_STRUCT,
177188
),
178189
safety_settings={
179190
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,

vertexai/generative_models/_generative_models.py

+9
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,7 @@ def __init__(
11941194
presence_penalty: Optional[float] = None,
11951195
frequency_penalty: Optional[float] = None,
11961196
response_mime_type: Optional[str] = None,
1197+
response_schema: Optional[Dict[str, Any]] = None,
11971198
):
11981199
r"""Constructs a GenerationConfig object.
11991200
@@ -1216,6 +1217,8 @@ def __init__(
12161217
12171218
The model needs to be prompted to output the appropriate
12181219
response type, otherwise the behavior is undefined.
1220+
response_schema: Output response schema of the genreated candidate text. Only valid when
1221+
response_mime_type is application/json.
12191222
12201223
Usage:
12211224
```
@@ -1232,6 +1235,11 @@ def __init__(
12321235
)
12331236
```
12341237
"""
1238+
if response_schema is None:
1239+
raw_schema = None
1240+
else:
1241+
gapic_schema_dict = _convert_schema_dict_to_gapic(response_schema)
1242+
raw_schema = aiplatform_types.Schema(gapic_schema_dict)
12351243
self._raw_generation_config = gapic_content_types.GenerationConfig(
12361244
temperature=temperature,
12371245
top_p=top_p,
@@ -1242,6 +1250,7 @@ def __init__(
12421250
presence_penalty=presence_penalty,
12431251
frequency_penalty=frequency_penalty,
12441252
response_mime_type=response_mime_type,
1253+
response_schema=raw_schema,
12451254
)
12461255

12471256
@classmethod

0 commit comments

Comments
 (0)