Skip to content

Commit 74c2066

Browse files
sararobcopybara-github
authored andcommitted
chore: add model garden and publisher model v1 support
PiperOrigin-RevId: 537955149
1 parent c5d62fb commit 74c2066

File tree

7 files changed

+28
-18
lines changed

7 files changed

+28
-18
lines changed

google/cloud/aiplatform/compat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
)
141141
services.featurestore_service_client = services.featurestore_service_client_v1
142142
services.job_service_client = services.job_service_client_v1
143+
services.model_garden_service_client = services.model_garden_service_client_v1
143144
services.model_service_client = services.model_service_client_v1
144145
services.pipeline_service_client = services.pipeline_service_client_v1
145146
services.prediction_service_client = services.prediction_service_client_v1
@@ -206,6 +207,7 @@
206207
types.pipeline_service = types.pipeline_service_v1
207208
types.pipeline_state = types.pipeline_state_v1
208209
types.prediction_service = types.prediction_service_v1
210+
types.publisher_model = types.publisher_model_v1
209211
types.specialist_pool = types.specialist_pool_v1
210212
types.specialist_pool_service = types.specialist_pool_service_v1
211213
types.study = types.study_v1

google/cloud/aiplatform/compat/services/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@
9494
from google.cloud.aiplatform_v1.services.metadata_service import (
9595
client as metadata_service_client_v1,
9696
)
97+
from google.cloud.aiplatform_v1.services.model_garden_service import (
98+
client as model_garden_service_client_v1,
99+
)
97100
from google.cloud.aiplatform_v1.services.model_service import (
98101
client as model_service_client_v1,
99102
)
@@ -123,6 +126,7 @@
123126
index_endpoint_service_client_v1,
124127
job_service_client_v1,
125128
metadata_service_client_v1,
129+
model_garden_service_client_v1,
126130
model_service_client_v1,
127131
pipeline_service_client_v1,
128132
prediction_service_client_v1,

google/cloud/aiplatform/compat/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
pipeline_service as pipeline_service_v1,
145145
pipeline_state as pipeline_state_v1,
146146
prediction_service as prediction_service_v1,
147+
publisher_model as publisher_model_v1,
147148
specialist_pool as specialist_pool_v1,
148149
specialist_pool_service as specialist_pool_service_v1,
149150
study as study_v1,
@@ -213,6 +214,7 @@
213214
pipeline_service_v1,
214215
pipeline_state_v1,
215216
prediction_service_v1,
217+
publisher_model_v1,
216218
specialist_pool_v1,
217219
specialist_pool_service_v1,
218220
tensorboard_v1,

google/cloud/aiplatform/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
index_endpoint_service_client_v1,
6464
job_service_client_v1,
6565
metadata_service_client_v1,
66+
model_garden_service_client_v1,
6667
model_service_client_v1,
6768
pipeline_service_client_v1,
6869
prediction_service_client_v1,
@@ -646,8 +647,9 @@ class VizierClientWithOverride(ClientWithOverride):
646647

647648
class ModelGardenClientWithOverride(ClientWithOverride):
648649
_is_temporary = True
649-
_default_version = compat.V1BETA1
650+
_default_version = compat.DEFAULT_VERSION
650651
_version_map = (
652+
(compat.V1, model_garden_service_client_v1.ModelGardenServiceClient),
651653
(compat.V1BETA1, model_garden_service_client_v1beta1.ModelGardenServiceClient),
652654
)
653655

tests/unit/aiplatform/test_language_models.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import constants as test_constants
3535

3636
from google.cloud.aiplatform.compat.services import (
37-
model_garden_service_client_v1beta1,
37+
model_garden_service_client,
3838
endpoint_service_client,
3939
model_service_client,
4040
pipeline_service_client,
@@ -46,9 +46,9 @@
4646
endpoint as gca_endpoint,
4747
pipeline_job as gca_pipeline_job,
4848
pipeline_state as gca_pipeline_state,
49-
deployed_model_ref_v1beta1,
49+
deployed_model_ref_v1,
5050
)
51-
from google.cloud.aiplatform_v1beta1.types import (
51+
from google.cloud.aiplatform_v1.types import (
5252
publisher_model as gca_publisher_model,
5353
model as gca_model,
5454
)
@@ -478,7 +478,7 @@ def get_model_with_tuned_version_label_mock():
478478
name=test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME,
479479
labels={"google-vertex-llm-tuning-base-model-id": "text-bison-001"},
480480
deployed_models=[
481-
deployed_model_ref_v1beta1.DeployedModelRef(
481+
deployed_model_ref_v1.DeployedModelRef(
482482
endpoint=test_constants.EndpointConstants._TEST_ENDPOINT_NAME,
483483
deployed_model_id=test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME,
484484
)
@@ -525,7 +525,7 @@ def test_text_generation(self):
525525
location=_TEST_LOCATION,
526526
)
527527
with mock.patch.object(
528-
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
528+
target=model_garden_service_client.ModelGardenServiceClient,
529529
attribute="get_publisher_model",
530530
return_value=gca_publisher_model.PublisherModel(
531531
_TEXT_BISON_PUBLISHER_MODEL_DICT
@@ -590,7 +590,7 @@ def test_tune_model(
590590
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
591591
)
592592
with mock.patch.object(
593-
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
593+
target=model_garden_service_client.ModelGardenServiceClient,
594594
attribute="get_publisher_model",
595595
return_value=gca_publisher_model.PublisherModel(
596596
_TEXT_BISON_PUBLISHER_MODEL_DICT
@@ -625,7 +625,7 @@ def test_get_tuned_model(
625625
)
626626

627627
with mock.patch.object(
628-
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
628+
target=model_garden_service_client.ModelGardenServiceClient,
629629
attribute="get_publisher_model",
630630
return_value=gca_publisher_model.PublisherModel(
631631
_TEXT_BISON_PUBLISHER_MODEL_DICT
@@ -662,7 +662,7 @@ def test_chat(self):
662662
location=_TEST_LOCATION,
663663
)
664664
with mock.patch.object(
665-
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
665+
target=model_garden_service_client.ModelGardenServiceClient,
666666
attribute="get_publisher_model",
667667
return_value=gca_publisher_model.PublisherModel(
668668
_CHAT_BISON_PUBLISHER_MODEL_DICT
@@ -780,7 +780,7 @@ def test_code_chat(self):
780780
location=_TEST_LOCATION,
781781
)
782782
with mock.patch.object(
783-
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
783+
target=model_garden_service_client.ModelGardenServiceClient,
784784
attribute="get_publisher_model",
785785
return_value=gca_publisher_model.PublisherModel(
786786
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
@@ -876,7 +876,7 @@ def test_code_generation(self):
876876
location=_TEST_LOCATION,
877877
)
878878
with mock.patch.object(
879-
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
879+
target=model_garden_service_client.ModelGardenServiceClient,
880880
attribute="get_publisher_model",
881881
return_value=gca_publisher_model.PublisherModel(
882882
_CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
@@ -942,7 +942,7 @@ def test_code_completion(self):
942942
location=_TEST_LOCATION,
943943
)
944944
with mock.patch.object(
945-
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
945+
target=model_garden_service_client.ModelGardenServiceClient,
946946
attribute="get_publisher_model",
947947
return_value=gca_publisher_model.PublisherModel(
948948
_CODE_COMPLETION_BISON_PUBLISHER_MODEL_DICT
@@ -1008,7 +1008,7 @@ def test_text_embedding(self):
10081008
location=_TEST_LOCATION,
10091009
)
10101010
with mock.patch.object(
1011-
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
1011+
target=model_garden_service_client.ModelGardenServiceClient,
10121012
attribute="get_publisher_model",
10131013
return_value=gca_publisher_model.PublisherModel(
10141014
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT

tests/unit/aiplatform/test_model_garden_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
import constants as test_constants
2727

2828
from google.cloud.aiplatform.compat.services import (
29-
model_garden_service_client_v1beta1,
29+
model_garden_service_client_v1,
3030
)
3131

32-
from google.cloud.aiplatform_v1beta1.types import (
32+
from google.cloud.aiplatform_v1.types import (
3333
publisher_model as gca_publisher_model,
3434
)
3535

@@ -76,7 +76,7 @@ def test_init_model_garden_model_with_from_pretrained(self):
7676
location=test_constants.ProjectConstants._TEST_LOCATION,
7777
)
7878
with mock.patch.object(
79-
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
79+
target=model_garden_service_client_v1.ModelGardenServiceClient,
8080
attribute="get_publisher_model",
8181
return_value=gca_publisher_model.PublisherModel(
8282
_TEXT_BISON_PUBLISHER_MODEL_DICT

tests/unit/aiplatform/test_publisher_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from google.cloud.aiplatform import _publisher_models
2727

2828
from google.cloud.aiplatform.compat.services import (
29-
model_garden_service_client_v1beta1,
29+
model_garden_service_client_v1,
3030
)
3131

3232

@@ -41,7 +41,7 @@
4141
@pytest.fixture
4242
def mock_get_publisher_model():
4343
with mock.patch.object(
44-
model_garden_service_client_v1beta1.ModelGardenServiceClient,
44+
model_garden_service_client_v1.ModelGardenServiceClient,
4545
"get_publisher_model",
4646
) as mock_get_publisher_model:
4747
yield mock_get_publisher_model

0 commit comments

Comments
 (0)