Skip to content

Commit 317ab8f

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Added support for SafetySetting.method (probability or severity)
PiperOrigin-RevId: 621331507
1 parent 6770625 commit 317ab8f

File tree

4 files changed

+84
-9
lines changed

4 files changed

+84
-9
lines changed

tests/unit/vertexai/test_generative_models.py

+12
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,18 @@ def test_generate_content(self, generative_models: generative_models):
307307
max_output_tokens=200,
308308
stop_sequences=["\n\n\n"],
309309
),
310+
safety_settings=[
311+
generative_models.SafetySetting(
312+
category=generative_models.SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
313+
threshold=generative_models.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
314+
method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,
315+
),
316+
generative_models.SafetySetting(
317+
category=generative_models.SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
318+
threshold=generative_models.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
319+
method=generative_models.SafetySetting.HarmBlockMethod.PROBABILITY,
320+
),
321+
],
310322
)
311323
assert response2.text
312324

vertexai/generative_models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Image,
3131
Part,
3232
ResponseValidationError,
33+
SafetySetting,
3334
Tool,
3435
)
3536

@@ -47,5 +48,6 @@
4748
"Image",
4849
"Part",
4950
"ResponseValidationError",
51+
"SafetySetting",
5052
"Tool",
5153
]

vertexai/generative_models/_generative_models.py

+68-9
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
]
8787

8888
SafetySettingsType = Union[
89-
List[gapic_content_types.SafetySetting],
89+
List["SafetySetting"],
9090
Dict[
9191
gapic_content_types.HarmCategory,
9292
gapic_content_types.SafetySetting.HarmBlockThreshold,
@@ -258,17 +258,20 @@ def _prepare_request(
258258
raise TypeError(
259259
"generation_config must either be a GenerationConfig object or a dictionary representation of it."
260260
)
261+
261262
gapic_safety_settings = None
262263
if safety_settings:
263264
if isinstance(safety_settings, Sequence):
264-
if not all(
265-
isinstance(safety_setting, gapic_content_types.SafetySetting)
266-
for safety_setting in safety_settings
267-
):
268-
raise TypeError(
269-
"When passing a list with SafetySettings objects, every item in a list must be a SafetySetting object."
270-
)
271-
gapic_safety_settings = safety_settings
265+
gapic_safety_settings = []
266+
for safety_setting in safety_settings:
267+
if isinstance(safety_setting, gapic_content_types.SafetySetting):
268+
gapic_safety_settings.append(safety_setting)
269+
elif isinstance(safety_setting, SafetySetting):
270+
gapic_safety_settings.append(safety_setting._raw_safety_setting)
271+
else:
272+
raise TypeError(
273+
"When passing a list with SafetySettings objects, every item in a list must be a SafetySetting object."
274+
)
272275
elif isinstance(safety_settings, dict):
273276
gapic_safety_settings = [
274277
gapic_content_types.SafetySetting(
@@ -283,6 +286,7 @@ def _prepare_request(
283286
raise TypeError(
284287
"safety_settings must either be a list of SafetySettings objects or a dictionary mapping from HarmCategory to HarmBlockThreshold."
285288
)
289+
286290
gapic_tools = None
287291
if tools:
288292
gapic_tools = []
@@ -1738,6 +1742,61 @@ def _image(self) -> "Image":
17381742
return Image.from_bytes(data=self._raw_part.inline_data.data)
17391743

17401744

1745+
class SafetySetting:
1746+
"""Parameters for the generation."""
1747+
1748+
HarmCategory = gapic_content_types.HarmCategory
1749+
HarmBlockMethod = gapic_content_types.SafetySetting.HarmBlockMethod
1750+
HarmBlockThreshold = gapic_content_types.SafetySetting.HarmBlockThreshold
1751+
1752+
def __init__(
1753+
self,
1754+
*,
1755+
category: "SafetySetting.HarmCategory",
1756+
threshold: "SafetySetting.HarmBlockThreshold",
1757+
method: Optional["SafetySetting.HarmBlockMethod"] = None,
1758+
):
1759+
r"""Safety settings.
1760+
1761+
Args:
1762+
category: Harm category.
1763+
threshold: The harm block threshold.
1764+
method: Specify if the threshold is used for probability or severity
1765+
score. If not specified, the threshold is used for probability
1766+
score.
1767+
"""
1768+
self._raw_safety_setting = gapic_content_types.SafetySetting(
1769+
category=category,
1770+
threshold=threshold,
1771+
method=method,
1772+
)
1773+
1774+
@classmethod
1775+
def _from_gapic(
1776+
cls,
1777+
raw_safety_setting: gapic_content_types.SafetySetting,
1778+
) -> "SafetySetting":
1779+
response = cls(
1780+
category=raw_safety_setting.category,
1781+
threshold=raw_safety_setting.threshold,
1782+
)
1783+
response._raw_safety_setting = raw_safety_setting
1784+
return response
1785+
1786+
@classmethod
1787+
def from_dict(cls, safety_setting_dict: Dict[str, Any]) -> "SafetySetting":
1788+
raw_safety_setting = gapic_content_types.SafetySetting(
1789+
safety_setting_dict
1790+
)
1791+
return cls._from_gapic(raw_safety_setting=raw_safety_setting)
1792+
1793+
def to_dict(self) -> Dict[str, Any]:
1794+
return type(self._raw_safety_setting).to_dict(self._raw_safety_setting)
1795+
1796+
def __repr__(self):
1797+
return self._raw_safety_setting.__repr__()
1798+
1799+
17411800
class grounding: # pylint: disable=invalid-name
17421801
"""Grounding namespace."""
17431802

vertexai/preview/generative_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Part,
3535
ResponseBlockedError,
3636
ResponseValidationError,
37+
SafetySetting,
3738
Tool,
3839
)
3940

@@ -64,6 +65,7 @@ class ChatSession(_PreviewChatSession):
6465
"Part",
6566
"ResponseBlockedError",
6667
"ResponseValidationError",
68+
"SafetySetting",
6769
"Tool",
6870
#
6971
]

0 commit comments

Comments
 (0)