Skip to content

Commit 744cc38

Browse files
feat: added support for conditional parameters in hyperparameter tuning (#1544)
* feat: added support for conditional parameters in hyperparameter tuning * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * fixing unit tests * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * fixed all failing tests * addressed PR comments * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 3526b3e commit 744cc38

File tree

2 files changed

+151
-17
lines changed

2 files changed

+151
-17
lines changed

google/cloud/aiplatform/hyperparameter_tuning.py

+75-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2021 Google LLC
3+
# Copyright 2022 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -29,6 +29,10 @@
2929
"unspecified": gca_study_compat.StudySpec.ParameterSpec.ScaleType.SCALE_TYPE_UNSPECIFIED,
3030
}
3131

32+
_INT_VALUE_SPEC = "integer_value_spec"
33+
_DISCRETE_VALUE_SPEC = "discrete_value_spec"
34+
_CATEGORICAL_VALUE_SPEC = "categorical_value_spec"
35+
3236

3337
class _ParameterSpec(metaclass=abc.ABCMeta):
3438
"""Base class represents a single parameter to optimize."""
@@ -77,10 +81,30 @@ def _to_parameter_spec(
7781
self, parameter_id: str
7882
) -> gca_study_compat.StudySpec.ParameterSpec:
7983
"""Converts this parameter to ParameterSpec."""
80-
# TODO: Conditional parameters
84+
conditions = []
85+
if self.conditional_parameter_spec is not None:
86+
for (conditional_param_id, spec) in self.conditional_parameter_spec.items():
87+
condition = (
88+
gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec()
89+
)
90+
if self._parameter_spec_value_key == _INT_VALUE_SPEC:
91+
condition.parent_int_values = gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition(
92+
values=spec.parent_values
93+
)
94+
elif self._parameter_spec_value_key == _CATEGORICAL_VALUE_SPEC:
95+
condition.parent_categorical_values = gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition(
96+
values=spec.parent_values
97+
)
98+
elif self._parameter_spec_value_key == _DISCRETE_VALUE_SPEC:
99+
condition.parent_discrete_values = gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition(
100+
values=spec.parent_values
101+
)
102+
condition.parameter_spec = spec._to_parameter_spec(conditional_param_id)
103+
conditions.append(condition)
81104
parameter_spec = gca_study_compat.StudySpec.ParameterSpec(
82105
parameter_id=parameter_id,
83106
scale_type=_SCALE_TYPE_MAP.get(getattr(self, "scale", "unspecified")),
107+
conditional_parameter_specs=conditions,
84108
)
85109

86110
setattr(
@@ -105,6 +129,8 @@ def __init__(
105129
min: float,
106130
max: float,
107131
scale: str,
132+
conditional_parameter_spec: Optional[Dict[str, "_ParameterSpec"]] = None,
133+
parent_values: Optional[Sequence[Union[int, float, str]]] = None,
108134
):
109135
"""
110136
Value specification for a parameter in ``DOUBLE`` type.
@@ -120,9 +146,16 @@ def __init__(
120146
Required. The type of scaling that should be applied to this parameter.
121147
122148
Accepts: 'linear', 'log', 'reverse_log'
149+
conditional_parameter_spec (Dict[str, _ParameterSpec]):
150+
Optional. The conditional parameters associated with the object. The dictionary key
151+
is the ID of the conditional parameter and the dictionary value is one of
152+
`IntegerParameterSpec`, `CategoricalParameterSpec`, or `DiscreteParameterSpec`
153+
parent_values (Sequence[Union[int, float, str]]):
154+
Optional. This argument is only needed when the object is a conditional parameter
155+
and specifies the parent parameter's values for which the condition applies.
123156
"""
124157

125-
super().__init__()
158+
super().__init__(conditional_parameter_spec, parent_values)
126159

127160
self.min = min
128161
self.max = max
@@ -142,6 +175,8 @@ def __init__(
142175
min: int,
143176
max: int,
144177
scale: str,
178+
conditional_parameter_spec: Optional[Dict[str, "_ParameterSpec"]] = None,
179+
parent_values: Optional[Sequence[Union[int, float, str]]] = None,
145180
):
146181
"""
147182
Value specification for a parameter in ``INTEGER`` type.
@@ -157,9 +192,18 @@ def __init__(
157192
Required. The type of scaling that should be applied to this parameter.
158193
159194
Accepts: 'linear', 'log', 'reverse_log'
195+
conditional_parameter_spec (Dict[str, _ParameterSpec]):
196+
Optional. The conditional parameters associated with the object. The dictionary key
197+
is the ID of the conditional parameter and the dictionary value is one of
198+
`IntegerParameterSpec`, `CategoricalParameterSpec`, or `DiscreteParameterSpec`
199+
parent_values (Sequence[int]):
200+
Optional. This argument is only needed when the object is a conditional parameter
201+
and specifies the parent parameter's values for which the condition applies.
160202
"""
161-
162-
super().__init__()
203+
super().__init__(
204+
conditional_parameter_spec=conditional_parameter_spec,
205+
parent_values=parent_values,
206+
)
163207

164208
self.min = min
165209
self.max = max
@@ -177,15 +221,26 @@ class CategoricalParameterSpec(_ParameterSpec):
177221
def __init__(
178222
self,
179223
values: Sequence[str],
224+
conditional_parameter_spec: Optional[Dict[str, "_ParameterSpec"]] = None,
225+
parent_values: Optional[Sequence[Union[int, float, str]]] = None,
180226
):
181227
"""Value specification for a parameter in ``CATEGORICAL`` type.
182228
183229
Args:
184230
values (Sequence[str]):
185231
Required. The list of possible categories.
232+
conditional_parameter_spec (Dict[str, _ParameterSpec]):
233+
Optional. The conditional parameters associated with the object. The dictionary key
234+
is the ID of the conditional parameter and the dictionary value is one of
235+
`IntegerParameterSpec`, `CategoricalParameterSpec`, or `DiscreteParameterSpec`
236+
parent_values (Sequence[str]):
237+
Optional. This argument is only needed when the object is a conditional parameter
238+
and specifies the parent parameter's values for which the condition applies.
186239
"""
187-
188-
super().__init__()
240+
super().__init__(
241+
conditional_parameter_spec=conditional_parameter_spec,
242+
parent_values=parent_values,
243+
)
189244

190245
self.values = values
191246

@@ -202,6 +257,8 @@ def __init__(
202257
self,
203258
values: Sequence[float],
204259
scale: str,
260+
conditional_parameter_spec: Optional[Dict[str, "_ParameterSpec"]] = None,
261+
parent_values: Optional[Sequence[Union[int, float, str]]] = None,
205262
):
206263
"""Value specification for a parameter in ``DISCRETE`` type.
207264
@@ -216,9 +273,18 @@ def __init__(
216273
Required. The type of scaling that should be applied to this parameter.
217274
218275
Accepts: 'linear', 'log', 'reverse_log'
276+
conditional_parameter_spec (Dict[str, _ParameterSpec]):
277+
Optional. The conditional parameters associated with the object. The dictionary key
278+
is the ID of the conditional parameter and the dictionary value is one of
279+
`IntegerParameterSpec`, `CategoricalParameterSpec`, or `DiscreteParameterSpec`
280+
parent_values (Sequence[float]):
281+
Optional. This argument is only needed when the object is a conditional parameter
282+
and specifies the parent parameter's values for which the condition applies.
219283
"""
220-
221-
super().__init__()
284+
super().__init__(
285+
conditional_parameter_spec=conditional_parameter_spec,
286+
parent_values=parent_values,
287+
)
222288

223289
self.values = values
224290
self.scale = scale

tests/unit/aiplatform/test_hyperparameter_tuning_job.py

+76-8
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@
7676

7777
_TEST_LABELS = {"my_hp_key": "my_hp_value"}
7878

79+
_TEST_CONDITIONAL_PARAMETER_DECAY = hpt.DoubleParameterSpec(
80+
min=1e-07, max=1, scale="linear", parent_values=[32, 64]
81+
)
82+
_TEST_CONDITIONAL_PARAMETER_LR = hpt.DoubleParameterSpec(
83+
min=1e-07, max=1, scale="linear", parent_values=[4, 8, 16]
84+
)
85+
7986
_TEST_BASE_HYPERPARAMETER_TUNING_JOB_PROTO = gca_hyperparameter_tuning_job_compat.HyperparameterTuningJob(
8087
display_name=_TEST_DISPLAY_NAME,
8188
study_spec=gca_study_compat.StudySpec(
@@ -109,8 +116,34 @@
109116
parameter_id="batch_size",
110117
scale_type=gca_study_compat.StudySpec.ParameterSpec.ScaleType.UNIT_LINEAR_SCALE,
111118
discrete_value_spec=gca_study_compat.StudySpec.ParameterSpec.DiscreteValueSpec(
112-
values=[16, 32]
119+
values=[4, 8, 16, 32, 64]
113120
),
121+
conditional_parameter_specs=[
122+
gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec(
123+
parent_discrete_values=gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition(
124+
values=[32, 64]
125+
),
126+
parameter_spec=gca_study_compat.StudySpec.ParameterSpec(
127+
double_value_spec=gca_study_compat.StudySpec.ParameterSpec.DoubleValueSpec(
128+
min_value=1e-07, max_value=1
129+
),
130+
scale_type=gca_study_compat.StudySpec.ParameterSpec.ScaleType.UNIT_LINEAR_SCALE,
131+
parameter_id="decay",
132+
),
133+
),
134+
gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec(
135+
parent_discrete_values=gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition(
136+
values=[4, 8, 16]
137+
),
138+
parameter_spec=gca_study_compat.StudySpec.ParameterSpec(
139+
double_value_spec=gca_study_compat.StudySpec.ParameterSpec.DoubleValueSpec(
140+
min_value=1e-07, max_value=1
141+
),
142+
scale_type=gca_study_compat.StudySpec.ParameterSpec.ScaleType.UNIT_LINEAR_SCALE,
143+
parameter_id="learning_rate",
144+
),
145+
),
146+
],
114147
),
115148
],
116149
algorithm=gca_study_compat.StudySpec.Algorithm.RANDOM_SEARCH,
@@ -388,7 +421,12 @@ def test_create_hyperparameter_tuning_job(
388421
values=["relu", "sigmoid", "elu", "selu", "tanh"]
389422
),
390423
"batch_size": hpt.DiscreteParameterSpec(
391-
values=[16, 32], scale="linear"
424+
values=[4, 8, 16, 32, 64],
425+
scale="linear",
426+
conditional_parameter_spec={
427+
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
428+
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
429+
},
392430
),
393431
},
394432
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
@@ -454,7 +492,12 @@ def test_create_hyperparameter_tuning_job_with_timeout(
454492
values=["relu", "sigmoid", "elu", "selu", "tanh"]
455493
),
456494
"batch_size": hpt.DiscreteParameterSpec(
457-
values=[16, 32], scale="linear"
495+
values=[4, 8, 16, 32, 64],
496+
scale="linear",
497+
conditional_parameter_spec={
498+
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
499+
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
500+
},
458501
),
459502
},
460503
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
@@ -515,7 +558,12 @@ def test_run_hyperparameter_tuning_job_with_fail_raises(
515558
values=["relu", "sigmoid", "elu", "selu", "tanh"]
516559
),
517560
"batch_size": hpt.DiscreteParameterSpec(
518-
values=[16, 32], scale="linear"
561+
values=[4, 8, 16, 32, 64],
562+
scale="linear",
563+
conditional_parameter_spec={
564+
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
565+
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
566+
},
519567
),
520568
},
521569
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
@@ -574,7 +622,12 @@ def test_run_hyperparameter_tuning_job_with_fail_at_creation(self):
574622
values=["relu", "sigmoid", "elu", "selu", "tanh"]
575623
),
576624
"batch_size": hpt.DiscreteParameterSpec(
577-
values=[16, 32], scale="linear"
625+
values=[4, 8, 16, 32, 64],
626+
scale="linear",
627+
conditional_parameter_spec={
628+
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
629+
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
630+
},
578631
),
579632
},
580633
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
@@ -639,7 +692,12 @@ def test_hyperparameter_tuning_job_get_state_raises_without_run(self):
639692
values=["relu", "sigmoid", "elu", "selu", "tanh"]
640693
),
641694
"batch_size": hpt.DiscreteParameterSpec(
642-
values=[16, 32, 64], scale="linear"
695+
values=[4, 8, 16, 32, 64],
696+
scale="linear",
697+
conditional_parameter_spec={
698+
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
699+
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
700+
},
643701
),
644702
},
645703
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
@@ -697,7 +755,12 @@ def test_create_hyperparameter_tuning_job_with_tensorboard(
697755
values=["relu", "sigmoid", "elu", "selu", "tanh"]
698756
),
699757
"batch_size": hpt.DiscreteParameterSpec(
700-
values=[16, 32], scale="linear"
758+
values=[4, 8, 16, 32, 64],
759+
scale="linear",
760+
conditional_parameter_spec={
761+
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
762+
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
763+
},
701764
),
702765
},
703766
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
@@ -769,7 +832,12 @@ def test_create_hyperparameter_tuning_job_with_enable_web_access(
769832
values=["relu", "sigmoid", "elu", "selu", "tanh"]
770833
),
771834
"batch_size": hpt.DiscreteParameterSpec(
772-
values=[16, 32], scale="linear"
835+
values=[4, 8, 16, 32, 64],
836+
scale="linear",
837+
conditional_parameter_spec={
838+
"decay": _TEST_CONDITIONAL_PARAMETER_DECAY,
839+
"learning_rate": _TEST_CONDITIONAL_PARAMETER_LR,
840+
},
773841
),
774842
},
775843
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,

0 commit comments

Comments
 (0)