Skip to content

Commit 0599ca1

Browse files
holtskinnercopybara-github
authored andcommitted
feat: Add additional parameters for GenerationConfig
- `presence_penalty` - `frequency_penalty` - `response_mime_type` PiperOrigin-RevId: 627067043
1 parent c0e7acc commit 0599ca1

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

tests/system/vertexai/test_generative_models.py

+27
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint: disable=protected-access, g-multiple-import
1818
"""System tests for generative models."""
1919

20+
import json
2021
import pytest
2122

2223
# Google imports
@@ -30,6 +31,7 @@
3031

3132
GEMINI_MODEL_NAME = "gemini-1.0-pro-002"
3233
GEMINI_VISION_MODEL_NAME = "gemini-1.0-pro-vision"
34+
GEMINI_15_MODEL_NAME = "gemini-1.5-pro-preview-0409"
3335

3436

3537
# A dummy function for function calling
@@ -150,6 +152,31 @@ def test_generate_content_with_parameters(self):
150152
)
151153
assert response.text
152154

155+
def test_generate_content_with_gemini_15_parameters(self):
156+
model = generative_models.GenerativeModel(GEMINI_15_MODEL_NAME)
157+
response = model.generate_content(
158+
contents="Why is sky blue? Respond in JSON Format.",
159+
generation_config=generative_models.GenerationConfig(
160+
temperature=0,
161+
top_p=0.95,
162+
top_k=20,
163+
candidate_count=1,
164+
max_output_tokens=100,
165+
stop_sequences=["STOP!"],
166+
presence_penalty=0.0,
167+
frequency_penalty=0.0,
168+
response_mime_type="application/json",
169+
),
170+
safety_settings={
171+
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
172+
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
173+
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
174+
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_NONE,
175+
},
176+
)
177+
assert response.text
178+
assert json.loads(response.text)
179+
153180
def test_generate_content_from_list_of_content_dict(self):
154181
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
155182
response = model.generate_content(

tests/unit/vertexai/test_generative_models.py

+29
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@ def test_generate_content(self, generative_models: generative_models):
404404
candidate_count=1,
405405
max_output_tokens=200,
406406
stop_sequences=["\n\n\n"],
407+
presence_penalty=0.0,
408+
frequency_penalty=0.0,
407409
),
408410
safety_settings=[
409411
generative_models.SafetySetting(
@@ -420,6 +422,33 @@ def test_generate_content(self, generative_models: generative_models):
420422
)
421423
assert response2.text
422424

425+
model3 = generative_models.GenerativeModel("gemini-1.5-pro-preview-0409")
426+
response3 = model3.generate_content(
427+
"Why is sky blue? Respond in JSON.",
428+
generation_config=generative_models.GenerationConfig(
429+
temperature=0.2,
430+
top_p=0.9,
431+
top_k=20,
432+
candidate_count=1,
433+
max_output_tokens=200,
434+
stop_sequences=["\n\n\n"],
435+
response_mime_type="application/json",
436+
),
437+
safety_settings=[
438+
generative_models.SafetySetting(
439+
category=generative_models.SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
440+
threshold=generative_models.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
441+
method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,
442+
),
443+
generative_models.SafetySetting(
444+
category=generative_models.SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
445+
threshold=generative_models.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
446+
method=generative_models.SafetySetting.HarmBlockMethod.PROBABILITY,
447+
),
448+
],
449+
)
450+
assert response3.text
451+
423452
@mock.patch.object(
424453
target=prediction_service.PredictionServiceClient,
425454
attribute="stream_generate_content",

vertexai/generative_models/_generative_models.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,9 @@ def __init__(
11891189
candidate_count: Optional[int] = None,
11901190
max_output_tokens: Optional[int] = None,
11911191
stop_sequences: Optional[List[str]] = None,
1192+
presence_penalty: Optional[float] = None,
1193+
frequency_penalty: Optional[float] = None,
1194+
response_mime_type: Optional[str] = None,
11921195
):
11931196
r"""Constructs a GenerationConfig object.
11941197
@@ -1199,6 +1202,18 @@ def __init__(
11991202
candidate_count: Number of candidates to generate.
12001203
max_output_tokens: The maximum number of output tokens to generate per message.
12011204
stop_sequences: A list of stop sequences.
1205+
presence_penalty: Positive values penalize tokens that have appeared in the generated text,
1206+
thus increasing the possibility of generating more diversed topics. Range: [-2.0, 2.0]
1207+
frequency_penalty: Positive values penalize tokens that repeatedly appear in the generated
1208+
text, thus decreasing the possibility of repeating the same content. Range: [-2.0, 2.0]
1209+
response_mime_type: Output response mimetype of the generated
1210+
candidate text. Supported mimetypes:
1211+
1212+
- ``text/plain``: (default) Text output.
1213+
- ``application/json``: JSON response in the candidates.
1214+
1215+
The model needs to be prompted to output the appropriate
1216+
response type, otherwise the behavior is undefined.
12021217
12031218
Usage:
12041219
```
@@ -1222,6 +1237,9 @@ def __init__(
12221237
candidate_count=candidate_count,
12231238
max_output_tokens=max_output_tokens,
12241239
stop_sequences=stop_sequences,
1240+
presence_penalty=presence_penalty,
1241+
frequency_penalty=frequency_penalty,
1242+
response_mime_type=response_mime_type,
12251243
)
12261244

12271245
@classmethod
@@ -1650,7 +1668,7 @@ def prompt_feedback(
16501668

16511669
@property
16521670
def usage_metadata(
1653-
self
1671+
self,
16541672
) -> gapic_prediction_service_types.GenerateContentResponse.UsageMetadata:
16551673
return self._raw_response.usage_metadata
16561674

0 commit comments

Comments
 (0)