Skip to content

Commit cc8bc96

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: LLM - Text embedding - Added the output_dimensionality and learning_rate_multiplier parameters to text embedding tuning (Preview only)
PiperOrigin-RevId: 631976561
1 parent 6150322 commit cc8bc96

File tree

2 files changed

+128
-12
lines changed

2 files changed

+128
-12
lines changed

tests/unit/aiplatform/test_language_models.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,7 @@ def get_endpoint_mock():
16611661
@pytest.fixture
16621662
def mock_deploy_tuned_embedding_model(get_endpoint_mock):
16631663
with mock.patch.object(
1664-
_language_models._TunableTextEmbeddingModelMixin, "deploy_tuned_model"
1664+
_language_models._PreviewTunableTextEmbeddingModelMixin, "deploy_tuned_model"
16651665
) as mock_text_generation_model:
16661666
mock_text_generation_model.return_value._model_id = (
16671667
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
@@ -2289,10 +2289,11 @@ def test_text_generation_response_repr(self):
22892289
indirect=True,
22902290
)
22912291
@pytest.mark.parametrize(
2292-
"base_model_version_id,tune_args,expected_pipeline_args",
2292+
"base_model_version_id,use_preview_module,tune_args,expected_pipeline_args",
22932293
[ # Do not pass any optional parameters.
22942294
(
22952295
"textembedding-gecko@003",
2296+
False,
22962297
dict(
22972298
training_data="gs://bucket/training.tsv",
22982299
corpus_data="gs://bucket/corpus.jsonl",
@@ -2309,6 +2310,7 @@ def test_text_generation_response_repr(self):
23092310
# Pass all optional parameters.
23102311
(
23112312
"text-multilingual-embedding-002",
2313+
True,
23122314
dict(
23132315
training_data="gs://bucket/training.tsv",
23142316
corpus_data="gs://bucket/corpus.jsonl",
@@ -2323,6 +2325,8 @@ def test_text_generation_response_repr(self):
23232325
accelerator_count=1,
23242326
machine_type="n1-highmem-16",
23252327
task_type="DEFAULT",
2328+
output_dimensionality=128,
2329+
learning_rate_multiplier=0.1,
23262330
),
23272331
dict(
23282332
train_steps=30,
@@ -2339,6 +2343,8 @@ def test_text_generation_response_repr(self):
23392343
validation_label_path="gs://bucket/validation.tsv",
23402344
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
23412345
task_type="DEFAULT",
2346+
output_dimensionality=128,
2347+
learning_rate_multiplier=0.1,
23422348
),
23432349
),
23442350
],
@@ -2357,6 +2363,7 @@ def test_tune_text_embedding_model(
23572363
tune_args,
23582364
expected_pipeline_args,
23592365
base_model_version_id,
2366+
use_preview_module,
23602367
):
23612368
"""Tests tuning the text embedding model."""
23622369
aiplatform.init(
@@ -2371,7 +2378,10 @@ def test_tune_text_embedding_model(
23712378
_TEXT_GECKO_PUBLISHER_MODEL_DICT
23722379
),
23732380
):
2374-
model = language_models.TextEmbeddingModel.from_pretrained(
2381+
language_models_module = (
2382+
preview_language_models if use_preview_module else language_models
2383+
)
2384+
model = language_models_module.TextEmbeddingModel.from_pretrained(
23752385
base_model_version_id
23762386
)
23772387
tuning_job = model.tune_model(**tune_args)

vertexai/language_models/_language_models.py

+115-9
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def tune_model(
239239
accelerator_count: Optional[int] = None,
240240
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
241241
max_context_length: Optional[str] = None,
242+
output_dimensionality: Optional[int] = None,
242243
) -> "_LanguageModelTuningJob":
243244
"""Tunes a model based on training data.
244245
@@ -273,6 +274,8 @@ def tune_model(
273274
accelerator_type: Type of accelerator to use. Type can be "TPU" or "GPU". Type is ignored, if accelerator is specified.
274275
max_context_length: The max context length used for tuning.
275276
Can be either '8k' or '32k'
277+
output_dimensionality: The output dimensionality of the tuned model,
278+
for text embedding tuning.
276279
277280
Returns:
278281
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -293,6 +296,8 @@ def tune_model(
293296
tuning_parameters["batch_size"] = batch_size
294297
if train_steps is not None:
295298
tuning_parameters["train_steps"] = train_steps
299+
if output_dimensionality is not None:
300+
tuning_parameters["output_dimensionality"] = output_dimensionality
296301
if learning_rate is not None:
297302
_LOGGER.warning(
298303
"The learning_rate parameter is deprecated."
@@ -2189,7 +2194,7 @@ async def get_embeddings_async(
21892194
# for corpus, queries, test and validation data.
21902195
# TODO(b/625884109): Validate input args, batch_size >0 and train_steps >30, and
21912196
# task_type must be 'DEFAULT' or None if _model_id is textembedding-gecko@001.
2192-
class _TunableTextEmbeddingModelMixin(_TunableModelMixin):
2197+
class _PreviewTunableTextEmbeddingModelMixin(_TunableModelMixin):
21932198
@classmethod
21942199
def get_tuned_model(cls, *args, **kwargs):
21952200
del args, kwargs # Unused.
@@ -2213,7 +2218,9 @@ def tune_model(
22132218
machine_type: Optional[str] = None,
22142219
accelerator: Optional[str] = None,
22152220
accelerator_count: Optional[int] = None,
2216-
) -> "_LanguageModelTuningJob":
2221+
output_dimensionality: Optional[int] = None,
2222+
learning_rate_multiplier: Optional[float] = None,
2223+
) -> "_TextEmbeddingModelTuningJob":
22172224
"""Tunes a model based on training data.
22182225
22192226
This method launches and returns an asynchronous model tuning job.
@@ -2229,14 +2236,30 @@ def tune_model(
22292236
queries_data: URI pointing to data in JSON lines format.
22302237
test_data: URI pointing to data in TSV format.
22312238
validation_data: URI pointing to data in TSV format.
2232-
batch_size: Size of batch.
2233-
train_steps: Number of training batches to tune on.
2239+
batch_size: The training batch size.
2240+
train_steps: The number of steps to perform model tuning. Must
2241+
be greater than 30.
22342242
tuned_model_location: GCP location where the tuned model should be deployed.
22352243
model_display_name: Custom display name for the tuned model.
2236-
task_type: Type of task. Can be "RETRIEVAL_QUERY", "RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", or "FACT_VERIFICATION".
2237-
machine_type: Machine type. E.g., "a2-highgpu-1g". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute.
2238-
accelerator_count: Count of accelerators.
2239-
accelerator: Kind of accelerator. E.g., "NVIDIA_TESLA_A100". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute.
2244+
task_type: The task type expected to be used during inference.
2245+
Valid values are `DEFAULT`, `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`,
2246+
`SEMANTIC_SIMILARITY`, `CLASSIFICATION`, `CLUSTERING`,
2247+
`FACT_VERIFICATION`, and `QUESTION_ANSWERING`.
2248+
machine_type: The machine type to use for training. For information
2249+
about selecting the machine type that matches the accelerator
2250+
type and count you have selected, see
2251+
https://cloud.google.com/compute/docs/gpus.
2252+
accelerator: The accelerator type to use for tuning, for example
2253+
`NVIDIA_TESLA_V100`. For possible values, see
2254+
https://cloud.google.com/vertex-ai/generative-ai/docs/models/tune-embeddings#using-accelerators.
2255+
accelerator_count: The number of accelerators to use when training.
2256+
Using a greater number of accelerators may make training faster,
2257+
but has no effect on quality.
2258+
output_dimensionality: The desired embedding dimension of your
2259+
tuned model, up to 768. This is only supported for models
2260+
`text-embedding-004` and `text-multilingual-embedding-002`.
2261+
learning_rate_multiplier: A multiplier to apply to the
2262+
recommended learning rate during tuning.
22402263
Returns:
22412264
A `LanguageModelTuningJob` object that represents the tuning job.
22422265
Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
@@ -2260,6 +2283,8 @@ def tune_model(
22602283
machine_type=machine_type,
22612284
accelerator=accelerator,
22622285
accelerator_count=accelerator_count,
2286+
output_dimensionality=output_dimensionality,
2287+
learning_rate_multiplier=learning_rate_multiplier,
22632288
)
22642289

22652290
def _bundle_up_tuning_job(self, pipeline_job):
@@ -2318,14 +2343,95 @@ def deploy_tuned_model(
23182343
return model
23192344

23202345

2346+
class _TunableTextEmbeddingModelMixin(_PreviewTunableTextEmbeddingModelMixin):
2347+
def tune_model(
2348+
self,
2349+
*,
2350+
training_data: Optional[str] = None,
2351+
corpus_data: Optional[str] = None,
2352+
queries_data: Optional[str] = None,
2353+
test_data: Optional[str] = None,
2354+
validation_data: Optional[str] = None,
2355+
batch_size: Optional[int] = None,
2356+
train_steps: Optional[int] = None,
2357+
tuned_model_location: Optional[str] = None,
2358+
model_display_name: Optional[str] = None,
2359+
task_type: Optional[str] = None,
2360+
machine_type: Optional[str] = None,
2361+
accelerator: Optional[str] = None,
2362+
accelerator_count: Optional[int] = None,
2363+
) -> "_TextEmbeddingModelTuningJob":
2364+
"""Tunes a model based on training data.
2365+
2366+
This method launches and returns an asynchronous model tuning job.
2367+
Usage:
2368+
```
2369+
tuning_job = model.tune_model(...)
2370+
... do some other work
2371+
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
2372+
2373+
Args:
2374+
training_data: URI pointing to training data in TSV format.
2375+
corpus_data: URI pointing to data in JSON lines format.
2376+
queries_data: URI pointing to data in JSON lines format.
2377+
test_data: URI pointing to data in TSV format.
2378+
validation_data: URI pointing to data in TSV format.
2379+
batch_size: The training batch size.
2380+
train_steps: The number of steps to perform model tuning. Must
2381+
be greater than 30.
2382+
tuned_model_location: GCP location where the tuned model should be deployed.
2383+
model_display_name: Custom display name for the tuned model.
2384+
task_type: The task type expected to be used during inference.
2385+
Valid values are `DEFAULT`, `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`,
2386+
`SEMANTIC_SIMILARITY`, `CLASSIFICATION`, `CLUSTERING`,
2387+
`FACT_VERIFICATION`, and `QUESTION_ANSWERING`.
2388+
machine_type: The machine type to use for training. For information
2389+
about selecting the machine type that matches the accelerator
2390+
type and count you have selected, see
2391+
https://cloud.google.com/compute/docs/gpus.
2392+
accelerator: The accelerator type to use for tuning, for example
2393+
`NVIDIA_TESLA_V100`. For possible values, see
2394+
https://cloud.google.com/vertex-ai/generative-ai/docs/models/tune-embeddings#using-accelerators.
2395+
accelerator_count: The number of accelerators to use when training.
2396+
Using a greater number of accelerators may make training faster,
2397+
but has no effect on quality.
2398+
Returns:
2399+
A `LanguageModelTuningJob` object that represents the tuning job.
2400+
Calling `job.result()` blocks until the tuning is complete and
2401+
returns a `LanguageModel` object.
2402+
2403+
Raises:
2404+
ValueError: If the provided parameter combinations or values are not
2405+
supported.
2406+
RuntimeError: If the model does not support tuning
2407+
"""
2408+
2409+
return super().tune_model(
2410+
training_data=training_data,
2411+
corpus_data=corpus_data,
2412+
queries_data=queries_data,
2413+
test_data=test_data,
2414+
validation_data=validation_data,
2415+
task_type=task_type,
2416+
batch_size=batch_size,
2417+
train_steps=train_steps,
2418+
tuned_model_location=tuned_model_location,
2419+
model_display_name=model_display_name,
2420+
machine_type=machine_type,
2421+
accelerator=accelerator,
2422+
accelerator_count=accelerator_count,
2423+
)
2424+
2425+
23212426
class TextEmbeddingModel(_TextEmbeddingModel, _TunableTextEmbeddingModelMixin):
23222427
__module__ = "vertexai.language_models"
23232428

23242429

23252430
class _PreviewTextEmbeddingModel(
2326-
TextEmbeddingModel,
2431+
_TextEmbeddingModel,
23272432
_ModelWithBatchPredict,
23282433
_CountTokensMixin,
2434+
_PreviewTunableTextEmbeddingModelMixin,
23292435
):
23302436
__name__ = "TextEmbeddingModel"
23312437
__module__ = "vertexai.preview.language_models"

0 commit comments

Comments
 (0)