Skip to content

Commit 2a08535

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: LLM - Fixed batch prediction on tuned models
PiperOrigin-RevId: 560910428
1 parent 2e3090b commit 2a08535

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

tests/unit/aiplatform/test_language_models.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
artifact as gca_artifact,
4646
prediction_service as gca_prediction_service,
4747
context as gca_context,
48-
endpoint as gca_endpoint,
48+
endpoint_v1 as gca_endpoint,
4949
pipeline_job as gca_pipeline_job,
5050
pipeline_state as gca_pipeline_state,
5151
deployed_model_ref_v1,
@@ -1030,6 +1030,11 @@ def get_endpoint_mock():
10301030
get_endpoint_mock.return_value = gca_endpoint.Endpoint(
10311031
display_name="test-display-name",
10321032
name=test_constants.EndpointConstants._TEST_ENDPOINT_NAME,
1033+
deployed_models=[
1034+
gca_endpoint.DeployedModel(
1035+
model=test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
1036+
),
1037+
],
10331038
)
10341039
yield get_endpoint_mock
10351040

@@ -2420,7 +2425,10 @@ def test_text_embedding_ga(self):
24202425
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
24212426
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]
24222427

2423-
def test_batch_prediction(self):
2428+
def test_batch_prediction(
2429+
self,
2430+
get_endpoint_mock,
2431+
):
24242432
"""Tests batch prediction."""
24252433
aiplatform.init(
24262434
project=_TEST_PROJECT,
@@ -2447,7 +2455,29 @@ def test_batch_prediction(self):
24472455
model_parameters={"temperature": 0.1},
24482456
)
24492457
mock_create.assert_called_once_with(
2450-
model_name="publishers/google/models/text-bison@001",
2458+
model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001",
2459+
job_display_name=None,
2460+
gcs_source="gs://test-bucket/test_table.jsonl",
2461+
gcs_destination_prefix="gs://test-bucket/results/",
2462+
model_parameters={"temperature": 0.1},
2463+
)
2464+
2465+
# Testing tuned model batch prediction
2466+
tuned_model = language_models.TextGenerationModel(
2467+
model_id=model._model_id,
2468+
endpoint_name=test_constants.EndpointConstants._TEST_ENDPOINT_NAME,
2469+
)
2470+
with mock.patch.object(
2471+
target=aiplatform.BatchPredictionJob,
2472+
attribute="create",
2473+
) as mock_create:
2474+
tuned_model.batch_predict(
2475+
dataset="gs://test-bucket/test_table.jsonl",
2476+
destination_uri_prefix="gs://test-bucket/results/",
2477+
model_parameters={"temperature": 0.1},
2478+
)
2479+
mock_create.assert_called_once_with(
2480+
model_name=test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME,
24512481
job_display_name=None,
24522482
gcs_source="gs://test-bucket/test_table.jsonl",
24532483
gcs_destination_prefix="gs://test-bucket/results/",
@@ -2481,7 +2511,7 @@ def test_batch_prediction_for_text_embedding(self):
24812511
model_parameters={},
24822512
)
24832513
mock_create.assert_called_once_with(
2484-
model_name="publishers/google/models/textembedding-gecko@001",
2514+
model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001",
24852515
job_display_name=None,
24862516
gcs_source="gs://test-bucket/test_table.jsonl",
24872517
gcs_destination_prefix="gs://test-bucket/results/",

vertexai/language_models/_language_models.py

-5
Original file line numberDiff line numberDiff line change
@@ -839,11 +839,6 @@ def batch_predict(
839839
raise ValueError(f"Unsupported destination_uri: {destination_uri_prefix}")
840840

841841
model_name = self._model_resource_name
842-
# TODO(b/284512065): Batch prediction service does not support
843-
# fully qualified publisher model names yet
844-
publishers_index = model_name.index("/publishers/")
845-
if publishers_index > 0:
846-
model_name = model_name[publishers_index + 1 :]
847842

848843
job = aiplatform.BatchPredictionJob.create(
849844
model_name=model_name,

0 commit comments

Comments
 (0)