Skip to content

Commit 7991b8a

Browse files
GhassenJedward-bot
authored andcommitted
Add diversity summaries to naive ensemble methods.
PiperOrigin-RevId: 317341908
1 parent 990e3e7 commit 7991b8a

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

baselines/cifar/ensemble.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ def main(argv):
196196
corrupt_metrics['test/ece_{}'.format(name)] = (
197197
ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins))
198198

199+
test_diversity = {
200+
'test/disagreement': tf.keras.metrics.Mean(),
201+
'test/average_kl': tf.keras.metrics.Mean(),
202+
'test/cosine_similarity': tf.keras.metrics.Mean(),
203+
}
204+
199205
# Evaluate model predictions.
200206
for n, (name, test_dataset) in enumerate(test_datasets.items()):
201207
logits_dataset = []
@@ -214,6 +220,11 @@ def main(argv):
214220
negative_log_likelihood = tf.reduce_mean(
215221
ensemble_negative_log_likelihood(labels, logits))
216222
per_probs = tf.nn.softmax(logits)
223+
diversity_results = ed.metrics.average_pairwise_diversity(
224+
per_probs, ensemble_size)
225+
for k, v in diversity_results.items():
226+
test_diversity['test/' + k].update_state(v)
227+
217228
probs = tf.reduce_mean(per_probs, axis=0)
218229
if name == 'clean':
219230
gibbs_ce = tf.reduce_mean(gibbs_cross_entropy(labels, logits))
@@ -234,6 +245,8 @@ def main(argv):
234245
(n + 1) / num_datasets, n + 1, num_datasets))
235246
logging.info(message)
236247

248+
total_metrics = metrics.copy()
249+
total_metrics.update(test_diversity)
237250
corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
238251
corruption_types,
239252
max_intensity)

baselines/imagenet/ensemble.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,11 @@ def main(argv):
195195
tf.keras.metrics.SparseCategoricalAccuracy())
196196
corrupt_metrics['test/ece_{}'.format(
197197
name)] = ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins)
198-
198+
test_diversity = {
199+
'test/disagreement': tf.keras.metrics.Mean(),
200+
'test/average_kl': tf.keras.metrics.Mean(),
201+
'test/cosine_similarity': tf.keras.metrics.Mean(),
202+
}
199203
# Evaluate model predictions.
200204
for n, (name, test_dataset) in enumerate(test_datasets.items()):
201205
logits_dataset = []
@@ -214,6 +218,10 @@ def main(argv):
214218
negative_log_likelihood = tf.reduce_mean(
215219
ensemble_negative_log_likelihood(labels, logits))
216220
per_probs = tf.nn.softmax(logits)
221+
diversity_results = ed.metrics.average_pairwise_diversity(
222+
per_probs, ensemble_size)
223+
for k, v in diversity_results.items():
224+
test_diversity['test/' + k].update_state(v)
217225
probs = tf.reduce_mean(per_probs, axis=0)
218226
if name == 'clean':
219227
gibbs_ce = tf.reduce_mean(gibbs_cross_entropy(labels, logits))
@@ -234,11 +242,15 @@ def main(argv):
234242
(n + 1) / num_datasets, n + 1, num_datasets))
235243
logging.info(message)
236244

245+
total_metrics = metrics.copy()
246+
total_metrics.update(test_diversity)
237247
corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
238248
corruption_types,
239249
max_intensity,
240250
FLAGS.alexnet_errors_path)
241-
total_results = {name: metric.result() for name, metric in metrics.items()}
251+
total_results = {
252+
name: metric.result() for name, metric in total_metrics.items()
253+
}
242254
total_results.update(corrupt_results)
243255
logging.info('Metrics: %s', total_results)
244256

0 commit comments

Comments
 (0)