16
16
17
17
import dataclasses
18
18
from typing import Any , Dict , List , Optional , Sequence , Union
19
+ import warnings
19
20
20
21
from google .cloud import aiplatform
21
22
from google .cloud .aiplatform import base
@@ -332,14 +333,14 @@ class _ModelWithBatchPredict(_LanguageModel):
332
333
def batch_predict (
333
334
self ,
334
335
* ,
335
- source_uri : Union [str , List [str ]],
336
+ dataset : Union [str , List [str ]],
336
337
destination_uri_prefix : str ,
337
338
model_parameters : Optional [Dict ] = None ,
338
339
) -> aiplatform .BatchPredictionJob :
339
340
"""Starts a batch prediction job with the model.
340
341
341
342
Args:
342
- source_uri : The location of the dataset.
343
+ dataset : The location of the dataset.
343
344
`gs://` and `bq://` URIs are supported.
344
345
destination_uri_prefix: The URI prefix for the prediction.
345
346
`gs://` and `bq://` URIs are supported.
@@ -351,22 +352,22 @@ def batch_predict(
351
352
ValueError: When source or destination URI is not supported.
352
353
"""
353
354
arguments = {}
354
- first_source_uri = source_uri if isinstance (source_uri , str ) else source_uri [0 ]
355
+ first_source_uri = dataset if isinstance (dataset , str ) else dataset [0 ]
355
356
if first_source_uri .startswith ("gs://" ):
356
- if not isinstance (source_uri , str ):
357
- if not all (uri .startswith ("gs://" ) for uri in source_uri ):
357
+ if not isinstance (dataset , str ):
358
+ if not all (uri .startswith ("gs://" ) for uri in dataset ):
358
359
raise ValueError (
359
- f"All URIs in the list must start with 'gs://': { source_uri } "
360
+ f"All URIs in the list must start with 'gs://': { dataset } "
360
361
)
361
- arguments ["gcs_source" ] = source_uri
362
+ arguments ["gcs_source" ] = dataset
362
363
elif first_source_uri .startswith ("bq://" ):
363
- if not isinstance (source_uri , str ):
364
+ if not isinstance (dataset , str ):
364
365
raise ValueError (
365
- f"Only single BigQuery source can be specified: { source_uri } "
366
+ f"Only single BigQuery source can be specified: { dataset } "
366
367
)
367
- arguments ["bigquery_source" ] = source_uri
368
+ arguments ["bigquery_source" ] = dataset
368
369
else :
369
- raise ValueError (f"Unsupported source_uri: { source_uri } " )
370
+ raise ValueError (f"Unsupported source_uri: { dataset } " )
370
371
371
372
if destination_uri_prefix .startswith ("gs://" ):
372
373
arguments ["gcs_destination_prefix" ] = destination_uri_prefix
@@ -391,8 +392,48 @@ def batch_predict(
391
392
return job
392
393
393
394
395
+ class _PreviewModelWithBatchPredict (_ModelWithBatchPredict ):
396
+ """Model that supports batch prediction."""
397
+
398
+ def batch_predict (
399
+ self ,
400
+ * ,
401
+ destination_uri_prefix : str ,
402
+ dataset : Optional [Union [str , List [str ]]] = None ,
403
+ model_parameters : Optional [Dict ] = None ,
404
+ ** _kwargs : Optional [Dict [str , Any ]],
405
+ ) -> aiplatform .BatchPredictionJob :
406
+ """Starts a batch prediction job with the model.
407
+
408
+ Args:
409
+ dataset: Required. The location of the dataset.
410
+ `gs://` and `bq://` URIs are supported.
411
+ destination_uri_prefix: The URI prefix for the prediction.
412
+ `gs://` and `bq://` URIs are supported.
413
+ model_parameters: Model-specific parameters to send to the model.
414
+ **_kwargs: Deprecated.
415
+
416
+ Returns:
417
+ A `BatchPredictionJob` object
418
+ Raises:
419
+ ValueError: When source or destination URI is not supported.
420
+ """
421
+ if "source_uri" in _kwargs :
422
+ warnings .warn ("source_uri is deprecated, use dataset instead." )
423
+ if dataset :
424
+ raise ValueError ("source_uri is deprecated, use dataset instead." )
425
+ dataset = _kwargs ["source_uri" ]
426
+ if not dataset :
427
+ raise ValueError ("dataset must be specified" )
428
+ return super ().batch_predict (
429
+ dataset = dataset ,
430
+ destination_uri_prefix = destination_uri_prefix ,
431
+ model_parameters = model_parameters ,
432
+ )
433
+
434
+
394
435
class _PreviewTextGenerationModel (
395
- TextGenerationModel , _TunableModelMixin , _ModelWithBatchPredict
436
+ TextGenerationModel , _TunableModelMixin , _PreviewModelWithBatchPredict
396
437
):
397
438
"""Preview text generation model."""
398
439
0 commit comments