Skip to content

Commit 249a5fa

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Expose new text embedding tuning parameters in GA namespace.
PiperOrigin-RevId: 646160221
1 parent d4b0091 commit 249a5fa

File tree

2 files changed

+4
-86
lines changed

2 files changed

+4
-86
lines changed

tests/unit/aiplatform/test_language_models.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -2290,11 +2290,10 @@ def test_text_generation_response_repr(self):
22902290
indirect=True,
22912291
)
22922292
@pytest.mark.parametrize(
2293-
"base_model_version_id,use_preview_module,tune_args,expected_pipeline_args",
2293+
"base_model_version_id,tune_args,expected_pipeline_args",
22942294
[ # Do not pass any optional parameters.
22952295
(
22962296
"textembedding-gecko@003",
2297-
False,
22982297
dict(
22992298
training_data="gs://bucket/training.tsv",
23002299
corpus_data="gs://bucket/corpus.jsonl",
@@ -2311,7 +2310,6 @@ def test_text_generation_response_repr(self):
23112310
# Pass all optional parameters.
23122311
(
23132312
"text-multilingual-embedding-002",
2314-
True,
23152313
dict(
23162314
training_data="gs://bucket/training.tsv",
23172315
corpus_data="gs://bucket/corpus.jsonl",
@@ -2364,7 +2362,6 @@ def test_tune_text_embedding_model(
23642362
tune_args,
23652363
expected_pipeline_args,
23662364
base_model_version_id,
2367-
use_preview_module,
23682365
):
23692366
"""Tests tuning the text embedding model."""
23702367
aiplatform.init(
@@ -2379,10 +2376,7 @@ def test_tune_text_embedding_model(
23792376
_TEXT_GECKO_PUBLISHER_MODEL_DICT
23802377
),
23812378
):
2382-
language_models_module = (
2383-
preview_language_models if use_preview_module else language_models
2384-
)
2385-
model = language_models_module.TextEmbeddingModel.from_pretrained(
2379+
model = language_models.TextEmbeddingModel.from_pretrained(
23862380
base_model_version_id
23872381
)
23882382
tuning_result = model.tune_model(**tune_args)

vertexai/language_models/_language_models.py

+2-78
Original file line numberDiff line numberDiff line change
@@ -2268,7 +2268,7 @@ def tune_model(
22682268
```
22692269
tuning_job = model.tune_model(...)
22702270
... do some other work
2271-
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
2271+
tuned_model = tuning_job.deploy_tuned_model() # Blocks until tuning is complete
22722272
22732273
Args:
22742274
training_data: URI pointing to training data in TSV format.
@@ -2414,83 +2414,7 @@ def deploy_tuned_model(
24142414

24152415

24162416
class _TunableTextEmbeddingModelMixin(_PreviewTunableTextEmbeddingModelMixin):
2417-
def tune_model(
2418-
self,
2419-
*,
2420-
training_data: Optional[str] = None,
2421-
corpus_data: Optional[str] = None,
2422-
queries_data: Optional[str] = None,
2423-
test_data: Optional[str] = None,
2424-
validation_data: Optional[str] = None,
2425-
batch_size: Optional[int] = None,
2426-
train_steps: Optional[int] = None,
2427-
tuned_model_location: Optional[str] = None,
2428-
model_display_name: Optional[str] = None,
2429-
task_type: Optional[str] = None,
2430-
machine_type: Optional[str] = None,
2431-
accelerator: Optional[str] = None,
2432-
accelerator_count: Optional[int] = None,
2433-
) -> "_TextEmbeddingModelTuningJob":
2434-
"""Tunes a model based on training data.
2435-
2436-
This method launches and returns an asynchronous model tuning job.
2437-
Usage:
2438-
```
2439-
tuning_job = model.tune_model(...)
2440-
... do some other work
2441-
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
2442-
2443-
Args:
2444-
training_data: URI pointing to training data in TSV format.
2445-
corpus_data: URI pointing to data in JSON lines format.
2446-
queries_data: URI pointing to data in JSON lines format.
2447-
test_data: URI pointing to data in TSV format.
2448-
validation_data: URI pointing to data in TSV format.
2449-
batch_size: The training batch size.
2450-
train_steps: The number of steps to perform model tuning. Must
2451-
be greater than 30.
2452-
tuned_model_location: GCP location where the tuned model should be deployed.
2453-
model_display_name: Custom display name for the tuned model.
2454-
task_type: The task type expected to be used during inference.
2455-
Valid values are `DEFAULT`, `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`,
2456-
`SEMANTIC_SIMILARITY`, `CLASSIFICATION`, `CLUSTERING`,
2457-
`FACT_VERIFICATION`, and `QUESTION_ANSWERING`.
2458-
machine_type: The machine type to use for training. For information
2459-
about selecting the machine type that matches the accelerator
2460-
type and count you have selected, see
2461-
https://cloud.google.com/compute/docs/gpus.
2462-
accelerator: The accelerator type to use for tuning, for example
2463-
`NVIDIA_TESLA_V100`. For possible values, see
2464-
https://cloud.google.com/vertex-ai/generative-ai/docs/models/tune-embeddings#using-accelerators.
2465-
accelerator_count: The number of accelerators to use when training.
2466-
Using a greater number of accelerators may make training faster,
2467-
but has no effect on quality.
2468-
Returns:
2469-
A `LanguageModelTuningJob` object that represents the tuning job.
2470-
Calling `job.result()` blocks until the tuning is complete and
2471-
returns a `LanguageModel` object.
2472-
2473-
Raises:
2474-
ValueError: If the provided parameter combinations or values are not
2475-
supported.
2476-
RuntimeError: If the model does not support tuning
2477-
"""
2478-
2479-
return super().tune_model(
2480-
training_data=training_data,
2481-
corpus_data=corpus_data,
2482-
queries_data=queries_data,
2483-
test_data=test_data,
2484-
validation_data=validation_data,
2485-
task_type=task_type,
2486-
batch_size=batch_size,
2487-
train_steps=train_steps,
2488-
tuned_model_location=tuned_model_location,
2489-
model_display_name=model_display_name,
2490-
machine_type=machine_type,
2491-
accelerator=accelerator,
2492-
accelerator_count=accelerator_count,
2493-
)
2417+
pass
24942418

24952419

24962420
class TextEmbeddingModel(_TextEmbeddingModel, _TunableTextEmbeddingModelMixin):

0 commit comments

Comments
 (0)