The goal of this library is to provide an easy-to-use interface for measuring uncertainty across Google and the open-source community.
Machine learning models often produce incorrect (over or under confident) probabilities. In real-world decision making systems, classification models must not only be accurate, but also should indicate when they are likely to be incorrect. For example, one important property is calibration: the idea that a model's predicted probabilities of outcomes reflect true probabilities of those outcomes. Intuitively, for class predictions, calibration means that if a model assigns a class with 90% probability, that class should appear 90% of the time.
pip install uncertainty_metrics
To install the latest development version, run
pip install "git+https://github.com/google/uncertainty_metrics.git#egg=uncertainty_metrics"
There is not yet a stable version (nor an official release of this library). All APIs are subject to change.
Here are some examples to get you started.
Expected Calibration Error.
import uncertainty_metrics.numpy as um
probabilities = ...
labels = ...
ece = um.ece(labels, probabilities, num_bins=30)
Reliability Diagram.
import uncertainty_metrics.numpy as um
probabilities = ...
labels = ...
diagram = um.reliability_diagram(labels, probabilities)
Brier Score.
import uncertainty_metrics as um
tf_probabilities = ...
labels = ...
bs = um.brier_score(labels=labels, probabilities=tf_probabilities)
How to diagnose miscalibration. Calibration is one of the most important properties of a trained model beyond accuracy. We demonsrate how to calculate calibration measure and diagnose miscalibration with the help of this library. One typical measure of calibration is Expected Calibration Error (ECE) (Guo et al., 2017). To calculate ECE, we group predictions into M bins (M=15 in our example) according to their confidence, which in ECE is the value of the max softmax output, and compute the accuracy in each bin. Let B_m be the set of examples whose predicted confidence falls into the m th interval. The Acc and the Conf of bin B_m is
ECE is defined to be the sum of the absolute value of the difference of Acc and Conf in each bin. Thus, we can see that ECE is designed to measure the alignment between accuracy and confidence. This provides a quantitative way to measure calibration. The better calibration leads to lower ECE.
In this example, we also need to introduce mixup (Zhang et al., 2017). It is a data-augmentation technique in image classification, which improves both accuracy and calibration in single model. Mixup applies the following only in the training,
We focus on the calibration (measured by ECE) of Mixup + BatchEnsemble (Wen et al., 2020). We first calculate the ECE of some fully trained models using this library.
import tensorflow as tf
import uncertainty_metrics.numpy as um
# Load and preprocess a dataset. Also load the model.
test_images, test_labels = ...
model = ...
# Obtain predictive probabilities.
probs = model(test_images, training=False) # probs is of shape [4, testset_size, num_classes] if the model is an ensemble of 4 individual models.
ensemble_probs = tf.reduce_mean(model, axis=0)
# Calculate individual calibration error.
individual_eces = []
for i in range(ensemble_size):
individual_eces.append(um.ece(labels, probs[i], num_bins=15))
ensemble_ece = um.ece(labels, ensemble_probs, num_bins=15)
We collect the ECE in the following table.
Method/Metric | CIFAR-10 | CIFAR-100 | |||
---|---|---|---|---|---|
Acc | ECE | Acc | ECE | ||
BatchEnsemble | In | 95.88 | 2.3% | 80.64 | 8.7% |
En | 96.22 | 1.8% | 81.85 | 2.8% | |
Mixup0.2 BE | In | 96.43 | 0.8% | 81.44 | 1.5% |
En | 96.75 | 1.5% | 82.79 | 3.9% | |
Mixup1 BE | In | 96.67 | 5.5% | 81.32 | 6.6% |
En | 96.98 | 6.4% | 83.12 | 9.7% |
In the above table, In stands for individual model; En stands for ensemble models. Mixup0.2 stands for small mixup augmentation while mixup1 stands for strong mixup augmentation. Ensemble typically improves both accuracy and calibration, but this does not apply to mixup. Scalars obsure useful information, so we try to understand more insights by examining the per-bin result.
ensemble_metric = um.GeneralCalibrationError(
num_bins=15,
binning_scheme='even',
class_conditional=False,
max_prob=True,
norm='l1')
ensemble_metric.update_state(labels, ensemble_probs)
individual_metric = um.GeneralCalibrationError(
num_bins=15,
binning_scheme='even',
class_conditional=False,
max_prob=True,
norm='l1')
for i in range(4)
individual_metric.update_state(labels, probs[i])
ensemble_reliability = ensemble_metric.accuracies - ensemble_metric.confidences
individual_reliability = (
individual_metric.accuracies - individual_metric.confidences)
Now we can plot the reliability diagram which demonstrates more details of calibration. The backbone model in the following figure is BatchEnsemble with ensemble size 4. The plot has 6 lines: we trained three independent BatchEnsemble models with large, small, and no Mixup; and for each model, we compute the calibration of both ensemble and individual predictions. The plot shows that only Mixup models have positive (Acc - Conf) values on the test set, which suggests that Mixup encourages underconfidence. Mixup ensemble's positive value is also greater than Mixup individual's. This suggests that Mixup ensembles compound in encouraging underconfidence, leading to worse calibration than when not ensembling. Therefore, we successfully find the reason why Mixup+Ensemble leads to worse calibration, by leveraging this library.
Uncertainty Metrics provides several types of measures of probabilistic error:
- Calibration error
- Proper scoring rules
- Information critera
- Diversity
- AUC/Rejection
- Visualization tools
We outline each type below.
Calibration refers to a frequentist property of probabilistic predictions being correct on average. Intuitively, when predicting a binary outcome using a model, if we group all predictions where the outcome is believed to be 80 percent probable, then within this group we should, on average, see this outcome become true 80 percent of the time. A model with this property is said to be well-calibrated.
Formally, calibration can best be defined in terms of measuring the difference between a conditional predictive distribution and the true conditional distribution, where
the conditioning is done with respect to a set defined as a function of the prediction.
For
Then, the true conditional distribution
$$q(y|\gamma) = \frac{\mathbb{E}{x \sim Q}[\mathbb{1}{{f(x)=\gamma}} , q(y|x)]}{Q(\mathcal{X}_{\gamma})}.$$
The model predictive conditional distribution
$$p(y|\gamma) = \frac{\mathbb{E}{x \sim Q}[\mathbb{1}{{f(x)=\gamma}} , p(y|x)]}{Q(\mathcal{X}_{\gamma})}.$$
The reliability of a probabilistic prediction system is now defined as an expected difference between the two quantities,
A reliability of zero means that the model is perfectly calibrated: the predictions are on average correct. In practice the reliability needs to be estimated: none of the expectations are available analytically. Most importantly, to estimate the true conditional distribution requires discretization of the set of predictions.
We support the following calibration metrics:
- Expected Calibration Error [3]
- Root-Mean-Squared Calibration Error [14]
- Static Calibration Error [2]
- Adaptive Calibration Error / Thresholded Adaptive Calibration Error [2]
- General Calibration Error (a space of calibration metrics) [2]
- Class-conditional / Class-conflationary versions of all of the above. [2]
- Bayesian Expected Calibration Error
- Semiparametric Calibration Error
We describe examples below.
Example: Expected Calibration Error. The expected calibration error (ECE) is a scalar summary statistic between zero (perfect calibration) and one (fully miscalibrated). It is widely reported in the literature as a standard summary statistic for classification.
ECE works by binning the probability of the decision label into a
fixed number of num_bins
bins. Typically num_bins
is chosen to be 5 or 10.
For example, a binning into 5 bins would yield a partition,
and all counting and frequency estimation of probabilities is done using this
binning.
For num_bins
, and a total of
where
and
To compute the ECE metric you can pass in the decision labels labels_predicted
keyword argument.
In case you do not use labels_predicted
,
the argmax label will be automatically inferred from the logits
. (Therefore,
in the following code we could remove labels_predicted
.)
features, labels = ... # get from minibatch
probs = model(features)
ece = um.ece(labels=labels, probs=probs, num_bins=10)
Example: Bayesian Expected Calibration Error. ECE is a scalar summary statistic of miscalibration evaluated on a finite sample of validation data. Resulting in a single scalar, the sampling variation due to the limited amount of validation data is hidden, and this can result in significant over- or under-estimation of the ECE as well as wrongly concluding significant differences in ECE between multiple models.
To address these issues, a Bayesian estimator of the ECE can be used. The resulting estimate is not a single scalar summary but instead a probability distribution over possible ECE values.
The generative model is given by the following generative mechanism.
The distribution
The per-bin means are generated by a truncated Normal distribution,
For inference, we are given a list of
-
$$\hat{y}_i$$ is the decision label of the prediction system, typically being the$$\textrm{argmax}$$ over the predicted probabilities; -
$$y_i$$ is the true observed label; and -
$$p_i$$ is the probability$$p(\hat{y}_i|x_i)$$ provided by the prediction system.
Each probability
The posterior
The first factor
Given the posterior
Using the probability distribution of ECE values we can make statistical conclusions. For example, we can assess whether the ECE really is significantly different between two or more models by reasoning about the overlap of the ECE distributions of each model, using for example a Wilcoxon signed rank test.
The Bayesian ECE can be used like the normal ECE, as in the following code:
# labels_true is a tf.int32 Tensor
logits = model(validation_data)
ece_samples = um.bayesian_expected_calibration_error(
10, logits=logits, labels_true=labels_true)
ece_quantiles = tfp.stats.percentile(ece_samples, [10,50,90])
The above code also includes an example of using the samples to infer 10%/50%/90% quantiles of the distribution of possible ECE values.
Proper scoring rules are loss functions for probabilistic predictions.
Formally, a proper scoring rule is a function which assign a numerical score
Example: Brier score.
The Brier score for discrete labels is defined as follows: given a predicted
probability vector
Here is an example of how to use the Brier score as loss function for a classifier. Suppose you have a classifier implemented in Tensorflow and your current training code looks like
per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=target_labels, logits=logits)
loss = tf.reduce_mean(per_example_loss)
Then you can alternatively use the API-compatible Brier loss as follows:
per_example_loss = um.brier_score(labels=target_labels, logits=logits)
loss = tf.reduce_mean(per_example_loss)
The Brier score penalizes low-probability predictions which do occur less than the cross-entropy loss.
Example: Brier score's decomposition. Here is an example of how to compute calibration metrics for a classifier. Suppose you evaluate the accuracy of your classifier on the validation set,
logits = model(validation_data)
class_prediction = tf.argmax(logits, 1)
accuracy = tf.metrics.accuracy(validation_labels, class_prediction)
You can compute additional metrics using the so called Brier decomposition that quantify prediction uncertainty, resolution, and reliability by appending the following code,
uncert, resol, reliab = um.brier_decomposition(labels=labels, logits=logits)
In particular, the reliability (reliab
in the above line) is a measure of
calibration error, with a value between 0 and 2, where zero means perfect
calibration.
Example: Continuous Ranked Probability Score (CRPS). The continuous ranked probability score (CRPS) has several equivalent definitions.
Definition 1: CRPS measures the integrated squared difference between an
arbitrary cummulative distribution function
Definition 2: CRPS measures the expected distance to the realization minus one half the expected distance between samples,
$$S_{\textrm{CRPS}}(F,y) = \mathbb{E}_{z \sim F}[|z-y|]
- \frac{1}{2} \mathbb{E}_{z,z' \sim F}[|z-z'|].$$
CRPS has two desirable properties:
- It generalizes the absolute error loss and recovers the absolute error if a predicted distribution
$$F$$ is deterministic. - It is reported in the same units as the predicted quantity.
To compute CRPS we either need to make an assumption regarding the form of
For a regression model which predicts Normal distributions with mean means
and
standard deviation stddevs
, we can compute CRPS as follows:
squared_errors = tf.square(target_labels - pred_means)
per_example_crps = um.crps_normal_score(
labels=target_labels,
means=pred_mean,
stddevs=pred_stddevs)
For non-Normal models, as long as we can sample predictions, we can construct a
Tensor predictive_samples
of size (ninstances, npredictive_samples)
and
evaluate the Monte Carlo CRPS against the true targets target_labels
using the
following code,
per_example_crps = um.crps_score(
labels=target_labels,
predictive_samples=predictive_samples)
Information criteria are used after or during model training to estimate the predictive performance on future holdout data. They can be useful for selecting among multiple possible models or to perform hyperparameter optimization. There are also strong connections between cross validation estimates and some information criteria.
We estimate information criteria using log-likelihood values on training samples. In particular, for both the WAIC and the ISCV criteria we assume that we have an ensemble of models with equal weights, such that the average predictive distribution over the ensemble is a good approximation to the true Bayesian posterior predictive distribution.
For a set of [i,j]
element contains
j
'th ensemble element. For example, the
If the
Example: Mutual information.
Mutual information is a way to measure the spread or disagreement of an
ensemble
Formally, given an ensemble model $${p(y|x^,\theta^{(m)})}_{m=1}^M$$ trained on a finite dataset $$D$$, model uncertainty for a test input $$x^$$ is defined as:
$$ MI[y, \theta|x^, D] = \mathcal{H}(\mathbb{E}_{p(\theta|D)}[p(y|x^,\theta)]) - \mathbb{E}_{p(\theta|D)}[\mathcal{H}(p(u|x^*, \theta))] $$
The total uncertainty will be high whenever the model is uncertain - both in regions of severe class overlap and out-of-domain. However, model uncertainty, the difference between total and expected data uncertainty, will be non-zero iff the ensemble disagrees.
Knowledge uncertainty estimates the mutual information between between the
categorical output
$$ MI[y, \pi|x^, \hat{\theta}] = \mathcal{H}(\mathbb{E}_{p(\pi| x^, \hat{\theta})}[p(y|\pi)]) - \mathbb{E}_{p(\pi|x^*,\hat{\theta})}[\mathcal{H}(p(y|\pi))] $$
Model uncertainty measures the mutual information between the categorical output
logits = model(validation_data)
model_uncert, total_uncert, avg_data_uncert = um.model_uncertainty(logits)
Example: Widely/Watanabe Applicable Information Criterion (WAIC). The negative WAIC criterion estimates log-likelihood on future observables. It is given as
$$\textrm{nWAIC}1 := \frac{1}{n} \sum{i=1}^n \left( \log \frac{1}{m} \sum_{j=1}^m p(y_i|x_i,\theta_j)
- \hat{\mathbb{V}}_i \right),$$
where $$\hat{\mathbb{V}}i := \frac{1}{m-1} \sum{j=1}^m \left( \log p(y_i|x_i,\theta_j) - \frac{1}{m} \sum_{k=1}^m \log p(y_i|x_i,\theta_k)\right)^2$$.
An alternative form of the WAIC, called type-2 WAIC, is computed as
$$\textrm{nWAIC}2 := \frac{1}{n} \sum{i=1}^n \left( \frac{2}{m} \sum_{j=1}^m \log p(y_i|x_i,\theta_j)
- \log \frac{1}{m} \sum_{j=1}^m p(y_i|x_i,\theta_j) \right).$$
Both nWAIC criteria have comparable properties, but Watanabe recommends
# logp has shape (n,m), n instances, m ensemble members
neg_waic, neg_waic_sem = um.negative_waic(logp, waic_type="waic1")
You can select the type of nWAIC to estimate using waic_type="waic1"
or
waic_type="waic2"
. The method returns the scalar estimate as well as the
standard error of the mean of the estimate.
Example: Importance Sampling Cross Validation Criterion (ISCV).
Like the negative WAIC, the ISCV criterion estimates the holdout log-likelihood on future observables using the training data
We can estimate the ISCV using the following code:
# logp has shape (n,m), n instances, m ensemble members
iscv, iscv_sem = um.importance_sampling_cross_validation(logp)
- Add the paper reference to the
References
section below. - Add the metric definition to the numpy/ dir for a numpy based metric or to the tensorflow/ dir for a tensorflow based metric.s
- Add the metric class or function to the corresponding init.py file.
- Add a test that at a minimum implements the metric using 'import uncertainty_metrics as um' and um.your metric and checks that the value is in the appropriate range.
[1] Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. (2017, August). On calibration of modern neural networks. In Proceedings of the 34th International Conference on Machine Learning. Paper Link.
[2] Nixon, J., Dusenberry, M. W., Zhang, L., Jerfel, G., & Tran, D. (2019). Measuring Calibration in Deep Learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops (pp. 38-41). Paper Link.
[3] Naeini, Mahdi Pakdaman, Gregory Cooper, and Milos Hauskrecht. "Obtaining well calibrated probabilities using bayesian binning." Twenty-Ninth AAAI Conference on Artificial Intelligence. 2015. Paper Link.
[4] Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht. "Binary classifier calibration using a Bayesian non-parametric approach." Proceedings of the 2015 SIAM International Conference on Data Mining. Society for Industrial and Applied Mathematics, 2015. Paper Link.
[5] J. Platt. Probabilistic outputs for support vector machines and comparisons to regularized likelihood methods. Advances in Large Margin Classifiers, 10(3):61–74, 1999. Paper Link.
[6] Kumar, A., Liang, P. S., & Ma, T. (2019). Verified uncertainty calibration. In Advances in Neural Information Processing Systems (pp. 3787-3798). Paper Link.
[7] Kumar, A., Sarawagi, S., & Jain, U. (2018, July). Trainable calibration measures for neural networks from kernel mean embeddings. In International Conference on Machine Learning (pp. 2805-2814). Paper Link. [8] Calibrating Neural Networks Documentation. Link.
[9] Zadrozny, Bianca, and Charles Elkan. "Transforming classifier scores into accurate multiclass probability estimates." Proceedings of the eighth ACM SIGKDD international conference on Knowledge discovery and data mining. 2002. Paper Link.
[10] Müller, Rafael, Simon Kornblith, and Geoffrey E. Hinton. "When does label smoothing help?." Advances in Neural Information Processing Systems. 2019. Paper Link.
[11] Pereyra, Gabriel, et al. "Regularizing neural networks by penalizing confident output distributions." arXiv preprint arXiv:1701.06548 (2017). Paper Link.
[12] Lakshminarayanan, B., Pritzel, A., and Blundell, C. Simple and scalable predictive uncertainty estimation using deep ensembles. In NIPS, pp. 6405–6416. 2017. Paper Link.
[13] Louizos, C. and Welling, M. Multiplicative normalizing flows for variational Bayesian neural networks. In ICML, volume 70, pp. 2218–2227, 2017. Paper Link.
[14] Hendrycks, D., Mu, N., Cubuk, E. D., Zoph, B., Gilmer, J., & Lakshminarayanan, B. (2019). AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty. Paper Link.
[15] Jochen Brocker, "Reliability, sufficiency, and the decomposition of proper scores", Quarterly Journal of the Royal Meteorological Society, 2009. (PDF)
[16] Stefan Depeweg, José Miguel Hernández-Lobato, Finale Doshi-Velez, and Steffen Udluft, "Decomposition of uncertainty for active learning and reliable reinforcement learning in stochastic systems", stat 1050, 2017. (PDF)
[17] Alan E. Gelfand, Dipak K. Dey, and Hong Chang. "Model determination using predictive distributions with implementation via sampling-based methods", Technical report No. 462, Department of Statistics, Stanford university, 1992. [(PDF)](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.860.3702&rep=re p1&type=pdf)
[18] Tilmann Gneiting and Adrian E. Raftery, "Strictly Proper Scoring Rules, Prediction, and Estimation", Journal of the American Statistical Association (JASA), 2007. [(PDF)](https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pd f)
[19] Andrey Malinin, Bruno Mlodozeniec and Mark Gales, "Ensemble Distribution Distillation.", arXiv:1905.00076, 2019. (PDF)
[20] Aki Vehtari, Andrew Gelman, and Jonah Gabry. "Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC", arXiv:1507.04544, (PDF)
[21] Sumio Watanabe, "Mathematical Theory of Bayesian Statistics", CRC Press, 2018.