Skip to content

Commit f821e45

Browse files
Ark-kuncopybara-github
authored andcommitted
chore: LLM - Removed the model launch stage limitations
Model garden models can have various launch stages: GA, public review, private preview. Previously, the `ModelGardenModel` base class required a model interface class to be marked as preview to be able to use preview models. This resulted in maintenance issues and unnecessary proliferations of classes, because each class had to be duplicated up to 3 times to be able to accommodate models in different launch stages. For example, some people with a private preview model had wither change their model launch stage or create a special private preview version of the model interface class just to work around the loading restrictions. The launch stage restriction has confused some of our users. Some user change `text-bison` to `text-bison-32k` and the code starts throwing error. With this change, the model interface classes no longer reject preview models. The breaking change guarantees now apply separately to the model interface classes and models. A preview class can change when going to GA. A preview model can change when going to GA. PiperOrigin-RevId: 606430510
1 parent 4661e58 commit f821e45

File tree

5 files changed

+0
-109
lines changed

5 files changed

+0
-109
lines changed

tests/unit/aiplatform/test_model_garden_models.py

-22
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ class TestModelGardenModels:
5454

5555
class FakeModelGardenModel(_model_garden_models._ModelGardenModel):
5656

57-
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
58-
5957
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"
6058

6159
def setup_method(self):
@@ -84,23 +82,3 @@ def test_init_model_garden_model_with_from_pretrained(self):
8482
name="publishers/google/models/text-bison@001",
8583
retry=base._DEFAULT_RETRY,
8684
)
87-
88-
def test_init_preview_model_raises_with_ga_launch_stage_set(self):
89-
"""Tests the text generation model."""
90-
aiplatform.init(
91-
project=test_constants.ProjectConstants._TEST_PROJECT,
92-
location=test_constants.ProjectConstants._TEST_LOCATION,
93-
)
94-
with mock.patch.object(
95-
target=model_garden_service_client_v1.ModelGardenServiceClient,
96-
attribute="get_publisher_model",
97-
return_value=gca_publisher_model.PublisherModel(
98-
_TEXT_BISON_PUBLISHER_MODEL_DICT
99-
),
100-
):
101-
self.FakeModelGardenModel._LAUNCH_STAGE = (
102-
_model_garden_models._SDK_GA_LAUNCH_STAGE
103-
)
104-
105-
with pytest.raises(ValueError):
106-
self.FakeModelGardenModel.from_pretrained("text-bison@001")

vertexai/_model_garden/_model_garden_models.py

-44
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,6 @@
4545
"codechat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
4646
}
4747

48-
_SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset(
49-
[
50-
gca_publisher_model.PublisherModel.LaunchStage.PRIVATE_PREVIEW,
51-
gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW,
52-
gca_publisher_model.PublisherModel.LaunchStage.GA,
53-
]
54-
)
55-
_SDK_PUBLIC_PREVIEW_LAUNCH_STAGE = frozenset(
56-
[
57-
gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW,
58-
gca_publisher_model.PublisherModel.LaunchStage.GA,
59-
]
60-
)
61-
_SDK_GA_LAUNCH_STAGE = frozenset([gca_publisher_model.PublisherModel.LaunchStage.GA])
62-
6348
_LOGGER = base.Logger(__name__)
6449

6550
T = TypeVar("T", bound="_ModelGardenModel")
@@ -241,10 +226,6 @@ def _from_pretrained(
241226
f"{model_name} is of type {model_info.interface_class.__name__} not of type {interface_class.__name__}"
242227
)
243228

244-
interface_class._validate_launch_stage(
245-
interface_class, model_info.publisher_model_resource
246-
)
247-
248229
return model_info.interface_class(
249230
model_id=model_name,
250231
endpoint_name=model_info.endpoint_name,
@@ -254,31 +235,6 @@ def _from_pretrained(
254235
class _ModelGardenModel:
255236
"""Base class for shared methods and properties across Model Garden models."""
256237

257-
_LAUNCH_STAGE: gca_publisher_model.PublisherModel.LaunchStage = (
258-
_SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
259-
)
260-
261-
def _validate_launch_stage(
262-
self,
263-
publisher_model_resource: gca_publisher_model.PublisherModel,
264-
) -> None:
265-
"""Validates the model class _LAUNCH_STAGE matches the PublisherModel resource's launch stage.
266-
267-
Args:
268-
publisher_model_resource (gca_publisher_model.PublisherModel
269-
The GAPIC PublisherModel resource for this model.
270-
"""
271-
272-
publisher_launch_stage = publisher_model_resource.launch_stage
273-
274-
if publisher_launch_stage not in self._LAUNCH_STAGE:
275-
raise ValueError(
276-
f"The model you are trying to instantiate has launch stage '{publisher_launch_stage.name}'"
277-
f", but the '{type(self).__module__}.{type(self).__name__}' class"
278-
f" only supports the following launch stages: {self._LAUNCH_STAGE}."
279-
" For preview models please use the classes from the `vertexai.preview.*` namespace."
280-
)
281-
282238
# Subclasses override this attribute to specify their instance schema
283239
_INSTANCE_SCHEMA_URI: Optional[str] = None
284240

vertexai/language_models/_language_models.py

-20
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,6 @@ def get_tuned_model(cls, tuned_model_name: str) -> "_LanguageModel":
205205
model_id=base_model_id,
206206
schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls},
207207
)
208-
cls._validate_launch_stage(cls, model_info.publisher_model_resource)
209-
210208
model = model_info.interface_class(
211209
model_id=base_model_id,
212210
endpoint_name=endpoint_name,
@@ -1241,8 +1239,6 @@ class _TextGenerationModel(_LanguageModel):
12411239
model.predict("What is life?")
12421240
"""
12431241

1244-
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
1245-
12461242
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"
12471243

12481244
_DEFAULT_MAX_OUTPUT_TOKENS = 128
@@ -1984,8 +1980,6 @@ class TextEmbeddingModel(_LanguageModel):
19841980

19851981
__module__ = "vertexai.language_models"
19861982

1987-
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
1988-
19891983
_INSTANCE_SCHEMA_URI = (
19901984
"gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml"
19911985
)
@@ -2121,8 +2115,6 @@ class _PreviewTextEmbeddingModel(
21212115
__name__ = "TextEmbeddingModel"
21222116
__module__ = "vertexai.preview.language_models"
21232117

2124-
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
2125-
21262118

21272119
@dataclasses.dataclass
21282120
class TextEmbeddingStatistics:
@@ -2173,8 +2165,6 @@ class ChatMessage:
21732165
class _ChatModelBase(_LanguageModel):
21742166
"""_ChatModelBase is a base class for chat models."""
21752167

2176-
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
2177-
21782168
def start_chat(
21792169
self,
21802170
*,
@@ -2251,8 +2241,6 @@ class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin):
22512241
__name__ = "ChatModel"
22522242
__module__ = "vertexai.preview.language_models"
22532243

2254-
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
2255-
22562244
def start_chat(
22572245
self,
22582246
*,
@@ -2313,7 +2301,6 @@ class CodeChatModel(_ChatModelBase, _TunableChatModelMixin):
23132301
__module__ = "vertexai.language_models"
23142302

23152303
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/codechat_generation_1.0.0.yaml"
2316-
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
23172304

23182305
def start_chat(
23192306
self,
@@ -2351,8 +2338,6 @@ class _PreviewCodeChatModel(CodeChatModel):
23512338
__name__ = "CodeChatModel"
23522339
__module__ = "vertexai.preview.language_models"
23532340

2354-
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
2355-
23562341
def start_chat(
23572342
self,
23582343
*,
@@ -3122,7 +3107,6 @@ class _CodeGenerationModel(_LanguageModel):
31223107

31233108
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
31243109

3125-
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
31263110

31273111
def _create_prediction_request(
31283112
self,
@@ -3390,8 +3374,6 @@ class _PreviewCodeGenerationModel(CodeGenerationModel, _CountTokensCodeGeneratio
33903374
__name__ = "CodeGenerationModel"
33913375
__module__ = "vertexai.preview.language_models"
33923376

3393-
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
3394-
33953377

33963378
###### Model tuning
33973379
# Currently, tuning can only work in this location
@@ -3710,5 +3692,3 @@ class _PreviewTextGenerationModel(
37103692
# Do not add docstring so that it's inherited from the base class.
37113693
__name__ = "TextGenerationModel"
37123694
__module__ = "vertexai.preview.language_models"
3713-
3714-
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE

vertexai/preview/vision_models.py

-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Classes for working with vision models."""
1616

1717
from vertexai.vision_models._vision_models import (
18-
_PreviewImageTextModel,
1918
Image,
2019
ImageGenerationModel,
2120
ImageGenerationResponse,
@@ -27,8 +26,6 @@
2726
MultiModalEmbeddingResponse,
2827
)
2928

30-
ImageTextModel = _PreviewImageTextModel
31-
3229
__all__ = [
3330
"Image",
3431
"ImageGenerationModel",

vertexai/vision_models/_vision_models.py

-20
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,6 @@ class ImageCaptioningModel(
599599
__module__ = "vertexai.vision_models"
600600

601601
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_reasoning_model_1.0.0.yaml"
602-
_LAUNCH_STAGE = (
603-
_model_garden_models._SDK_GA_LAUNCH_STAGE # pylint: disable=protected-access
604-
)
605602

606603
def get_captions(
607604
self,
@@ -667,9 +664,6 @@ class ImageQnAModel(
667664
__module__ = "vertexai.vision_models"
668665

669666
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_reasoning_model_1.0.0.yaml"
670-
_LAUNCH_STAGE = (
671-
_model_garden_models._SDK_GA_LAUNCH_STAGE # pylint: disable=protected-access
672-
)
673667

674668
def ask_question(
675669
self,
@@ -729,10 +723,6 @@ class MultiModalEmbeddingModel(_model_garden_models._ModelGardenModel):
729723

730724
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_embedding_model_1.0.0.yaml"
731725

732-
_LAUNCH_STAGE = (
733-
_model_garden_models._SDK_GA_LAUNCH_STAGE # pylint: disable=protected-access
734-
)
735-
736726
def get_embeddings(
737727
self,
738728
image: Optional[Image] = None,
@@ -847,13 +837,3 @@ class ImageTextModel(ImageCaptioningModel, ImageQnAModel):
847837
# since SDK Model Garden classes should follow the design pattern of exactly 1 SDK class to 1 Model Garden schema URI
848838

849839
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_reasoning_model_1.0.0.yaml"
850-
_LAUNCH_STAGE = (
851-
_model_garden_models._SDK_GA_LAUNCH_STAGE # pylint: disable=protected-access
852-
)
853-
854-
855-
class _PreviewImageTextModel(ImageTextModel):
856-
857-
__module__ = "vertexai.preview.vision_models"
858-
859-
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE

0 commit comments

Comments
 (0)