Skip to content

Commit e51c977

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Support accelerator_type in RLHF tuning
PiperOrigin-RevId: 617572810
1 parent 805bd40 commit e51c977

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

tests/unit/aiplatform/test_language_models.py

+8
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,11 @@ def reverse_string_2(s):""",
723723
"isOptional": True,
724724
"parameterType": "NUMBER_INTEGER",
725725
},
726+
"accelerator_type": {
727+
"defaultValue": "",
728+
"isOptional": True,
729+
"parameterType": "STRING",
730+
},
726731
"tensorboard_resource_id": {
727732
"isOptional": True,
728733
"parameterType": "STRING",
@@ -2696,6 +2701,7 @@ def test_tune_text_generation_model_rlhf(
26962701
kl_coeff = 0.3
26972702
tensorboard_resource_id = _get_test_tensorboard_resource_id()
26982703
eval_dataset = "gs://bucket/eval.jsonl"
2704+
accelerator_type = "TPU"
26992705

27002706
with mock.patch.object(
27012707
target=model_garden_service_client.ModelGardenServiceClient,
@@ -2718,6 +2724,7 @@ def test_tune_text_generation_model_rlhf(
27182724
reward_model_train_steps=reward_model_train_steps,
27192725
reinforcement_learning_train_steps=reinforcement_learning_train_steps,
27202726
kl_coeff=kl_coeff,
2727+
accelerator_type=accelerator_type,
27212728
tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec(
27222729
tensorboard=tensorboard_resource_id,
27232730
evaluation_data=eval_dataset,
@@ -2756,6 +2763,7 @@ def test_tune_text_generation_model_rlhf(
27562763
pipeline_arguments["tensorboard_resource_id"] == tensorboard_resource_id
27572764
)
27582765
assert pipeline_arguments["eval_dataset"] == eval_dataset
2766+
assert pipeline_arguments["accelerator_type"] == "TPU"
27592767

27602768
@pytest.mark.parametrize(
27612769
"job_spec",

vertexai/language_models/_language_models.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ class _RlhfTuningParameters:
413413
deploy_model: Optional[bool] = None
414414
eval_dataset: Optional[str] = None
415415
project: Optional[str] = None
416+
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None
416417
tensorboard_resource_id: Optional[str] = None
417418

418419
def asdict(self) -> Dict[str, Any]:
@@ -439,6 +440,7 @@ def tune_model_rlhf(
439440
kl_coeff: Optional[float] = None,
440441
default_context: Optional[str] = None,
441442
tuning_job_location: Optional[str] = None,
443+
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
442444
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
443445
) -> "_LanguageModelTuningJob":
444446
"""Tunes a model using reinforcement learning from human feedback.
@@ -491,6 +493,7 @@ def tune_model_rlhf(
491493
negative" or "Translate this sentence to Danish". Do not specify this
492494
if your dataset already prepends the instruction to the inputs field.
493495
tuning_job_location: GCP location where the tuning job should be run.
496+
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
494497
tuning_evaluation_spec: Evaluation settings to use during tuning.
495498
496499
Returns:
@@ -527,6 +530,13 @@ def tune_model_rlhf(
527530
model_id=self._model_id,
528531
)
529532

533+
if accelerator_type:
534+
if accelerator_type not in _ACCELERATOR_TYPES:
535+
raise ValueError(
536+
f"Unsupported accelerator type: {accelerator_type}."
537+
f" Supported types: {_ACCELERATOR_TYPES}"
538+
)
539+
530540
tuning_parameters = _RlhfTuningParameters(
531541
prompt_dataset=prompt_dataset_uri,
532542
preference_dataset=preference_dataset_uri,
@@ -542,6 +552,7 @@ def tune_model_rlhf(
542552
kl_coeff=kl_coeff,
543553
instruction=default_context,
544554
eval_dataset=eval_dataset,
555+
accelerator_type=accelerator_type,
545556
tensorboard_resource_id=tensorboard_resource_id,
546557
)
547558

@@ -574,7 +585,7 @@ def _tune_model_rlhf(
574585
raise ValueError(
575586
_get_invalid_tuning_location_msg(
576587
requested_location=tuning_parameters.location,
577-
valid_locations=_SUPPORTED_RLHF_LOCATIONS,
588+
valid_locations=_TUNING_LOCATIONS,
578589
)
579590
)
580591
if self._model_id not in _SUPPORTED_RLHF_MODELS:
@@ -3433,13 +3444,6 @@ class _PreviewCodeGenerationModel(CodeGenerationModel, _CountTokensCodeGeneratio
34333444
# Currently, deployment can only work in these locations
34343445
_TUNED_MODEL_LOCATIONS = _SUPPORTED_LOCATIONS
34353446

3436-
# TODO(b/318874365): Use _SUPPORTED_LOCATIONS defined above once DRZ for RLHF is
3437-
# implemented.
3438-
_SUPPORTED_RLHF_LOCATIONS = {
3439-
"us-central1",
3440-
"europe-west4",
3441-
}
3442-
34433447
# All models supported by RLHF that can also be used for online and batch prediction:
34443448
_SUPPORTED_RLHF_MODELS = {
34453449
"text-bison@001",

0 commit comments

Comments
 (0)