Skip to content

Commit c2c8a5e

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: LLM - CodeGenerationModel now supports safety attributes
PiperOrigin-RevId: 560967317
1 parent 2a08535 commit c2c8a5e

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

tests/unit/aiplatform/test_language_models.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@
257257

258258
_TEST_CODE_GENERATION_PREDICTION = {
259259
"safetyAttributes": {
260-
"categories": [],
261-
"blocked": False,
262-
"scores": [],
260+
"blocked": True,
261+
"categories": ["Finance"],
262+
"scores": [0.1],
263263
},
264264
"content": """
265265
```python
@@ -2188,6 +2188,17 @@ def test_code_generation(self):
21882188
temperature=0.2,
21892189
)
21902190
assert response.text == _TEST_CODE_GENERATION_PREDICTION["content"]
2191+
expected_safety_attributes_raw = _TEST_CODE_GENERATION_PREDICTION[
2192+
"safetyAttributes"
2193+
]
2194+
expected_safety_attributes = dict(
2195+
zip(
2196+
expected_safety_attributes_raw["categories"],
2197+
expected_safety_attributes_raw["scores"],
2198+
)
2199+
)
2200+
assert response.safety_attributes == expected_safety_attributes
2201+
assert response.is_blocked == expected_safety_attributes_raw["blocked"]
21912202

21922203
# Validating the parameters
21932204
predict_temperature = 0.1

vertexai/language_models/_language_models.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,26 @@ def predict_streaming(
789789
)
790790

791791

792+
def _parse_text_generation_model_response(
793+
prediction_response: aiplatform.models.Prediction,
794+
prediction_idx: int = 0,
795+
) -> TextGenerationResponse:
796+
"""Converts the raw text_generation model response to `TextGenerationResponse`."""
797+
prediction = prediction_response.predictions[prediction_idx]
798+
safety_attributes_dict = prediction.get("safetyAttributes", {})
799+
return TextGenerationResponse(
800+
text=prediction["content"],
801+
_prediction_response=prediction_response,
802+
is_blocked=safety_attributes_dict.get("blocked", False),
803+
safety_attributes=dict(
804+
zip(
805+
safety_attributes_dict.get("categories") or [],
806+
safety_attributes_dict.get("scores") or [],
807+
)
808+
),
809+
)
810+
811+
792812
class _ModelWithBatchPredict(_LanguageModel):
793813
"""Model that supports batch prediction."""
794814

@@ -1754,11 +1774,7 @@ def predict(
17541774
instances=[prediction_request.instance],
17551775
parameters=prediction_request.parameters,
17561776
)
1757-
1758-
return TextGenerationResponse(
1759-
text=prediction_response.predictions[0]["content"],
1760-
_prediction_response=prediction_response,
1761-
)
1777+
return _parse_text_generation_model_response(prediction_response)
17621778

17631779
def predict_streaming(
17641780
self,
@@ -1800,10 +1816,7 @@ def predict_streaming(
18001816
predictions=[prediction_dict],
18011817
deployed_model_id="",
18021818
)
1803-
yield TextGenerationResponse(
1804-
text=prediction_dict["content"],
1805-
_prediction_response=prediction_obj,
1806-
)
1819+
yield _parse_text_generation_model_response(prediction_obj)
18071820

18081821

18091822
class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):

0 commit comments

Comments
 (0)