@@ -413,6 +413,7 @@ class _RlhfTuningParameters:
413
413
deploy_model : Optional [bool ] = None
414
414
eval_dataset : Optional [str ] = None
415
415
project : Optional [str ] = None
416
+ accelerator_type : Optional [_ACCELERATOR_TYPE_TYPE ] = None
416
417
tensorboard_resource_id : Optional [str ] = None
417
418
418
419
def asdict (self ) -> Dict [str , Any ]:
@@ -439,6 +440,7 @@ def tune_model_rlhf(
439
440
kl_coeff : Optional [float ] = None ,
440
441
default_context : Optional [str ] = None ,
441
442
tuning_job_location : Optional [str ] = None ,
443
+ accelerator_type : Optional [_ACCELERATOR_TYPE_TYPE ] = None ,
442
444
tuning_evaluation_spec : Optional ["TuningEvaluationSpec" ] = None ,
443
445
) -> "_LanguageModelTuningJob" :
444
446
"""Tunes a model using reinforcement learning from human feedback.
@@ -491,6 +493,7 @@ def tune_model_rlhf(
491
493
negative" or "Translate this sentence to Danish". Do not specify this
492
494
if your dataset already prepends the instruction to the inputs field.
493
495
tuning_job_location: GCP location where the tuning job should be run.
496
+ accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
494
497
tuning_evaluation_spec: Evaluation settings to use during tuning.
495
498
496
499
Returns:
@@ -527,6 +530,13 @@ def tune_model_rlhf(
527
530
model_id = self ._model_id ,
528
531
)
529
532
533
+ if accelerator_type :
534
+ if accelerator_type not in _ACCELERATOR_TYPES :
535
+ raise ValueError (
536
+ f"Unsupported accelerator type: { accelerator_type } ."
537
+ f" Supported types: { _ACCELERATOR_TYPES } "
538
+ )
539
+
530
540
tuning_parameters = _RlhfTuningParameters (
531
541
prompt_dataset = prompt_dataset_uri ,
532
542
preference_dataset = preference_dataset_uri ,
@@ -542,6 +552,7 @@ def tune_model_rlhf(
542
552
kl_coeff = kl_coeff ,
543
553
instruction = default_context ,
544
554
eval_dataset = eval_dataset ,
555
+ accelerator_type = accelerator_type ,
545
556
tensorboard_resource_id = tensorboard_resource_id ,
546
557
)
547
558
@@ -574,7 +585,7 @@ def _tune_model_rlhf(
574
585
raise ValueError (
575
586
_get_invalid_tuning_location_msg (
576
587
requested_location = tuning_parameters .location ,
577
- valid_locations = _SUPPORTED_RLHF_LOCATIONS ,
588
+ valid_locations = _TUNING_LOCATIONS ,
578
589
)
579
590
)
580
591
if self ._model_id not in _SUPPORTED_RLHF_MODELS :
@@ -3433,13 +3444,6 @@ class _PreviewCodeGenerationModel(CodeGenerationModel, _CountTokensCodeGeneratio
3433
3444
# Currently, deployment can only work in these locations
3434
3445
_TUNED_MODEL_LOCATIONS = _SUPPORTED_LOCATIONS
3435
3446
3436
- # TODO(b/318874365): Use _SUPPORTED_LOCATIONS defined above once DRZ for RLHF is
3437
- # implemented.
3438
- _SUPPORTED_RLHF_LOCATIONS = {
3439
- "us-central1" ,
3440
- "europe-west4" ,
3441
- }
3442
-
3443
3447
# All models supported by RLHF that can also be used for online and batch prediction:
3444
3448
_SUPPORTED_RLHF_MODELS = {
3445
3449
"text-bison@001" ,
0 commit comments