@@ -789,6 +789,26 @@ def predict_streaming(
789
789
)
790
790
791
791
792
+ def _parse_text_generation_model_response (
793
+ prediction_response : aiplatform .models .Prediction ,
794
+ prediction_idx : int = 0 ,
795
+ ) -> TextGenerationResponse :
796
+ """Converts the raw text_generation model response to `TextGenerationResponse`."""
797
+ prediction = prediction_response .predictions [prediction_idx ]
798
+ safety_attributes_dict = prediction .get ("safetyAttributes" , {})
799
+ return TextGenerationResponse (
800
+ text = prediction ["content" ],
801
+ _prediction_response = prediction_response ,
802
+ is_blocked = safety_attributes_dict .get ("blocked" , False ),
803
+ safety_attributes = dict (
804
+ zip (
805
+ safety_attributes_dict .get ("categories" ) or [],
806
+ safety_attributes_dict .get ("scores" ) or [],
807
+ )
808
+ ),
809
+ )
810
+
811
+
792
812
class _ModelWithBatchPredict (_LanguageModel ):
793
813
"""Model that supports batch prediction."""
794
814
@@ -1754,11 +1774,7 @@ def predict(
1754
1774
instances = [prediction_request .instance ],
1755
1775
parameters = prediction_request .parameters ,
1756
1776
)
1757
-
1758
- return TextGenerationResponse (
1759
- text = prediction_response .predictions [0 ]["content" ],
1760
- _prediction_response = prediction_response ,
1761
- )
1777
+ return _parse_text_generation_model_response (prediction_response )
1762
1778
1763
1779
def predict_streaming (
1764
1780
self ,
@@ -1800,10 +1816,7 @@ def predict_streaming(
1800
1816
predictions = [prediction_dict ],
1801
1817
deployed_model_id = "" ,
1802
1818
)
1803
- yield TextGenerationResponse (
1804
- text = prediction_dict ["content" ],
1805
- _prediction_response = prediction_obj ,
1806
- )
1819
+ yield _parse_text_generation_model_response (prediction_obj )
1807
1820
1808
1821
1809
1822
class _PreviewCodeGenerationModel (CodeGenerationModel , _TunableModelMixin ):
0 commit comments