@@ -195,7 +195,11 @@ def main(argv):
195
195
tf .keras .metrics .SparseCategoricalAccuracy ())
196
196
corrupt_metrics ['test/ece_{}' .format (
197
197
name )] = ed .metrics .ExpectedCalibrationError (num_bins = FLAGS .num_bins )
198
-
198
+ test_diversity = {
199
+ 'test/disagreement' : tf .keras .metrics .Mean (),
200
+ 'test/average_kl' : tf .keras .metrics .Mean (),
201
+ 'test/cosine_similarity' : tf .keras .metrics .Mean (),
202
+ }
199
203
# Evaluate model predictions.
200
204
for n , (name , test_dataset ) in enumerate (test_datasets .items ()):
201
205
logits_dataset = []
@@ -214,6 +218,10 @@ def main(argv):
214
218
negative_log_likelihood = tf .reduce_mean (
215
219
ensemble_negative_log_likelihood (labels , logits ))
216
220
per_probs = tf .nn .softmax (logits )
221
+ diversity_results = ed .metrics .average_pairwise_diversity (
222
+ per_probs , ensemble_size )
223
+ for k , v in diversity_results .items ():
224
+ test_diversity ['test/' + k ].update_state (v )
217
225
probs = tf .reduce_mean (per_probs , axis = 0 )
218
226
if name == 'clean' :
219
227
gibbs_ce = tf .reduce_mean (gibbs_cross_entropy (labels , logits ))
@@ -234,11 +242,15 @@ def main(argv):
234
242
(n + 1 ) / num_datasets , n + 1 , num_datasets ))
235
243
logging .info (message )
236
244
245
+ total_metrics = metrics .copy ()
246
+ total_metrics .update (test_diversity )
237
247
corrupt_results = utils .aggregate_corrupt_metrics (corrupt_metrics ,
238
248
corruption_types ,
239
249
max_intensity ,
240
250
FLAGS .alexnet_errors_path )
241
- total_results = {name : metric .result () for name , metric in metrics .items ()}
251
+ total_results = {
252
+ name : metric .result () for name , metric in total_metrics .items ()
253
+ }
242
254
total_results .update (corrupt_results )
243
255
logging .info ('Metrics: %s' , total_results )
244
256
0 commit comments