13
13
# limitations under the License.
14
14
#
15
15
16
- from typing import Optional , Union
16
+ from typing import Literal , Optional , Union
17
17
18
18
from google .cloud .aiplatform_v1 .types import tuning_job as gca_tuning_job_types
19
19
@@ -29,6 +29,7 @@ def train(
29
29
tuned_model_display_name : Optional [str ] = None ,
30
30
epochs : Optional [int ] = None ,
31
31
learning_rate_multiplier : Optional [float ] = None ,
32
+ adapter_size : Optional [Literal [1 , 4 , 8 , 16 ]] = None ,
32
33
) -> "SupervisedTuningJob" :
33
34
"""Tunes a model using supervised training.
34
35
@@ -44,6 +45,7 @@ def train(
44
45
be up to 128 characters long and can consist of any UTF-8 characters.
45
46
epochs: Number of training epoches for this tuning job.
46
47
learning_rate_multiplier: Learning rate multiplier for tuning.
48
+ adapter_size: Adapter size for tuning.
47
49
48
50
Returns:
49
51
A `TuningJob` object.
@@ -54,6 +56,7 @@ def train(
54
56
hyper_parameters = gca_tuning_job_types .SupervisedHyperParameters (
55
57
epoch_count = epochs ,
56
58
learning_rate_multiplier = learning_rate_multiplier ,
59
+ adapter_size = adapter_size ,
57
60
),
58
61
)
59
62
0 commit comments