Skip to content

Commit 0866009

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI - Added Anthropic models support in GenAI batch prediction
PiperOrigin-RevId: 688995187
1 parent 025e3dc commit 0866009

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

tests/unit/vertexai/test_batch_prediction.py

+64
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}"
5353
_TEST_LLAMA_MODEL_NAME = "llama3-405b-instruct-maas"
5454
_TEST_LLAMA_MODEL_RESOURCE_NAME = f"publishers/meta/models/{_TEST_LLAMA_MODEL_NAME}"
55+
_TEST_CLAUDE_MODEL_NAME = "claude-3-opus"
56+
_TEST_CLAUDE_MODEL_RESOURCE_NAME = (
57+
f"publishers/anthropic/models/{_TEST_CLAUDE_MODEL_NAME}"
58+
)
5559

5660
_TEST_GCS_INPUT_URI = "gs://test-bucket/test-input.jsonl"
5761
_TEST_GCS_INPUT_URI_2 = "gs://test-bucket/test-input-2.jsonl"
@@ -146,6 +150,23 @@ def get_batch_prediction_job_with_llama_model_mock():
146150
yield get_job_mock
147151

148152

153+
@pytest.fixture
154+
def get_batch_prediction_job_with_claude_model_mock():
155+
with mock.patch.object(
156+
job_service_client.JobServiceClient, "get_batch_prediction_job"
157+
) as get_job_mock:
158+
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
159+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
160+
display_name=_TEST_DISPLAY_NAME,
161+
model=_TEST_CLAUDE_MODEL_RESOURCE_NAME,
162+
state=_TEST_JOB_STATE_SUCCESS,
163+
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
164+
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
165+
),
166+
)
167+
yield get_job_mock
168+
169+
149170
@pytest.fixture
150171
def get_batch_prediction_job_with_tuned_gemini_model_mock():
151172
with mock.patch.object(
@@ -281,6 +302,16 @@ def test_init_batch_prediction_job_with_llama_model(
281302
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
282303
)
283304

305+
def test_init_batch_prediction_job_with_claude_model(
306+
self,
307+
get_batch_prediction_job_with_claude_model_mock,
308+
):
309+
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
310+
311+
get_batch_prediction_job_with_claude_model_mock.assert_called_once_with(
312+
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
313+
)
314+
284315
def test_init_batch_prediction_job_with_tuned_gemini_model(
285316
self,
286317
get_batch_prediction_job_with_tuned_gemini_model_mock,
@@ -509,6 +540,39 @@ def test_submit_batch_prediction_job_with_llama_model(
509540
timeout=None,
510541
)
511542

543+
def test_submit_batch_prediction_job_with_claude_model(
544+
self,
545+
create_batch_prediction_job_mock,
546+
):
547+
job = batch_prediction.BatchPredictionJob.submit(
548+
source_model=_TEST_CLAUDE_MODEL_RESOURCE_NAME,
549+
input_dataset=_TEST_BQ_INPUT_URI,
550+
)
551+
552+
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
553+
554+
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
555+
display_name=_TEST_DISPLAY_NAME,
556+
model=_TEST_CLAUDE_MODEL_RESOURCE_NAME,
557+
input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
558+
instances_format="bigquery",
559+
bigquery_source=gca_io_compat.BigQuerySource(
560+
input_uri=_TEST_BQ_INPUT_URI
561+
),
562+
),
563+
output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
564+
bigquery_destination=gca_io_compat.BigQueryDestination(
565+
output_uri=_TEST_BQ_OUTPUT_PREFIX
566+
),
567+
predictions_format="bigquery",
568+
),
569+
)
570+
create_batch_prediction_job_mock.assert_called_once_with(
571+
parent=_TEST_PARENT,
572+
batch_prediction_job=expected_gapic_batch_prediction_job,
573+
timeout=None,
574+
)
575+
512576
@pytest.mark.usefixtures("create_batch_prediction_job_mock")
513577
def test_submit_batch_prediction_job_with_tuned_model(
514578
self,

vertexai/batch_prediction/_batch_prediction.py

+6
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
_GEMINI_MODEL_PATTERN = r"publishers/google/models/gemini"
3636
_LLAMA_MODEL_PATTERN = r"publishers/meta/models/llama"
37+
_CLAUDE_MODEL_PATTERN = r"publishers/anthropic/models/claude"
3738
_GEMINI_TUNED_MODEL_PATTERN = r"^projects/[0-9]+?/locations/[0-9a-z-]+?/models/[0-9]+?$"
3839

3940

@@ -287,6 +288,7 @@ def _reconcile_model_name(cls, model_name: str) -> str:
287288
# publisher model full name
288289
not model_name.startswith("publishers/google/models/")
289290
and not model_name.startswith("publishers/meta/models/")
291+
and not model_name.startswith("publishers/anthropic/models/")
290292
# tuned model full resource name
291293
and not re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name)
292294
):
@@ -314,6 +316,10 @@ def _is_genai_model(cls, model_name: str) -> bool:
314316
# Model is a Llama3 model.
315317
return True
316318

319+
if re.search(_CLAUDE_MODEL_PATTERN, model_name):
320+
# Model is a claude model.
321+
return True
322+
317323
return False
318324

319325
@classmethod

0 commit comments

Comments
 (0)