Skip to content

Commit 4aa7745

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Support tuning in the "us-central1" location
PiperOrigin-RevId: 547655421
1 parent c903e7d commit 4aa7745

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

vertexai/language_models/_language_models.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ def tune_model(
153153
The dataset must have the "input_text" and "output_text" columns.
154154
train_steps: Number of training batches to tune on (batch size is 8 samples).
155155
learning_rate: Learning rate for the tuning
156-
tuning_job_location: GCP location where the tuning job should be run. Only "europe-west4" is supported for now.
156+
tuning_job_location: GCP location where the tuning job should be run.
157+
Only "europe-west4" and "us-central1" locations are supported for now.
157158
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
158159
model_display_name: Custom display name for the tuned model.
159160
@@ -166,9 +167,10 @@ def tune_model(
166167
ValueError: If the "tuned_model_location" value is not supported
167168
RuntimeError: If the model does not support tuning
168169
"""
169-
if tuning_job_location != _TUNING_LOCATION:
170+
if tuning_job_location not in _TUNING_LOCATIONS:
170171
raise ValueError(
171-
f'Tuning is only supported in the following locations: tuning_job_location="{_TUNING_LOCATION}"'
172+
"Please specify the tuning job location (`tuning_job_location`)."
173+
f"Tuning is supported in the following locations: {_TUNING_LOCATIONS}"
172174
)
173175
if tuned_model_location != _TUNED_MODEL_LOCATION:
174176
raise ValueError(
@@ -187,6 +189,7 @@ def tune_model(
187189
tuning_pipeline_uri=model_info.tuning_pipeline_uri,
188190
model_display_name=model_display_name,
189191
learning_rate=learning_rate,
192+
tuning_job_location=tuning_job_location,
190193
)
191194

192195
job = _LanguageModelTuningJob(
@@ -965,7 +968,7 @@ def predict(
965968

966969
###### Model tuning
967970
# Currently, tuning can only work in this location
968-
_TUNING_LOCATION = "europe-west4"
971+
_TUNING_LOCATIONS = ("europe-west4", "us-central1")
969972
# Currently, deployment can only work in this location
970973
_TUNED_MODEL_LOCATION = "us-central1"
971974

@@ -1051,6 +1054,7 @@ def _launch_tuning_job(
10511054
train_steps: Optional[int] = None,
10521055
model_display_name: Optional[str] = None,
10531056
learning_rate: Optional[float] = None,
1057+
tuning_job_location: str = _TUNING_LOCATIONS[0],
10541058
) -> aiplatform.PipelineJob:
10551059
output_dir_uri = _generate_tuned_model_dir_uri(model_id=model_id)
10561060
if isinstance(training_data, str):
@@ -1073,6 +1077,7 @@ def _launch_tuning_job(
10731077
tuning_pipeline_uri=tuning_pipeline_uri,
10741078
model_display_name=model_display_name,
10751079
learning_rate=learning_rate,
1080+
tuning_job_location=tuning_job_location,
10761081
)
10771082
return job
10781083

@@ -1084,6 +1089,7 @@ def _launch_tuning_job_on_jsonl_data(
10841089
train_steps: Optional[int] = None,
10851090
learning_rate: Optional[float] = None,
10861091
model_display_name: Optional[str] = None,
1092+
tuning_job_location: str = _TUNING_LOCATIONS[0],
10871093
) -> aiplatform.PipelineJob:
10881094
if not model_display_name:
10891095
# Creating a human-readable model display name
@@ -1126,7 +1132,7 @@ def _launch_tuning_job_on_jsonl_data(
11261132
display_name=None,
11271133
parameter_values=pipeline_arguments,
11281134
# TODO(b/275444101): Remove the explicit location once model can be deployed in all regions
1129-
location=_TUNING_LOCATION,
1135+
location=tuning_job_location,
11301136
)
11311137
job.submit()
11321138
return job

0 commit comments

Comments
 (0)