11
11
import tensorflow as tf
12
12
import time
13
13
14
- class PBTBenchmarkExample ():
14
+
15
+ class PBTBenchmarkExample :
15
16
"""Toy PBT problem for benchmarking adaptive learning rate.
16
17
The goal is to optimize this trainable's accuracy. The accuracy increases
17
18
fastest at the optimal lr, which is a function of the current accuracy.
@@ -36,24 +37,23 @@ def __init__(self, lr, log_dir: str, log_interval: int, checkpoint: str):
36
37
self ._log_interval = log_interval
37
38
self ._lr = lr
38
39
39
- self ._checkpoint_file = os .path .join (checkpoint , ' training.ckpt' )
40
+ self ._checkpoint_file = os .path .join (checkpoint , " training.ckpt" )
40
41
if os .path .exists (self ._checkpoint_file ):
41
- with open (self ._checkpoint_file , 'rb' ) as fin :
42
+ with open (self ._checkpoint_file , "rb" ) as fin :
42
43
checkpoint_data = pickle .load (fin )
43
- self ._accuracy = checkpoint_data [' accuracy' ]
44
- self ._step = checkpoint_data [' step' ]
44
+ self ._accuracy = checkpoint_data [" accuracy" ]
45
+ self ._step = checkpoint_data [" step" ]
45
46
else :
46
47
os .makedirs (checkpoint , exist_ok = True )
47
48
self ._step = 1
48
49
self ._accuracy = 0.0
49
-
50
50
51
51
def save_checkpoint (self ):
52
- with open (self ._checkpoint_file , 'wb' ) as fout :
53
- pickle .dump ({' step' : self ._step , ' accuracy' : self ._accuracy }, fout )
52
+ with open (self ._checkpoint_file , "wb" ) as fout :
53
+ pickle .dump ({" step" : self ._step , " accuracy" : self ._accuracy }, fout )
54
54
55
55
def step (self ):
56
- midpoint = 100 # lr starts decreasing after acc > midpoint
56
+ midpoint = 100 # lr starts decreasing after acc > midpoint
57
57
q_tolerance = 3 # penalize exceeding lr by more than this multiple
58
58
noise_level = 2 # add gaussian noise to the acc increase
59
59
# triangle wave:
@@ -80,32 +80,53 @@ def step(self):
80
80
if not self ._writer :
81
81
self ._writer = tf .summary .create_file_writer (self ._log_dir )
82
82
with self ._writer .as_default ():
83
- tf .summary .scalar ("Validation-accuracy" , self ._accuracy , step = self ._step )
83
+ tf .summary .scalar (
84
+ "Validation-accuracy" , self ._accuracy , step = self ._step
85
+ )
84
86
tf .summary .scalar ("lr" , self ._lr , step = self ._step )
85
87
self ._writer .flush ()
86
88
87
89
self ._step += 1
88
90
89
91
def __repr__ (self ):
90
- return "epoch {}:\n lr={:0.4f}\n Validation-accuracy={:0.4f}" .format (self ._step , self ._lr , self ._accuracy )
92
+ return "epoch {}:\n lr={:0.4f}\n Validation-accuracy={:0.4f}" .format (
93
+ self ._step , self ._lr , self ._accuracy
94
+ )
91
95
92
96
93
97
if __name__ == "__main__" :
94
98
# Parse CLI arguments
95
- parser = argparse .ArgumentParser (description = 'PBT Basic Test' )
96
- parser .add_argument ('--lr' , type = float , default = 0.0001 ,
97
- help = 'learning rate (default: 0.0001)' )
98
- parser .add_argument ('--epochs' , type = int , default = 20 ,
99
- help = 'number of epochs to train (default: 20)' )
100
- parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
101
- help = 'how many batches to wait before logging training status (default: 1)' )
102
- parser .add_argument ('--log-path' , type = str , default = "/var/log/katib/tfevent/" ,
103
- help = 'tfevent output path (default: /var/log/katib/tfevent/)' )
104
- parser .add_argument ('--checkpoint' , type = str , default = "/var/log/katib/checkpoints/" ,
105
- help = 'checkpoint directory (resume and save)' )
99
+ parser = argparse .ArgumentParser (description = "PBT Basic Test" )
100
+ parser .add_argument (
101
+ "--lr" , type = float , default = 0.0001 , help = "learning rate (default: 0.0001)"
102
+ )
103
+ parser .add_argument (
104
+ "--epochs" , type = int , default = 20 , help = "number of epochs to train (default: 20)"
105
+ )
106
+ parser .add_argument (
107
+ "--log-interval" ,
108
+ type = int ,
109
+ default = 10 ,
110
+ metavar = "N" ,
111
+ help = "how many batches to wait before logging training status (default: 1)" ,
112
+ )
113
+ parser .add_argument (
114
+ "--log-path" ,
115
+ type = str ,
116
+ default = "/var/log/katib/tfevent/" ,
117
+ help = "tfevent output path (default: /var/log/katib/tfevent/)" ,
118
+ )
119
+ parser .add_argument (
120
+ "--checkpoint" ,
121
+ type = str ,
122
+ default = "/var/log/katib/checkpoints/" ,
123
+ help = "checkpoint directory (resume and save)" ,
124
+ )
106
125
opt = parser .parse_args ()
107
126
108
- benchmark = PBTBenchmarkExample (opt .lr , opt .log_path , opt .log_interval , opt .checkpoint )
127
+ benchmark = PBTBenchmarkExample (
128
+ opt .lr , opt .log_path , opt .log_interval , opt .checkpoint
129
+ )
109
130
for i in range (opt .epochs ):
110
131
benchmark .step ()
111
132
time .sleep (0.2 )
0 commit comments