Skip to content

Commit 6963074

Browse files
committed
feat: PbtTemplate and associated test image
1 parent bfa52bb commit 6963074

File tree

5 files changed

+143
-0
lines changed

5 files changed

+143
-0
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
FROM tensorflow/tensorflow:2.8.0
2+
3+
ADD test/functional/v1beta1/suggestion/pbt /opt/pbt
4+
WORKDIR /opt/pbt
5+
6+
RUN chgrp -R 0 /opt/pbt \
7+
&& chmod -R g+rwX /opt/pbt
8+
9+
ENTRYPOINT ["python3", "/opt/pbt/pbt_test.py"]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#!/usr/bin/env python
2+
3+
# Implementation based on:
4+
# https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_example.py
5+
6+
import argparse
7+
import numpy as np
8+
import os
9+
import pickle
10+
import random
11+
import tensorflow as tf
12+
import time
13+
14+
class PBTBenchmarkExample():
15+
"""Toy PBT problem for benchmarking adaptive learning rate.
16+
The goal is to optimize this trainable's accuracy. The accuracy increases
17+
fastest at the optimal lr, which is a function of the current accuracy.
18+
The optimal lr schedule for this problem is the triangle wave as follows.
19+
Note that many lr schedules for real models also follow this shape:
20+
best lr
21+
^
22+
| /\
23+
| / \
24+
| / \
25+
| / \
26+
------------> accuracy
27+
In this problem, using PBT with a population of 2-4 is sufficient to
28+
roughly approximate this lr schedule. Higher population sizes will yield
29+
faster convergence. Training will not converge without PBT.
30+
"""
31+
32+
def __init__(self, lr, log_dir: str, log_interval: int, checkpoint: str):
33+
# Allow lazy creation of tfevent file
34+
self._log_dir = log_dir
35+
self._writer = None
36+
self._log_interval = log_interval
37+
self._lr = lr
38+
39+
self._checkpoint_file = os.path.join(checkpoint, 'training.ckpt')
40+
if os.path.exists(self._checkpoint_file):
41+
with open(self._checkpoint_file, 'rb') as fin:
42+
checkpoint_data = pickle.load(fin)
43+
self._accuracy = checkpoint_data['accuracy']
44+
self._step = checkpoint_data['step']
45+
else:
46+
os.makedirs(checkpoint, exist_ok=True)
47+
self._step = 1
48+
self._accuracy = 0.0
49+
50+
51+
def save_checkpoint(self):
52+
with open(self._checkpoint_file, 'wb') as fout:
53+
pickle.dump({'step': self._step, 'accuracy': self._accuracy}, fout)
54+
55+
def step(self):
56+
midpoint = 100 # lr starts decreasing after acc > midpoint
57+
q_tolerance = 3 # penalize exceeding lr by more than this multiple
58+
noise_level = 2 # add gaussian noise to the acc increase
59+
# triangle wave:
60+
# - start at 0.001 @ t=0,
61+
# - peak at 0.01 @ t=midpoint,
62+
# - end at 0.001 @ t=midpoint * 2,
63+
if self._accuracy < midpoint:
64+
optimal_lr = 0.01 * self._accuracy / midpoint
65+
else:
66+
optimal_lr = 0.01 - 0.01 * (self._accuracy - midpoint) / midpoint
67+
optimal_lr = min(0.01, max(0.001, optimal_lr))
68+
69+
# compute accuracy increase
70+
q_err = max(self._lr, optimal_lr) / min(self._lr, optimal_lr)
71+
if q_err < q_tolerance:
72+
self._accuracy += (1.0 / q_err) * random.random()
73+
elif self._lr > optimal_lr:
74+
self._accuracy -= (q_err - q_tolerance) * random.random()
75+
self._accuracy += noise_level * np.random.normal()
76+
self._accuracy = max(0, self._accuracy)
77+
78+
if self._step == 1 or self._step % self._log_interval == 0:
79+
self.save_checkpoint()
80+
if not self._writer:
81+
self._writer = tf.summary.create_file_writer(self._log_dir)
82+
with self._writer.as_default():
83+
tf.summary.scalar("Validation-accuracy", self._accuracy, step=self._step)
84+
tf.summary.scalar("lr", self._lr, step=self._step)
85+
self._writer.flush()
86+
87+
self._step += 1
88+
89+
def __repr__(self):
90+
return "epoch {}:\nlr={:0.4f}\nValidation-accuracy={:0.4f}".format(self._step, self._lr, self._accuracy)
91+
92+
93+
if __name__ == "__main__":
94+
# Parse CLI arguments
95+
parser = argparse.ArgumentParser(description='PBT Basic Test')
96+
parser.add_argument('--lr', type=float, default=0.0001,
97+
help='learning rate (default: 0.0001)')
98+
parser.add_argument('--epochs', type=int, default=20,
99+
help='number of epochs to train (default: 20)')
100+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
101+
help='how many batches to wait before logging training status (default: 1)')
102+
parser.add_argument('--log-path', type=str, default="/var/log/katib/tfevent/",
103+
help='tfevent output path (default: /var/log/katib/tfevent/)')
104+
parser.add_argument('--checkpoint', type=str, default="/var/log/katib/checkpoints/",
105+
help='checkpoint directory (resume and save)')
106+
opt = parser.parse_args()
107+
108+
benchmark = PBTBenchmarkExample(opt.lr, opt.log_path, opt.log_interval, opt.checkpoint)
109+
for i in range(opt.epochs):
110+
benchmark.step()
111+
time.sleep(0.2)
112+
print(benchmark)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
numpy==1.22.2

manifests/v1beta1/components/controller/trial-templates.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,23 @@ data:
7676
- "--epochs=1"
7777
- "--lr=${trialParameters.learningRate}"
7878
- "--momentum=${trialParameters.momentum}"
79+
pbtTemplate.yaml: |-
80+
apiVersion: batch/v1
81+
kind: Job
82+
spec:
83+
template:
84+
metadata:
85+
annotations:
86+
sidecar.istio.io/inject: "false"
87+
spec:
88+
containers:
89+
- name: training-container
90+
image: docker.io/kubeflowkatib/simple-pbt:latest
91+
imagePullPolicy: Always
92+
command:
93+
- "python3"
94+
- "/opt/pbt/pbt_test.py"
95+
- "--epochs=20"
96+
- "--lr=${trialParameters.learningRate}"
97+
- "--checkpoint=/var/log/katib/checkpoints/"
98+
restartPolicy: Never

operators/katib-controller/src/charm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def set_pod_spec(self, event):
195195
for f, suffix in (
196196
("defaultTrialTemplate", ".yaml"),
197197
("enasCPUTemplate", ""),
198+
("pbtTemplate", ""),
198199
("pytorchJobTemplate", ""),
199200
)
200201
},

0 commit comments

Comments
 (0)