25
25
import edward2 as ed
26
26
import batchensemble_model # local file import
27
27
import utils # local file import
28
- import tensorflow as tf
28
+ from edward2 .google .rank1_pert .ensemble_keras import utils as be_utils
29
+ import tensorflow .compat .v2 as tf
29
30
30
31
flags .DEFINE_integer ('ensemble_size' , 4 , 'Size of ensemble.' )
31
32
flags .DEFINE_integer ('per_core_batch_size' , 128 , 'Batch size per TPU core/GPU.' )
39
40
'fast weights lr multiplier.' )
40
41
flags .DEFINE_string ('data_dir' , None , 'Path to training and testing data.' )
41
42
flags .mark_flag_as_required ('data_dir' )
42
- flags .DEFINE_string ('output_dir' , '/tmp/imagenet' ,
43
- 'The directory where the model weights and '
44
- 'training/evaluation summaries are stored.' )
43
+ flags .DEFINE_string (
44
+ 'output_dir' , '/tmp/imagenet' , 'The directory where the model weights and '
45
+ 'training/evaluation summaries are stored.' )
45
46
flags .DEFINE_integer ('train_epochs' , 135 , 'Number of training epochs.' )
46
- flags .DEFINE_integer ('corruptions_interval' , 135 ,
47
- 'Number of epochs between evaluating on the corrupted '
48
- 'test data. Use -1 to never evaluate.' )
49
- flags .DEFINE_integer ('checkpoint_interval' , 27 ,
50
- 'Number of epochs between saving checkpoints. Use -1 to '
51
- 'never save checkpoints.' )
47
+ flags .DEFINE_integer (
48
+ 'corruptions_interval' , 135 ,
49
+ 'Number of epochs between evaluating on the corrupted '
50
+ 'test data. Use -1 to never evaluate.' )
51
+ flags .DEFINE_integer (
52
+ 'checkpoint_interval' , 27 ,
53
+ 'Number of epochs between saving checkpoints. Use -1 to '
54
+ 'never save checkpoints.' )
52
55
flags .DEFINE_string ('alexnet_errors_path' , None ,
53
56
'Path to AlexNet corruption errors file.' )
54
57
flags .DEFINE_integer ('num_bins' , 15 , 'Number of bins for ECE computation.' )
60
63
flags .DEFINE_integer ('num_cores' , 32 , 'Number of TPU cores or number of GPUs.' )
61
64
flags .DEFINE_string ('tpu' , None ,
62
65
'Name of the TPU. Only used if use_gpu is False.' )
66
+ flags .DEFINE_string ('similarity_metric' , 'cosine' , 'Similarity metric in '
67
+ '[cosine, dpp_logdet]' )
68
+ flags .DEFINE_string ('dpp_kernel' , 'linear' , 'Kernel for DPP log determinant' )
69
+ flags .DEFINE_bool ('use_output_similarity' , False ,
70
+ 'If true, compute similarity on the ensemble outputs.' )
71
+ flags .DEFINE_enum ('diversity_scheduler' , 'LinearAnnealing' ,
72
+ ['LinearAnnealing' , 'ExponentialDecay' , 'Fixed' ],
73
+ 'Diversity coefficient scheduler..' )
74
+ flags .DEFINE_float ('annealing_epochs' , 200 ,
75
+ 'Number of epochs over which to linearly anneal' )
76
+ flags .DEFINE_float ('diversity_coeff' , 0. , 'Diversity loss coefficient.' )
77
+ flags .DEFINE_float ('diversity_decay_epoch' , 4 , 'Diversity decay epoch.' )
78
+ flags .DEFINE_float ('diversity_decay_rate' , 0.97 , 'Rate of exponential decay.' )
79
+ flags .DEFINE_integer ('diversity_start_epoch' , 100 ,
80
+ 'Diversity loss starting epoch' )
81
+
63
82
FLAGS = flags .FLAGS
64
83
65
84
# Number of images in ImageNet-1k train dataset.
68
87
IMAGENET_VALIDATION_IMAGES = 50000
69
88
NUM_CLASSES = 1000
70
89
71
- _LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
90
+ _LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
72
91
(1.0 , 5 ), (0.1 , 30 ), (0.01 , 60 ), (0.001 , 80 )
73
92
]
74
93
@@ -147,22 +166,53 @@ def main(argv):
147
166
logging .info ('Model number of weights: %s' , model .count_params ())
148
167
# Scale learning rate and decay epochs by vanilla settings.
149
168
base_lr = FLAGS .base_learning_rate * batch_size / 256
150
- learning_rate = utils .LearningRateSchedule (steps_per_epoch ,
151
- base_lr ,
152
- FLAGS .train_epochs ,
153
- _LR_SCHEDULE )
154
- optimizer = tf .keras .optimizers .SGD (learning_rate = learning_rate ,
155
- momentum = 0.9 ,
156
- nesterov = True )
169
+ learning_rate = utils .LearningRateSchedule (steps_per_epoch , base_lr ,
170
+ FLAGS .train_epochs , _LR_SCHEDULE )
171
+ optimizer = tf .keras .optimizers .SGD (
172
+ learning_rate = learning_rate , momentum = 0.9 , nesterov = True )
173
+
174
+ if FLAGS .diversity_scheduler == 'ExponentialDecay' :
175
+ diversity_schedule = be_utils .ExponentialDecay (
176
+ initial_coeff = FLAGS .diversity_coeff ,
177
+ start_epoch = FLAGS .diversity_start_epoch ,
178
+ decay_epoch = FLAGS .diversity_decay_epoch ,
179
+ steps_per_epoch = steps_per_epoch ,
180
+ decay_rate = FLAGS .diversity_decay_rate ,
181
+ staircase = True )
182
+
183
+ elif FLAGS .diversity_scheduler == 'LinearAnnealing' :
184
+ diversity_schedule = be_utils .LinearAnnealing (
185
+ initial_coeff = FLAGS .diversity_coeff ,
186
+ annealing_epochs = FLAGS .annealing_epochs ,
187
+ steps_per_epoch = steps_per_epoch )
188
+ else :
189
+ diversity_schedule = lambda x : FLAGS .diversity_coeff
190
+
157
191
metrics = {
158
- 'train/negative_log_likelihood' : tf .keras .metrics .Mean (),
159
- 'train/accuracy' : tf .keras .metrics .SparseCategoricalAccuracy (),
160
- 'train/loss' : tf .keras .metrics .Mean (),
161
- 'train/ece' : ed .metrics .ExpectedCalibrationError (
162
- num_bins = FLAGS .num_bins ),
163
- 'test/negative_log_likelihood' : tf .keras .metrics .Mean (),
164
- 'test/accuracy' : tf .keras .metrics .SparseCategoricalAccuracy (),
165
- 'test/ece' : ed .metrics .ExpectedCalibrationError (num_bins = FLAGS .num_bins )
192
+ 'train/similarity_loss' :
193
+ tf .keras .metrics .Mean (),
194
+ 'train/weights_similarity' :
195
+ tf .keras .metrics .Mean (),
196
+ 'train/outputs_similarity' :
197
+ tf .keras .metrics .Mean (),
198
+ 'train/negative_log_likelihood' :
199
+ tf .keras .metrics .Mean (),
200
+ 'train/accuracy' :
201
+ tf .keras .metrics .SparseCategoricalAccuracy (),
202
+ 'train/loss' :
203
+ tf .keras .metrics .Mean (),
204
+ 'train/ece' :
205
+ ed .metrics .ExpectedCalibrationError (num_bins = FLAGS .num_bins ),
206
+ 'test/negative_log_likelihood' :
207
+ tf .keras .metrics .Mean (),
208
+ 'test/accuracy' :
209
+ tf .keras .metrics .SparseCategoricalAccuracy (),
210
+ 'test/ece' :
211
+ ed .metrics .ExpectedCalibrationError (num_bins = FLAGS .num_bins ),
212
+ 'test/weights_similarity' :
213
+ tf .keras .metrics .Mean (),
214
+ 'test/outputs_similarity' :
215
+ tf .keras .metrics .Mean ()
166
216
}
167
217
if FLAGS .corruptions_interval > 0 :
168
218
corrupt_metrics = {}
@@ -208,6 +258,7 @@ def main(argv):
208
258
@tf .function
209
259
def train_step (iterator ):
210
260
"""Training StepFn."""
261
+
211
262
def step_fn (inputs ):
212
263
"""Per-Replica StepFn."""
213
264
images , labels = inputs
@@ -225,10 +276,20 @@ def step_fn(inputs):
225
276
diversity_results = ed .metrics .average_pairwise_diversity (
226
277
per_probs , FLAGS .ensemble_size )
227
278
279
+ # print(' > per_probs {}'.format(per_probs))
280
+ similarity_coeff , similarity_loss = be_utils .scaled_similarity_loss (
281
+ FLAGS .diversity_coeff , diversity_schedule , optimizer .iterations ,
282
+ FLAGS .similarity_metric , FLAGS .dpp_kernel ,
283
+ model .trainable_variables , FLAGS .use_output_similarity , per_probs )
284
+ weights_similarity = be_utils .fast_weights_similarity (
285
+ model .trainable_variables , FLAGS .similarity_metric ,
286
+ FLAGS .dpp_kernel )
287
+ outputs_similarity = be_utils .outputs_similarity (
288
+ per_probs , FLAGS .similarity_metric , FLAGS .dpp_kernel )
289
+
228
290
negative_log_likelihood = tf .reduce_mean (
229
- tf .keras .losses .sparse_categorical_crossentropy (labels ,
230
- logits ,
231
- from_logits = True ))
291
+ tf .keras .losses .sparse_categorical_crossentropy (
292
+ labels , logits , from_logits = True ))
232
293
filtered_variables = []
233
294
for var in model .trainable_variables :
234
295
# Apply l2 on the slow weights and bias terms. This excludes BN
@@ -239,7 +300,7 @@ def step_fn(inputs):
239
300
240
301
l2_loss = FLAGS .l2 * 2 * tf .nn .l2_loss (
241
302
tf .concat (filtered_variables , axis = 0 ))
242
- loss = negative_log_likelihood + l2_loss
303
+ loss = negative_log_likelihood + l2_loss + similarity_coeff * similarity_loss
243
304
# Scale the loss given the TPUStrategy will reduce sum all gradients.
244
305
scaled_loss = loss / strategy .num_replicas_in_sync
245
306
@@ -252,14 +313,18 @@ def step_fn(inputs):
252
313
# Apply different learning rate on the fast weights. This excludes BN
253
314
# and slow weights, but pay caution to the naming scheme.
254
315
if ('batch_norm' not in var .name and 'kernel' not in var .name ):
255
- grads_and_vars .append ((grad * FLAGS .fast_weight_lr_multiplier ,
256
- var ))
316
+ grads_and_vars .append ((grad * FLAGS .fast_weight_lr_multiplier , var ))
257
317
else :
258
318
grads_and_vars .append ((grad , var ))
259
319
optimizer .apply_gradients (grads_and_vars )
260
320
else :
261
321
optimizer .apply_gradients (zip (grads , model .trainable_variables ))
262
322
323
+ metrics ['train/similarity_loss' ].update_state (similarity_coeff *
324
+ similarity_loss )
325
+ metrics ['train/weights_similarity' ].update_state (weights_similarity )
326
+ metrics ['train/outputs_similarity' ].update_state (outputs_similarity )
327
+
263
328
metrics ['train/ece' ].update_state (labels , probs )
264
329
metrics ['train/loss' ].update_state (loss )
265
330
metrics ['train/negative_log_likelihood' ].update_state (
@@ -273,6 +338,7 @@ def step_fn(inputs):
273
338
@tf .function
274
339
def test_step (iterator , dataset_name ):
275
340
"""Evaluation StepFn."""
341
+
276
342
def step_fn (inputs ):
277
343
"""Per-Replica StepFn."""
278
344
images , labels = inputs
@@ -287,6 +353,8 @@ def step_fn(inputs):
287
353
probs , tf .concat ([[FLAGS .ensemble_size , - 1 ], probs .shape [1 :]], 0 ))
288
354
diversity_results = ed .metrics .average_pairwise_diversity (
289
355
per_probs_tensor , FLAGS .ensemble_size )
356
+ outputs_similarity = be_utils .outputs_similarity (
357
+ per_probs_tensor , FLAGS .similarity_metric , FLAGS .dpp_kernel )
290
358
for k , v in diversity_results .items ():
291
359
test_diversity ['test/' + k ].update_state (v )
292
360
@@ -310,6 +378,11 @@ def step_fn(inputs):
310
378
negative_log_likelihood )
311
379
metrics ['test/accuracy' ].update_state (labels , probs )
312
380
metrics ['test/ece' ].update_state (labels , probs )
381
+ weights_similarity = be_utils .fast_weights_similarity (
382
+ model .trainable_variables , FLAGS .similarity_metric ,
383
+ FLAGS .dpp_kernel )
384
+ metrics ['test/weights_similarity' ].update_state (weights_similarity )
385
+ metrics ['test/outputs_similarity' ].update_state (outputs_similarity )
313
386
else :
314
387
corrupt_metrics ['test/nll_{}' .format (dataset_name )].update_state (
315
388
negative_log_likelihood )
@@ -334,12 +407,8 @@ def step_fn(inputs):
334
407
eta_seconds = (max_steps - current_step ) / steps_per_sec
335
408
message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
336
409
'ETA: {:.0f} min. Time elapsed: {:.0f} min' .format (
337
- current_step / max_steps ,
338
- epoch + 1 ,
339
- FLAGS .train_epochs ,
340
- steps_per_sec ,
341
- eta_seconds / 60 ,
342
- time_elapsed / 60 ))
410
+ current_step / max_steps , epoch + 1 , FLAGS .train_epochs ,
411
+ steps_per_sec , eta_seconds / 60 , time_elapsed / 60 ))
343
412
if step % 20 == 0 :
344
413
logging .info (message )
345
414
@@ -352,8 +421,7 @@ def step_fn(inputs):
352
421
logging .info ('Testing on dataset %s' , dataset_name )
353
422
for step in range (steps_per_eval ):
354
423
if step % 20 == 0 :
355
- logging .info ('Starting to run eval step %s of epoch: %s' , step ,
356
- epoch )
424
+ logging .info ('Starting to run eval step %s of epoch: %s' , step , epoch )
357
425
test_step (test_iterator , dataset_name )
358
426
logging .info ('Done with testing on %s' , dataset_name )
359
427
@@ -371,15 +439,16 @@ def step_fn(inputs):
371
439
metrics ['test/negative_log_likelihood' ].result (),
372
440
metrics ['test/accuracy' ].result () * 100 )
373
441
for i in range (FLAGS .ensemble_size ):
374
- logging .info ('Member %d Test Loss: %.4f, Accuracy: %.2f%%' ,
375
- i , metrics ['test/nll_member_{}' .format (i )].result (),
442
+ logging .info ('Member %d Test Loss: %.4f, Accuracy: %.2f%%' , i ,
443
+ metrics ['test/nll_member_{}' .format (i )].result (),
376
444
metrics ['test/accuracy_member_{}' .format (i )].result () * 100 )
377
445
378
446
total_metrics = metrics .copy ()
379
447
total_metrics .update (training_diversity )
380
448
total_metrics .update (test_diversity )
381
- total_results = {name : metric .result ()
382
- for name , metric in total_metrics .items ()}
449
+ total_results = {
450
+ name : metric .result () for name , metric in total_metrics .items ()
451
+ }
383
452
total_results .update (corrupt_results )
384
453
with summary_writer .as_default ():
385
454
for name , result in total_results .items ():
@@ -390,13 +459,14 @@ def step_fn(inputs):
390
459
391
460
if (FLAGS .checkpoint_interval > 0 and
392
461
(epoch + 1 ) % FLAGS .checkpoint_interval == 0 ):
393
- checkpoint_name = checkpoint .save (os . path . join (
394
- FLAGS .output_dir , 'checkpoint' ))
462
+ checkpoint_name = checkpoint .save (
463
+ os . path . join ( FLAGS .output_dir , 'checkpoint' ))
395
464
logging .info ('Saved checkpoint to %s' , checkpoint_name )
396
465
397
466
final_checkpoint_name = checkpoint .save (
398
467
os .path .join (FLAGS .output_dir , 'checkpoint' ))
399
468
logging .info ('Saved last checkpoint to %s' , final_checkpoint_name )
400
469
470
+
401
471
if __name__ == '__main__' :
402
472
app .run (main )
0 commit comments