|
52 | 52 | _TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}"
|
53 | 53 | _TEST_LLAMA_MODEL_NAME = "llama3-405b-instruct-maas"
|
54 | 54 | _TEST_LLAMA_MODEL_RESOURCE_NAME = f"publishers/meta/models/{_TEST_LLAMA_MODEL_NAME}"
|
| 55 | +_TEST_CLAUDE_MODEL_NAME = "claude-3-opus" |
| 56 | +_TEST_CLAUDE_MODEL_RESOURCE_NAME = ( |
| 57 | + f"publishers/anthropic/models/{_TEST_CLAUDE_MODEL_NAME}" |
| 58 | +) |
55 | 59 |
|
56 | 60 | _TEST_GCS_INPUT_URI = "gs://test-bucket/test-input.jsonl"
|
57 | 61 | _TEST_GCS_INPUT_URI_2 = "gs://test-bucket/test-input-2.jsonl"
|
@@ -146,6 +150,23 @@ def get_batch_prediction_job_with_llama_model_mock():
|
146 | 150 | yield get_job_mock
|
147 | 151 |
|
148 | 152 |
|
| 153 | +@pytest.fixture |
| 154 | +def get_batch_prediction_job_with_claude_model_mock(): |
| 155 | + with mock.patch.object( |
| 156 | + job_service_client.JobServiceClient, "get_batch_prediction_job" |
| 157 | + ) as get_job_mock: |
| 158 | + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 159 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, |
| 160 | + display_name=_TEST_DISPLAY_NAME, |
| 161 | + model=_TEST_CLAUDE_MODEL_RESOURCE_NAME, |
| 162 | + state=_TEST_JOB_STATE_SUCCESS, |
| 163 | + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( |
| 164 | + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX |
| 165 | + ), |
| 166 | + ) |
| 167 | + yield get_job_mock |
| 168 | + |
| 169 | + |
149 | 170 | @pytest.fixture
|
150 | 171 | def get_batch_prediction_job_with_tuned_gemini_model_mock():
|
151 | 172 | with mock.patch.object(
|
@@ -281,6 +302,16 @@ def test_init_batch_prediction_job_with_llama_model(
|
281 | 302 | name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
|
282 | 303 | )
|
283 | 304 |
|
| 305 | + def test_init_batch_prediction_job_with_claude_model( |
| 306 | + self, |
| 307 | + get_batch_prediction_job_with_claude_model_mock, |
| 308 | + ): |
| 309 | + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) |
| 310 | + |
| 311 | + get_batch_prediction_job_with_claude_model_mock.assert_called_once_with( |
| 312 | + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY |
| 313 | + ) |
| 314 | + |
284 | 315 | def test_init_batch_prediction_job_with_tuned_gemini_model(
|
285 | 316 | self,
|
286 | 317 | get_batch_prediction_job_with_tuned_gemini_model_mock,
|
@@ -509,6 +540,39 @@ def test_submit_batch_prediction_job_with_llama_model(
|
509 | 540 | timeout=None,
|
510 | 541 | )
|
511 | 542 |
|
| 543 | + def test_submit_batch_prediction_job_with_claude_model( |
| 544 | + self, |
| 545 | + create_batch_prediction_job_mock, |
| 546 | + ): |
| 547 | + job = batch_prediction.BatchPredictionJob.submit( |
| 548 | + source_model=_TEST_CLAUDE_MODEL_RESOURCE_NAME, |
| 549 | + input_dataset=_TEST_BQ_INPUT_URI, |
| 550 | + ) |
| 551 | + |
| 552 | + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB |
| 553 | + |
| 554 | + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( |
| 555 | + display_name=_TEST_DISPLAY_NAME, |
| 556 | + model=_TEST_CLAUDE_MODEL_RESOURCE_NAME, |
| 557 | + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( |
| 558 | + instances_format="bigquery", |
| 559 | + bigquery_source=gca_io_compat.BigQuerySource( |
| 560 | + input_uri=_TEST_BQ_INPUT_URI |
| 561 | + ), |
| 562 | + ), |
| 563 | + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( |
| 564 | + bigquery_destination=gca_io_compat.BigQueryDestination( |
| 565 | + output_uri=_TEST_BQ_OUTPUT_PREFIX |
| 566 | + ), |
| 567 | + predictions_format="bigquery", |
| 568 | + ), |
| 569 | + ) |
| 570 | + create_batch_prediction_job_mock.assert_called_once_with( |
| 571 | + parent=_TEST_PARENT, |
| 572 | + batch_prediction_job=expected_gapic_batch_prediction_job, |
| 573 | + timeout=None, |
| 574 | + ) |
| 575 | + |
512 | 576 | @pytest.mark.usefixtures("create_batch_prediction_job_mock")
|
513 | 577 | def test_submit_batch_prediction_job_with_tuned_model(
|
514 | 578 | self,
|
|
0 commit comments