Skip to content

Commit 01ba3ca

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Exposed the safety attributes
PiperOrigin-RevId: 541035595
1 parent 21e48fe commit 01ba3ca

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

tests/unit/aiplatform/test_language_models.py

+4
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,10 @@ def test_text_generation(self):
570570
)
571571

572572
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
573+
assert (
574+
response.safety_attributes["Violent"]
575+
== _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0]
576+
)
573577

574578
def test_text_generation_ga(self):
575579
"""Tests the text generation model."""

vertexai/language_models/_language_models.py

+39-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Classes for working with language models."""
1616

1717
import dataclasses
18-
from typing import Any, List, Optional, Sequence, Union
18+
from typing import Any, Dict, List, Optional, Sequence, Union
1919

2020
from google.cloud import aiplatform
2121
from google.cloud.aiplatform import base
@@ -198,10 +198,19 @@ def tune_model(
198198

199199
@dataclasses.dataclass
200200
class TextGenerationResponse:
201-
"""TextGenerationResponse represents a response of a language model."""
201+
"""TextGenerationResponse represents a response of a language model.
202+
Attributes:
203+
text: The generated text
204+
is_blocked: Whether the the request was blocked.
205+
safety_attributes: Scores for safety attributes.
206+
Learn more about the safety attributes here:
207+
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
208+
"""
202209

203210
text: str
204211
_prediction_response: Any
212+
is_blocked: bool = False
213+
safety_attributes: Dict[str, float] = dataclasses.field(default_factory=dict)
205214

206215
def __repr__(self):
207216
return self.text
@@ -289,13 +298,23 @@ def _batch_predict(
289298
parameters=prediction_parameters,
290299
)
291300

292-
return [
293-
TextGenerationResponse(
294-
text=prediction["content"],
295-
_prediction_response=prediction_response,
301+
results = []
302+
for prediction in prediction_response.predictions:
303+
safety_attributes_dict = prediction.get("safetyAttributes", {})
304+
results.append(
305+
TextGenerationResponse(
306+
text=prediction["content"],
307+
_prediction_response=prediction_response,
308+
is_blocked=safety_attributes_dict.get("blocked", False),
309+
safety_attributes=dict(
310+
zip(
311+
safety_attributes_dict.get("categories", []),
312+
safety_attributes_dict.get("scores", []),
313+
)
314+
),
315+
)
296316
)
297-
for prediction in prediction_response.predictions
298-
]
317+
return results
299318

300319

301320
_TextGenerationModel = TextGenerationModel
@@ -690,9 +709,20 @@ def send_message(
690709
parameters=prediction_parameters,
691710
)
692711

712+
prediction = prediction_response.predictions[0]
713+
safety_attributes = prediction["safetyAttributes"]
693714
response_obj = TextGenerationResponse(
694-
text=prediction_response.predictions[0]["candidates"][0]["content"],
715+
text=prediction["candidates"][0]["content"]
716+
if prediction.get("candidates")
717+
else None,
695718
_prediction_response=prediction_response,
719+
is_blocked=safety_attributes.get("blocked", False),
720+
safety_attributes=dict(
721+
zip(
722+
safety_attributes.get("categories", []),
723+
safety_attributes.get("scores", []),
724+
)
725+
),
696726
)
697727
response_text = response_obj.text
698728

0 commit comments

Comments
 (0)