Skip to content

Commit a368538

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Support for Batch Prediction for the textembedding models (preview)
Usage: ``` model = TextEmbeddingModel.from_pretrained("textembedding-gecko@001") job = model.batch_predict( dataset="gs://<bicket>/dataset.jsonl", destination_uri_prefix="gs://<bicket>/batch_prediction/", # Optional: model_parameters={}, ) ``` PiperOrigin-RevId: 551663844
1 parent 7d72bd1 commit a368538

File tree

3 files changed

+59
-7
lines changed

3 files changed

+59
-7
lines changed

tests/system/aiplatform/test_language_models.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from google.cloud import aiplatform
2121
from google.cloud.aiplatform.compat.types import (
22-
job_state_v1beta1 as gca_job_state_v1beta1,
22+
job_state as gca_job_state,
2323
)
2424
from tests.system.aiplatform import e2e_base
2525
from vertexai.preview.language_models import (
@@ -160,7 +160,7 @@ def test_tuning(self, shared_state):
160160
)
161161
assert tuned_model_response.text
162162

163-
def test_batch_prediction(self):
163+
def test_batch_prediction_for_text_generation(self):
164164
source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/batch_prediction_prompts1.jsonl"
165165
destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/text-bison@001_"
166166

@@ -178,4 +178,24 @@ def test_batch_prediction(self):
178178
gapic_job = job._gca_resource
179179
job.delete()
180180

181-
assert gapic_job.state == gca_job_state_v1beta1.JobState.JOB_STATE_SUCCEEDED
181+
assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED
182+
183+
def test_batch_prediction_for_textembedding(self):
184+
source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/batch_prediction_prompts1.jsonl"
185+
destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/textembedding-gecko@001_"
186+
187+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
188+
189+
model = TextEmbeddingModel.from_pretrained("textembedding-gecko")
190+
job = model.batch_predict(
191+
dataset=source_uri,
192+
destination_uri_prefix=destination_uri_prefix,
193+
model_parameters={},
194+
)
195+
196+
job.wait_for_resource_creation()
197+
job.wait()
198+
gapic_job = job._gca_resource
199+
job.delete()
200+
201+
assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED

tests/unit/aiplatform/test_language_models.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@
141141
"version_id": "001",
142142
"open_source_category": "PROPRIETARY",
143143
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA,
144-
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001",
144+
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/textembedding-gecko@001",
145145
"predict_schemata": {
146146
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml",
147147
"parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/text_generation_1.0.0.yaml",
@@ -1323,3 +1323,37 @@ def test_batch_prediction(self):
13231323
gcs_destination_prefix="gs://test-bucket/results/",
13241324
model_parameters={"temperature": 0.1},
13251325
)
1326+
1327+
def test_batch_prediction_for_text_embedding(self):
1328+
"""Tests batch prediction."""
1329+
aiplatform.init(
1330+
project=_TEST_PROJECT,
1331+
location=_TEST_LOCATION,
1332+
)
1333+
with mock.patch.object(
1334+
target=model_garden_service_client.ModelGardenServiceClient,
1335+
attribute="get_publisher_model",
1336+
return_value=gca_publisher_model.PublisherModel(
1337+
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
1338+
),
1339+
):
1340+
model = preview_language_models.TextEmbeddingModel.from_pretrained(
1341+
"textembedding-gecko@001"
1342+
)
1343+
1344+
with mock.patch.object(
1345+
target=aiplatform.BatchPredictionJob,
1346+
attribute="create",
1347+
) as mock_create:
1348+
model.batch_predict(
1349+
dataset="gs://test-bucket/test_table.jsonl",
1350+
destination_uri_prefix="gs://test-bucket/results/",
1351+
model_parameters={},
1352+
)
1353+
mock_create.assert_called_once_with(
1354+
model_name="publishers/google/models/textembedding-gecko@001",
1355+
job_display_name=None,
1356+
gcs_source="gs://test-bucket/test_table.jsonl",
1357+
gcs_destination_prefix="gs://test-bucket/results/",
1358+
model_parameters={},
1359+
)

vertexai/language_models/_language_models.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -586,9 +586,7 @@ def get_embeddings(self, texts: List[str]) -> List["TextEmbedding"]:
586586
]
587587

588588

589-
class _PreviewTextEmbeddingModel(TextEmbeddingModel):
590-
"""Preview text embedding model."""
591-
589+
class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
592590
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
593591

594592

0 commit comments

Comments
 (0)