Skip to content

Commit fbf2f7c

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Add support for batch prediction to CodeGenerationModel (code-bison)
PiperOrigin-RevId: 609627761
1 parent 0b55762 commit fbf2f7c

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

tests/system/aiplatform/test_language_models.py

+21
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from vertexai.preview.language_models import (
3535
ChatModel,
36+
CodeGenerationModel,
3637
InputOutputTextPair,
3738
TextGenerationModel,
3839
TextGenerationResponse,
@@ -434,6 +435,26 @@ def test_batch_prediction_for_textembedding(self):
434435

435436
assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED
436437

438+
def test_batch_prediction_for_code_generation(self):
439+
source_uri = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/code-bison.batch_prediction_prompts.1.jsonl"
440+
destination_uri_prefix = "gs://ucaip-samples-us-central1/model/llm/batch_prediction/predictions/code-bison@001_"
441+
442+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
443+
444+
model = CodeGenerationModel.from_pretrained("code-bison@001")
445+
job = model.batch_predict(
446+
dataset=source_uri,
447+
destination_uri_prefix=destination_uri_prefix,
448+
model_parameters={"temperature": 0},
449+
)
450+
451+
job.wait_for_resource_creation()
452+
job.wait()
453+
gapic_job = job._gca_resource
454+
job.delete()
455+
456+
assert gapic_job.state == gca_job_state.JobState.JOB_STATE_SUCCEEDED
457+
437458
def test_code_generation_streaming(self):
438459
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
439460

tests/unit/aiplatform/test_language_models.py

+30
Original file line numberDiff line numberDiff line change
@@ -4324,6 +4324,36 @@ def test_batch_prediction(
43244324
model_parameters={"temperature": 0.1},
43254325
)
43264326

4327+
def test_batch_prediction_for_code_generation(self):
4328+
"""Tests batch prediction."""
4329+
with mock.patch.object(
4330+
target=model_garden_service_client.ModelGardenServiceClient,
4331+
attribute="get_publisher_model",
4332+
return_value=gca_publisher_model.PublisherModel(
4333+
_CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
4334+
),
4335+
):
4336+
model = preview_language_models.CodeGenerationModel.from_pretrained(
4337+
"code-bison@001"
4338+
)
4339+
4340+
with mock.patch.object(
4341+
target=aiplatform.BatchPredictionJob,
4342+
attribute="create",
4343+
) as mock_create:
4344+
model.batch_predict(
4345+
dataset="gs://test-bucket/test_table.jsonl",
4346+
destination_uri_prefix="gs://test-bucket/results/",
4347+
model_parameters={},
4348+
)
4349+
mock_create.assert_called_once_with(
4350+
model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/code-bison@001",
4351+
job_display_name=None,
4352+
gcs_source="gs://test-bucket/test_table.jsonl",
4353+
gcs_destination_prefix="gs://test-bucket/results/",
4354+
model_parameters={},
4355+
)
4356+
43274357
def test_batch_prediction_for_text_embedding(self):
43284358
"""Tests batch prediction."""
43294359
aiplatform.init(

vertexai/language_models/_language_models.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3366,7 +3366,11 @@ def count_tokens(
33663366
)
33673367

33683368

3369-
class CodeGenerationModel(_CodeGenerationModel, _TunableTextModelMixin):
3369+
class CodeGenerationModel(
3370+
_CodeGenerationModel,
3371+
_TunableTextModelMixin,
3372+
_ModelWithBatchPredict,
3373+
):
33703374
pass
33713375

33723376

0 commit comments

Comments
 (0)