Skip to content

Commit 755c3f9

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: LLM - Support model evaluation when tuning chat models (ChatModel, CodeChatModel)
PiperOrigin-RevId: 580611746
1 parent dcb6205 commit 755c3f9

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

tests/unit/aiplatform/test_language_models.py

+53
Original file line numberDiff line numberDiff line change
@@ -2125,12 +2125,18 @@ def test_tune_chat_model(
21252125
):
21262126
model = language_models.ChatModel.from_pretrained("chat-bison@001")
21272127

2128+
tuning_job_location = "europe-west4"
2129+
tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"
2130+
21282131
default_context = "Default context"
21292132
tuning_job = model.tune_model(
21302133
training_data=_TEST_TEXT_BISON_TRAINING_DF,
21312134
tuning_job_location="europe-west4",
21322135
tuned_model_location="us-central1",
21332136
default_context=default_context,
2137+
tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec(
2138+
tensorboard=tensorboard_name,
2139+
),
21342140
accelerator_type="TPU",
21352141
)
21362142
call_kwargs = mock_pipeline_service_create.call_args[1]
@@ -2140,6 +2146,7 @@ def test_tune_chat_model(
21402146
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"
21412147
assert pipeline_arguments["default_context"] == default_context
21422148
assert pipeline_arguments["accelerator_type"] == "TPU"
2149+
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
21432150

21442151
# Testing the tuned model
21452152
tuned_model = tuning_job.get_tuned_model()
@@ -2148,6 +2155,26 @@ def test_tune_chat_model(
21482155
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
21492156
)
21502157

2158+
unsupported_tuning_evaluation_spec_att = (
2159+
{"evaluation_data": "gs://bucket/eval.jsonl"},
2160+
{"evaluation_interval": 37},
2161+
{"enable_early_stopping": True},
2162+
{"enable_checkpoint_selection": True},
2163+
)
2164+
for unsupported_att in unsupported_tuning_evaluation_spec_att:
2165+
unsupported_tuning_evaluation_spec = (
2166+
preview_language_models.TuningEvaluationSpec(**unsupported_att)
2167+
)
2168+
with pytest.raises(AttributeError):
2169+
model.tune_model(
2170+
training_data=_TEST_TEXT_BISON_TRAINING_DF,
2171+
tuning_job_location="europe-west4",
2172+
tuned_model_location="us-central1",
2173+
default_context=default_context,
2174+
tuning_evaluation_spec=unsupported_tuning_evaluation_spec,
2175+
accelerator_type="TPU",
2176+
)
2177+
21512178
@pytest.mark.parametrize(
21522179
"job_spec",
21532180
[_TEST_PIPELINE_SPEC_JSON],
@@ -2228,12 +2255,18 @@ def test_tune_code_chat_model(
22282255
):
22292256
model = language_models.CodeChatModel.from_pretrained("codechat-bison@001")
22302257

2258+
tuning_job_location = "europe-west4"
2259+
tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"
2260+
22312261
# The tune_model call needs to be inside the PublisherModel mock
22322262
# since it gets a new PublisherModel when tuning completes.
22332263
model.tune_model(
22342264
training_data=_TEST_TEXT_BISON_TRAINING_DF,
22352265
tuning_job_location="europe-west4",
22362266
tuned_model_location="us-central1",
2267+
tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec(
2268+
tensorboard=tensorboard_name,
2269+
),
22372270
accelerator_type="TPU",
22382271
)
22392272
call_kwargs = mock_pipeline_service_create.call_args[1]
@@ -2242,6 +2275,26 @@ def test_tune_code_chat_model(
22422275
].runtime_config.parameter_values
22432276
assert pipeline_arguments["large_model_reference"] == "codechat-bison@001"
22442277
assert pipeline_arguments["accelerator_type"] == "TPU"
2278+
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
2279+
2280+
unsupported_tuning_evaluation_spec_att = (
2281+
{"evaluation_data": "gs://bucket/eval.jsonl"},
2282+
{"evaluation_interval": 37},
2283+
{"enable_early_stopping": True},
2284+
{"enable_checkpoint_selection": True},
2285+
)
2286+
for unsupported_att in unsupported_tuning_evaluation_spec_att:
2287+
unsupported_tuning_evaluation_spec = (
2288+
preview_language_models.TuningEvaluationSpec(**unsupported_att)
2289+
)
2290+
with pytest.raises(AttributeError):
2291+
model.tune_model(
2292+
training_data=_TEST_TEXT_BISON_TRAINING_DF,
2293+
tuning_job_location="europe-west4",
2294+
tuned_model_location="us-central1",
2295+
tuning_evaluation_spec=unsupported_tuning_evaluation_spec,
2296+
accelerator_type="TPU",
2297+
)
22452298

22462299
@pytest.mark.usefixtures(
22472300
"get_model_with_tuned_version_label_mock",

vertexai/language_models/_language_models.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def tune_model(
496496
model_display_name: Optional[str] = None,
497497
default_context: Optional[str] = None,
498498
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
499+
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
499500
) -> "_LanguageModelTuningJob":
500501
"""Tunes a model based on training data.
501502
@@ -520,6 +521,7 @@ def tune_model(
520521
model_display_name: Custom display name for the tuned model.
521522
default_context: The context to use for all training samples by default.
522523
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
524+
tuning_evaluation_spec: Specification for the model evaluation during tuning.
523525
524526
Returns:
525527
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -529,8 +531,25 @@ def tune_model(
529531
ValueError: If the "tuning_job_location" value is not supported
530532
ValueError: If the "tuned_model_location" value is not supported
531533
RuntimeError: If the model does not support tuning
534+
AttributeError: If any attribute in the "tuning_evaluation_spec" is not supported
532535
"""
533-
# Note: Chat models do not support tuning_evaluation_spec
536+
537+
if tuning_evaluation_spec is not None:
538+
unsupported_chat_model_tuning_eval_spec = {
539+
"evaluation_data": tuning_evaluation_spec.evaluation_data,
540+
"evaluation_interval": tuning_evaluation_spec.evaluation_interval,
541+
"enable_early_stopping": tuning_evaluation_spec.enable_early_stopping,
542+
"enable_checkpoint_selection": tuning_evaluation_spec.enable_checkpoint_selection,
543+
}
544+
545+
for att_name, att_value in unsupported_chat_model_tuning_eval_spec.items():
546+
if not att_value is None:
547+
raise AttributeError(
548+
(
549+
f"ChatModel and CodeChatModel only support tensorboard as attribute for TuningEvaluationSpec"
550+
f"found attribute name {att_name} with value {att_value}, please leave {att_name} to None"
551+
)
552+
)
534553
return super().tune_model(
535554
training_data=training_data,
536555
train_steps=train_steps,
@@ -540,6 +559,7 @@ def tune_model(
540559
model_display_name=model_display_name,
541560
default_context=default_context,
542561
accelerator_type=accelerator_type,
562+
tuning_evaluation_spec=tuning_evaluation_spec,
543563
)
544564

545565

0 commit comments

Comments
 (0)