Skip to content

Commit 13b11c6

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: Support publisher models in BatchPredictionJob.create
PiperOrigin-RevId: 536581722
1 parent 0463678 commit 13b11c6

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

google/cloud/aiplatform/jobs.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from google.cloud.aiplatform import hyperparameter_tuning
5656
from google.cloud.aiplatform import model_monitoring
5757
from google.cloud.aiplatform import utils
58+
from google.cloud.aiplatform.preview import _publisher_model
5859
from google.cloud.aiplatform.utils import console_utils
5960
from google.cloud.aiplatform.utils import source_utils
6061
from google.cloud.aiplatform.utils import worker_spec_utils
@@ -624,15 +625,22 @@ def create(
624625
utils.validate_labels(labels)
625626

626627
if isinstance(model_name, str):
627-
model_name = utils.full_resource_name(
628-
resource_name=model_name,
629-
resource_noun="models",
630-
parse_resource_name_method=aiplatform.Model._parse_resource_name,
631-
format_resource_name_method=aiplatform.Model._format_resource_name,
632-
project=project,
633-
location=location,
634-
resource_id_validator=super()._revisioned_resource_id_validator,
635-
)
628+
try:
629+
model_name = utils.full_resource_name(
630+
resource_name=model_name,
631+
resource_noun="models",
632+
parse_resource_name_method=aiplatform.Model._parse_resource_name,
633+
format_resource_name_method=aiplatform.Model._format_resource_name,
634+
project=project,
635+
location=location,
636+
resource_id_validator=super()._revisioned_resource_id_validator,
637+
)
638+
except ValueError:
639+
# Do not raise exception if model_name is a valid PublisherModel name
640+
if not _publisher_model._PublisherModel._parse_resource_name(
641+
model_name
642+
):
643+
raise
636644

637645
# Raise error if both or neither source URIs are provided
638646
if bool(gcs_source) == bool(bigquery_source):

tests/unit/aiplatform/test_jobs.py

+26
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@
8888
_TEST_MODEL_VERSION_ID = "2"
8989
_TEST_VERSIONED_MODEL_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_ALT_ID}@{_TEST_MODEL_VERSION_ID}"
9090

91+
_TEST_PUBLISHER_MODEL_NAME = (
92+
f"publishers/google/models/text-model-name@{_TEST_MODEL_VERSION_ID}"
93+
)
94+
9195
_TEST_BATCH_PREDICTION_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/batchPredictionJobs/{_TEST_ID}"
9296
_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME = "test-batch-prediction-job"
9397

@@ -1267,6 +1271,28 @@ def test_batch_predict_job_with_versioned_model(
12671271
== _TEST_VERSIONED_MODEL_NAME
12681272
)
12691273

1274+
@pytest.mark.usefixtures("get_batch_prediction_job_mock")
1275+
def test_batch_predict_job_with_publisher_model(
1276+
self, create_batch_prediction_job_mock
1277+
):
1278+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
1279+
1280+
# Make SDK batch_predict method call
1281+
_ = jobs.BatchPredictionJob.create(
1282+
model_name=_TEST_PUBLISHER_MODEL_NAME,
1283+
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
1284+
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
1285+
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
1286+
sync=True,
1287+
service_account=_TEST_SERVICE_ACCOUNT,
1288+
)
1289+
assert (
1290+
create_batch_prediction_job_mock.call_args_list[0][1][
1291+
"batch_prediction_job"
1292+
].model
1293+
== _TEST_PUBLISHER_MODEL_NAME
1294+
)
1295+
12701296

12711297
@pytest.fixture
12721298
def get_mdm_job_mock():

0 commit comments

Comments
 (0)