Skip to content

Commit 76465e2

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Released the LLM SDK to GA
This commit also refactors the machinery around the mapping between schemas and interface classes. The central table is not needed now. Each interface class knows its own instance schema. This architecture makes it much easier to experiment with different interface classes since `MyClass.from_pretrained` now returns an instance of `MyClass` (if the model's instance schema matches). This change makes it trivial to have multiple versions of an interface class (e.g. GA, preview etc). It was much harder with a centralized table that could only hold a single interface class for each instance schema. PiperOrigin-RevId: 538100676
1 parent ce5dee4 commit 76465e2

File tree

5 files changed

+148
-72
lines changed

5 files changed

+148
-72
lines changed

tests/unit/aiplatform/test_language_models.py

+103-15
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@
5353
model as gca_model,
5454
)
5555

56-
from vertexai.preview import language_models
56+
from vertexai.preview import (
57+
language_models as preview_language_models,
58+
)
59+
from vertexai import language_models
5760
from google.cloud.aiplatform_v1 import Execution as GapicExecution
5861
from google.cloud.aiplatform.compat.types import (
5962
encryption_spec as gca_encryption_spec,
@@ -456,7 +459,7 @@ def get_endpoint_mock():
456459
@pytest.fixture
457460
def mock_get_tuned_model(get_endpoint_mock):
458461
with mock.patch.object(
459-
language_models.TextGenerationModel, "get_tuned_model"
462+
preview_language_models.TextGenerationModel, "get_tuned_model"
460463
) as mock_text_generation_model:
461464
mock_text_generation_model._model_id = (
462465
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
@@ -519,6 +522,50 @@ def teardown_method(self):
519522
initializer.global_pool.shutdown(wait=True)
520523

521524
def test_text_generation(self):
525+
"""Tests the text generation model."""
526+
aiplatform.init(
527+
project=_TEST_PROJECT,
528+
location=_TEST_LOCATION,
529+
)
530+
with mock.patch.object(
531+
target=model_garden_service_client.ModelGardenServiceClient,
532+
attribute="get_publisher_model",
533+
return_value=gca_publisher_model.PublisherModel(
534+
_TEXT_BISON_PUBLISHER_MODEL_DICT
535+
),
536+
) as mock_get_publisher_model:
537+
model = preview_language_models.TextGenerationModel.from_pretrained(
538+
"text-bison@001"
539+
)
540+
541+
mock_get_publisher_model.assert_called_once_with(
542+
name="publishers/google/models/text-bison@001", retry=base._DEFAULT_RETRY
543+
)
544+
545+
assert (
546+
model._model_resource_name
547+
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001"
548+
)
549+
550+
gca_predict_response = gca_prediction_service.PredictResponse()
551+
gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION)
552+
553+
with mock.patch.object(
554+
target=prediction_service_client.PredictionServiceClient,
555+
attribute="predict",
556+
return_value=gca_predict_response,
557+
):
558+
response = model.predict(
559+
"What is the best recipe for banana bread? Recipe:",
560+
max_output_tokens=128,
561+
temperature=0,
562+
top_p=1,
563+
top_k=5,
564+
)
565+
566+
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
567+
568+
def test_text_generation_ga(self):
522569
"""Tests the text generation model."""
523570
aiplatform.init(
524571
project=_TEST_PROJECT,
@@ -596,7 +643,7 @@ def test_tune_model(
596643
_TEXT_BISON_PUBLISHER_MODEL_DICT
597644
),
598645
):
599-
model = language_models.TextGenerationModel.from_pretrained(
646+
model = preview_language_models.TextGenerationModel.from_pretrained(
600647
"text-bison@001"
601648
)
602649

@@ -631,7 +678,7 @@ def test_get_tuned_model(
631678
_TEXT_BISON_PUBLISHER_MODEL_DICT
632679
),
633680
):
634-
tuned_model = language_models.TextGenerationModel.get_tuned_model(
681+
tuned_model = preview_language_models.TextGenerationModel.get_tuned_model(
635682
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
636683
)
637684

@@ -651,7 +698,7 @@ def get_tuned_model_raises_if_not_called_with_mg_model(self):
651698
)
652699

653700
with pytest.raises(ValueError):
654-
language_models.TextGenerationModel.get_tuned_model(
701+
preview_language_models.TextGenerationModel.get_tuned_model(
655702
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
656703
)
657704

@@ -668,7 +715,7 @@ def test_chat(self):
668715
_CHAT_BISON_PUBLISHER_MODEL_DICT
669716
),
670717
) as mock_get_publisher_model:
671-
model = language_models.ChatModel.from_pretrained("chat-bison@001")
718+
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")
672719

673720
mock_get_publisher_model.assert_called_once_with(
674721
name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY
@@ -681,11 +728,11 @@ def test_chat(self):
681728
My favorite movies are Lord of the Rings and Hobbit.
682729
""",
683730
examples=[
684-
language_models.InputOutputTextPair(
731+
preview_language_models.InputOutputTextPair(
685732
input_text="Who do you work for?",
686733
output_text="I work for Ned.",
687734
),
688-
language_models.InputOutputTextPair(
735+
preview_language_models.InputOutputTextPair(
689736
input_text="What do I like?",
690737
output_text="Ned likes watching movies.",
691738
),
@@ -786,7 +833,7 @@ def test_code_chat(self):
786833
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
787834
),
788835
) as mock_get_publisher_model:
789-
model = language_models.CodeChatModel.from_pretrained(
836+
model = preview_language_models.CodeChatModel.from_pretrained(
790837
"google/codechat-bison@001"
791838
)
792839

@@ -882,7 +929,7 @@ def test_code_generation(self):
882929
_CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
883930
),
884931
) as mock_get_publisher_model:
885-
model = language_models.CodeGenerationModel.from_pretrained(
932+
model = preview_language_models.CodeGenerationModel.from_pretrained(
886933
"google/code-bison@001"
887934
)
888935

@@ -909,9 +956,11 @@ def test_code_generation(self):
909956
# Validating the parameters
910957
predict_temperature = 0.1
911958
predict_max_output_tokens = 100
912-
default_temperature = language_models.CodeGenerationModel._DEFAULT_TEMPERATURE
959+
default_temperature = (
960+
preview_language_models.CodeGenerationModel._DEFAULT_TEMPERATURE
961+
)
913962
default_max_output_tokens = (
914-
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
963+
preview_language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
915964
)
916965

917966
with mock.patch.object(
@@ -948,7 +997,7 @@ def test_code_completion(self):
948997
_CODE_COMPLETION_BISON_PUBLISHER_MODEL_DICT
949998
),
950999
) as mock_get_publisher_model:
951-
model = language_models.CodeGenerationModel.from_pretrained(
1000+
model = preview_language_models.CodeGenerationModel.from_pretrained(
9521001
"google/code-gecko@001"
9531002
)
9541003

@@ -975,9 +1024,11 @@ def test_code_completion(self):
9751024
# Validating the parameters
9761025
predict_temperature = 0.1
9771026
predict_max_output_tokens = 100
978-
default_temperature = language_models.CodeGenerationModel._DEFAULT_TEMPERATURE
1027+
default_temperature = (
1028+
preview_language_models.CodeGenerationModel._DEFAULT_TEMPERATURE
1029+
)
9791030
default_max_output_tokens = (
980-
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
1031+
preview_language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
9811032
)
9821033

9831034
with mock.patch.object(
@@ -1002,6 +1053,43 @@ def test_code_completion(self):
10021053
assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens
10031054

10041055
def test_text_embedding(self):
1056+
"""Tests the text embedding model."""
1057+
aiplatform.init(
1058+
project=_TEST_PROJECT,
1059+
location=_TEST_LOCATION,
1060+
)
1061+
with mock.patch.object(
1062+
target=model_garden_service_client.ModelGardenServiceClient,
1063+
attribute="get_publisher_model",
1064+
return_value=gca_publisher_model.PublisherModel(
1065+
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
1066+
),
1067+
) as mock_get_publisher_model:
1068+
model = preview_language_models.TextEmbeddingModel.from_pretrained(
1069+
"textembedding-gecko@001"
1070+
)
1071+
1072+
mock_get_publisher_model.assert_called_once_with(
1073+
name="publishers/google/models/textembedding-gecko@001",
1074+
retry=base._DEFAULT_RETRY,
1075+
)
1076+
1077+
gca_predict_response = gca_prediction_service.PredictResponse()
1078+
gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION)
1079+
1080+
with mock.patch.object(
1081+
target=prediction_service_client.PredictionServiceClient,
1082+
attribute="predict",
1083+
return_value=gca_predict_response,
1084+
):
1085+
embeddings = model.get_embeddings(["What is life?"])
1086+
assert embeddings
1087+
for embedding in embeddings:
1088+
vector = embedding.values
1089+
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
1090+
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]
1091+
1092+
def test_text_embedding_ga(self):
10051093
"""Tests the text embedding model."""
10061094
aiplatform.init(
10071095
project=_TEST_PROJECT,

tests/unit/aiplatform/test_model_garden_models.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import pytest
1919
from importlib import reload
2020
from unittest import mock
21-
from typing import Dict, Type
2221

2322
from google.cloud import aiplatform
2423
from google.cloud.aiplatform import base
@@ -53,14 +52,7 @@ class TestModelGardenModels:
5352
"""Unit tests for the _ModelGardenModel base class."""
5453

5554
class FakeModelGardenModel(_model_garden_models._ModelGardenModel):
56-
@staticmethod
57-
def _get_public_preview_class_map() -> Dict[
58-
str, Type[_model_garden_models._ModelGardenModel]
59-
]:
60-
test_map = {
61-
"gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml": TestModelGardenModels.FakeModelGardenModel
62-
}
63-
return test_map
55+
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"
6456

6557
def setup_method(self):
6658
reload(initializer)

vertexai/_model_garden/_model_garden_models.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def _get_model_info(
107107
)
108108

109109
if not interface_class:
110-
raise ValueError(f"Unknown model {publisher_model_res.name}")
110+
raise ValueError(
111+
f"Unknown model {publisher_model_res.name}; {schema_to_class_map}"
112+
)
111113

112114
return _ModelInfo(
113115
endpoint_name=endpoint_name,
@@ -120,18 +122,8 @@ def _get_model_info(
120122
class _ModelGardenModel:
121123
"""Base class for shared methods and properties across Model Garden models."""
122124

123-
@staticmethod
124-
@abc.abstractmethod
125-
def _get_public_preview_class_map() -> Dict[str, Type["_ModelGardenModel"]]:
126-
"""Returns a Dict mapping schema URI to model class.
127-
128-
Subclasses should implement this method. Example mapping:
129-
130-
{
131-
"gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml": TextGenerationModel
132-
}
133-
"""
134-
pass
125+
# Subclasses override this attribute to specify their instance schema
126+
_INSTANCE_SCHEMA_URI: Optional[str] = None
135127

136128
def __init__(self, model_id: str, endpoint_name: Optional[str] = None):
137129
"""Creates a _ModelGardenModel.
@@ -168,8 +160,13 @@ def from_pretrained(cls, model_name: str) -> "_ModelGardenModel":
168160
ValueError: If model does not support this class.
169161
"""
170162

163+
if not cls._INSTANCE_SCHEMA_URI:
164+
raise ValueError(
165+
f"Class {cls} is not a correct model interface class since it does not have an instance schema URI."
166+
)
167+
171168
model_info = _get_model_info(
172-
model_id=model_name, schema_to_class_map=cls._get_public_preview_class_map()
169+
model_id=model_name, schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls}
173170
)
174171

175172
if not issubclass(model_info.interface_class, cls):

vertexai/language_models/__init__.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15+
"""Classes for working with language models."""
1516

16-
from vertexai.language_models import _language_models
17+
from vertexai.language_models._language_models import (
18+
InputOutputTextPair,
19+
TextEmbedding,
20+
TextEmbeddingModel,
21+
TextGenerationModel,
22+
TextGenerationResponse,
23+
)
24+
25+
__all__ = [
26+
"InputOutputTextPair",
27+
"TextEmbedding",
28+
"TextEmbeddingModel",
29+
"TextGenerationModel",
30+
"TextGenerationResponse",
31+
]

0 commit comments

Comments
 (0)