50
50
_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456"
51
51
_TEST_PALM_MODEL_NAME = "text-bison"
52
52
_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{ _TEST_PALM_MODEL_NAME } "
53
+ _TEST_LLAMA_MODEL_NAME = "llama3-405b-instruct-maas"
54
+ _TEST_LLAMA_MODEL_RESOURCE_NAME = f"publishers/meta/models/{ _TEST_LLAMA_MODEL_NAME } "
53
55
54
56
_TEST_GCS_INPUT_URI = "gs://test-bucket/test-input.jsonl"
55
57
_TEST_GCS_INPUT_URI_2 = "gs://test-bucket/test-input-2.jsonl"
@@ -127,6 +129,23 @@ def get_batch_prediction_job_with_gcs_output_mock():
127
129
yield get_job_mock
128
130
129
131
132
+ @pytest .fixture
133
+ def get_batch_prediction_job_with_llama_model_mock ():
134
+ with mock .patch .object (
135
+ job_service_client .JobServiceClient , "get_batch_prediction_job"
136
+ ) as get_job_mock :
137
+ get_job_mock .return_value = gca_batch_prediction_job_compat .BatchPredictionJob (
138
+ name = _TEST_BATCH_PREDICTION_JOB_NAME ,
139
+ display_name = _TEST_DISPLAY_NAME ,
140
+ model = _TEST_LLAMA_MODEL_RESOURCE_NAME ,
141
+ state = _TEST_JOB_STATE_SUCCESS ,
142
+ output_info = gca_batch_prediction_job_compat .BatchPredictionJob .OutputInfo (
143
+ gcs_output_directory = _TEST_GCS_OUTPUT_PREFIX
144
+ ),
145
+ )
146
+ yield get_job_mock
147
+
148
+
130
149
@pytest .fixture
131
150
def get_batch_prediction_job_with_tuned_gemini_model_mock ():
132
151
with mock .patch .object (
@@ -252,6 +271,16 @@ def test_init_batch_prediction_job(
252
271
name = _TEST_BATCH_PREDICTION_JOB_NAME , retry = aiplatform_base ._DEFAULT_RETRY
253
272
)
254
273
274
+ def test_init_batch_prediction_job_with_llama_model (
275
+ self ,
276
+ get_batch_prediction_job_with_llama_model_mock ,
277
+ ):
278
+ batch_prediction .BatchPredictionJob (_TEST_BATCH_PREDICTION_JOB_ID )
279
+
280
+ get_batch_prediction_job_with_llama_model_mock .assert_called_once_with (
281
+ name = _TEST_BATCH_PREDICTION_JOB_NAME , retry = aiplatform_base ._DEFAULT_RETRY
282
+ )
283
+
255
284
def test_init_batch_prediction_job_with_tuned_gemini_model (
256
285
self ,
257
286
get_batch_prediction_job_with_tuned_gemini_model_mock ,
@@ -447,6 +476,39 @@ def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix(
447
476
timeout = None ,
448
477
)
449
478
479
+ def test_submit_batch_prediction_job_with_llama_model (
480
+ self ,
481
+ create_batch_prediction_job_mock ,
482
+ ):
483
+ job = batch_prediction .BatchPredictionJob .submit (
484
+ source_model = _TEST_LLAMA_MODEL_RESOURCE_NAME ,
485
+ input_dataset = _TEST_BQ_INPUT_URI ,
486
+ )
487
+
488
+ assert job .gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
489
+
490
+ expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat .BatchPredictionJob (
491
+ display_name = _TEST_DISPLAY_NAME ,
492
+ model = _TEST_LLAMA_MODEL_RESOURCE_NAME ,
493
+ input_config = gca_batch_prediction_job_compat .BatchPredictionJob .InputConfig (
494
+ instances_format = "bigquery" ,
495
+ bigquery_source = gca_io_compat .BigQuerySource (
496
+ input_uri = _TEST_BQ_INPUT_URI
497
+ ),
498
+ ),
499
+ output_config = gca_batch_prediction_job_compat .BatchPredictionJob .OutputConfig (
500
+ bigquery_destination = gca_io_compat .BigQueryDestination (
501
+ output_uri = _TEST_BQ_OUTPUT_PREFIX
502
+ ),
503
+ predictions_format = "bigquery" ,
504
+ ),
505
+ )
506
+ create_batch_prediction_job_mock .assert_called_once_with (
507
+ parent = _TEST_PARENT ,
508
+ batch_prediction_job = expected_gapic_batch_prediction_job ,
509
+ timeout = None ,
510
+ )
511
+
450
512
@pytest .mark .usefixtures ("create_batch_prediction_job_mock" )
451
513
def test_submit_batch_prediction_job_with_tuned_model (
452
514
self ,
@@ -467,14 +529,28 @@ def test_submit_batch_prediction_job_with_invalid_source_model(self):
467
529
with pytest .raises (
468
530
ValueError ,
469
531
match = (
470
- f"Model '{ _TEST_PALM_MODEL_RESOURCE_NAME } ' is not a Generative AI model."
532
+ "Abbreviated model names are only supported for Gemini models. "
533
+ "Please provide the full publisher model name."
471
534
),
472
535
):
473
536
batch_prediction .BatchPredictionJob .submit (
474
537
source_model = _TEST_PALM_MODEL_NAME ,
475
538
input_dataset = _TEST_GCS_INPUT_URI ,
476
539
)
477
540
541
+ def test_submit_batch_prediction_job_with_invalid_abbreviated_model_name (self ):
542
+ with pytest .raises (
543
+ ValueError ,
544
+ match = (
545
+ "Abbreviated model names are only supported for Gemini models. "
546
+ "Please provide the full publisher model name."
547
+ ),
548
+ ):
549
+ batch_prediction .BatchPredictionJob .submit (
550
+ source_model = _TEST_LLAMA_MODEL_NAME ,
551
+ input_dataset = _TEST_GCS_INPUT_URI ,
552
+ )
553
+
478
554
@pytest .mark .usefixtures ("get_non_gemini_model_mock" )
479
555
def test_submit_batch_prediction_job_with_non_gemini_tuned_model (self ):
480
556
with pytest .raises (
0 commit comments