Skip to content

Commit 8ca9cdf

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Deploy a tuned text embedding model -- it doesn't matter, if it's tuned using Node.js, or curl.
PiperOrigin-RevId: 629619980
1 parent b22a8b8 commit 8ca9cdf

File tree

1 file changed

+64
-1
lines changed

1 file changed

+64
-1
lines changed

vertexai/language_models/_language_models.py

+64-1
Original file line numberDiff line numberDiff line change
@@ -1996,7 +1996,7 @@ class TextEmbeddingInput:
19961996
title: Optional[str] = None
19971997

19981998

1999-
class TextEmbeddingModel(_LanguageModel):
1999+
class _TextEmbeddingModel(_LanguageModel):
20002000
"""TextEmbeddingModel class calculates embeddings for the given texts.
20012001
20022002
Examples::
@@ -2126,6 +2126,69 @@ async def get_embeddings_async(
21262126
]
21272127

21282128

2129+
class _TunableTextEmbeddingModelMixin(_TunableModelMixin):
2130+
@classmethod
2131+
def get_tuned_model():
2132+
raise NotImplementedError(
2133+
"Use deploy_tuned_model instead to get the tuned model."
2134+
)
2135+
2136+
# IMPORTANT: Keep this method supported even if you end up deploying the tuned model as part of the tuning pipeline template.
2137+
@classmethod
2138+
def deploy_tuned_model(
2139+
cls,
2140+
tuned_model_name: str,
2141+
machine_type: Optional[str] = None,
2142+
accelerator: Optional[str] = None,
2143+
accelerator_count: Optional[int] = None,
2144+
) -> "_LanguageModel":
2145+
"""Loads the specified tuned language model.
2146+
2147+
Args:
2148+
tuned_model_name: Tuned model's resource name.
2149+
machine_type: Machine type. E.g., "a2-highgpu-1g". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute.
2150+
accelerator: Kind of accelerator. E.g., "NVIDIA_TESLA_A100". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute.
2151+
accelerator_count: Count of accelerators.
2152+
2153+
Returns:
2154+
Tuned `LanguageModel` object.
2155+
"""
2156+
tuned_vertex_model = aiplatform.Model(tuned_model_name)
2157+
tuned_model_labels = tuned_vertex_model.labels
2158+
2159+
if _TUNING_BASE_MODEL_ID_LABEL_KEY not in tuned_model_labels:
2160+
raise ValueError(
2161+
f"The provided model {tuned_model_name} does not have a base model ID."
2162+
)
2163+
2164+
tuning_model_id = tuned_vertex_model.labels[_TUNING_BASE_MODEL_ID_LABEL_KEY]
2165+
tuned_model_deployments = tuned_vertex_model.gca_resource.deployed_models
2166+
if len(tuned_model_deployments) == 0:
2167+
# Deploying a model to an endpoint requires a resource quota.
2168+
endpoint_name = tuned_vertex_model.deploy(
2169+
machine_type=machine_type,
2170+
accelerator_type=accelerator,
2171+
accelerator_count=accelerator_count,
2172+
).resource_name
2173+
else:
2174+
endpoint_name = tuned_model_deployments[0].endpoint
2175+
2176+
base_model_id = _get_model_id_from_tuning_model_id(tuning_model_id)
2177+
model_info = _model_garden_models._get_model_info(
2178+
model_id=base_model_id,
2179+
schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls},
2180+
)
2181+
model = model_info.interface_class(
2182+
model_id=base_model_id,
2183+
endpoint_name=endpoint_name,
2184+
)
2185+
return model
2186+
2187+
2188+
class TextEmbeddingModel(_TextEmbeddingModel, _TunableTextEmbeddingModelMixin):
2189+
__module__ = "vertexai.language_models"
2190+
2191+
21292192
class _PreviewTextEmbeddingModel(
21302193
TextEmbeddingModel, _ModelWithBatchPredict, _CountTokensMixin
21312194
):

0 commit comments

Comments
 (0)