Skip to content

Commit 04ac975

Browse files
authored
Population based training (#1833)
* docs: update new algorithm service details * feat: trial augmentation strategy * feat: pbt suggestion service * feat: PbtTemplate and associated test image * feat: introduce annotation field to trial specifications * feat: trial assignment changes to support annotations from suggestion - Add new Annotation types to suggestion_types.go - Add Annotation object and update Trial parser in trial.py * feat: update pbt suggestion to use new Annotation api - Suggestion uses exact match to track spawned trials - Trials that get transmitted, but not created (or added to experiment) are added back to the respawn pool (population_size consistency) * chore: gofmt and black run across PBT changes * feedback: remove tf summary export, change default print unit, reduce range to be percentage compatible. * feedback: move PBT template to example. * feedback: changes to inject_webhook and utils. - Rename mutateVolume to mutateMetricsCollectorVolume - Add addContainerVolumeMount - Add getPrimaryContainerIndex * feedback: change suggestion mutation mount variable name and add to consts * feedback: Add trial_names to GetSuggestionsReply and change suggestion path to <experiment>/<trial> * feedback: removed unnecessary checks and moved to async pbt implementation * feedback: update trial name override location and change annotations override to labels. * feedback: add pbt to github workflow * feedback: move labels to ParameterAssignments in GetSuggestionsReply and cleanup pbt.yaml. * feedback: remove operator changes * feedback: GHA updates * feedback: new formatting changes * feedback: add suggestion-pbt to gh-actions build-load.sh. * fix: missing pbt->simple-pbt name changes, add simple-pbt to update-images.sh update yaml function (causing failing gha). * feedback: add pointer to website from main readme for pbt
1 parent f7261de commit 04ac975

File tree

45 files changed

+1526
-232
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1526
-232
lines changed

.github/workflows/publish-algorithm-images.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ jobs:
4141
dockerfile: cmd/suggestion/goptuna/v1beta1/Dockerfile
4242
- component-name: suggestion-optuna
4343
dockerfile: cmd/suggestion/optuna/v1beta1/Dockerfile
44+
- component-name: suggestion-pbt
45+
dockerfile: cmd/suggestion/pbt/v1beta1/Dockerfile
4446
- component-name: suggestion-enas
4547
dockerfile: cmd/suggestion/nas/enas/v1beta1/Dockerfile
4648
- component-name: suggestion-darts

.github/workflows/publish-trial-images.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,5 @@ jobs:
4343
dockerfile: examples/v1beta1/trial-images/darts-cnn-cifar10/Dockerfile.cpu
4444
- trial-name: darts-cnn-cifar10-gpu
4545
dockerfile: examples/v1beta1/trial-images/darts-cnn-cifar10/Dockerfile.gpu
46+
- trial-name: simple-pbt
47+
dockerfile: examples/v1beta1/trial-images/simple-pbt/Dockerfile
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: E2E Test with simple-pbt
2+
on:
3+
- pull_request
4+
5+
env:
6+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
7+
8+
jobs:
9+
e2e:
10+
runs-on: ubuntu-20.04
11+
timeout-minutes: 120
12+
steps:
13+
- name: Checkout
14+
uses: actions/checkout@v2
15+
16+
- name: Setup Test Env
17+
uses: ./.github/workflows/template-setup-e2e-test
18+
with:
19+
kubernetes-version: ${{ matrix.kubernetes-version }}
20+
21+
- name: Run e2e test with ${{ matrix.experiments }} experiments
22+
uses: ./.github/workflows/template-e2e-test
23+
with:
24+
experiments: ${{ matrix.experiments }}
25+
# Comma Delimited
26+
trial-images: simple-pbt
27+
28+
strategy:
29+
fail-fast: false
30+
matrix:
31+
# Detail: https://hub.docker.com/r/kindest/node
32+
# TODO (tenzen-y): We need to consider running tests on more kubernetes versions.
33+
# kubernetes-version: ["v1.20.15", "v1.21.12", "v1.22.9", "v1.23.6", "v1.24.1"]
34+
kubernetes-version: ["v1.21.12", "v1.22.9", "v1.23.6"]
35+
# Comma Delimited
36+
experiments: ["simple-pbt"]

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ prepare-pytest:
128128
pip install -r cmd/suggestion/hyperband/v1beta1/requirements.txt
129129
pip install -r cmd/suggestion/nas/enas/v1beta1/requirements.txt
130130
pip install -r cmd/suggestion/nas/darts/v1beta1/requirements.txt
131+
pip install -r cmd/suggestion/pbt/v1beta1/requirements.txt
131132
pip install -r cmd/earlystopping/medianstop/v1beta1/requirements.txt
132133
pip install -r cmd/metricscollector/v1beta1/tfevent-metricscollector/requirements.txt
133134

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ custom algorithm.
125125
<td>
126126
</td>
127127
</tr>
128+
<tr align="center">
129+
<td>
130+
<a href="https://www.kubeflow.org/docs/components/katib/experiment/#pbt">Population Based Training</a>
131+
</td>
132+
<td>
133+
</td>
134+
<td>
135+
</td>
136+
</tr>
128137
</tbody>
129138
</table>
130139

cmd/suggestion/pbt/v1beta1/Dockerfile

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
FROM python:3.9-slim
2+
3+
ENV TARGET_DIR /opt/katib
4+
ENV SUGGESTION_DIR cmd/suggestion/pbt/v1beta1
5+
ENV GRPC_HEALTH_PROBE_VERSION v0.4.6
6+
7+
RUN if [ "$(uname -m)" = "ppc64le" ] || [ "$(uname -m)" = "aarch64" ]; then \
8+
apt-get -y update && \
9+
apt-get -y install gfortran libopenblas-dev liblapack-dev wget && \
10+
apt-get clean && \
11+
rm -rf /var/lib/apt/lists/*; \
12+
else \
13+
apt-get -y update && \
14+
apt-get -y install wget && \
15+
apt-get clean && \
16+
rm -rf /var/lib/apt/lists/*; \
17+
fi
18+
RUN if [ "$(uname -m)" = "ppc64le" ]; then \
19+
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-ppc64le; \
20+
elif [ "$(uname -m)" = "aarch64" ]; then \
21+
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-arm64; \
22+
else \
23+
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64; \
24+
fi && \
25+
chmod +x /bin/grpc_health_probe
26+
27+
ADD ./pkg/ ${TARGET_DIR}/pkg/
28+
ADD ./${SUGGESTION_DIR}/ ${TARGET_DIR}/${SUGGESTION_DIR}/
29+
WORKDIR ${TARGET_DIR}/${SUGGESTION_DIR}
30+
RUN pip install --no-cache-dir -r requirements.txt
31+
32+
RUN chgrp -R 0 ${TARGET_DIR} \
33+
&& chmod -R g+rwX ${TARGET_DIR}
34+
35+
ENV PYTHONPATH ${TARGET_DIR}:${TARGET_DIR}/pkg/apis/manager/v1beta1/python:${TARGET_DIR}/pkg/apis/manager/health/python
36+
37+
ENTRYPOINT ["python", "main.py"]

cmd/suggestion/pbt/v1beta1/main.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2022 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import grpc
16+
import time
17+
from pkg.apis.manager.v1beta1.python import api_pb2_grpc
18+
from pkg.apis.manager.health.python import health_pb2_grpc
19+
from pkg.suggestion.v1beta1.pbt.service import PbtService
20+
from concurrent import futures
21+
22+
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
23+
DEFAULT_PORT = "0.0.0.0:6789"
24+
25+
26+
def serve():
27+
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
28+
service = PbtService()
29+
api_pb2_grpc.add_SuggestionServicer_to_server(service, server)
30+
health_pb2_grpc.add_HealthServicer_to_server(service, server)
31+
32+
server.add_insecure_port(DEFAULT_PORT)
33+
print("Listening...")
34+
server.start()
35+
try:
36+
while True:
37+
time.sleep(_ONE_DAY_IN_SECONDS)
38+
except KeyboardInterrupt:
39+
server.stop(0)
40+
41+
42+
if __name__ == "__main__":
43+
serve()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
grpcio==1.41.1
2+
protobuf==3.19.1
3+
googleapis-common-protos==1.53.0
4+
numpy==1.22.2

docs/new-algorithm-service.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ from pkg.apis.manager.v1beta1.python import api_pb2_grpc
1818
from pkg.suggestion.v1beta1.internal.search_space import HyperParameter, HyperParameterSearchSpace
1919
from pkg.suggestion.v1beta1.internal.trial import Trial, Assignment
2020
from pkg.suggestion.v1beta1.hyperopt.base_service import BaseHyperoptService
21-
from pkg.suggestion.v1beta1.base_health_service import HealthServicer
21+
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
2222

2323

2424
# Inherit SuggestionServicer and implement GetSuggestions.
@@ -90,9 +90,7 @@ Then build the Docker image.
9090

9191
### Use the algorithm in Katib.
9292

93-
Update the [Katib config](../manifests/v1beta1/components/controller/katib-config.yaml)
94-
and [Katib config patch](../manifests/v1beta1/installs/katib-standalone/katib-config-patch.yaml)
95-
with the new algorithm entity:
93+
Update the [Katib config](../manifests/v1beta1/components/controller/katib-config.yaml) and [operator](../operators/katib-controller/src/suggestion.json) with the new algorithm entity:
9694

9795
```diff
9896
suggestion: |-

examples/v1beta1/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ Experiments for the following algorithms:
4444

4545
- [HyperBand](./hp-tuning/hyperband.yaml)
4646

47+
- [PBT](./hp-tuning/simple-pbt.yaml)
48+
4749
### Neural Architecture Search
4850

4951
Check the [Neural Architecture Search](https://www.kubeflow.org/docs/components/katib/overview/#neural-architecture-search)
@@ -110,6 +112,8 @@ Check the following images for the Trial containers:
110112

111113
- [DARTS PyTorch CNN CIFAR-10](./trial-images/darts-cnn-cifar10)
112114

115+
- [PBT proof of concept](./trial-images/simple-pbt)
116+
113117
## Katib with Kubeflow Training Operator
114118

115119
Katib has out of the box support for the [Kubeflow Training Operators](https://github.com/kubeflow/training-operator) to
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
apiVersion: kubeflow.org/v1beta1
2+
kind: Experiment
3+
metadata:
4+
namespace: kubeflow
5+
name: simple-pbt
6+
spec:
7+
maxTrialCount: 2
8+
parallelTrialCount: 2
9+
maxFailedTrialCount: 3
10+
resumePolicy: FromVolume
11+
objective:
12+
type: maximize
13+
goal: 0.99
14+
objectiveMetricName: Validation-accuracy
15+
algorithm:
16+
algorithmName: pbt
17+
algorithmSettings:
18+
- name: suggestion_trial_dir
19+
value: /var/log/katib/checkpoints/
20+
- name: n_population
21+
value: '40'
22+
- name: truncation_threshold
23+
value: '0.2'
24+
parameters:
25+
- name: lr
26+
parameterType: double
27+
feasibleSpace:
28+
min: '0.0001'
29+
max: '0.02'
30+
step: '0.0001'
31+
trialTemplate:
32+
primaryContainerName: training-container
33+
trialParameters:
34+
- name: learningRate
35+
description: Learning rate for training the model
36+
reference: lr
37+
trialSpec:
38+
apiVersion: batch/v1
39+
kind: Job
40+
spec:
41+
template:
42+
spec:
43+
containers:
44+
- name: training-container
45+
image: docker.io/kubeflowkatib/simple-pbt:latest
46+
command:
47+
- "python3"
48+
- "/opt/pbt/pbt_test.py"
49+
- "--epochs=20"
50+
- "--lr=${trialParameters.learningRate}"
51+
- "--checkpoint=/var/log/katib/checkpoints/"
52+
restartPolicy: Never
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
FROM python:3.9-slim
2+
3+
ADD examples/v1beta1/trial-images/simple-pbt /opt/pbt
4+
WORKDIR /opt/pbt
5+
6+
RUN python3 -m pip install -r requirements.txt
7+
8+
RUN chgrp -R 0 /opt/pbt \
9+
&& chmod -R g+rwX /opt/pbt
10+
11+
ENTRYPOINT ["python3", "/opt/pbt/pbt_test.py"]
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#!/usr/bin/env python
2+
3+
# Implementation based on:
4+
# https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_example.py
5+
6+
import argparse
7+
import numpy as np
8+
import os
9+
import pickle
10+
import random
11+
import time
12+
13+
# Ensure job runs for at least this long (secs) to allow metrics collector to
14+
# read PID correctly before cleanup
15+
_METRICS_COLLECTOR_SPAWN_LATENCY = 7
16+
17+
18+
class PBTBenchmarkExample:
19+
"""Toy PBT problem for benchmarking adaptive learning rate.
20+
The goal is to optimize this trainable's accuracy. The accuracy increases
21+
fastest at the optimal lr, which is a function of the current accuracy.
22+
The optimal lr schedule for this problem is the triangle wave as follows.
23+
Note that many lr schedules for real models also follow this shape:
24+
best lr
25+
^
26+
| /\
27+
| / \
28+
| / \
29+
| / \
30+
------------> accuracy
31+
In this problem, using PBT with a population of 2-4 is sufficient to
32+
roughly approximate this lr schedule. Higher population sizes will yield
33+
faster convergence. Training will not converge without PBT.
34+
"""
35+
36+
def __init__(self, lr, checkpoint: str):
37+
self._lr = lr
38+
39+
self._checkpoint_file = os.path.join(checkpoint, "training.ckpt")
40+
if os.path.exists(self._checkpoint_file):
41+
with open(self._checkpoint_file, "rb") as fin:
42+
checkpoint_data = pickle.load(fin)
43+
self._accuracy = checkpoint_data["accuracy"]
44+
self._step = checkpoint_data["step"]
45+
else:
46+
os.makedirs(checkpoint, exist_ok=True)
47+
self._step = 1
48+
self._accuracy = 0.0
49+
50+
def save_checkpoint(self):
51+
with open(self._checkpoint_file, "wb") as fout:
52+
pickle.dump({"step": self._step, "accuracy": self._accuracy}, fout)
53+
54+
def step(self):
55+
midpoint = 50 # lr starts decreasing after acc > midpoint
56+
q_tolerance = 3 # penalize exceeding lr by more than this multiple
57+
noise_level = 2 # add gaussian noise to the acc increase
58+
# triangle wave:
59+
# - start at 0.001 @ t=0,
60+
# - peak at 0.01 @ t=midpoint,
61+
# - end at 0.001 @ t=midpoint * 2,
62+
if self._accuracy < midpoint:
63+
optimal_lr = 0.01 * self._accuracy / midpoint
64+
else:
65+
optimal_lr = 0.01 - 0.01 * (self._accuracy - midpoint) / midpoint
66+
optimal_lr = min(0.01, max(0.001, optimal_lr))
67+
68+
# compute accuracy increase
69+
q_err = max(self._lr, optimal_lr) / (
70+
min(self._lr, optimal_lr) + np.finfo(float).eps
71+
)
72+
if q_err < q_tolerance:
73+
self._accuracy += (1.0 / q_err) * random.random()
74+
elif self._lr > optimal_lr:
75+
self._accuracy -= (q_err - q_tolerance) * random.random()
76+
self._accuracy += noise_level * np.random.normal()
77+
self._accuracy = max(0, min(100, self._accuracy))
78+
79+
self._step += 1
80+
81+
def __repr__(self):
82+
return "epoch {}:\nlr={:0.4f}\nValidation-accuracy={:0.4f}".format(
83+
self._step, self._lr, self._accuracy / 100
84+
)
85+
86+
87+
if __name__ == "__main__":
88+
# Parse CLI arguments
89+
parser = argparse.ArgumentParser(description="PBT Basic Test")
90+
parser.add_argument(
91+
"--lr", type=float, default=0.0001, help="learning rate (default: 0.0001)"
92+
)
93+
parser.add_argument(
94+
"--epochs", type=int, default=20, help="number of epochs to train (default: 20)"
95+
)
96+
parser.add_argument(
97+
"--checkpoint",
98+
type=str,
99+
default="/var/log/katib/checkpoints/",
100+
help="checkpoint directory (resume and save)",
101+
)
102+
opt = parser.parse_args()
103+
104+
benchmark = PBTBenchmarkExample(opt.lr, opt.checkpoint)
105+
106+
start_time = time.time()
107+
for i in range(opt.epochs):
108+
benchmark.step()
109+
exec_time_thresh = time.time() - start_time - _METRICS_COLLECTOR_SPAWN_LATENCY
110+
if exec_time_thresh < 0:
111+
time.sleep(abs(exec_time_thresh))
112+
benchmark.save_checkpoint()
113+
114+
print(benchmark)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
numpy==1.22.2

0 commit comments

Comments
 (0)