|
15 | 15 | """Classes for working with language models."""
|
16 | 16 |
|
17 | 17 | import dataclasses
|
18 |
| -from typing import Any, List, Optional, Sequence, Union |
| 18 | +from typing import Any, Dict, List, Optional, Sequence, Union |
19 | 19 |
|
20 | 20 | from google.cloud import aiplatform
|
21 | 21 | from google.cloud.aiplatform import base
|
@@ -198,10 +198,19 @@ def tune_model(
|
198 | 198 |
|
199 | 199 | @dataclasses.dataclass
|
200 | 200 | class TextGenerationResponse:
|
201 |
| - """TextGenerationResponse represents a response of a language model.""" |
| 201 | + """TextGenerationResponse represents a response of a language model. |
| 202 | + Attributes: |
| 203 | + text: The generated text |
| 204 | + is_blocked: Whether the the request was blocked. |
| 205 | + safety_attributes: Scores for safety attributes. |
| 206 | + Learn more about the safety attributes here: |
| 207 | + https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions |
| 208 | + """ |
202 | 209 |
|
203 | 210 | text: str
|
204 | 211 | _prediction_response: Any
|
| 212 | + is_blocked: bool = False |
| 213 | + safety_attributes: Dict[str, float] = dataclasses.field(default_factory=dict) |
205 | 214 |
|
206 | 215 | def __repr__(self):
|
207 | 216 | return self.text
|
@@ -289,13 +298,23 @@ def _batch_predict(
|
289 | 298 | parameters=prediction_parameters,
|
290 | 299 | )
|
291 | 300 |
|
292 |
| - return [ |
293 |
| - TextGenerationResponse( |
294 |
| - text=prediction["content"], |
295 |
| - _prediction_response=prediction_response, |
| 301 | + results = [] |
| 302 | + for prediction in prediction_response.predictions: |
| 303 | + safety_attributes_dict = prediction.get("safetyAttributes", {}) |
| 304 | + results.append( |
| 305 | + TextGenerationResponse( |
| 306 | + text=prediction["content"], |
| 307 | + _prediction_response=prediction_response, |
| 308 | + is_blocked=safety_attributes_dict.get("blocked", False), |
| 309 | + safety_attributes=dict( |
| 310 | + zip( |
| 311 | + safety_attributes_dict.get("categories", []), |
| 312 | + safety_attributes_dict.get("scores", []), |
| 313 | + ) |
| 314 | + ), |
| 315 | + ) |
296 | 316 | )
|
297 |
| - for prediction in prediction_response.predictions |
298 |
| - ] |
| 317 | + return results |
299 | 318 |
|
300 | 319 |
|
301 | 320 | _TextGenerationModel = TextGenerationModel
|
@@ -690,9 +709,20 @@ def send_message(
|
690 | 709 | parameters=prediction_parameters,
|
691 | 710 | )
|
692 | 711 |
|
| 712 | + prediction = prediction_response.predictions[0] |
| 713 | + safety_attributes = prediction["safetyAttributes"] |
693 | 714 | response_obj = TextGenerationResponse(
|
694 |
| - text=prediction_response.predictions[0]["candidates"][0]["content"], |
| 715 | + text=prediction["candidates"][0]["content"] |
| 716 | + if prediction.get("candidates") |
| 717 | + else None, |
695 | 718 | _prediction_response=prediction_response,
|
| 719 | + is_blocked=safety_attributes.get("blocked", False), |
| 720 | + safety_attributes=dict( |
| 721 | + zip( |
| 722 | + safety_attributes.get("categories", []), |
| 723 | + safety_attributes.get("scores", []), |
| 724 | + ) |
| 725 | + ), |
696 | 726 | )
|
697 | 727 | response_text = response_obj.text
|
698 | 728 |
|
|
0 commit comments