Skip to content

Commit 6166152

Browse files
jaycee-licopybara-github
authored andcommitted
feat: GenAI - Added Llama3 support in GenAI batch prediction
PiperOrigin-RevId: 669193397
1 parent 72fcc06 commit 6166152

File tree

2 files changed

+90
-2
lines changed

2 files changed

+90
-2
lines changed

tests/unit/vertexai/test_batch_prediction.py

+77-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456"
5151
_TEST_PALM_MODEL_NAME = "text-bison"
5252
_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}"
53+
_TEST_LLAMA_MODEL_NAME = "llama3-405b-instruct-maas"
54+
_TEST_LLAMA_MODEL_RESOURCE_NAME = f"publishers/meta/models/{_TEST_LLAMA_MODEL_NAME}"
5355

5456
_TEST_GCS_INPUT_URI = "gs://test-bucket/test-input.jsonl"
5557
_TEST_GCS_INPUT_URI_2 = "gs://test-bucket/test-input-2.jsonl"
@@ -127,6 +129,23 @@ def get_batch_prediction_job_with_gcs_output_mock():
127129
yield get_job_mock
128130

129131

132+
@pytest.fixture
133+
def get_batch_prediction_job_with_llama_model_mock():
134+
with mock.patch.object(
135+
job_service_client.JobServiceClient, "get_batch_prediction_job"
136+
) as get_job_mock:
137+
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
138+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
139+
display_name=_TEST_DISPLAY_NAME,
140+
model=_TEST_LLAMA_MODEL_RESOURCE_NAME,
141+
state=_TEST_JOB_STATE_SUCCESS,
142+
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
143+
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
144+
),
145+
)
146+
yield get_job_mock
147+
148+
130149
@pytest.fixture
131150
def get_batch_prediction_job_with_tuned_gemini_model_mock():
132151
with mock.patch.object(
@@ -252,6 +271,16 @@ def test_init_batch_prediction_job(
252271
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
253272
)
254273

274+
def test_init_batch_prediction_job_with_llama_model(
275+
self,
276+
get_batch_prediction_job_with_llama_model_mock,
277+
):
278+
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
279+
280+
get_batch_prediction_job_with_llama_model_mock.assert_called_once_with(
281+
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
282+
)
283+
255284
def test_init_batch_prediction_job_with_tuned_gemini_model(
256285
self,
257286
get_batch_prediction_job_with_tuned_gemini_model_mock,
@@ -447,6 +476,39 @@ def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix(
447476
timeout=None,
448477
)
449478

479+
def test_submit_batch_prediction_job_with_llama_model(
480+
self,
481+
create_batch_prediction_job_mock,
482+
):
483+
job = batch_prediction.BatchPredictionJob.submit(
484+
source_model=_TEST_LLAMA_MODEL_RESOURCE_NAME,
485+
input_dataset=_TEST_BQ_INPUT_URI,
486+
)
487+
488+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
489+
490+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
491+
display_name=_TEST_DISPLAY_NAME,
492+
model=_TEST_LLAMA_MODEL_RESOURCE_NAME,
493+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
494+
instances_format="bigquery",
495+
bigquery_source=gca_io_compat.BigQuerySource(
496+
input_uri=_TEST_BQ_INPUT_URI
497+
),
498+
),
499+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
500+
bigquery_destination=gca_io_compat.BigQueryDestination(
501+
output_uri=_TEST_BQ_OUTPUT_PREFIX
502+
),
503+
predictions_format="bigquery",
504+
),
505+
)
506+
create_batch_prediction_job_mock.assert_called_once_with(
507+
parent=_TEST_PARENT,
508+
batch_prediction_job=expected_gapic_batch_prediction_job,
509+
timeout=None,
510+
)
511+
450512
@pytest.mark.usefixtures("create_batch_prediction_job_mock")
451513
def test_submit_batch_prediction_job_with_tuned_model(
452514
self,
@@ -467,14 +529,28 @@ def test_submit_batch_prediction_job_with_invalid_source_model(self):
467529
with pytest.raises(
468530
ValueError,
469531
match=(
470-
f"Model '{_TEST_PALM_MODEL_RESOURCE_NAME}' is not a Generative AI model."
532+
"Abbreviated model names are only supported for Gemini models. "
533+
"Please provide the full publisher model name."
471534
),
472535
):
473536
batch_prediction.BatchPredictionJob.submit(
474537
source_model=_TEST_PALM_MODEL_NAME,
475538
input_dataset=_TEST_GCS_INPUT_URI,
476539
)
477540

541+
def test_submit_batch_prediction_job_with_invalid_abbreviated_model_name(self):
542+
with pytest.raises(
543+
ValueError,
544+
match=(
545+
"Abbreviated model names are only supported for Gemini models. "
546+
"Please provide the full publisher model name."
547+
),
548+
):
549+
batch_prediction.BatchPredictionJob.submit(
550+
source_model=_TEST_LLAMA_MODEL_NAME,
551+
input_dataset=_TEST_GCS_INPUT_URI,
552+
)
553+
478554
@pytest.mark.usefixtures("get_non_gemini_model_mock")
479555
def test_submit_batch_prediction_job_with_non_gemini_tuned_model(self):
480556
with pytest.raises(

vertexai/batch_prediction/_batch_prediction.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_LOGGER = aiplatform_base.Logger(__name__)
3434

3535
_GEMINI_MODEL_PATTERN = r"publishers/google/models/gemini"
36+
_LLAMA_MODEL_PATTERN = r"publishers/meta/models/llama"
3637
_GEMINI_TUNED_MODEL_PATTERN = r"^projects/[0-9]+?/locations/[0-9a-z-]+?/models/[0-9]+?$"
3738

3839

@@ -272,13 +273,20 @@ def _reconcile_model_name(cls, model_name: str) -> str:
272273

273274
if "/" not in model_name:
274275
# model name (e.g., gemini-1.0-pro)
275-
model_name = "publishers/google/models/" + model_name
276+
if model_name.startswith("gemini"):
277+
model_name = "publishers/google/models/" + model_name
278+
else:
279+
raise ValueError(
280+
"Abbreviated model names are only supported for Gemini models. "
281+
"Please provide the full publisher model name."
282+
)
276283
elif model_name.startswith("models/"):
277284
# publisher model name (e.g., models/gemini-1.0-pro)
278285
model_name = "publishers/google/" + model_name
279286
elif (
280287
# publisher model full name
281288
not model_name.startswith("publishers/google/models/")
289+
and not model_name.startswith("publishers/meta/models/")
282290
# tuned model full resource name
283291
and not re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name)
284292
):
@@ -302,6 +310,10 @@ def _is_genai_model(cls, model_name: str) -> bool:
302310
# Model is a tuned Gemini model.
303311
return True
304312

313+
if re.search(_LLAMA_MODEL_PATTERN, model_name):
314+
# Model is a Llama3 model.
315+
return True
316+
305317
return False
306318

307319
@classmethod

0 commit comments

Comments
 (0)