@@ -153,7 +153,8 @@ def tune_model(
153
153
The dataset must have the "input_text" and "output_text" columns.
154
154
train_steps: Number of training batches to tune on (batch size is 8 samples).
155
155
learning_rate: Learning rate for the tuning
156
- tuning_job_location: GCP location where the tuning job should be run. Only "europe-west4" is supported for now.
156
+ tuning_job_location: GCP location where the tuning job should be run.
157
+ Only "europe-west4" and "us-central1" locations are supported for now.
157
158
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
158
159
model_display_name: Custom display name for the tuned model.
159
160
@@ -166,9 +167,10 @@ def tune_model(
166
167
ValueError: If the "tuned_model_location" value is not supported
167
168
RuntimeError: If the model does not support tuning
168
169
"""
169
- if tuning_job_location != _TUNING_LOCATION :
170
+ if tuning_job_location not in _TUNING_LOCATIONS :
170
171
raise ValueError (
171
- f'Tuning is only supported in the following locations: tuning_job_location="{ _TUNING_LOCATION } "'
172
+ "Please specify the tuning job location (`tuning_job_location`)."
173
+ f"Tuning is supported in the following locations: { _TUNING_LOCATIONS } "
172
174
)
173
175
if tuned_model_location != _TUNED_MODEL_LOCATION :
174
176
raise ValueError (
@@ -187,6 +189,7 @@ def tune_model(
187
189
tuning_pipeline_uri = model_info .tuning_pipeline_uri ,
188
190
model_display_name = model_display_name ,
189
191
learning_rate = learning_rate ,
192
+ tuning_job_location = tuning_job_location ,
190
193
)
191
194
192
195
job = _LanguageModelTuningJob (
@@ -965,7 +968,7 @@ def predict(
965
968
966
969
###### Model tuning
967
970
# Currently, tuning can only work in this location
968
- _TUNING_LOCATION = "europe-west4"
971
+ _TUNING_LOCATIONS = ( "europe-west4" , "us-central1" )
969
972
# Currently, deployment can only work in this location
970
973
_TUNED_MODEL_LOCATION = "us-central1"
971
974
@@ -1051,6 +1054,7 @@ def _launch_tuning_job(
1051
1054
train_steps : Optional [int ] = None ,
1052
1055
model_display_name : Optional [str ] = None ,
1053
1056
learning_rate : Optional [float ] = None ,
1057
+ tuning_job_location : str = _TUNING_LOCATIONS [0 ],
1054
1058
) -> aiplatform .PipelineJob :
1055
1059
output_dir_uri = _generate_tuned_model_dir_uri (model_id = model_id )
1056
1060
if isinstance (training_data , str ):
@@ -1073,6 +1077,7 @@ def _launch_tuning_job(
1073
1077
tuning_pipeline_uri = tuning_pipeline_uri ,
1074
1078
model_display_name = model_display_name ,
1075
1079
learning_rate = learning_rate ,
1080
+ tuning_job_location = tuning_job_location ,
1076
1081
)
1077
1082
return job
1078
1083
@@ -1084,6 +1089,7 @@ def _launch_tuning_job_on_jsonl_data(
1084
1089
train_steps : Optional [int ] = None ,
1085
1090
learning_rate : Optional [float ] = None ,
1086
1091
model_display_name : Optional [str ] = None ,
1092
+ tuning_job_location : str = _TUNING_LOCATIONS [0 ],
1087
1093
) -> aiplatform .PipelineJob :
1088
1094
if not model_display_name :
1089
1095
# Creating a human-readable model display name
@@ -1126,7 +1132,7 @@ def _launch_tuning_job_on_jsonl_data(
1126
1132
display_name = None ,
1127
1133
parameter_values = pipeline_arguments ,
1128
1134
# TODO(b/275444101): Remove the explicit location once model can be deployed in all regions
1129
- location = _TUNING_LOCATION ,
1135
+ location = tuning_job_location ,
1130
1136
)
1131
1137
job .submit ()
1132
1138
return job
0 commit comments