Skip to content

Commit 3af9cc8

Browse files
KCFindstrcopybara-github
authored andcommitted
feat: Allow setting Vertex Model Garden source model name during model upload
PiperOrigin-RevId: 697786055
1 parent 3ab39a4 commit 3af9cc8

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

google/cloud/aiplatform/models.py

+14
Original file line numberDiff line numberDiff line change
@@ -4663,6 +4663,7 @@ def upload(
46634663
serving_container_health_probe_exec: Optional[Sequence[str]] = None,
46644664
serving_container_health_probe_period_seconds: Optional[int] = None,
46654665
serving_container_health_probe_timeout_seconds: Optional[int] = None,
4666+
model_garden_source_model_name: Optional[str] = None,
46664667
) -> "Model":
46674668
"""Uploads a model and returns a Model representing the uploaded Model
46684669
resource.
@@ -4875,6 +4876,10 @@ def upload(
48754876
serving_container_health_probe_timeout_seconds (int):
48764877
Optional. Number of seconds after which the health probe times
48774878
out. Defaults to 1 second. Minimum value is 1.
4879+
model_garden_source_model_name:
4880+
Optional. The model garden source model resource name if the
4881+
model is from Vertex Model Garden.
4882+
48784883
48794884
Returns:
48804885
model (aiplatform.Model):
@@ -5003,6 +5008,14 @@ def upload(
50035008
version_aliases=version_aliases, is_default_version=is_default_version
50045009
)
50055010

5011+
base_model_source = None
5012+
if model_garden_source_model_name:
5013+
base_model_source = gca_model_compat.Model.BaseModelSource(
5014+
model_garden_source=gca_model_compat.ModelGardenSource(
5015+
public_model_name=model_garden_source_model_name
5016+
)
5017+
)
5018+
50065019
managed_model = gca_model_compat.Model(
50075020
display_name=display_name,
50085021
description=description,
@@ -5012,6 +5025,7 @@ def upload(
50125025
predict_schemata=model_predict_schemata,
50135026
labels=labels,
50145027
encryption_spec=encryption_spec,
5028+
base_model_source=base_model_source,
50155029
)
50165030

50175031
if artifact_uri and not artifact_uri.startswith("gs://"):

tests/unit/aiplatform/test_models.py

+48
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174

175175
_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())
176176
_TEST_SERVICE_ACCOUNT = "[email protected]"
177+
_TEST_MODEL_GARDEN_SOURCE_MODEL_NAME = "publishers/meta/models/llama3_1"
177178

178179

179180
_TEST_EXPLANATION_METADATA = explain.ExplanationMetadata(
@@ -1900,6 +1901,53 @@ def test_upload_uploads_and_gets_model_with_custom_location(
19001901
name=test_model_resource_name, retry=base._DEFAULT_RETRY
19011902
)
19021903

1904+
@pytest.mark.parametrize("sync", [True, False])
1905+
def test_upload_with_model_garden_source(
1906+
self, upload_model_mock, get_model_mock, sync
1907+
):
1908+
1909+
my_model = models.Model.upload(
1910+
display_name=_TEST_MODEL_NAME,
1911+
serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
1912+
serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
1913+
serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
1914+
sync=sync,
1915+
upload_request_timeout=None,
1916+
model_garden_source_model_name=_TEST_MODEL_GARDEN_SOURCE_MODEL_NAME,
1917+
)
1918+
1919+
if not sync:
1920+
my_model.wait()
1921+
1922+
container_spec = gca_model.ModelContainerSpec(
1923+
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
1924+
predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
1925+
health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
1926+
)
1927+
1928+
managed_model = gca_model.Model(
1929+
display_name=_TEST_MODEL_NAME,
1930+
container_spec=container_spec,
1931+
version_aliases=["default"],
1932+
base_model_source=gca_model.Model.BaseModelSource(
1933+
model_garden_source=gca_model.ModelGardenSource(
1934+
public_model_name=_TEST_MODEL_GARDEN_SOURCE_MODEL_NAME
1935+
)
1936+
),
1937+
)
1938+
1939+
upload_model_mock.assert_called_once_with(
1940+
request=gca_model_service.UploadModelRequest(
1941+
parent=initializer.global_config.common_location_path(),
1942+
model=managed_model,
1943+
),
1944+
timeout=None,
1945+
)
1946+
1947+
get_model_mock.assert_called_once_with(
1948+
name=_TEST_MODEL_RESOURCE_NAME, retry=base._DEFAULT_RETRY
1949+
)
1950+
19031951
@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
19041952
@pytest.mark.parametrize("sync", [True, False])
19051953
def test_deploy(self, deploy_model_mock, sync):

0 commit comments

Comments
 (0)