Bayesian Inference for Large Scale Image Classification

by   Jonathan Heek, et al.

Bayesian inference promises to ground and improve the performance of deep neural networks. It promises to be robust to overfitting, to simplify the training procedure and the space of hyperparameters, and to provide a calibrated measure of uncertainty that can enhance decision making, agent exploration and prediction fairness. Markov Chain Monte Carlo (MCMC) methods enable Bayesian inference by generating samples from the posterior distribution over model parameters. Despite the theoretical advantages of Bayesian inference and the similarity between MCMC and optimization methods, the performance of sampling methods has so far lagged behind optimization methods for large scale deep learning tasks. We aim to fill this gap and introduce ATMC, an adaptive noise MCMC algorithm that estimates and is able to sample from the posterior of a neural network. ATMC dynamically adjusts the amount of momentum and noise applied to each parameter update in order to compensate for the use of stochastic gradients. We use a ResNet architecture without batch normalization to test ATMC on the Cifar10 benchmark and the large scale ImageNet benchmark and show that, despite the absence of batch normalization, ATMC outperforms a strong optimization baseline in terms of both classification accuracy and test log-likelihood. We show that ATMC is intrinsically robust to overfitting on the training data and that ATMC provides a better calibrated measure of uncertainty compared to the optimization baseline.


page 1

page 2

page 3

page 4


Challenges in Bayesian inference via Markov chain Monte Carlo for neural networks

Markov chain Monte Carlo (MCMC) methods and neural networks are instrume...

Deep learning-based prediction of kinetic parameters from myocardial perfusion MRI

The quantification of myocardial perfusion MRI has the potential to prov...

Impact of Parameter Sparsity on Stochastic Gradient MCMC Methods for Bayesian Deep Learning

Bayesian methods hold significant promise for improving the uncertainty ...

Deep Bayesian inference for seismic imaging with tasks

We propose to use techniques from Bayesian inference and deep neural net...

Sampling-Free Learning of Bayesian Quantized Neural Networks

Bayesian learning of model parameters in neural networks is important in...

Bayesian Coresets: An Optimization Perspective

Bayesian coresets have emerged as a promising approach for implementing ...

Bayesian inference based process design and uncertainty analysis of simulated moving bed chromatographic systems

Prominent features of simulated moving bed (SMB) chromatography processe...

1 Introduction

In contrast to optimization approaches in machine learning that derive a single estimate for the weights of a neural network, Bayesian inference aims at deriving a posterior distribution over the weights of the network. This makes it possible to sample model instances from the distribution over the weights and offers unique advantages. Multiple model instances can be aggregated to obtain robust uncertainty estimates over the network’s predictions; uncertainty estimates are crucial in domains such as medical diagnosis and autonomous driving where following a model’s incorrect predictions can result in catastrophe

kendall2017uncertainties . Sampling a distribution, as opposed to optimizing a loss, is less prone to overfitting and more training doesn’t decrease test performance. Bayesian inference can also be applied to differential privacy, where each individual sample has increased privacy guarantees privacy

, and to reinforcement learning, where one can leverage model uncertainty to balance between exploration and exploitation

thompson-sampling .

Markov Chain Monte Carlo (MCMC) methods are a standard class of methods for generating samples from the posterior distribution over model parameters. These methods are seldom applied in deep learning because they have traditionally failed to scale well with large datasets and many parameters big-mcmc . Stochastic Gradient MCMC (SG-MCMC) methods have fared somewhat better in scaling to large datasets due to their close relationship to stochastic optimization methods. For example the SGLD sampler sgld

amounts to performing stochastic gradient descent while adding Gaussian noise to each parameter update. Despite these improvements, samplers like SGLD are only guaranteed to converge to the correct distribution when the step size is annealed to zero; additional control variates have been developed to mitigate this to some extent

sgfs ; bayes-thermostat .

The objective of this work is to make Bayesian inference practical for deep learning by making SG-MCMC methods scale to large models and datasets. The contributions described in this work fall in three categories. We first propose the Adaptive Thermostat Monte Carlo (ATMC) sampler that offers improved convergence and stability. ATMC dynamically adjusts the amount of momentum and noise applied to each model parameter. Secondly, we improve an existing second order numerical integration method that is needed for the ATMC sampler. Third, since ATMC, like other SG-MCMC samplers, is not directly compatible with stochastic regularization methods as batch normalization (BatchNorm) and Dropout (see Sect. 4), we construct the ResNet++ network by taking the original ResNet architecture resnet , removing BatchNorm and introducing SELUs selu , Fixup initialization fixup and weight normalization weight-norm . We design ResNet++ so that its parameters are easy to sample from and the gradients are well-behaved even in the absence of BatchNorm.

We show that the ATMC sampler is able to outperform optimization methods in terms of accuracy, log-likelihood and uncertainty calibration in the following settings. First, when using the ResNet++ architecture for both the ATMC sampler and the optimization baseline, the ATMC sampler significantly outperforms the optimization baseline on both Cifar-10 and ImageNet. Secondly, when using the standard ResNet for the optimization baseline and the ResNet++ for the ATMC sampler, multiple samples of the ATMC that approximate the predictive posterior of the model are still able to outperform the optimization baseline on ImageNet. Using the ResNet++ architecture, the ATMC sampler reduces the need for hyper-parameter tuning since it does not require early stopping, does not use stochastic regularization, is not prone to over-fitting on the training data and avoids a carefully tuned learning rate decay schedule.

2 ATMC Sampler

1:procedure atmc_training()
4:     while  do
Algorithm 1 The ATMC sampler. The algorithm accepts the initialized model parameters , step size , pre-conditioner , and momentum noise .

In this section we define the Stochastic Differential Equation (SDE) that gives rise to the ATMC sampler described in Algorithm 1. A detailed background and framework for constructing SDEs that converge to a target distribution can be found in mcmc-recipe .

2.1 General form of the SDE

Our starting point for constructing the ATMC sampler is the framework of Stochastic Differential Equations. We are interested in SDEs that converge to a distribution for which we can evaluate . Because only the gradient of is required, it is sufficient to define an energy function up to a constant . As a consequence, we can sample from the posterior distribution by only evaluating the energy function gradient . The general form of SDEs converging to for which only the gradient of is required is as follows mcmc-recipe :


where is a positive-definite matrix that determines the amount of noise,

is a skew-symmetric matrix that mixes energy between variables, and

is a correction factor that compensates for dynamics that depend on the current state . The ATMC sampler that we propose is an instance of (1) for specific definitions of , , and .

2.2 Energy Function

We start by defining the energy function . The energy function for the model posterior

is defined by the loss function

. Because the dataset is generally large, we would like to only evaluate a mini-batch loss . However, naively using a stochastic gradient in (1) will result in significant bias sghmc

. Motivated by the Central Limit Theorem, the stochastic gradient is assumed to follow a Gaussian distribution

where the covariance is additionally assumed to be diagonal and constant w.r.t. . The energy function for the ATMC sampler is defined as:


where is the momentum, defines the momentum distribution, and is a control variate referred to as the temperature. Both and have the same dimensionality as . The distribution of the control variate depends on the amount of noise in the stochastic gradient estimate .

2.3 Noise robust dynamics

Next we define the dynamics and such that the SDE that results from (1) can be simulated without the need to evaluate :


where is a non-negative function that determines how the temperature affects the amount of noise added to the momentum update.

We first illustrate the resulting SDE by using a simpler Gaussian momentum distribution (where is a hyper-parameter). We substitute the dynamics and defined in (3) and energy function defined in (2) into (1):


where we use to replace the gradient of the loss with the mini-batch estimate. The momentum is dampened by a friction term that depends on the choice of . The stochastic gradient noise does not show up in (4) due to the particular choice of energy function and dynamics . Note however this analysis relies on the assumption that the covariance of the stochastic gradient noise is constant in and a single temperature variable per parameter can only correct for a diagonal covariance . We do not expect that this assumption will hold in practise and the approximation will therefore lead to bias in the samples. However, annealing the step size will reduce the error due to mini-batching together with other sources of discretization error sgld .

2.4 Adaptive Noise Thermostat

Finally, we must choose a function which controls the amount of noise and momentum damping . Previous work uses the Nosé-Hoover thermostat that is defined by where is a constant determining the amount of noise added to the momentum update bayes-thermostat . Although the Nosé-Hoover thermostat is able to correct the stochastic gradient noise , the correction comes at the cost of slower convergence because additional friction is applied as increases. Another drawback of the Nosé-Hoover thermostat is that it causes negative friction when . In the negative friction phase , previous gradient terms are amplified rather than dampened. Although this behavior is mathematically sound we find that it can cause exploding momentum variables.

Our choice of is based on the idea that negative friction should not occur and convergence speed should not be reduced by the stochastic gradient noise. Based on this intuition, we define the ATMC sampler by . The ATMC sampler is best characterized by the various temperature stages. For the total amount of noise added to the momentum is and the friction coefficient . At this stage, the stochastic gradient noise is compensated for by adding less noise to the momentum update. If the dominant stage will be resulting in and zero noise being added to the momentum. Finally, when the friction coefficient and the noise added to the momentum is proportional to . Thus, the momentum always experiences a minimum amount of friction determined by the hyper-parameter and the noise added to the momentum update is automatically adjusted based on the amount of noise present in the stochastic gradients.

2.5 Momentum energy function

Following rhmc , we generalize the momentum energy function to the symmetric hyperbolic distribution which is defined as follows rhmc :


where and are hyper-parameters. The Gaussian kinetic energy is a special case obtained by taking the limit . The magnitude of parameter updates is determined by the gradient of the momentum:


Hence, the hyperbolic distribution results in relativistic momentum dynamics where the parameter updates are upper bounded by and the pre-conditioner depends on . The average update magnitude for . Consequently, the parameters and are interpretable hyper-parameters controlling the average and maximum parameter update per step together with the step size .

The SDE we derive in (4) and integrate in Sec. 3 uses a Gaussian momentum energy function for clarity. Deriving ATMC with a different momentum distribution like the hyperbolic distribution amounts to substituting (2), (3), and the alternative momentum distribution into (1). For the hyperbolic distribution, the dynamic friction coefficient will also depend on . For the numerical integration of (4) with a hyperbolic momentum distribution we assume to be constant in .

3 Improved numerical integrator for MCMC samplers

In this section we construct the numerical integrator required to numerically approximate the ATMC sampler defined in (4). An efficient numerical integrator can be constructed by splitting the SDE into two terms:


Hence, we obtain a linear ODE in part (A) that updates the parameters and the thermostats and a linear SDE in part (B) that updates the momentum . The operators that simulate these dynamics exactly for a time step are denoted and , respectively. Using the Strang splitting scheme yields a second order method split-integrator :


The first operator is given by


The second operator is an instance of the Ornstein–Uhlenbeck process which can also be computed exactly as follows:


Previous work split-integrator on higher order integrators for samplers splits the SDE into three parts where the third term is obtained from separating the friction term from the other terms in the momentum update . By integrating (10) exactly the gradient step and the noise and gradient term are directly affected by the friction. An exact momentum update provides additional robustness to large gradients because the temperature will increase in order to compensate for momentum updates that would lead to excessively large steps. Another advantage of a two-way split integrator is that the first and last steps in (8) can be fused together such that only a momentum update is performed per iteration. Algorithm 1 shows the pseudocode for the ATMC sampler with the split integrator defined in (9) and (10).

4 The ResNet++ Architecture

Figure 1: Residual blocks in respectively the ResNet and ResNet++ architectures.

The generalization performance of large neural nets trained using optimization depend on stochastic regularization methods like Dropout dropout and BatchNorm batchnorm . These methods implicitly add noise into the model parameters dropout-vi ; batchnorm-vi

and significantly boost training performance and generalization for image classifiers. These methods can be interpreted as a coarse approximation of Bayesian Inference

dropout-vi ; batchnorm-vi . But a stochastic gradient sampler like ATMC already adds the necessary amount of noise and combined with BatchNorm or Dropout it leads to underfitting. We thus define a BatchNorm free version of ResNet called ResNet++ that includes SELUs selu , Fixup initialization fixup and weight normalization weight-norm (see Fig. 1). We use ATMC to fill the significant gap in performance due to the absence of BatchNorm in ResNet++.

Figure 2: Calibration plot for Cifar10 Figure 3: Calibration plot for ImageNet

4.1 Selu

We find the SELU activation to work well in BatchNorm free networks. SELU forces the statistics of the activations towards zero mean and unit variance


. The SELU activation function additionally has a non-zero gradient everywhere which could improve the mixing of the sampler by providing a more informative gradient.

4.2 Fixup initialization

ResNets are known to scale well with depth resnet

. However, the additive effect of the residual branch causes the magnitudes of the activations to increase with the number of residual connections. Fixup is a recently proposed initialization method that mitigates the exploding residual branch problem without using BatchNorm

fixup . We use a simplified version of Fixup by initializing the scales of the final layer in each residual branch to a small constant.

4.3 Weight normalization

We use weight normalization weight-norm

to separate the direction and scale of each linear feature vector


where is the direction vector and is the magnitude of a feature vector . Weight normalization does not depend on batch statistics and is compatible with MCMC methods .

The scale of the direction vector does not affect the outputs of the model. It does however affect the effective step size wngrad . Therefore the prior on the direction vector is chosen such that it is forced to unit length


The prior on the scales is problem specific and can for example be chosen to encode a preference for structurally sparse models.

5 Experiments

The experiments presented here aim to demonstrate that the ATMC sampler is competitive with a well-tuned optimization baseline for large-scale datasets and models. We use the TensorFlow official implementation of ResNet-56 and ResNet-50 on Cifar10 and ImageNet, respectively. We compare our ATMC sampler to an optimization baseline with and without BatchNorm. For the optimization baseline without BatchNorm we use the ResNet++ architecture as described in Sec. 


. For the baseline with BatchNorm we found standard ResNet with Xavier initialization and the ReLU non-linearity to work better.

For the ATMC sampler we report both the performance of a single sample and the estimated posterior predictive based on a finite number of samples. Similar to earlier work

mcmc-cycle we found that many fewer samples are needed when a cyclic step size with cycle length is used. The final sample in each cycle is used to estimate the posterior predictive.

For ResNet++ we further use a group Laplace prior with to regularize the scales of each linear feature in ResNet++. The momentum noise is chosen as such that the friction applied to the momentum is at least .

5.1 Cifar 10

Setup Top 1 acc. [%] NLL [Nats]
SGD + BatchNorm 94.4
ATMC (single sample)
ATMC (Posterior predictive) 0.194
SGNHT (single sample)
SGNHT (Posterior predictive)
Table 1: Performance on Cifar10 with ResNet-56 model. The posterior predictive is estimate using a sample of the posterior parameters at the end of each learning rate cycle.

For Cifar10 we choose the step size

and the cycle length is set to 50 epochs. The momentum hyper-parameters are

and such that the average speed and maximum speed per step are and , respectively. The number of convolution filters is doubled to compared to the original ResNet-56 implementation. We use a single V100 GPU with a batch size of . The sampler runs for epochs and we start collecting samples for the posterior predictive after epochs. The optimization baseline converges in epochs. We also report the results of sampling with a sampler based Nosé-Hoover thermostats (SGNHT) bayes-thermostat ; rhmc applied to the ResNet++ architecture.

Table 1 lists the test set performance for Cifar10. A single sample from the posterior already outperforms the baseline without BatchNorm by a significant margin in both test accuracy and log-likelihood. Using BatchNorm significantly improves the generalization of the optimization baseline. It outperforms the estimate of the posterior predictive in accuracy yet it does not have a better test log-likelihood.

To further analyze the quality of the uncertainty estimates, we group each model’s prediction in 8 equally sized bins based on the confidence where

is the maximum probability class for example

. If the probabilities are well-calibrated, the average confidence should be close to the average accuracy. Figure 2 shows the calibration of the uncertainty estimates for the posterior predictive and optimization baselines. The posterior predictive is calibrated for the least confident predictions and shows less bias towards overconfidence compared to the models trained with SGD.

5.2 ImageNet

Setup Top 1 acc. [%] NLL [Nats]
SGD + BatchNorm
ATMC (single sample)
ATMC (Posterior predictive) 77.5 0.883
SGNHT (single sample)
SGNHT (Posterior predictive)
Table 2: Performance on ImageNet with ResNet-50 model. The posterior predictive is estimate using a sample of the posterior parameters at the end of each learning rate cycle.

For the ImageNet experiments we use an initial step size and a cycle length of 20 epochs. The other hyper-parameters for the sampler are the same as for the Cifar10 experiments. We use a a single Google Cloud TPUv3 with a batch size of 1024. We did not observe a significant difference in wall clock time per epoch for SGD and ATMC. BatchNorm did result in an overhead of roughly compared to the ResNet++ model. Samples for the posterior predictive are collected after epochs and the sampler runs for epochs. The optimization baseline converges in epochs.

Table 2 lists the results for ImageNet classification. A single sample from the posterior outperforms the optimization baseline without BatchNorm. The posterior predictive based on ATMC outperforms the optimizer with BatchNorm by a wide margin in both accuracy and test log-likelihood. We note that the sampler runs significantly longer (x) compared to the optimization baseline because it takes a long time for the posterior predictive estimate to converge. However, the posterior predictive of ATMC matches the accuracy of the optimization baseline with BatchNorm () after epochs.

Figure 3 shows the quality of the uncertainty for various levels of confidence. Again, the ATMC based posterior predictive produces much better calibrated predictions and is almost perfectly calibrated for low confidence predictions and shows less bias towards overconfidence compared to the optimization baseline.

6 Discussion

The empirical results show it is possible to sample the posterior distribution of neural networks on large scale image classification problems like ImageNet. A major obstacle for sampling the posterior of ResNets in particular is the lack of compatibility with BatchNorm. Using recent advances in initialization and the SELU activation function we are able to stabilize and speed up training of ResNets without resorting to BatchNorm. Nonetheless, we observe that BatchNorm still offers a unique advantage in terms of generalization performance. We hope that future work will allow the implicit inductive bias that BatchNorm has to be transferred into an explicit prior that is compatible with sampling methods.

Multiple posterior samples provide a much more accurate estimate of the posterior predictive, and consequently much better accuracy and uncertainty estimates. For inference, making predictions using a large ensemble of models sampled from the posterior can be costly. Variational Inference methods can be used to quickly characterize a local mode of the posterior bayes-by-backprop . More recent work shows that a running estimate of the mean and variance of the parameters during training can also be used to approximate a mode of the posterior swag . Methods like distillation could potentially be used to compress a high-quality ensemble into a single network with a limited computational budget distilation .

Although the form in (4) is very general, alternative methods for dealing with stochastic gradients have been proposed in the literature. One approach is to estimate the covariance of the stochastic gradient noise explicitly and use it correct and pre-condition the sampling dynamics sgfs ; sgfs-adam .

Other sampling methods are not based on an SDE that converges to the target distribution. Under some conditions stochastic optimization methods can be interpreted as such a biased sampling method sgd-bi . Predictions based on multiple samples from the trajectory of SGD have been used successfully for obtaining uncertainty estimates in large scale Deep Learning swag . However, these methods rely on tuning hyper-parameters in such a way that just the right amount of noise is inserted.

7 Conclusion

This work introduces the ATMC sampler, a robust posterior sampling method that scales to large deep learning problems. To the best of our knowledge, we are the first to successfully train neural networks using MCMC on ImageNet. In a BatchNorm free setting, a single sample from the posterior generated by ATMC outperforms the optimization baseline. A posterior predictive estimate outperforms the optimization baseline with BatchNorm on ImageNet. Based on these empirical results we hope the ATMC sampler will enable new applications of Bayesian inference in deep learning.


We would like to thank Jascha Sohl-dickstein and Sebastian Nowozin for helpful feedback. In particular we wish to thank Tim Salimans for his feedback and insightful discussions on MCMC methods.