Skip to content

Commit 459ba86

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: LLM - Fixed the chat models failing due to safetyAttributes format
PiperOrigin-RevId: 544512632
1 parent 970970e commit 459ba86

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

tests/unit/aiplatform/test_language_models.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,13 @@
164164
}
165165

166166
_TEST_CHAT_GENERATION_PREDICTION1 = {
167-
"safetyAttributes": {
168-
"scores": [],
169-
"blocked": False,
170-
"categories": [],
171-
},
167+
"safetyAttributes": [
168+
{
169+
"scores": [],
170+
"blocked": False,
171+
"categories": [],
172+
}
173+
],
172174
"candidates": [
173175
{
174176
"author": "1",
@@ -177,11 +179,13 @@
177179
],
178180
}
179181
_TEST_CHAT_GENERATION_PREDICTION2 = {
180-
"safetyAttributes": {
181-
"scores": [],
182-
"blocked": False,
183-
"categories": [],
184-
},
182+
"safetyAttributes": [
183+
{
184+
"scores": [],
185+
"blocked": False,
186+
"categories": [],
187+
}
188+
],
185189
"candidates": [
186190
{
187191
"author": "1",

vertexai/language_models/_language_models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,8 @@ def send_message(
799799
)
800800

801801
prediction = prediction_response.predictions[0]
802-
safety_attributes = prediction["safetyAttributes"]
802+
# ! Note: For chat models, the safetyAttributes is a list.
803+
safety_attributes = prediction["safetyAttributes"][0]
803804
response_obj = TextGenerationResponse(
804805
text=prediction["candidates"][0]["content"]
805806
if prediction.get("candidates")

0 commit comments

Comments
 (0)