@@ -59,6 +59,13 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str:
59
59
return f"publishers/google/models/{ model_name } @{ version } "
60
60
61
61
62
+ @dataclasses .dataclass
63
+ class _PredictionRequest :
64
+ """A single-instance prediction request."""
65
+ instance : Dict [str , Any ]
66
+ parameters : Optional [Dict [str , Any ]] = None
67
+
68
+
62
69
class _LanguageModel (_model_garden_models ._ModelGardenModel ):
63
70
"""_LanguageModel is a base class for all language models."""
64
71
@@ -1250,15 +1257,15 @@ class CodeGenerationModel(_LanguageModel):
1250
1257
_LAUNCH_STAGE = _model_garden_models ._SDK_GA_LAUNCH_STAGE
1251
1258
_DEFAULT_MAX_OUTPUT_TOKENS = 128
1252
1259
1253
- def predict (
1260
+ def _create_prediction_request (
1254
1261
self ,
1255
1262
prefix : str ,
1256
1263
suffix : Optional [str ] = None ,
1257
1264
* ,
1258
1265
max_output_tokens : Optional [int ] = _DEFAULT_MAX_OUTPUT_TOKENS ,
1259
1266
temperature : Optional [float ] = None ,
1260
- ) -> "TextGenerationResponse" :
1261
- """Gets model response for a single prompt .
1267
+ ) -> _PredictionRequest :
1268
+ """Creates a code generation prediction request .
1262
1269
1263
1270
Args:
1264
1271
prefix: Code before the current point.
@@ -1281,16 +1288,89 @@ def predict(
1281
1288
if max_output_tokens :
1282
1289
prediction_parameters ["maxOutputTokens" ] = max_output_tokens
1283
1290
1291
+ return _PredictionRequest (instance = instance , parameters = prediction_parameters )
1292
+
1293
+ def predict (
1294
+ self ,
1295
+ prefix : str ,
1296
+ suffix : Optional [str ] = None ,
1297
+ * ,
1298
+ max_output_tokens : Optional [int ] = _DEFAULT_MAX_OUTPUT_TOKENS ,
1299
+ temperature : Optional [float ] = None ,
1300
+ ) -> "TextGenerationResponse" :
1301
+ """Gets model response for a single prompt.
1302
+
1303
+ Args:
1304
+ prefix: Code before the current point.
1305
+ suffix: Code after the current point.
1306
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
1307
+ temperature: Controls the randomness of predictions. Range: [0, 1].
1308
+
1309
+ Returns:
1310
+ A `TextGenerationResponse` object that contains the text produced by the model.
1311
+ """
1312
+ prediction_request = self ._create_prediction_request (
1313
+ prefix = prefix ,
1314
+ suffix = suffix ,
1315
+ max_output_tokens = max_output_tokens ,
1316
+ temperature = temperature ,
1317
+ )
1318
+
1284
1319
prediction_response = self ._endpoint .predict (
1285
- instances = [instance ],
1286
- parameters = prediction_parameters ,
1320
+ instances = [prediction_request . instance ],
1321
+ parameters = prediction_request . parameters ,
1287
1322
)
1288
1323
1289
1324
return TextGenerationResponse (
1290
1325
text = prediction_response .predictions [0 ]["content" ],
1291
1326
_prediction_response = prediction_response ,
1292
1327
)
1293
1328
1329
+ def predict_streaming (
1330
+ self ,
1331
+ prefix : str ,
1332
+ suffix : Optional [str ] = None ,
1333
+ * ,
1334
+ max_output_tokens : Optional [int ] = _DEFAULT_MAX_OUTPUT_TOKENS ,
1335
+ temperature : Optional [float ] = None ,
1336
+ ) -> Iterator [TextGenerationResponse ]:
1337
+ """Predicts the code based on previous code.
1338
+
1339
+ The result is a stream (generator) of partial responses.
1340
+
1341
+ Args:
1342
+ prefix: Code before the current point.
1343
+ suffix: Code after the current point.
1344
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
1345
+ temperature: Controls the randomness of predictions. Range: [0, 1].
1346
+
1347
+ Yields:
1348
+ A stream of `TextGenerationResponse` objects that contain partial
1349
+ responses produced by the model.
1350
+ """
1351
+ prediction_request = self ._create_prediction_request (
1352
+ prefix = prefix ,
1353
+ suffix = suffix ,
1354
+ max_output_tokens = max_output_tokens ,
1355
+ temperature = temperature ,
1356
+ )
1357
+
1358
+ prediction_service_client = self ._endpoint ._prediction_client
1359
+ for prediction_dict in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
1360
+ prediction_service_client = prediction_service_client ,
1361
+ endpoint_name = self ._endpoint_name ,
1362
+ instance = prediction_request .instance ,
1363
+ parameters = prediction_request .parameters ,
1364
+ ):
1365
+ prediction_obj = aiplatform .models .Prediction (
1366
+ predictions = [prediction_dict ],
1367
+ deployed_model_id = "" ,
1368
+ )
1369
+ yield TextGenerationResponse (
1370
+ text = prediction_dict ["content" ],
1371
+ _prediction_response = prediction_obj ,
1372
+ )
1373
+
1294
1374
1295
1375
class _PreviewCodeGenerationModel (CodeGenerationModel , _TunableModelMixin ):
1296
1376
_LAUNCH_STAGE = _model_garden_models ._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
0 commit comments