Skip to content

Commit 5a300c1

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: LLM - Text Embedding - Added validation for text embedding tuning parameters.
PiperOrigin-RevId: 632301450
1 parent cb8b10f commit 5a300c1

File tree

2 files changed

+83
-3
lines changed

2 files changed

+83
-3
lines changed

tests/unit/aiplatform/test_language_models.py

+52
Original file line numberDiff line numberDiff line change
@@ -2407,6 +2407,58 @@ def test_tune_text_embedding_model(
24072407
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
24082408
)
24092409

2410+
@pytest.mark.parametrize(
2411+
"optional_tune_args,error_regex",
2412+
[
2413+
(
2414+
dict(test_data="/tmp/bucket/test.tsv"),
2415+
"Each tuning dataset file must be a Google Cloud Storage URI starting with 'gs://'.",
2416+
),
2417+
(
2418+
dict(output_dimensionality=-1),
2419+
"output_dimensionality must be an integer between 1 and 768",
2420+
),
2421+
(
2422+
dict(learning_rate_multiplier=0),
2423+
"learning_rate_multiplier must be greater than 0",
2424+
),
2425+
(
2426+
dict(train_steps=29),
2427+
"train_steps must be greater than or equal to 30",
2428+
),
2429+
(
2430+
dict(batch_size=2048),
2431+
"batch_size must be between 1 and 1024",
2432+
),
2433+
],
2434+
)
2435+
def test_tune_text_embedding_model_invalid_values(
2436+
self, optional_tune_args, error_regex
2437+
):
2438+
"""Tests that certain embedding tuning values fail validation."""
2439+
aiplatform.init(
2440+
project=_TEST_PROJECT,
2441+
location=_TEST_LOCATION,
2442+
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
2443+
)
2444+
with mock.patch.object(
2445+
target=model_garden_service_client.ModelGardenServiceClient,
2446+
attribute="get_publisher_model",
2447+
return_value=gca_publisher_model.PublisherModel(
2448+
_TEXT_GECKO_PUBLISHER_MODEL_DICT
2449+
),
2450+
):
2451+
model = preview_language_models.TextEmbeddingModel.from_pretrained(
2452+
"text-multilingual-embedding-002"
2453+
)
2454+
with pytest.raises(ValueError, match=error_regex):
2455+
model.tune_model(
2456+
training_data="gs://bucket/training.tsv",
2457+
corpus_data="gs://bucket/corpus.jsonl",
2458+
queries_data="gs://bucket/queries.jsonl",
2459+
**optional_tune_args,
2460+
)
2461+
24102462
@pytest.mark.parametrize(
24112463
"job_spec",
24122464
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB],

vertexai/language_models/_language_models.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -2192,8 +2192,6 @@ async def get_embeddings_async(
21922192

21932193
# TODO(b/625884109): Support Union[str, "pandas.core.frame.DataFrame"]
21942194
# for corpus, queries, test and validation data.
2195-
# TODO(b/625884109): Validate input args, batch_size >0 and train_steps >30, and
2196-
# task_type must be 'DEFAULT' or None if _model_id is textembedding-gecko@001.
21972195
class _PreviewTunableTextEmbeddingModelMixin(_TunableModelMixin):
21982196
@classmethod
21992197
def get_tuned_model(cls, *args, **kwargs):
@@ -2265,9 +2263,39 @@ def tune_model(
22652263
Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
22662264
22672265
Raises:
2268-
ValueError: If the "tuned_model_location" value is not supported
2266+
ValueError: If the provided parameter combinations or values are not
2267+
supported.
22692268
RuntimeError: If the model does not support tuning
22702269
"""
2270+
if batch_size is not None and batch_size not in range(1, 1024):
2271+
raise ValueError(
2272+
f"batch_size must be between 1 and 1024. Given {batch_size}."
2273+
)
2274+
if train_steps is not None and train_steps < 30:
2275+
raise ValueError(
2276+
f"train_steps must be greater than or equal to 30. Given {train_steps}."
2277+
)
2278+
if learning_rate_multiplier is not None and learning_rate_multiplier <= 0:
2279+
raise ValueError(
2280+
f"learning_rate_multiplier must be greater than 0. Given {learning_rate_multiplier}."
2281+
)
2282+
if output_dimensionality is not None and output_dimensionality not in range(
2283+
1, 769
2284+
):
2285+
raise ValueError(
2286+
f"output_dimensionality must be an integer between 1 and 768. Given {output_dimensionality}."
2287+
)
2288+
for dataset in [
2289+
training_data,
2290+
corpus_data,
2291+
queries_data,
2292+
test_data,
2293+
validation_data,
2294+
]:
2295+
if dataset is not None and not dataset.startswith("gs://"):
2296+
raise ValueError(
2297+
f"Each tuning dataset file must be a Google Cloud Storage URI starting with 'gs://'. Given {dataset}."
2298+
)
22712299

22722300
return super().tune_model(
22732301
training_data=training_data,

0 commit comments

Comments
 (0)