Skip to content

Commit 43488fe

Browse files
jaycee-licopybara-github
authored andcommitted
chore: add _PublisherModel class in preview module
PiperOrigin-RevId: 528939747
1 parent fa7d3a0 commit 43488fe

File tree

6 files changed

+198
-1
lines changed

6 files changed

+198
-1
lines changed

google/cloud/aiplatform/compat/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
services.featurestore_service_client = services.featurestore_service_client_v1beta1
3737
services.job_service_client = services.job_service_client_v1beta1
3838
services.model_service_client = services.model_service_client_v1beta1
39+
services.model_garden_service_client = services.model_garden_service_client_v1beta1
3940
services.pipeline_service_client = services.pipeline_service_client_v1beta1
4041
services.prediction_service_client = services.prediction_service_client_v1beta1
4142
services.specialist_pool_service_client = (
@@ -103,6 +104,7 @@
103104
types.model_deployment_monitoring_job = (
104105
types.model_deployment_monitoring_job_v1beta1
105106
)
107+
types.model_garden_service = types.model_garden_service_v1beta1
106108
types.model_monitoring = types.model_monitoring_v1beta1
107109
types.model_service = types.model_service_v1beta1
108110
types.operation = types.operation_v1beta1
@@ -111,6 +113,7 @@
111113
types.pipeline_service = types.pipeline_service_v1beta1
112114
types.pipeline_state = types.pipeline_state_v1beta1
113115
types.prediction_service = types.prediction_service_v1beta1
116+
types.publisher_model = types.publisher_model_v1beta1
114117
types.specialist_pool = types.specialist_pool_v1beta1
115118
types.specialist_pool_service = types.specialist_pool_service_v1beta1
116119
types.study = types.study_v1beta1

google/cloud/aiplatform/compat/services/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
from google.cloud.aiplatform_v1beta1.services.metadata_service import (
4646
client as metadata_service_client_v1beta1,
4747
)
48+
from google.cloud.aiplatform_v1beta1.services.model_garden_service import (
49+
client as model_garden_service_client_v1beta1,
50+
)
4851
from google.cloud.aiplatform_v1beta1.services.model_service import (
4952
client as model_service_client_v1beta1,
5053
)
@@ -133,6 +136,7 @@
133136
index_endpoint_service_client_v1beta1,
134137
job_service_client_v1beta1,
135138
match_service_client_v1beta1,
139+
model_garden_service_client_v1beta1,
136140
model_service_client_v1beta1,
137141
pipeline_service_client_v1beta1,
138142
prediction_service_client_v1beta1,

google/cloud/aiplatform/compat/types/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
model_evaluation as model_evaluation_v1beta1,
6666
model_evaluation_slice as model_evaluation_slice_v1beta1,
6767
model_deployment_monitoring_job as model_deployment_monitoring_job_v1beta1,
68+
model_garden_service as model_garden_service_v1beta1,
6869
model_service as model_service_v1beta1,
6970
model_monitoring as model_monitoring_v1beta1,
7071
operation as operation_v1beta1,
@@ -73,6 +74,7 @@
7374
pipeline_service as pipeline_service_v1beta1,
7475
pipeline_state as pipeline_state_v1beta1,
7576
prediction_service as prediction_service_v1beta1,
77+
publisher_model as publisher_model_v1beta1,
7678
specialist_pool as specialist_pool_v1beta1,
7779
specialist_pool_service as specialist_pool_service_v1beta1,
7880
study as study_v1beta1,
@@ -204,7 +206,7 @@
204206
model_service_v1,
205207
model_monitoring_v1,
206208
operation_v1,
207-
pipeline_failure_policy_v1beta1,
209+
pipeline_failure_policy_v1,
208210
pipeline_job_v1,
209211
pipeline_service_v1,
210212
pipeline_state_v1,
@@ -219,6 +221,8 @@
219221
tensorboard_time_series_v1,
220222
training_pipeline_v1,
221223
types_v1,
224+
study_v1,
225+
vizier_service_v1,
222226
# v1beta1
223227
accelerator_type_v1beta1,
224228
annotation_v1beta1,
@@ -269,6 +273,7 @@
269273
model_evaluation_v1beta1,
270274
model_evaluation_slice_v1beta1,
271275
model_deployment_monitoring_job_v1beta1,
276+
model_garden_service_v1beta1,
272277
model_service_v1beta1,
273278
model_monitoring_v1beta1,
274279
operation_v1beta1,
@@ -277,8 +282,10 @@
277282
pipeline_service_v1beta1,
278283
pipeline_state_v1beta1,
279284
prediction_service_v1beta1,
285+
publisher_model_v1beta1,
280286
specialist_pool_v1beta1,
281287
specialist_pool_service_v1beta1,
288+
study_v1beta1,
282289
tensorboard_v1beta1,
283290
tensorboard_data_v1beta1,
284291
tensorboard_experiment_v1beta1,
@@ -287,4 +294,5 @@
287294
tensorboard_time_series_v1beta1,
288295
training_pipeline_v1beta1,
289296
types_v1beta1,
297+
vizier_service_v1beta1,
290298
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import re
19+
from typing import Optional
20+
21+
from google.auth import credentials as auth_credentials
22+
from google.cloud.aiplatform import base
23+
from google.cloud.aiplatform import utils
24+
25+
26+
class _PublisherModel(base.VertexAiResourceNoun):
27+
"""Publisher Model Resource for Vertex AI."""
28+
29+
client_class = utils.ModelGardenClientWithOverride
30+
31+
_resource_noun = "publisher_model"
32+
_getter_method = "get_publisher_model"
33+
_delete_method = None
34+
_parse_resource_name_method = "parse_publisher_model_path"
35+
_format_resource_name_method = "publisher_model_path"
36+
37+
def __init__(
38+
self,
39+
resource_name: str,
40+
project: Optional[str] = None,
41+
location: Optional[str] = None,
42+
credentials: Optional[auth_credentials.Credentials] = None,
43+
):
44+
"""Retrieves an existing PublisherModel resource given a resource name or model garden id.
45+
46+
Args:
47+
resource_name (str):
48+
Required. A fully-qualified PublisherModel resource name or
49+
model garden id. Format:
50+
`publishers/{publisher}/models/{publisher_model}` or
51+
`{publisher}/{publisher_model}`.
52+
project (str):
53+
Optional. Project to retrieve the resource from. If not set,
54+
project set in aiplatform.init will be used.
55+
location (str):
56+
Optional. Location to retrieve the resource from. If not set,
57+
location set in aiplatform.init will be used.
58+
credentials (auth_credentials.Credentials):
59+
Optional. Custom credentials to use to retrieve the resource.
60+
Overrides credentials set in aiplatform.init.
61+
"""
62+
63+
super().__init__(project=project, location=location, credentials=credentials)
64+
65+
if self._parse_resource_name(resource_name):
66+
full_resource_name = resource_name
67+
else:
68+
m = re.match(r"^(?P<publisher>.+?)/(?P<model>.+?)$", resource_name)
69+
if m:
70+
full_resource_name = self._format_resource_name(**m.groupdict())
71+
else:
72+
raise ValueError(
73+
f"`{resource_name}` is not a valid PublisherModel resource "
74+
"name or model garden id."
75+
)
76+
77+
self._gca_resource = getattr(self.api_client, self._getter_method)(
78+
name=full_resource_name, retry=base._DEFAULT_RETRY
79+
)

google/cloud/aiplatform/utils/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
prediction_service_client_v1beta1,
5252
tensorboard_service_client_v1beta1,
5353
vizier_service_client_v1beta1,
54+
model_garden_service_client_v1beta1,
5455
)
5556
from google.cloud.aiplatform.compat.services import (
5657
dataset_service_client_v1,
@@ -633,6 +634,14 @@ class VizierClientWithOverride(ClientWithOverride):
633634
)
634635

635636

637+
class ModelGardenClientWithOverride(ClientWithOverride):
638+
_is_temporary = True
639+
_default_version = compat.V1BETA1
640+
_version_map = (
641+
(compat.V1BETA1, model_garden_service_client_v1beta1.ModelGardenServiceClient),
642+
)
643+
644+
636645
VertexAiServiceClientWithOverride = TypeVar(
637646
"VertexAiServiceClientWithOverride",
638647
DatasetClientWithOverride,
@@ -647,6 +656,7 @@ class VizierClientWithOverride(ClientWithOverride):
647656
MetadataClientWithOverride,
648657
TensorboardClientWithOverride,
649658
VizierClientWithOverride,
659+
ModelGardenClientWithOverride,
650660
)
651661

652662

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
20+
from unittest import mock
21+
from importlib import reload
22+
23+
from google.cloud import aiplatform
24+
from google.cloud.aiplatform import base
25+
from google.cloud.aiplatform import initializer
26+
from google.cloud.aiplatform.preview import _publisher_model
27+
28+
from google.cloud.aiplatform.compat.services import (
29+
model_garden_service_client_v1beta1,
30+
)
31+
32+
33+
_TEST_PROJECT = "test-project"
34+
_TEST_LOCATION = "us-central1"
35+
36+
_TEST_RESOURCE_NAME = "publishers/google/models/chat-bison@001"
37+
_TEST_MODEL_GARDEN_ID = "google/chat-bison@001"
38+
_TEST_INVALID_RESOURCE_NAME = "google.chat-bison@001"
39+
40+
41+
@pytest.fixture
42+
def mock_get_publisher_model():
43+
with mock.patch.object(
44+
model_garden_service_client_v1beta1.ModelGardenServiceClient,
45+
"get_publisher_model",
46+
) as mock_get_publisher_model:
47+
yield mock_get_publisher_model
48+
49+
50+
@pytest.mark.usefixtures("google_auth_mock")
51+
class TestPublisherModel:
52+
def setup_method(self):
53+
reload(initializer)
54+
reload(aiplatform)
55+
56+
def teardown_method(self):
57+
initializer.global_pool.shutdown(wait=True)
58+
59+
def test_init_publisher_model_with_resource_name(self, mock_get_publisher_model):
60+
aiplatform.init(
61+
project=_TEST_PROJECT,
62+
location=_TEST_LOCATION,
63+
)
64+
_ = _publisher_model._PublisherModel(_TEST_RESOURCE_NAME)
65+
mock_get_publisher_model.assert_called_once_with(
66+
name=_TEST_RESOURCE_NAME, retry=base._DEFAULT_RETRY
67+
)
68+
69+
def test_init_publisher_model_with_model_garden_id(self, mock_get_publisher_model):
70+
aiplatform.init(
71+
project=_TEST_PROJECT,
72+
location=_TEST_LOCATION,
73+
)
74+
_ = _publisher_model._PublisherModel(_TEST_MODEL_GARDEN_ID)
75+
mock_get_publisher_model.assert_called_once_with(
76+
name=_TEST_RESOURCE_NAME, retry=base._DEFAULT_RETRY
77+
)
78+
79+
def test_init_publisher_model_with_invalid_resource_name(
80+
self, mock_get_publisher_model
81+
):
82+
aiplatform.init(
83+
project=_TEST_PROJECT,
84+
location=_TEST_LOCATION,
85+
)
86+
with pytest.raises(
87+
ValueError,
88+
match=(
89+
f"`{_TEST_INVALID_RESOURCE_NAME}` is not a valid PublisherModel "
90+
"resource name or model garden id."
91+
),
92+
):
93+
_ = _publisher_model._PublisherModel(_TEST_INVALID_RESOURCE_NAME)

0 commit comments

Comments
 (0)