@@ -1996,7 +1996,7 @@ class TextEmbeddingInput:
1996
1996
title : Optional [str ] = None
1997
1997
1998
1998
1999
- class TextEmbeddingModel (_LanguageModel ):
1999
+ class _TextEmbeddingModel (_LanguageModel ):
2000
2000
"""TextEmbeddingModel class calculates embeddings for the given texts.
2001
2001
2002
2002
Examples::
@@ -2126,6 +2126,69 @@ async def get_embeddings_async(
2126
2126
]
2127
2127
2128
2128
2129
+ class _TunableTextEmbeddingModelMixin (_TunableModelMixin ):
2130
+ @classmethod
2131
+ def get_tuned_model ():
2132
+ raise NotImplementedError (
2133
+ "Use deploy_tuned_model instead to get the tuned model."
2134
+ )
2135
+
2136
+ # IMPORTANT: Keep this method supported even if you end up deploying the tuned model as part of the tuning pipeline template.
2137
+ @classmethod
2138
+ def deploy_tuned_model (
2139
+ cls ,
2140
+ tuned_model_name : str ,
2141
+ machine_type : Optional [str ] = None ,
2142
+ accelerator : Optional [str ] = None ,
2143
+ accelerator_count : Optional [int ] = None ,
2144
+ ) -> "_LanguageModel" :
2145
+ """Loads the specified tuned language model.
2146
+
2147
+ Args:
2148
+ tuned_model_name: Tuned model's resource name.
2149
+ machine_type: Machine type. E.g., "a2-highgpu-1g". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute.
2150
+ accelerator: Kind of accelerator. E.g., "NVIDIA_TESLA_A100". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute.
2151
+ accelerator_count: Count of accelerators.
2152
+
2153
+ Returns:
2154
+ Tuned `LanguageModel` object.
2155
+ """
2156
+ tuned_vertex_model = aiplatform .Model (tuned_model_name )
2157
+ tuned_model_labels = tuned_vertex_model .labels
2158
+
2159
+ if _TUNING_BASE_MODEL_ID_LABEL_KEY not in tuned_model_labels :
2160
+ raise ValueError (
2161
+ f"The provided model { tuned_model_name } does not have a base model ID."
2162
+ )
2163
+
2164
+ tuning_model_id = tuned_vertex_model .labels [_TUNING_BASE_MODEL_ID_LABEL_KEY ]
2165
+ tuned_model_deployments = tuned_vertex_model .gca_resource .deployed_models
2166
+ if len (tuned_model_deployments ) == 0 :
2167
+ # Deploying a model to an endpoint requires a resource quota.
2168
+ endpoint_name = tuned_vertex_model .deploy (
2169
+ machine_type = machine_type ,
2170
+ accelerator_type = accelerator ,
2171
+ accelerator_count = accelerator_count ,
2172
+ ).resource_name
2173
+ else :
2174
+ endpoint_name = tuned_model_deployments [0 ].endpoint
2175
+
2176
+ base_model_id = _get_model_id_from_tuning_model_id (tuning_model_id )
2177
+ model_info = _model_garden_models ._get_model_info (
2178
+ model_id = base_model_id ,
2179
+ schema_to_class_map = {cls ._INSTANCE_SCHEMA_URI : cls },
2180
+ )
2181
+ model = model_info .interface_class (
2182
+ model_id = base_model_id ,
2183
+ endpoint_name = endpoint_name ,
2184
+ )
2185
+ return model
2186
+
2187
+
2188
+ class TextEmbeddingModel (_TextEmbeddingModel , _TunableTextEmbeddingModelMixin ):
2189
+ __module__ = "vertexai.language_models"
2190
+
2191
+
2129
2192
class _PreviewTextEmbeddingModel (
2130
2193
TextEmbeddingModel , _ModelWithBatchPredict , _CountTokensMixin
2131
2194
):
0 commit comments