From 7e777f3ae816c742500dbdfa4cf8294164cf88f8 Mon Sep 17 00:00:00 2001 From: Ghassen Jerfel Date: Fri, 19 Jun 2020 11:11:41 -0700 Subject: [PATCH] Track corrupted output_similarity. PiperOrigin-RevId: 317342104 --- baselines/cifar/utils.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/baselines/cifar/utils.py b/baselines/cifar/utils.py index 5a6d7398..c4331f12 100644 --- a/baselines/cifar/utils.py +++ b/baselines/cifar/utils.py @@ -283,7 +283,8 @@ def aggregate_corrupt_metrics(metrics, Dictionary of aggregated results. """ - diversity_keys = ['disagreement', 'cosine_similarity', 'average_kl'] + diversity_keys = ['disagreement', 'cosine_similarity', 'average_kl', + 'outputs_similarity'] results = { 'test/nll_mean_corrupted': 0., 'test/accuracy_mean_corrupted': 0., @@ -305,7 +306,7 @@ def aggregate_corrupt_metrics(metrics, disagreement = np.zeros(len(corruption_types)) cosine_similarity = np.zeros(len(corruption_types)) average_kl = np.zeros(len(corruption_types)) - + outputs_similarity = np.zeros(len(corruption_types)) for i in range(len(corruption_types)): dataset_name = '{0}_{1}'.format(corruption_types[i], intensity) nll[i] = metrics['test/nll_{}'.format(dataset_name)].result() @@ -321,17 +322,21 @@ def aggregate_corrupt_metrics(metrics, dataset_name)].result() member_ece[i] = 0. if corrupt_diversity is not None: + error = 1 - acc[i] + tf.keras.backend.epsilon() disagreement[i] = ( corrupt_diversity['corrupt_diversity/disagreement_{}'.format( - dataset_name)].result()) + dataset_name)].result()) / error # Normalize the corrupt disagreement by its error rate. - error = 1 - acc[i] + tf.keras.backend.epsilon() cosine_similarity[i] = ( corrupt_diversity['corrupt_diversity/cosine_similarity_{}'.format( - dataset_name)].result()) / error + dataset_name)].result()) average_kl[i] = ( corrupt_diversity['corrupt_diversity/average_kl_{}'.format( dataset_name)].result()) + outputs_similarity[i] = ( + corrupt_diversity['corrupt_diversity/outputs_similarity_{}'.format( + dataset_name)].result()) + if log_fine_metrics or output_dir is not None: fine_metrics_results['test/nll_{}'.format(dataset_name)] = nll[i] fine_metrics_results['test/accuracy_{}'.format(dataset_name)] = acc[i] @@ -343,6 +348,9 @@ def aggregate_corrupt_metrics(metrics, dataset_name)] = cosine_similarity[i] fine_metrics_results['corrupt_diversity/average_kl_{}'.format( dataset_name)] = average_kl[i] + fine_metrics_results['corrupt_diversity/outputs_similarity_{}'.format( + dataset_name)] = outputs_similarity[i] + avg_nll = np.mean(nll) avg_accuracy = np.mean(acc) avg_ece = np.mean(ece) @@ -363,7 +371,7 @@ def aggregate_corrupt_metrics(metrics, results['test/member_ece_mean_corrupted'] += avg_member_ece if corrupt_diversity is not None: avg_diversity_metrics = [np.mean(disagreement), np.mean( - cosine_similarity), np.mean(average_kl)] + cosine_similarity), np.mean(average_kl), np.mean(outputs_similarity)] for key, avg in zip(diversity_keys, avg_diversity_metrics): results['corrupt_diversity/{}_mean_{}'.format( key, intensity)] = avg