@@ -139,6 +139,7 @@ def tune_model(
139
139
training_data : Union [str , "pandas.core.frame.DataFrame" ],
140
140
* ,
141
141
train_steps : int = 1000 ,
142
+ learning_rate : Optional [float ] = None ,
142
143
tuning_job_location : Optional [str ] = None ,
143
144
tuned_model_location : Optional [str ] = None ,
144
145
model_display_name : Optional [str ] = None ,
@@ -151,6 +152,7 @@ def tune_model(
151
152
training_data: A Pandas DataFrame of a URI pointing to data in JSON lines format.
152
153
The dataset must have the "input_text" and "output_text" columns.
153
154
train_steps: Number of training steps to perform.
155
+ learning_rate: Learning rate for the tuning
154
156
tuning_job_location: GCP location where the tuning job should be run. Only "europe-west4" is supported for now.
155
157
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
156
158
model_display_name: Custom display name for the tuned model.
@@ -184,6 +186,7 @@ def tune_model(
184
186
model_id = model_info .tuning_model_id ,
185
187
tuning_pipeline_uri = model_info .tuning_pipeline_uri ,
186
188
model_display_name = model_display_name ,
189
+ learning_rate = learning_rate ,
187
190
)
188
191
189
192
job = _LanguageModelTuningJob (
@@ -1041,6 +1044,7 @@ def _launch_tuning_job(
1041
1044
tuning_pipeline_uri : str ,
1042
1045
train_steps : Optional [int ] = None ,
1043
1046
model_display_name : Optional [str ] = None ,
1047
+ learning_rate : Optional [float ] = None ,
1044
1048
) -> aiplatform .PipelineJob :
1045
1049
output_dir_uri = _generate_tuned_model_dir_uri (model_id = model_id )
1046
1050
if isinstance (training_data , str ):
@@ -1062,6 +1066,7 @@ def _launch_tuning_job(
1062
1066
train_steps = train_steps ,
1063
1067
tuning_pipeline_uri = tuning_pipeline_uri ,
1064
1068
model_display_name = model_display_name ,
1069
+ learning_rate = learning_rate ,
1065
1070
)
1066
1071
return job
1067
1072
@@ -1071,11 +1076,15 @@ def _launch_tuning_job_on_jsonl_data(
1071
1076
dataset_name_or_uri : str ,
1072
1077
tuning_pipeline_uri : str ,
1073
1078
train_steps : Optional [int ] = None ,
1079
+ learning_rate : Optional [float ] = None ,
1074
1080
model_display_name : Optional [str ] = None ,
1075
1081
) -> aiplatform .PipelineJob :
1076
1082
if not model_display_name :
1077
1083
# Creating a human-readable model display name
1078
- name = f"{ model_id } tuned for { train_steps } steps on "
1084
+ name = f"{ model_id } tuned for { train_steps } steps"
1085
+ if learning_rate :
1086
+ name += f" with learning rate { learning_rate } "
1087
+ name += " on "
1079
1088
# Truncating the start of the dataset URI to keep total length <= 128.
1080
1089
max_display_name_length = 128
1081
1090
if len (dataset_name_or_uri + name ) <= max_display_name_length :
@@ -1095,6 +1104,8 @@ def _launch_tuning_job_on_jsonl_data(
1095
1104
"large_model_reference" : model_id ,
1096
1105
"model_display_name" : model_display_name ,
1097
1106
}
1107
+ if learning_rate :
1108
+ pipeline_arguments ["learning_rate" ] = learning_rate
1098
1109
1099
1110
if dataset_name_or_uri .startswith ("projects/" ):
1100
1111
pipeline_arguments ["dataset_name" ] = dataset_name_or_uri
0 commit comments