Skip to content

Commit 68460b8

Browse files
sararobpabloem
authored andcommitted
chore: add validation for model launch stage
PiperOrigin-RevId: 538471545
1 parent a6fb43c commit 68460b8

File tree

5 files changed

+87
-5
lines changed

5 files changed

+87
-5
lines changed

tests/unit/aiplatform/test_language_models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
pipeline_state as gca_pipeline_state,
4949
deployed_model_ref_v1,
5050
)
51-
from google.cloud.aiplatform_v1.types import (
51+
from google.cloud.aiplatform.compat.types import (
5252
publisher_model as gca_publisher_model,
5353
model as gca_model,
5454
)
@@ -75,6 +75,7 @@
7575
"name": "publishers/google/models/text-bison",
7676
"version_id": "001",
7777
"open_source_category": "PROPRIETARY",
78+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA,
7879
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/text-bison@001",
7980
"predict_schemata": {
8081
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml",
@@ -87,6 +88,7 @@
8788
"name": "publishers/google/models/chat-bison",
8889
"version_id": "001",
8990
"open_source_category": "PROPRIETARY",
91+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW,
9092
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001",
9193
"predict_schemata": {
9294
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml",
@@ -99,6 +101,7 @@
99101
"name": "publishers/google/models/codechat-bison",
100102
"version_id": "001",
101103
"open_source_category": "PROPRIETARY",
104+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW,
102105
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/codechat-bison@001",
103106
"predict_schemata": {
104107
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/codechat_generation_1.0.0.yaml",
@@ -111,6 +114,7 @@
111114
"name": "publishers/google/models/code-bison",
112115
"version_id": "001",
113116
"open_source_category": "PROPRIETARY",
117+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW,
114118
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/code-bison@001",
115119
"predict_schemata": {
116120
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml",
@@ -123,6 +127,7 @@
123127
"name": "publishers/google/models/code-gecko",
124128
"version_id": "001",
125129
"open_source_category": "PROPRIETARY",
130+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW,
126131
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/code-gecko@001",
127132
"predict_schemata": {
128133
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml",
@@ -135,6 +140,7 @@
135140
"name": "publishers/google/models/textembedding-gecko",
136141
"version_id": "001",
137142
"open_source_category": "PROPRIETARY",
143+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA,
138144
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001",
139145
"predict_schemata": {
140146
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml",

tests/unit/aiplatform/test_model_garden_models.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
model_garden_service_client_v1,
2929
)
3030

31-
from google.cloud.aiplatform_v1.types import (
31+
from google.cloud.aiplatform.compat.types import (
3232
publisher_model as gca_publisher_model,
3333
)
3434

@@ -38,6 +38,7 @@
3838
"name": "publishers/google/models/text-bison",
3939
"version_id": "001",
4040
"open_source_category": "PROPRIETARY",
41+
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW,
4142
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/text-bison@001",
4243
"predict_schemata": {
4344
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml",
@@ -52,6 +53,9 @@ class TestModelGardenModels:
5253
"""Unit tests for the _ModelGardenModel base class."""
5354

5455
class FakeModelGardenModel(_model_garden_models._ModelGardenModel):
56+
57+
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
58+
5559
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"
5660

5761
def setup_method(self):
@@ -80,3 +84,23 @@ def test_init_model_garden_model_with_from_pretrained(self):
8084
name="publishers/google/models/text-bison@001",
8185
retry=base._DEFAULT_RETRY,
8286
)
87+
88+
def test_init_preview_model_raises_with_ga_launch_stage_set(self):
89+
"""Tests the text generation model."""
90+
aiplatform.init(
91+
project=test_constants.ProjectConstants._TEST_PROJECT,
92+
location=test_constants.ProjectConstants._TEST_LOCATION,
93+
)
94+
with mock.patch.object(
95+
target=model_garden_service_client_v1.ModelGardenServiceClient,
96+
attribute="get_publisher_model",
97+
return_value=gca_publisher_model.PublisherModel(
98+
_TEXT_BISON_PUBLISHER_MODEL_DICT
99+
),
100+
):
101+
self.FakeModelGardenModel._LAUNCH_STAGE = (
102+
_model_garden_models._SDK_GA_LAUNCH_STAGE
103+
)
104+
105+
with pytest.raises(ValueError):
106+
self.FakeModelGardenModel.from_pretrained("text-bison@001")

vertexai/_model_garden/_model_garden_models.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
"""Base class for working with Model Garden models."""
1717

18-
import abc
1918
import dataclasses
2019
from typing import Dict, Optional, Type
2120

@@ -25,20 +24,32 @@
2524
from google.cloud.aiplatform import models as aiplatform_models
2625
from google.cloud.aiplatform import _publisher_models
2726

27+
from google.cloud.aiplatform.compat.types import (
28+
publisher_model as gca_publisher_model,
29+
)
2830

2931
_SUPPORTED_PUBLISHERS = ["google"]
3032

3133
_SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = {
3234
"text-bison": "https://us-kfp.pkg.dev/vertex-ai/large-language-model-pipelines/tune-large-model/sdk-1-25"
3335
}
3436

37+
_SDK_PUBLIC_PREVIEW_LAUNCH_STAGE = frozenset(
38+
[
39+
gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW,
40+
gca_publisher_model.PublisherModel.LaunchStage.GA,
41+
]
42+
)
43+
_SDK_GA_LAUNCH_STAGE = frozenset([gca_publisher_model.PublisherModel.LaunchStage.GA])
44+
3545
_LOGGER = base.Logger(__name__)
3646

3747

3848
@dataclasses.dataclass
3949
class _ModelInfo:
4050
endpoint_name: str
4151
interface_class: Type["_ModelGardenModel"]
52+
publisher_model_resource: _publisher_models._PublisherModel
4253
tuning_pipeline_uri: Optional[str] = None
4354
tuning_model_id: Optional[str] = None
4455

@@ -114,6 +125,7 @@ def _get_model_info(
114125
return _ModelInfo(
115126
endpoint_name=endpoint_name,
116127
interface_class=interface_class,
128+
publisher_model_resource=publisher_model_res,
117129
tuning_pipeline_uri=tuning_pipeline_uri,
118130
tuning_model_id=tuning_model_id,
119131
)
@@ -122,6 +134,28 @@ def _get_model_info(
122134
class _ModelGardenModel:
123135
"""Base class for shared methods and properties across Model Garden models."""
124136

137+
_LAUNCH_STAGE: gca_publisher_model.PublisherModel.LaunchStage = (
138+
_SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
139+
)
140+
141+
def _validate_launch_stage(
142+
self,
143+
publisher_model_resource: gca_publisher_model.PublisherModel,
144+
) -> None:
145+
"""Validates the model class _LAUNCH_STAGE matches the PublisherModel resource's launch stage.
146+
147+
Args:
148+
publisher_model_resource (gca_publisher_model.PublisherModel
149+
The GAPIC PublisherModel resource for this model.
150+
"""
151+
152+
publisher_launch_stage = publisher_model_resource.launch_stage
153+
154+
if publisher_launch_stage not in self._LAUNCH_STAGE:
155+
raise ValueError(
156+
f"The model you are trying to instantiate does not support the launch stage: {publisher_launch_stage.name}"
157+
)
158+
125159
# Subclasses override this attribute to specify their instance schema
126160
_INSTANCE_SCHEMA_URI: Optional[str] = None
127161

@@ -174,6 +208,8 @@ def from_pretrained(cls, model_name: str) -> "_ModelGardenModel":
174208
f"{model_name} is of type {model_info.interface_class.__name__} not of type {cls.__name__}"
175209
)
176210

211+
cls._validate_launch_stage(cls, model_info.publisher_model_resource)
212+
177213
return model_info.interface_class(
178214
model_id=model_name,
179215
endpoint_name=model_info.endpoint_name,

vertexai/language_models/_language_models.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def get_tuned_model(cls, tuned_model_name: str) -> "_LanguageModel":
126126
model_id=base_model_id,
127127
schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls},
128128
)
129+
cls._validate_launch_stage(cls, model_info.publisher_model_resource)
130+
129131
model = model_info.interface_class(
130132
model_id=base_model_id,
131133
endpoint_name=endpoint_name,
@@ -215,6 +217,8 @@ class TextGenerationModel(_LanguageModel):
215217
model.predict("What is life?")
216218
"""
217219

220+
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
221+
218222
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"
219223

220224
_DEFAULT_TEMPERATURE = 0.0
@@ -300,7 +304,7 @@ def _batch_predict(
300304
class _PreviewTextGenerationModel(TextGenerationModel, _TunableModelMixin):
301305
"""Tunable text generation model."""
302306

303-
pass
307+
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
304308

305309

306310
class _ChatModel(TextGenerationModel):
@@ -427,6 +431,8 @@ class TextEmbeddingModel(_LanguageModel):
427431
print(len(vector))
428432
"""
429433

434+
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
435+
430436
_INSTANCE_SCHEMA_URI = (
431437
"gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml"
432438
)
@@ -447,6 +453,12 @@ def get_embeddings(self, texts: List[str]) -> List["TextEmbedding"]:
447453
]
448454

449455

456+
class _PreviewTextEmbeddingModel(TextEmbeddingModel):
457+
"""Preview text embedding model."""
458+
459+
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
460+
461+
450462
class TextEmbedding:
451463
"""Contains text embedding vector."""
452464

@@ -470,6 +482,8 @@ class InputOutputTextPair:
470482
class _ChatModelBase(_LanguageModel):
471483
"""_ChatModelBase is a base class for chat models."""
472484

485+
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
486+
473487
def start_chat(
474488
self,
475489
*,
@@ -777,6 +791,7 @@ class CodeGenerationModel(_LanguageModel):
777791

778792
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
779793

794+
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
780795
_DEFAULT_TEMPERATURE = 0.0
781796
_DEFAULT_MAX_OUTPUT_TOKENS = 128
782797

vertexai/preview/language_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Classes for working with language models."""
1616

1717
from vertexai.language_models._language_models import (
18+
_PreviewTextEmbeddingModel,
1819
_PreviewTextGenerationModel,
1920
ChatModel,
2021
ChatSession,
@@ -23,11 +24,11 @@
2324
CodeGenerationModel,
2425
InputOutputTextPair,
2526
TextEmbedding,
26-
TextEmbeddingModel,
2727
TextGenerationResponse,
2828
)
2929

3030
TextGenerationModel = _PreviewTextGenerationModel
31+
TextEmbeddingModel = _PreviewTextEmbeddingModel
3132

3233
__all__ = [
3334
"ChatModel",

0 commit comments

Comments
 (0)