Skip to content

Commit 7da7f21

Browse files
GhassenJedward-bot
authored andcommitted
Add diversity regularization to BatchEnsemble on CIFAR and ImageNet.
PiperOrigin-RevId: 314502866
1 parent 990e3e7 commit 7da7f21

File tree

2 files changed

+124
-50
lines changed

2 files changed

+124
-50
lines changed

baselines/imagenet/batchensemble.py

Lines changed: 116 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
import edward2 as ed
2626
import batchensemble_model # local file import
2727
import utils # local file import
28-
import tensorflow as tf
28+
from edward2.google.rank1_pert.ensemble_keras import utils as be_utils
29+
import tensorflow.compat.v2 as tf
2930

3031
flags.DEFINE_integer('ensemble_size', 4, 'Size of ensemble.')
3132
flags.DEFINE_integer('per_core_batch_size', 128, 'Batch size per TPU core/GPU.')
@@ -39,16 +40,18 @@
3940
'fast weights lr multiplier.')
4041
flags.DEFINE_string('data_dir', None, 'Path to training and testing data.')
4142
flags.mark_flag_as_required('data_dir')
42-
flags.DEFINE_string('output_dir', '/tmp/imagenet',
43-
'The directory where the model weights and '
44-
'training/evaluation summaries are stored.')
43+
flags.DEFINE_string(
44+
'output_dir', '/tmp/imagenet', 'The directory where the model weights and '
45+
'training/evaluation summaries are stored.')
4546
flags.DEFINE_integer('train_epochs', 135, 'Number of training epochs.')
46-
flags.DEFINE_integer('corruptions_interval', 135,
47-
'Number of epochs between evaluating on the corrupted '
48-
'test data. Use -1 to never evaluate.')
49-
flags.DEFINE_integer('checkpoint_interval', 27,
50-
'Number of epochs between saving checkpoints. Use -1 to '
51-
'never save checkpoints.')
47+
flags.DEFINE_integer(
48+
'corruptions_interval', 135,
49+
'Number of epochs between evaluating on the corrupted '
50+
'test data. Use -1 to never evaluate.')
51+
flags.DEFINE_integer(
52+
'checkpoint_interval', 27,
53+
'Number of epochs between saving checkpoints. Use -1 to '
54+
'never save checkpoints.')
5255
flags.DEFINE_string('alexnet_errors_path', None,
5356
'Path to AlexNet corruption errors file.')
5457
flags.DEFINE_integer('num_bins', 15, 'Number of bins for ECE computation.')
@@ -60,6 +63,22 @@
6063
flags.DEFINE_integer('num_cores', 32, 'Number of TPU cores or number of GPUs.')
6164
flags.DEFINE_string('tpu', None,
6265
'Name of the TPU. Only used if use_gpu is False.')
66+
flags.DEFINE_string('similarity_metric', 'cosine', 'Similarity metric in '
67+
'[cosine, dpp_logdet]')
68+
flags.DEFINE_string('dpp_kernel', 'linear', 'Kernel for DPP log determinant')
69+
flags.DEFINE_bool('use_output_similarity', False,
70+
'If true, compute similarity on the ensemble outputs.')
71+
flags.DEFINE_enum('diversity_scheduler', 'LinearAnnealing',
72+
['LinearAnnealing', 'ExponentialDecay', 'Fixed'],
73+
'Diversity coefficient scheduler..')
74+
flags.DEFINE_float('annealing_epochs', 200,
75+
'Number of epochs over which to linearly anneal')
76+
flags.DEFINE_float('diversity_coeff', 0., 'Diversity loss coefficient.')
77+
flags.DEFINE_float('diversity_decay_epoch', 4, 'Diversity decay epoch.')
78+
flags.DEFINE_float('diversity_decay_rate', 0.97, 'Rate of exponential decay.')
79+
flags.DEFINE_integer('diversity_start_epoch', 100,
80+
'Diversity loss starting epoch')
81+
6382
FLAGS = flags.FLAGS
6483

6584
# Number of images in ImageNet-1k train dataset.
@@ -68,7 +87,7 @@
6887
IMAGENET_VALIDATION_IMAGES = 50000
6988
NUM_CLASSES = 1000
7089

71-
_LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
90+
_LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
7291
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
7392
]
7493

@@ -147,22 +166,53 @@ def main(argv):
147166
logging.info('Model number of weights: %s', model.count_params())
148167
# Scale learning rate and decay epochs by vanilla settings.
149168
base_lr = FLAGS.base_learning_rate * batch_size / 256
150-
learning_rate = utils.LearningRateSchedule(steps_per_epoch,
151-
base_lr,
152-
FLAGS.train_epochs,
153-
_LR_SCHEDULE)
154-
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
155-
momentum=0.9,
156-
nesterov=True)
169+
learning_rate = utils.LearningRateSchedule(steps_per_epoch, base_lr,
170+
FLAGS.train_epochs, _LR_SCHEDULE)
171+
optimizer = tf.keras.optimizers.SGD(
172+
learning_rate=learning_rate, momentum=0.9, nesterov=True)
173+
174+
if FLAGS.diversity_scheduler == 'ExponentialDecay':
175+
diversity_schedule = be_utils.ExponentialDecay(
176+
initial_coeff=FLAGS.diversity_coeff,
177+
start_epoch=FLAGS.diversity_start_epoch,
178+
decay_epoch=FLAGS.diversity_decay_epoch,
179+
steps_per_epoch=steps_per_epoch,
180+
decay_rate=FLAGS.diversity_decay_rate,
181+
staircase=True)
182+
183+
elif FLAGS.diversity_scheduler == 'LinearAnnealing':
184+
diversity_schedule = be_utils.LinearAnnealing(
185+
initial_coeff=FLAGS.diversity_coeff,
186+
annealing_epochs=FLAGS.annealing_epochs,
187+
steps_per_epoch=steps_per_epoch)
188+
else:
189+
diversity_schedule = lambda x: FLAGS.diversity_coeff
190+
157191
metrics = {
158-
'train/negative_log_likelihood': tf.keras.metrics.Mean(),
159-
'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
160-
'train/loss': tf.keras.metrics.Mean(),
161-
'train/ece': ed.metrics.ExpectedCalibrationError(
162-
num_bins=FLAGS.num_bins),
163-
'test/negative_log_likelihood': tf.keras.metrics.Mean(),
164-
'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
165-
'test/ece': ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins)
192+
'train/similarity_loss':
193+
tf.keras.metrics.Mean(),
194+
'train/weights_similarity':
195+
tf.keras.metrics.Mean(),
196+
'train/outputs_similarity':
197+
tf.keras.metrics.Mean(),
198+
'train/negative_log_likelihood':
199+
tf.keras.metrics.Mean(),
200+
'train/accuracy':
201+
tf.keras.metrics.SparseCategoricalAccuracy(),
202+
'train/loss':
203+
tf.keras.metrics.Mean(),
204+
'train/ece':
205+
ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
206+
'test/negative_log_likelihood':
207+
tf.keras.metrics.Mean(),
208+
'test/accuracy':
209+
tf.keras.metrics.SparseCategoricalAccuracy(),
210+
'test/ece':
211+
ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
212+
'test/weights_similarity':
213+
tf.keras.metrics.Mean(),
214+
'test/outputs_similarity':
215+
tf.keras.metrics.Mean()
166216
}
167217
if FLAGS.corruptions_interval > 0:
168218
corrupt_metrics = {}
@@ -208,6 +258,7 @@ def main(argv):
208258
@tf.function
209259
def train_step(iterator):
210260
"""Training StepFn."""
261+
211262
def step_fn(inputs):
212263
"""Per-Replica StepFn."""
213264
images, labels = inputs
@@ -225,10 +276,20 @@ def step_fn(inputs):
225276
diversity_results = ed.metrics.average_pairwise_diversity(
226277
per_probs, FLAGS.ensemble_size)
227278

279+
# print(' > per_probs {}'.format(per_probs))
280+
similarity_coeff, similarity_loss = be_utils.scaled_similarity_loss(
281+
FLAGS.diversity_coeff, diversity_schedule, optimizer.iterations,
282+
FLAGS.similarity_metric, FLAGS.dpp_kernel,
283+
model.trainable_variables, FLAGS.use_output_similarity, per_probs)
284+
weights_similarity = be_utils.fast_weights_similarity(
285+
model.trainable_variables, FLAGS.similarity_metric,
286+
FLAGS.dpp_kernel)
287+
outputs_similarity = be_utils.outputs_similarity(
288+
per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel)
289+
228290
negative_log_likelihood = tf.reduce_mean(
229-
tf.keras.losses.sparse_categorical_crossentropy(labels,
230-
logits,
231-
from_logits=True))
291+
tf.keras.losses.sparse_categorical_crossentropy(
292+
labels, logits, from_logits=True))
232293
filtered_variables = []
233294
for var in model.trainable_variables:
234295
# Apply l2 on the slow weights and bias terms. This excludes BN
@@ -239,7 +300,7 @@ def step_fn(inputs):
239300

240301
l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
241302
tf.concat(filtered_variables, axis=0))
242-
loss = negative_log_likelihood + l2_loss
303+
loss = negative_log_likelihood + l2_loss + similarity_coeff * similarity_loss
243304
# Scale the loss given the TPUStrategy will reduce sum all gradients.
244305
scaled_loss = loss / strategy.num_replicas_in_sync
245306

@@ -252,14 +313,18 @@ def step_fn(inputs):
252313
# Apply different learning rate on the fast weights. This excludes BN
253314
# and slow weights, but pay caution to the naming scheme.
254315
if ('batch_norm' not in var.name and 'kernel' not in var.name):
255-
grads_and_vars.append((grad * FLAGS.fast_weight_lr_multiplier,
256-
var))
316+
grads_and_vars.append((grad * FLAGS.fast_weight_lr_multiplier, var))
257317
else:
258318
grads_and_vars.append((grad, var))
259319
optimizer.apply_gradients(grads_and_vars)
260320
else:
261321
optimizer.apply_gradients(zip(grads, model.trainable_variables))
262322

323+
metrics['train/similarity_loss'].update_state(similarity_coeff *
324+
similarity_loss)
325+
metrics['train/weights_similarity'].update_state(weights_similarity)
326+
metrics['train/outputs_similarity'].update_state(outputs_similarity)
327+
263328
metrics['train/ece'].update_state(labels, probs)
264329
metrics['train/loss'].update_state(loss)
265330
metrics['train/negative_log_likelihood'].update_state(
@@ -273,6 +338,7 @@ def step_fn(inputs):
273338
@tf.function
274339
def test_step(iterator, dataset_name):
275340
"""Evaluation StepFn."""
341+
276342
def step_fn(inputs):
277343
"""Per-Replica StepFn."""
278344
images, labels = inputs
@@ -287,6 +353,8 @@ def step_fn(inputs):
287353
probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0))
288354
diversity_results = ed.metrics.average_pairwise_diversity(
289355
per_probs_tensor, FLAGS.ensemble_size)
356+
outputs_similarity = be_utils.outputs_similarity(
357+
per_probs_tensor, FLAGS.similarity_metric, FLAGS.dpp_kernel)
290358
for k, v in diversity_results.items():
291359
test_diversity['test/' + k].update_state(v)
292360

@@ -310,6 +378,11 @@ def step_fn(inputs):
310378
negative_log_likelihood)
311379
metrics['test/accuracy'].update_state(labels, probs)
312380
metrics['test/ece'].update_state(labels, probs)
381+
weights_similarity = be_utils.fast_weights_similarity(
382+
model.trainable_variables, FLAGS.similarity_metric,
383+
FLAGS.dpp_kernel)
384+
metrics['test/weights_similarity'].update_state(weights_similarity)
385+
metrics['test/outputs_similarity'].update_state(outputs_similarity)
313386
else:
314387
corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state(
315388
negative_log_likelihood)
@@ -334,12 +407,8 @@ def step_fn(inputs):
334407
eta_seconds = (max_steps - current_step) / steps_per_sec
335408
message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
336409
'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
337-
current_step / max_steps,
338-
epoch + 1,
339-
FLAGS.train_epochs,
340-
steps_per_sec,
341-
eta_seconds / 60,
342-
time_elapsed / 60))
410+
current_step / max_steps, epoch + 1, FLAGS.train_epochs,
411+
steps_per_sec, eta_seconds / 60, time_elapsed / 60))
343412
if step % 20 == 0:
344413
logging.info(message)
345414

@@ -352,8 +421,7 @@ def step_fn(inputs):
352421
logging.info('Testing on dataset %s', dataset_name)
353422
for step in range(steps_per_eval):
354423
if step % 20 == 0:
355-
logging.info('Starting to run eval step %s of epoch: %s', step,
356-
epoch)
424+
logging.info('Starting to run eval step %s of epoch: %s', step, epoch)
357425
test_step(test_iterator, dataset_name)
358426
logging.info('Done with testing on %s', dataset_name)
359427

@@ -371,15 +439,16 @@ def step_fn(inputs):
371439
metrics['test/negative_log_likelihood'].result(),
372440
metrics['test/accuracy'].result() * 100)
373441
for i in range(FLAGS.ensemble_size):
374-
logging.info('Member %d Test Loss: %.4f, Accuracy: %.2f%%',
375-
i, metrics['test/nll_member_{}'.format(i)].result(),
442+
logging.info('Member %d Test Loss: %.4f, Accuracy: %.2f%%', i,
443+
metrics['test/nll_member_{}'.format(i)].result(),
376444
metrics['test/accuracy_member_{}'.format(i)].result() * 100)
377445

378446
total_metrics = metrics.copy()
379447
total_metrics.update(training_diversity)
380448
total_metrics.update(test_diversity)
381-
total_results = {name: metric.result()
382-
for name, metric in total_metrics.items()}
449+
total_results = {
450+
name: metric.result() for name, metric in total_metrics.items()
451+
}
383452
total_results.update(corrupt_results)
384453
with summary_writer.as_default():
385454
for name, result in total_results.items():
@@ -390,13 +459,14 @@ def step_fn(inputs):
390459

391460
if (FLAGS.checkpoint_interval > 0 and
392461
(epoch + 1) % FLAGS.checkpoint_interval == 0):
393-
checkpoint_name = checkpoint.save(os.path.join(
394-
FLAGS.output_dir, 'checkpoint'))
462+
checkpoint_name = checkpoint.save(
463+
os.path.join(FLAGS.output_dir, 'checkpoint'))
395464
logging.info('Saved checkpoint to %s', checkpoint_name)
396465

397466
final_checkpoint_name = checkpoint.save(
398467
os.path.join(FLAGS.output_dir, 'checkpoint'))
399468
logging.info('Saved last checkpoint to %s', final_checkpoint_name)
400469

470+
401471
if __name__ == '__main__':
402472
app.run(main)

edward2/tensorflow/metrics.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,13 @@ def logit_kl_divergence(logits_1, logits_2):
171171
return tf.reduce_mean(vals)
172172

173173

174-
def kl_divergence(p, q):
174+
def kl_divergence(p, q, clip=True):
175175
"""Generalized KL divergence [1] for unnormalized distributions.
176176
177177
Args:
178178
p: tf.Tensor.
179-
q: tf.Tensor
179+
q: tf.Tensor.
180+
clip: bool.
180181
181182
Returns:
182183
tf.Tensor of the Kullback-Leibler divergences per example.
@@ -187,7 +188,10 @@ def kl_divergence(p, q):
187188
matrix factorization." Advances in neural information processing systems.
188189
2001.
189190
"""
190-
return tf.reduce_sum(p * tf.math.log(p / q) - p + q, axis=-1)
191+
if clip:
192+
p = tf.clip_by_value(p, tf.keras.backend.epsilon(), 1)
193+
q = tf.clip_by_value(q, tf.keras.backend.epsilon(), 1)
194+
return tf.reduce_sum(p * tf.math.log(p / q), axis=-1)
191195

192196

193197
def lp_distance(x, y, p=1):
@@ -229,7 +233,7 @@ def average_pairwise_diversity(probs, num_models, error=None):
229233
# TODO(ghassen): we could also return max and min pairwise metrics.
230234
average_disagreement = tf.reduce_mean(tf.stack(pairwise_disagreement))
231235
if error is not None:
232-
average_disagreement /= (error + tf.keras.backend.epsilon())
236+
average_disagreement /= (1 - error + tf.keras.backend.epsilon())
233237
average_kl_divergence = tf.reduce_mean(tf.stack(pairwise_kl_divergence))
234238
average_cosine_distance = tf.reduce_mean(tf.stack(pairwise_cosine_distance))
235239

0 commit comments

Comments
 (0)