File tree 2 files changed +16
-11
lines changed
2 files changed +16
-11
lines changed Original file line number Diff line number Diff line change 164
164
}
165
165
166
166
_TEST_CHAT_GENERATION_PREDICTION1 = {
167
- "safetyAttributes" : {
168
- "scores" : [],
169
- "blocked" : False ,
170
- "categories" : [],
171
- },
167
+ "safetyAttributes" : [
168
+ {
169
+ "scores" : [],
170
+ "blocked" : False ,
171
+ "categories" : [],
172
+ }
173
+ ],
172
174
"candidates" : [
173
175
{
174
176
"author" : "1" ,
177
179
],
178
180
}
179
181
_TEST_CHAT_GENERATION_PREDICTION2 = {
180
- "safetyAttributes" : {
181
- "scores" : [],
182
- "blocked" : False ,
183
- "categories" : [],
184
- },
182
+ "safetyAttributes" : [
183
+ {
184
+ "scores" : [],
185
+ "blocked" : False ,
186
+ "categories" : [],
187
+ }
188
+ ],
185
189
"candidates" : [
186
190
{
187
191
"author" : "1" ,
Original file line number Diff line number Diff line change @@ -799,7 +799,8 @@ def send_message(
799
799
)
800
800
801
801
prediction = prediction_response .predictions [0 ]
802
- safety_attributes = prediction ["safetyAttributes" ]
802
+ # ! Note: For chat models, the safetyAttributes is a list.
803
+ safety_attributes = prediction ["safetyAttributes" ][0 ]
803
804
response_obj = TextGenerationResponse (
804
805
text = prediction ["candidates" ][0 ]["content" ]
805
806
if prediction .get ("candidates" )
You can’t perform that action at this time.
0 commit comments