Skip to content

Commit eabe720

Browse files
Ark-kuncopybara-github
authored andcommitted
chore: LLM - Renamed batch_predict function's source_uri to dataset.
This way we can support Pandas DataFrames and local datasets in the future. PiperOrigin-RevId: 548256736
1 parent aed8c76 commit eabe720

File tree

3 files changed

+55
-14
lines changed

3 files changed

+55
-14
lines changed

tests/system/aiplatform/test_language_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_batch_prediction(self):
168168

169169
model = TextGenerationModel.from_pretrained("text-bison@001")
170170
job = model.batch_predict(
171-
source_uri=source_uri,
171+
dataset=source_uri,
172172
destination_uri_prefix=destination_uri_prefix,
173173
model_parameters={"temperature": 0, "top_p": 1, "top_k": 5},
174174
)

tests/unit/aiplatform/test_language_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1311,7 +1311,7 @@ def test_batch_prediction(self):
13111311
attribute="create",
13121312
) as mock_create:
13131313
model.batch_predict(
1314-
source_uri="gs://test-bucket/test_table.jsonl",
1314+
dataset="gs://test-bucket/test_table.jsonl",
13151315
destination_uri_prefix="gs://test-bucket/results/",
13161316
model_parameters={"temperature": 0.1},
13171317
)

vertexai/language_models/_language_models.py

+53-12
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import dataclasses
1818
from typing import Any, Dict, List, Optional, Sequence, Union
19+
import warnings
1920

2021
from google.cloud import aiplatform
2122
from google.cloud.aiplatform import base
@@ -332,14 +333,14 @@ class _ModelWithBatchPredict(_LanguageModel):
332333
def batch_predict(
333334
self,
334335
*,
335-
source_uri: Union[str, List[str]],
336+
dataset: Union[str, List[str]],
336337
destination_uri_prefix: str,
337338
model_parameters: Optional[Dict] = None,
338339
) -> aiplatform.BatchPredictionJob:
339340
"""Starts a batch prediction job with the model.
340341
341342
Args:
342-
source_uri: The location of the dataset.
343+
dataset: The location of the dataset.
343344
`gs://` and `bq://` URIs are supported.
344345
destination_uri_prefix: The URI prefix for the prediction.
345346
`gs://` and `bq://` URIs are supported.
@@ -351,22 +352,22 @@ def batch_predict(
351352
ValueError: When source or destination URI is not supported.
352353
"""
353354
arguments = {}
354-
first_source_uri = source_uri if isinstance(source_uri, str) else source_uri[0]
355+
first_source_uri = dataset if isinstance(dataset, str) else dataset[0]
355356
if first_source_uri.startswith("gs://"):
356-
if not isinstance(source_uri, str):
357-
if not all(uri.startswith("gs://") for uri in source_uri):
357+
if not isinstance(dataset, str):
358+
if not all(uri.startswith("gs://") for uri in dataset):
358359
raise ValueError(
359-
f"All URIs in the list must start with 'gs://': {source_uri}"
360+
f"All URIs in the list must start with 'gs://': {dataset}"
360361
)
361-
arguments["gcs_source"] = source_uri
362+
arguments["gcs_source"] = dataset
362363
elif first_source_uri.startswith("bq://"):
363-
if not isinstance(source_uri, str):
364+
if not isinstance(dataset, str):
364365
raise ValueError(
365-
f"Only single BigQuery source can be specified: {source_uri}"
366+
f"Only single BigQuery source can be specified: {dataset}"
366367
)
367-
arguments["bigquery_source"] = source_uri
368+
arguments["bigquery_source"] = dataset
368369
else:
369-
raise ValueError(f"Unsupported source_uri: {source_uri}")
370+
raise ValueError(f"Unsupported source_uri: {dataset}")
370371

371372
if destination_uri_prefix.startswith("gs://"):
372373
arguments["gcs_destination_prefix"] = destination_uri_prefix
@@ -391,8 +392,48 @@ def batch_predict(
391392
return job
392393

393394

395+
class _PreviewModelWithBatchPredict(_ModelWithBatchPredict):
396+
"""Model that supports batch prediction."""
397+
398+
def batch_predict(
399+
self,
400+
*,
401+
destination_uri_prefix: str,
402+
dataset: Optional[Union[str, List[str]]] = None,
403+
model_parameters: Optional[Dict] = None,
404+
**_kwargs: Optional[Dict[str, Any]],
405+
) -> aiplatform.BatchPredictionJob:
406+
"""Starts a batch prediction job with the model.
407+
408+
Args:
409+
dataset: Required. The location of the dataset.
410+
`gs://` and `bq://` URIs are supported.
411+
destination_uri_prefix: The URI prefix for the prediction.
412+
`gs://` and `bq://` URIs are supported.
413+
model_parameters: Model-specific parameters to send to the model.
414+
**_kwargs: Deprecated.
415+
416+
Returns:
417+
A `BatchPredictionJob` object
418+
Raises:
419+
ValueError: When source or destination URI is not supported.
420+
"""
421+
if "source_uri" in _kwargs:
422+
warnings.warn("source_uri is deprecated, use dataset instead.")
423+
if dataset:
424+
raise ValueError("source_uri is deprecated, use dataset instead.")
425+
dataset = _kwargs["source_uri"]
426+
if not dataset:
427+
raise ValueError("dataset must be specified")
428+
return super().batch_predict(
429+
dataset=dataset,
430+
destination_uri_prefix=destination_uri_prefix,
431+
model_parameters=model_parameters,
432+
)
433+
434+
394435
class _PreviewTextGenerationModel(
395-
TextGenerationModel, _TunableModelMixin, _ModelWithBatchPredict
436+
TextGenerationModel, _TunableModelMixin, _PreviewModelWithBatchPredict
396437
):
397438
"""Preview text generation model."""
398439

0 commit comments

Comments
 (0)