Skip to content

Commit c6cdd10

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added support for learning_rate in tuning
PiperOrigin-RevId: 542784145
1 parent 750e161 commit c6cdd10

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

tests/system/aiplatform/test_language_models.py

+1
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def test_tuning(self, shared_state):
124124
train_steps=1,
125125
tuning_job_location="europe-west4",
126126
tuned_model_location="us-central1",
127+
learning_rate=2.0,
127128
)
128129
# According to the Pipelines design, external resources created by a pipeline
129130
# must not be modified or deleted. Otherwise caching will break next pipeline runs.

tests/unit/aiplatform/test_language_models.py

+5
Original file line numberDiff line numberDiff line change
@@ -661,8 +661,13 @@ def test_tune_model(
661661
training_data=_TEST_TEXT_BISON_TRAINING_DF,
662662
tuning_job_location="europe-west4",
663663
tuned_model_location="us-central1",
664+
learning_rate=0.1,
664665
)
665666
call_kwargs = mock_pipeline_service_create.call_args[1]
667+
pipeline_arguments = call_kwargs[
668+
"pipeline_job"
669+
].runtime_config.parameter_values
670+
assert pipeline_arguments["learning_rate"] == 0.1
666671
assert (
667672
call_kwargs["pipeline_job"].encryption_spec.kms_key_name
668673
== _TEST_ENCRYPTION_KEY_NAME

vertexai/language_models/_language_models.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def tune_model(
139139
training_data: Union[str, "pandas.core.frame.DataFrame"],
140140
*,
141141
train_steps: int = 1000,
142+
learning_rate: Optional[float] = None,
142143
tuning_job_location: Optional[str] = None,
143144
tuned_model_location: Optional[str] = None,
144145
model_display_name: Optional[str] = None,
@@ -151,6 +152,7 @@ def tune_model(
151152
training_data: A Pandas DataFrame of a URI pointing to data in JSON lines format.
152153
The dataset must have the "input_text" and "output_text" columns.
153154
train_steps: Number of training steps to perform.
155+
learning_rate: Learning rate for the tuning
154156
tuning_job_location: GCP location where the tuning job should be run. Only "europe-west4" is supported for now.
155157
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
156158
model_display_name: Custom display name for the tuned model.
@@ -184,6 +186,7 @@ def tune_model(
184186
model_id=model_info.tuning_model_id,
185187
tuning_pipeline_uri=model_info.tuning_pipeline_uri,
186188
model_display_name=model_display_name,
189+
learning_rate=learning_rate,
187190
)
188191

189192
job = _LanguageModelTuningJob(
@@ -1041,6 +1044,7 @@ def _launch_tuning_job(
10411044
tuning_pipeline_uri: str,
10421045
train_steps: Optional[int] = None,
10431046
model_display_name: Optional[str] = None,
1047+
learning_rate: Optional[float] = None,
10441048
) -> aiplatform.PipelineJob:
10451049
output_dir_uri = _generate_tuned_model_dir_uri(model_id=model_id)
10461050
if isinstance(training_data, str):
@@ -1062,6 +1066,7 @@ def _launch_tuning_job(
10621066
train_steps=train_steps,
10631067
tuning_pipeline_uri=tuning_pipeline_uri,
10641068
model_display_name=model_display_name,
1069+
learning_rate=learning_rate,
10651070
)
10661071
return job
10671072

@@ -1071,11 +1076,15 @@ def _launch_tuning_job_on_jsonl_data(
10711076
dataset_name_or_uri: str,
10721077
tuning_pipeline_uri: str,
10731078
train_steps: Optional[int] = None,
1079+
learning_rate: Optional[float] = None,
10741080
model_display_name: Optional[str] = None,
10751081
) -> aiplatform.PipelineJob:
10761082
if not model_display_name:
10771083
# Creating a human-readable model display name
1078-
name = f"{model_id} tuned for {train_steps} steps on "
1084+
name = f"{model_id} tuned for {train_steps} steps"
1085+
if learning_rate:
1086+
name += f" with learning rate {learning_rate}"
1087+
name += " on "
10791088
# Truncating the start of the dataset URI to keep total length <= 128.
10801089
max_display_name_length = 128
10811090
if len(dataset_name_or_uri + name) <= max_display_name_length:
@@ -1095,6 +1104,8 @@ def _launch_tuning_job_on_jsonl_data(
10951104
"large_model_reference": model_id,
10961105
"model_display_name": model_display_name,
10971106
}
1107+
if learning_rate:
1108+
pipeline_arguments["learning_rate"] = learning_rate
10981109

10991110
if dataset_name_or_uri.startswith("projects/"):
11001111
pipeline_arguments["dataset_name"] = dataset_name_or_uri

0 commit comments

Comments
 (0)