Skip to content

Commit 2235305

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added batch prediction
PiperOrigin-RevId: 542106410
1 parent cd67734 commit 2235305

File tree

3 files changed

+126
-2
lines changed

3 files changed

+126
-2
lines changed

tests/system/aiplatform/test_language_models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
# pylint: disable=protected-access, g-multiple-import
1919

2020
from google.cloud import aiplatform
21+
from google.cloud.aiplatform.compat.types import (
22+
job_state_v1beta1 as gca_job_state_v1beta1,
23+
)
2124
from tests.system.aiplatform import e2e_base
2225
from vertexai.preview.language_models import (
2326
ChatModel,
@@ -144,3 +147,23 @@ def test_tuning(self, shared_state):
144147
top_k=5,
145148
)
146149
assert tuned_model_response.text
150+
151+
def test_batch_prediction(self):
152+
source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/batch_prediction_prompts1.jsonl"
153+
destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/text-bison@001_"
154+
155+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
156+
157+
model = TextGenerationModel.from_pretrained("text-bison@001")
158+
job = model.batch_predict(
159+
source_uri=source_uri,
160+
destination_uri_prefix=destination_uri_prefix,
161+
model_parameters={"temperature": 0, "top_p": 1, "top_k": 5},
162+
)
163+
164+
job.wait_for_resource_creation()
165+
job.wait()
166+
gapic_job = job._gca_resource
167+
job.delete()
168+
169+
assert gapic_job.state == gca_job_state_v1beta1.JobState.JOB_STATE_SUCCEEDED

tests/unit/aiplatform/test_language_models.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,3 +1135,37 @@ def test_text_embedding_ga(self):
11351135
vector = embedding.values
11361136
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
11371137
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]
1138+
1139+
def test_batch_prediction(self):
1140+
"""Tests batch prediction."""
1141+
aiplatform.init(
1142+
project=_TEST_PROJECT,
1143+
location=_TEST_LOCATION,
1144+
)
1145+
with mock.patch.object(
1146+
target=model_garden_service_client.ModelGardenServiceClient,
1147+
attribute="get_publisher_model",
1148+
return_value=gca_publisher_model.PublisherModel(
1149+
_TEXT_BISON_PUBLISHER_MODEL_DICT
1150+
),
1151+
):
1152+
model = preview_language_models.TextGenerationModel.from_pretrained(
1153+
"text-bison@001"
1154+
)
1155+
1156+
with mock.patch.object(
1157+
target=aiplatform.BatchPredictionJob,
1158+
attribute="create",
1159+
) as mock_create:
1160+
model.batch_predict(
1161+
source_uri="gs://test-bucket/test_table.jsonl",
1162+
destination_uri_prefix="gs://test-bucket/results/",
1163+
model_parameters={"temperature": 0.1},
1164+
)
1165+
mock_create.assert_called_once_with(
1166+
model_name="publishers/google/models/text-bison@001",
1167+
job_display_name=None,
1168+
gcs_source="gs://test-bucket/test_table.jsonl",
1169+
gcs_destination_prefix="gs://test-bucket/results/",
1170+
model_parameters={"temperature": 0.1},
1171+
)

vertexai/language_models/_language_models.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,75 @@ def _batch_predict(
320320
_TextGenerationModel = TextGenerationModel
321321

322322

323-
class _PreviewTextGenerationModel(TextGenerationModel, _TunableModelMixin):
324-
"""Tunable text generation model."""
323+
class _ModelWithBatchPredict(_LanguageModel):
324+
"""Model that supports batch prediction."""
325+
326+
def batch_predict(
327+
self,
328+
*,
329+
source_uri: Union[str, List[str]],
330+
destination_uri_prefix: str,
331+
model_parameters: Optional[Dict] = None,
332+
) -> aiplatform.BatchPredictionJob:
333+
"""Starts a batch prediction job with the model.
334+
335+
Args:
336+
source_uri: The location of the dataset.
337+
`gs://` and `bq://` URIs are supported.
338+
destination_uri_prefix: The URI prefix for the prediction.
339+
`gs://` and `bq://` URIs are supported.
340+
model_parameters: Model-specific parameters to send to the model.
341+
342+
Returns:
343+
A `BatchPredictionJob` object
344+
Raises:
345+
ValueError: When source or destination URI is not supported.
346+
"""
347+
arguments = {}
348+
first_source_uri = source_uri if isinstance(source_uri, str) else source_uri[0]
349+
if first_source_uri.startswith("gs://"):
350+
if not isinstance(source_uri, str):
351+
if not all(uri.startswith("gs://") for uri in source_uri):
352+
raise ValueError(
353+
f"All URIs in the list must start with 'gs://': {source_uri}"
354+
)
355+
arguments["gcs_source"] = source_uri
356+
elif first_source_uri.startswith("bq://"):
357+
if not isinstance(source_uri, str):
358+
raise ValueError(
359+
f"Only single BigQuery source can be specified: {source_uri}"
360+
)
361+
arguments["bigquery_source"] = source_uri
362+
else:
363+
raise ValueError(f"Unsupported source_uri: {source_uri}")
364+
365+
if destination_uri_prefix.startswith("gs://"):
366+
arguments["gcs_destination_prefix"] = destination_uri_prefix
367+
elif destination_uri_prefix.startswith("bq://"):
368+
arguments["bigquery_destination_prefix"] = destination_uri_prefix
369+
else:
370+
raise ValueError(f"Unsupported destination_uri: {destination_uri_prefix}")
371+
372+
model_name = self._model_resource_name
373+
# TODO(b/284512065): Batch prediction service does not support
374+
# fully qualified publisher model names yet
375+
publishers_index = model_name.index("/publishers/")
376+
if publishers_index > 0:
377+
model_name = model_name[publishers_index + 1 :]
378+
379+
job = aiplatform.BatchPredictionJob.create(
380+
model_name=model_name,
381+
job_display_name=None,
382+
**arguments,
383+
model_parameters=model_parameters,
384+
)
385+
return job
386+
387+
388+
class _PreviewTextGenerationModel(
389+
TextGenerationModel, _TunableModelMixin, _ModelWithBatchPredict
390+
):
391+
"""Preview text generation model."""
325392

326393
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
327394

0 commit comments

Comments
 (0)