Skip to content

Commit 653b759

Browse files
gericdongivanmkc
andauthored
feat: update the samples of hyperparameter tuning in the public doc (#1600)
* Update Hyperparameter tuning job samples * Update Hyperparameter tuning job samples 2 * feat: fix linting problems in the hyperparamter tuning job samples * feat: resolve conflict for the hyperparameter tuning samples * feat: fix lint problems Co-authored-by: Ivan Cheung <[email protected]>
1 parent 8539327 commit 653b759

10 files changed

+433
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_sdk_cancel_hyperparameter_tuning_job_sample]
16+
from google.cloud import aiplatform
17+
18+
19+
def cancel_hyperparameter_tuning_job_sample(
20+
project: str,
21+
hyperparameter_tuning_job_id: str,
22+
location: str = "us-central1",
23+
):
24+
25+
aiplatform.init(project=project, location=location)
26+
27+
hpt_job = aiplatform.HyperparameterTuningJob.get(
28+
resource_name=hyperparameter_tuning_job_id,
29+
)
30+
31+
hpt_job.cancel()
32+
33+
34+
# [END aiplatform_sdk_cancel_hyperparameter_tuning_job_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import cancel_hyperparameter_tuning_job_sample
16+
import test_constants as constants
17+
18+
19+
def test_cancel_hyperparameter_tuning_job_sample(
20+
mock_sdk_init,
21+
mock_hyperparameter_tuning_job_get,
22+
mock_hyperparameter_tuning_job_cancel,
23+
):
24+
25+
cancel_hyperparameter_tuning_job_sample.cancel_hyperparameter_tuning_job_sample(
26+
project=constants.PROJECT,
27+
location=constants.LOCATION,
28+
hyperparameter_tuning_job_id=constants.HYPERPARAMETER_TUNING_JOB_ID,
29+
)
30+
31+
mock_sdk_init.assert_called_once_with(
32+
project=constants.PROJECT,
33+
location=constants.LOCATION,
34+
)
35+
36+
mock_hyperparameter_tuning_job_get.assert_called_once_with(
37+
resource_name=constants.HYPERPARAMETER_TUNING_JOB_ID,
38+
)
39+
40+
mock_hyperparameter_tuning_job_cancel.assert_called_once_with()

samples/model-builder/conftest.py

+55
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,19 @@ def mock_run_custom_package_training_job(mock_custom_package_training_job):
327327
yield mock
328328

329329

330+
@pytest.fixture
331+
def mock_custom_job():
332+
mock = MagicMock(aiplatform.CustomJob)
333+
yield mock
334+
335+
336+
@pytest.fixture
337+
def mock_get_custom_job(mock_custom_job):
338+
with patch.object(aiplatform, "CustomJob") as mock:
339+
mock.return_value = mock_custom_job
340+
yield mock
341+
342+
330343
"""
331344
----------------------------------------------------------------------------
332345
Model Fixtures
@@ -419,6 +432,48 @@ def mock_endpoint_explain(mock_endpoint):
419432
mock_get_endpoint.return_value = mock_endpoint
420433
yield mock_endpoint_explain
421434

435+
# ----------------------------------------------------------------------------
436+
# Hyperparameter Tuning Job Fixtures
437+
# ----------------------------------------------------------------------------
438+
439+
440+
@pytest.fixture
441+
def mock_hyperparameter_tuning_job():
442+
mock = MagicMock(aiplatform.HyperparameterTuningJob)
443+
yield mock
444+
445+
446+
@pytest.fixture
447+
def mock_get_hyperparameter_tuning_job(mock_hyperparameter_tuning_job):
448+
with patch.object(aiplatform, "HyperparameterTuningJob") as mock:
449+
mock.return_value = mock_hyperparameter_tuning_job
450+
yield mock
451+
452+
453+
@pytest.fixture
454+
def mock_run_hyperparameter_tuning_job(mock_hyperparameter_tuning_job):
455+
with patch.object(mock_hyperparameter_tuning_job, "run") as mock:
456+
yield mock
457+
458+
459+
@pytest.fixture
460+
def mock_hyperparameter_tuning_job_get(mock_hyperparameter_tuning_job):
461+
with patch.object(aiplatform.HyperparameterTuningJob, "get") as mock_hyperparameter_tuning_job_get:
462+
mock_hyperparameter_tuning_job_get.return_value = mock_hyperparameter_tuning_job
463+
yield mock_hyperparameter_tuning_job_get
464+
465+
466+
@pytest.fixture
467+
def mock_hyperparameter_tuning_job_cancel(mock_hyperparameter_tuning_job):
468+
with patch.object(mock_hyperparameter_tuning_job, "cancel") as mock:
469+
yield mock
470+
471+
472+
@pytest.fixture
473+
def mock_hyperparameter_tuning_job_delete(mock_hyperparameter_tuning_job):
474+
with patch.object(mock_hyperparameter_tuning_job, "delete") as mock:
475+
yield mock
476+
422477

423478
"""
424479
----------------------------------------------------------------------------
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_sdk_create_hyperparameter_tuning_job_sample]
16+
from google.cloud import aiplatform
17+
18+
from google.cloud.aiplatform import hyperparameter_tuning as hpt
19+
20+
21+
def create_hyperparameter_tuning_job_sample(
22+
project: str,
23+
location: str,
24+
staging_bucket: str,
25+
display_name: str,
26+
container_uri: str,
27+
):
28+
aiplatform.init(project=project, location=location, staging_bucket=staging_bucket)
29+
30+
worker_pool_specs = [
31+
{
32+
"machine_spec": {
33+
"machine_type": "n1-standard-4",
34+
"accelerator_type": "NVIDIA_TESLA_K80",
35+
"accelerator_count": 1,
36+
},
37+
"replica_count": 1,
38+
"container_spec": {
39+
"image_uri": container_uri,
40+
"command": [],
41+
"args": [],
42+
},
43+
}
44+
]
45+
46+
custom_job = aiplatform.CustomJob(
47+
display_name='custom_job',
48+
worker_pool_specs=worker_pool_specs,
49+
)
50+
51+
hpt_job = aiplatform.HyperparameterTuningJob(
52+
display_name=display_name,
53+
custom_job=custom_job,
54+
metric_spec={
55+
'loss': 'minimize',
56+
},
57+
parameter_spec={
58+
'lr': hpt.DoubleParameterSpec(min=0.001, max=0.1, scale='log'),
59+
'units': hpt.IntegerParameterSpec(min=4, max=128, scale='linear'),
60+
'activation': hpt.CategoricalParameterSpec(values=['relu', 'selu']),
61+
'batch_size': hpt.DiscreteParameterSpec(values=[128, 256], scale='linear')
62+
},
63+
max_trial_count=128,
64+
parallel_trial_count=8,
65+
labels={'my_key': 'my_value'},
66+
)
67+
68+
hpt_job.run()
69+
70+
print(hpt_job.resource_name)
71+
return hpt_job
72+
73+
74+
# [END aiplatform_sdk_create_hyperparameter_tuning_job_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import ANY
16+
17+
import create_hyperparameter_tuning_job_sample
18+
19+
import test_constants as constants
20+
21+
22+
def test_create_hyperparameter_tuning_job_sample(
23+
mock_sdk_init,
24+
mock_custom_job,
25+
mock_get_custom_job,
26+
mock_get_hyperparameter_tuning_job,
27+
mock_run_hyperparameter_tuning_job,
28+
):
29+
30+
create_hyperparameter_tuning_job_sample.create_hyperparameter_tuning_job_sample(
31+
project=constants.PROJECT,
32+
location=constants.LOCATION,
33+
staging_bucket=constants.STAGING_BUCKET,
34+
display_name=constants.HYPERPARAMETER_TUNING_JOB_DISPLAY_NAME,
35+
container_uri=constants.CONTAINER_URI,
36+
)
37+
38+
mock_sdk_init.assert_called_once_with(
39+
project=constants.PROJECT,
40+
location=constants.LOCATION,
41+
staging_bucket=constants.STAGING_BUCKET,
42+
)
43+
44+
mock_get_custom_job.assert_called_once_with(
45+
display_name=constants.CUSTOM_JOB_DISPLAY_NAME,
46+
worker_pool_specs=constants.CUSTOM_JOB_WORKER_POOL_SPECS,
47+
)
48+
49+
mock_get_hyperparameter_tuning_job.assert_called_once_with(
50+
display_name=constants.HYPERPARAMETER_TUNING_JOB_DISPLAY_NAME,
51+
custom_job=mock_custom_job,
52+
metric_spec=constants.HYPERPARAMETER_TUNING_JOB_METRIC_SPEC,
53+
parameter_spec=ANY,
54+
max_trial_count=constants.HYPERPARAMETER_TUNING_JOB_MAX_TRIAL_COUNT,
55+
parallel_trial_count=constants.HYPERPARAMETER_TUNING_JOB_PARALLEL_TRIAL_COUNT,
56+
labels=constants.HYPERPARAMETER_TUNING_JOB_LABELS,
57+
)
58+
59+
mock_run_hyperparameter_tuning_job.assert_called_once()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_sdk_delete_hyperparameter_tuning_job_sample]
16+
from google.cloud import aiplatform
17+
18+
19+
def delete_hyperparameter_tuning_job_sample(
20+
project: str,
21+
hyperparameter_tuning_job_id: str,
22+
location: str = "us-central1",
23+
):
24+
25+
aiplatform.init(project=project, location=location)
26+
27+
hpt_job = aiplatform.HyperparameterTuningJob.get(
28+
resource_name=hyperparameter_tuning_job_id,
29+
)
30+
31+
hpt_job.delete()
32+
33+
34+
# [END aiplatform_sdk_delete_hyperparameter_tuning_job_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import delete_hyperparameter_tuning_job_sample
16+
import test_constants as constants
17+
18+
19+
def test_delete_hyperparameter_tuning_job_sample(
20+
mock_sdk_init,
21+
mock_hyperparameter_tuning_job_get,
22+
mock_hyperparameter_tuning_job_delete,
23+
):
24+
25+
delete_hyperparameter_tuning_job_sample.delete_hyperparameter_tuning_job_sample(
26+
project=constants.PROJECT,
27+
location=constants.LOCATION,
28+
hyperparameter_tuning_job_id=constants.HYPERPARAMETER_TUNING_JOB_ID,
29+
)
30+
31+
mock_sdk_init.assert_called_once_with(
32+
project=constants.PROJECT,
33+
location=constants.LOCATION,
34+
)
35+
36+
mock_hyperparameter_tuning_job_get.assert_called_once_with(
37+
resource_name=constants.HYPERPARAMETER_TUNING_JOB_ID,
38+
)
39+
40+
mock_hyperparameter_tuning_job_delete.assert_called_once_with()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_sdk_get_hyperparameter_tuning_job_sample]
16+
from google.cloud import aiplatform
17+
18+
19+
def get_hyperparameter_tuning_job_sample(
20+
project: str,
21+
hyperparameter_tuning_job_id: str,
22+
location: str = "us-central1",
23+
):
24+
25+
aiplatform.init(project=project, location=location)
26+
27+
hpt_job = aiplatform.HyperparameterTuningJob.get(
28+
resource_name=hyperparameter_tuning_job_id,
29+
)
30+
31+
return hpt_job
32+
33+
34+
# [END aiplatform_sdk_get_hyperparameter_tuning_job_sample]

0 commit comments

Comments
 (0)