jaxvariationaldiffwave
DiffWave with variaitional diffusion models
view repo
Diffusionbased generative models have demonstrated a capacity for perceptually impressive synthesis, but can they also be great likelihoodbased models? We answer this in the affirmative, and introduce a family of diffusionbased generative models that obtain stateoftheart likelihoods on standard image density estimation benchmarks. Unlike other diffusionbased models, our method allows for efficient optimization of the noise schedule jointly with the rest of the model. We show that the variational lower bound (VLB) simplifies to a remarkably short expression in terms of the signaltonoise ratio of the diffused data, thereby improving our theoretical understanding of this model class. Using this insight, we prove an equivalence between several models proposed in the literature. In addition, we show that the continuoustime VLB is invariant to the noise schedule, except for the signaltonoise ratio at its endpoints. This enables us to learn a noise schedule that minimizes the variance of the resulting VLB estimator, leading to faster optimization. Combining these advances with architectural improvements, we obtain stateoftheart likelihoods on image density estimation benchmarks, outperforming autoregressive models that have dominated these benchmarks for many years, with often significantly faster optimization. In addition, we show how to turn the model into a bitsback compression scheme, and demonstrate lossless compression rates close to the theoretical optimum.
READ FULL TEXT VIEW PDFDiffWave with variaitional diffusion models
None
Likelihoodbased generative modeling is a central task in machine learning that is the basis for a wide range of applications ranging from speech synthesis
(Oord et al., 2016), to translation (Sutskever et al., 2014), to compression (MacKay, 2003), to many others. Autoregressive models have long been the dominant model class on this task due to their tractable likelihood and expressivity, as shown in Figure 1. Diffusion models have recently shown impressive results in image (Ho et al., 2020; Song et al., 2021b; Nichol and Dhariwal, 2021) and audio generation (Kong et al., 2020; Chen et al., 2020) in terms of perceptual quality, but have yet to match autoregressive models on density estimation benchmarks. In this paper we make several technical contributions that allow diffusion models to challenge the dominance of autoregressive models in this domain.Our main contributions are as follows:
We introduce a flexible family of diffusionbased generative models that achieve new stateoftheart loglikelihoods on standard image density estimation benchmarks (CIFAR10 and ImageNet). This is enabled by incorporating Fourier features into the diffusion model and using a learnable specification of the diffusion process, among other modeling contributions.
We improve our theoretical understanding of density modeling using diffusion models by analyzing their variational lower bound (VLB), deriving a remarkably simple expression in terms of the signaltonoise ratio of the diffusion process. This result delivers new insight into the model class: for the continuoustime (infinitedepth) setting we prove a novel invariance of the generative model and its VLB to the specification of the diffusion process, and we show that various diffusion models from the literature are equivalent up to a trivial timedependent rescaling of the data.
Our work builds on diffusion probabilistic models (DPMs) (SohlDickstein et al., 2015), or diffusion models
in short. DPMs can be viewed as a type of variational autoencoder (VAE)
(Kingma and Welling, 2013; Rezende et al., 2014), whose structure and loss function allows for efficient training of arbitrarily deep models. Interest in diffusion models has recently reignited due to their impressive image generation results
(Ho et al., 2020; Song and Ermon, 2020).Ho et al. (2020) introduced a number of model innovations to the original DPM, with impressive results on image generation quality benchmarks. They showed that the VLB objective, for a diffusion model with discrete time and diffusion variances shared across input dimensions, is equivalent to multiscale denoising score matching, up to particular weightings per noise scale. Further improvements were proposed by Nichol and Dhariwal (2021), resulting in better loglikelihood scores.
Song and Ermon (2019) first proposed learning generative models through a multiscale denoising score matching objective, with improved methods in Song and Ermon (2020). This was later extended to continuoustime diffusion with novel sampling algorithms based on reversing the diffusion process (Song et al., 2021b).
Concurrent to our work, Song et al. (2021a), Huang et al. (2021), and Vahdat et al. (2021) also derived variational lower bounds to the data likelihood under a continuoustime diffusion model. Where we consider the infinitely deep limit of a standard VAE, Song et al. (2021a) and Vahdat et al. (2021) present different derivations based on stochastic differential equations. Huang et al. (2021) considers both perspectives and discusses the similarities between the two approaches. An advantage of our analysis compared to these other works is that we present an intuitive expression of the VLB in terms of the signaltonoise ratio of the diffused data, which then leads to new results on the invariance of the generative model and its VLB to the specification of the diffusion process. We empirically compare to these works, as well as others, in Table 1.
Previous approaches to diffusion probabilistic models fixed the diffusion process, while we consider flexible learned diffusion processes. This is enabled by directly parameterizing the mean and variance of the marginal , where previous approaches instead parameterized the individual diffusion steps . In addition, our denoising models include several architecture changes, the most important of which is the use of Fourier features, which enable us to reach much better likelihoods than previous diffusion probabilistic models.
We will focus on the most basic case of generative modeling, where we have a dataset of observations of , and the task is to estimate the marginal distribution . As with most generative models, the described methods can be extended to the case of multiple observed variables, and/or the task of estimating conditional densities . The proposed latentvariable model consists of a diffusion process (Section 3.1) that we invert to obtain a hierarchical generative model (Section 3.2). We optimize the model parameters by maximizing the variational lower bound of the marginal loglikelihood (Section 4). In contrast with earlier DPMs, we optimize the forward time diffusion process jointly with the rest of the model. This turns the model into a type of VAE (Kingma and Welling, 2013; Rezende et al., 2014).
The starting point for our generative model is a diffusion process that begins with the data , and then samples a sequence of latent variables given , where runs forward in time from to . The distribution of latent variable conditioned on , for any is given by:
(1) 
where and are scalarvalued functions that define the mean and variance of the marginal distributions with domain and range . Both and are smooth, such that their derivatives with respect to time are finite. Furthermore, their ratio is strictly monotonically decreasing in , such that for any .
The joint distribution of latent variables
, at subsequent timestepsare distributed as a firstorder Markov chain, such that
. The distribution of given , for any , is then:(2) 
Given the distributions above, it is straightforward to verify that the reverse time inference distribution of given and , for any , is also Gaussian and given by:
(3)  
(4) 
In Appendix E we provide an implementation of that is numerically stable for small .
As we show in Section 4, the continuoustime VLB objective that we will propose optimizing is surprisingly invariant to the choice of functions and , which we refer to as the noise schedule. Their only impact on our objective is through their ratio at times and :
(5) 
which we call the signaltonoise ratio.
The specific diffusion processes used by Song and Ermon (2019) and SohlDickstein et al. (2015) can be seen as special cases of the proposed model. Song and Ermon (2019) use , called varianceexploding diffusion processes by Song et al. (2021b). In our experiments, we choose to use variancepreserving diffusion processes as in (SohlDickstein et al., 2015; Ho et al., 2020) where . Written as a function of , this is:
(6) 
In previous works the signaltonoise ratio was a fixed function of time, but here we learn this function jointly with the rest of the model, as we explain in Section 5.
We define our generative model by inverting the diffusion process of Section 3.1, yielding a hierarchical generative model that samples a sequence of latents , with time running backward from to . We consider both the case where this sequence consists of a finite number of steps , as well as a continuous time model corresponding to . We start by presenting the discretetime case.
Given finite , we discretize time uniformly into segments of width . Defining and , our hierarchical generative model for data is then given by:
(7) 
With the variance preserving diffusion specification and sufficiently small , we have that . We therefore model the marginal distribution of as a spherical Gaussian:
(8) 
For the reconstruction term, we wish to choose a model that is close to the unknown . Let and be the th elements of , respectively. We then use a factorized distribution of the form:
(9) 
where we choose . With sufficiently large , this becomes a very close approximation to the true , as the influence of the unknown data distribution is overwhelmed by the likelihood .
Finally, we choose the conditional model distributions as
(10) 
i.e. the same as the reverse time inference model , but with the original data replaced by the output of a denoising model that predicts from its noisy version . We then have , with
(11) 
with variance the same as in Equation 3, and with
(12) 
Equation 11 shows that we can interpret our model in three different ways: 1) In terms of the denoising model that recovers from its corrupted version . 2) In terms of a noise prediction model that directly infers the noise that was used to generate . 3) In terms of a score model , that at its optimum equals the scores of the marginal density: ; see Appendix G. These are three equally valid views on the same model class, that have been used interchangeably in the literature. We find the denoising interpretation the most intuitive, and will therefore mostly use in this paper, although in practice we parameterize our model via following Ho et al. (2020).
When we take the number of steps , our model for can best be described as a continuous time diffusion process (Song et al., 2021b), governed by the stochastic differential equation
(13) 
with time running backwards from to and
(14) 
As we argue in Section 4.2, we reach the best likelihoods with . For most practical purposes however, there is no difference between the continuous time formulation of our model and the discretetime formulation with large . For simplicity we therefore use the discretetime model to derive most of our results in the remaining discussion. For a more detailed discussion of generative modeling via stochastic differential equations see Song et al. (2021b).
Similar to the original DPMs (SohlDickstein et al., 2015), we optimize the parameters towards the variational lower bound of the marginal likelihood, also called the variational lower bound (VLB). Unlike earlier DPMs, but similar to VAEs (Kingma and Welling, 2013; Rezende et al., 2014), we optimize the inference model parameters (that define the forward time diffusion process) jointly with the rest of the model.
The negative marginal loglikelihood is bounded by:
(15) 
The prior loss is a KL divergence between two Gaussians that can be computed in closed form; see Appendix B. The reconstruction loss can be evaluated and optimized using standard reparameterization gradients (Kingma and Welling, 2013). The diffusion loss,
, is more complicated, and depends on the hyperparameter
that determines the depth of the generative model.In the case of finite , using , , the diffusion loss is:
(16) 
In appendix B we show that this expression simplifies considerably, yielding:
(17) 
where
is the uniform distribution on the integers
, and .A natural question to ask is what the number of time segments should be, and whether more segments is always better. In Appendix C we analyze the difference between the diffusion loss with segments, , and the diffusion loss with double that number of segments, , and find that
(18) 
with and . Since , latent is a less noisy version of the data from earlier in the diffusion process compared to , which means that predicting the uncorrupted data from is easier than from . If our trained model is sufficiently good, we should thus always have that , i.e. that our VLB will always be better for a larger number of time segments.
Since taking more time steps leads to a better VLB, we take in the remainder of this paper, effectively treating time as continuous rather than discrete. In Appendix B we show that in this limit the diffusion loss simplifies further. Letting , we have:
(19)  
(20) 
In terms of predicting the noise , this can equivalently be written as
(21)  
(22) 
where .
Note that the signaltonoise function is invertible due to the monotonicity assumption in Section 3.1. Due to this invertibility, we can perform a change of variables, and make everything a function of instead of , such that . Let and be the functions and evaluated at , and correspondingly let . Similarly, we rewrite our noise prediction model as . With this change of variables, our continuoustime loss in Equation 19 can equivalently be written as:
(23) 
where instead of integrating w.r.t. time we now integrate w.r.t. the signaltonoise ratio , and where and .
What this equation shows us is that the only effect the functions and have on the diffusion loss is through the values at endpoints and . Given these values and , the diffusion loss is invariant to the shape of function between and .
In the last section, we showed that the VLB is only impacted by the function through its endpoints . We will now show that, apart from these endpoints, the choice of and actually does not matter at all in continuous time. Since , we have that , which means that . We can therefore adapt a model trained with diffusion process , to a model corresponding to a different diffusion process simply by a timedependent rescaling its input : because , we can simply define . For , we now have that
(24) 
In other words, diffusion processes with their corresponding denoising models define the exact same generative model. Assuming these diffusion processes also start and stop at the same signaltonoise ratios and , the previous section tells us that , i.e. they also both produce the same diffusion loss in continuous time. Any two diffusion models A and B, under the mild constraints set in 3.1 (which includes the variance exploding and variance preserving specifications), can thus be seen as equivalent in continuous time, up to a timedependent rescaling of .
This equivalence between diffusion specifications continues to hold even if, instead of the VLB, these models optimize a weighted diffusion loss of the form:
(25) 
which e.g. captures all the different objectives discussed by Song et al. (2021b), see Appendix F. Here, is a weighting function that generally puts increased emphasis on the noisier data compared to the VLB, and which thereby can sometimes improve perceptual generation quality as measured by certain metrics like FID and Inception Score.
Since calculating the integral ; or its generalization , is not analytically tractable, we use its unbiased Monte Carlo estimator in practice. To do this, we construct the VLB in terms of predicting , which is equivalent to predicting , but easier to implement in a numerically stable way:
(26) 
with , and , with . For the models presented in this paper, we further use
as corresponding to the (unweighted) VLB. The resulting VLB estimate can then be optimized using stochastic gradient descent as usual. The noise schedule
influences the variance of our Monte Carlo estimator, which is why we jointly optimize it with the rest of the model as described in Section 5.Here, we also found it helpful to sample time using a lowdiscrepancy sampler. When processing a minibatch of examples , , we require timesteps sampled from a uniform distribution. Instead of sampling these timesteps independently, we sample a single uniform random number and then set . Each now as the correct uniform marginal distribution, but the minibatch of timesteps covers the space in more equally than when sampling independently, which we find to reduce the variance in our VLB estimate.
So far, we have not discussed how to select the signaltonoise ratio function, or noise schedule that governs our generative model. In previous work, has a fixed form (see Appendix D, Fig. 4(a)). Here, we propose learning this schedule jointly with our denoising model . We parameterize
using a monotonically increasing neural network
, details of which are given in Appendix D.In the discretetime case, we then learn parameters by maximizing the VLB, together with our other model parameters. We find this to be especially beneficial when is small, as we show in Section 7.
The continuoustime case is different: As we showed in Section 4.3.1, the continuoustime diffusion loss is invariant to , except for its endpoints and . For this case, we therefore only optimize the VLB with respect to , and not the parameters
of the schedule interpolating between them.
Although this interpolating function does not impact the value of the continuoustime diffusion loss, it does impact the variance of our stochastic estimate of it, given in Equation 26. We therefore propose to learn by minimizing the variance, which we do by performing stochastic gradient descent on our squared diffusion loss . We have that , where the first part is independent of , and hence that
(27) 
We can calculate this gradient with negligible computational overhead as a byproduct of calculating the gradient of the VLB, details of which are given in Appendix D.
This strategy of minimizing the variance of our diffusion loss estimate remains valid for weighted diffusion losses, , not corresponding to the VLB, and we therefore expect it to be useful beyond the goal of optimizing for likelihood that we consider in this paper.
Prior work on diffusion models has mainly focused on the perceptual quality of generated samples, which emphasizes coarse scale patterns and global consistency of generated images. Here, we optimize for likelihood, which is sensitive to fine scale details and exact values of individual pixels. Since our reconstruction model given in Equation 9 is weak, the burden of modeling these fine scale details falls on our denoising diffusion model . In initial experiments, we found that the denoising model had a hard time accurately modeling these details. At larger noise levels, the latents follow a smooth distribution due to the added Gaussian noise, but at the smallest noise levels the discrete nature of 8bit image data leads to sharply peaked marginal distributions .
To capture the fine scale details of the data, we propose adding a set of Fourier features to the input of our denoising model . Such Fourier features consist of a linear projection of the original data onto a set of periodic basis functions with high frequency, which allows the network to more easily model high frequency details of the data. Previous work (Tancik et al., 2020) has used these features for input coordinates to model high frequency details across the spatial dimension, and for time embeddings to condition denoising networks over the temporal dimension(Song et al., 2021b). Here we apply it to color channels for single pixels, in order to model fine distributional details at the level of each scalar input.
Concretely, let be the scalar value in the th channel in the spatial position of network input . We then add additional channels to the input of the denoising model of the form
(28) 
where we used . These additional channels are then concatenated to before being used as input in a standard convolutional denoising model similar to that used by Ho et al. (2020). We find that the presence of these high frequency features allows our network to learn with much higher values of , or conversely lower noise levels , than is otherwise optimal. This leads to large improvements in likelihood as demonstrated in Section 7 and Figure 4. We did not observe such improvements when incorporating Fourier features into autoregressive models.
We demonstrate our proposed class of diffusion models, which we call Variational Diffusion Models (VDMs), on the CIFAR10 (Krizhevsky et al., 2009) dataset, and the downsampled ImageNet (Van Oord et al., 2016; Deng et al., 2009) dataset, where we focus on maximizing likelihood. The score models we use closely follow Ho et al. (2020), except that they process the data solely at the original resolution, without any internal downsampling or upsampling. Our score models are also deeper than those used by others in the literature. All reported models incorporate Fourier features (Section 6) as well as a learnable diffusion specification (Section 5). For our result with data augmentation we used random flips, 90degree rotations, and color channel swapping. Complete details on our model specifications can be found in Appendix A.
Model  Type  CIFAR10  CIFAR10  ImageNet  ImageNet 
(Bits per dim on test set)  no data aug.  data aug.  32x32  64x64  
ResNet VAE with IAF (Kingma et al., 2016)  VAE  3.11  
Very Deep VAE (Child, 2020)  VAE  2.87  3.80  3.52  
NVAE (Vahdat and Kautz, 2020)  VAE  2.91  3.92  
CRNVAE (Sinha and Dieng, 2021)  VAE  
Glow (Kingma and Dhariwal, 2018)  Flow  4.09  3.81  
Flow++ (Ho et al., 2019a)  Flow  3.08  3.86  3.69  
PixelCNN (Van Oord et al., 2016)  AR  3.03  3.83  3.57  
PixelCNN++ (Salimans et al., 2017)  AR  2.92  
Image Transformer (Parmar et al., 2018)  AR  2.90  3.77  
SPN (Menick and Kalchbrenner, 2018)  AR  3.52  
Sparse Transformer (Child et al., 2019)  AR  2.80  3.44  
Routing Transformer (Roy et al., 2021)  AR  3.43  
Sparse Transformer + DistAug (Jun et al., 2020)  AR  
DDPM (Ho et al., 2020)  Diff  
Score SDE (Song et al., 2021b)  Diff  2.99  
Improved DDPM (Nichol and Dhariwal, 2021)  Diff  2.94  3.54  
LSGM (Vahdat et al., 2021)  Diff  2.87  
ScoreFlow (Song et al., 2021a) (variational bound)  Diff  2.87  3.84  
ScoreFlow (Song et al., 2021a) (cont. norm. flow)  Diff  2.74  3.76  
VDM (ours) (variational bound)  Diff  2.65  3.72  3.40 
Table 1 shows our results on modeling the CIFAR10 dataset, and the downsampled ImageNet dataset. We establish a new stateoftheart in terms of test set likelihood on all the benchmarks without data augmentation, by a significant margin. Our model for CIFAR10 without data augmentation surpasses the previous best result of about 10x faster than it takes the Sparse Transformer to reach this, in wall clock time on equivalent hardware.
On CIFAR10 with data augmentation we tie the concurrent work by Sinha and Dieng (2021), which obtains impressive results by applying data augmentation and a consistency regularizer to VAEs. The data augmentation we considered is relatively simple compared to their work as we only use permutations of the data (flips, 90 degree rotations, channel shuffling) and not augmentations that change the data itself (zoom, noninteger shift, more general rotations). Training our model with the augmentation procedure used by Sinha and Dieng (2021) is an interesting direction for future work.
Our CIFAR10 model, whose hyperparameters were tuned for likelihood, results in a FID (perceptual quality) score of 7.41. This would have been stateoftheart until recently, but is worse than recent diffusion models that specifically target FID scores (Nichol and Dhariwal, 2021; Song et al., 2021b; Ho et al., 2020). By instead using a weighted diffusion loss, with the weighting function used by Ho et al. (2020) and described in Appendix F, our FID score improves to 4.0. We did not pursue further tuning of the model to improve FID instead of likelihood.
Next, we investigate the relative importance of our contributions. In Table 4 we compare our discretetime and continuoustime specifications of the diffusion model: When evaluating our model with a small number of steps, our discretely trained models perform better by learning the diffusion schedule to optimize the VLB. However, as argued theoretically in Section 4.2, we find experimentally that more steps indeed gives better likelihood. When grows large, our continuously trained model performs best, helped by training its diffusion schedule to minimize variance instead.
Minimizing the variance also helps the continuous time model to train faster, as shown in Figure 4. This effect is further examined in Table 4(b), where we find dramatic variance reductions compared to our baselines in continuous time. Figure 4(a) shows how this effect is achieved: Compared to the other schedules, our learned schedule spends much more time in the high / low range.
BPD  BitsBack Net BPD  
10  10  4.31  
100  100  2.84  
250  250  2.73  
500  500  2.68  
1000  1000  2.67  
10000  10000  2.66  
10  7.54  7.54  
100  2.90  2.91  
250  2.74  2.76  
500  2.69  2.72  
1000  2.67  2.72  
10000  2.65  
2.65 
[width=0.5]media/images/cifar10/fourier_ablation.pdf

In Figure 4 we further show training curves for our model including and excluding the Fourier features proposed in Section 6: With Fourier features enabled our model achieves much better likelihood. For comparison we also implemented Fourier features in a PixelCNN++ model (Salimans et al., 2017), where we do not see a benefit.
For a fixed number of evaluation timesteps , our diffusion model in discrete time is a hierarchical latent variable model that can be turned into a lossless compression algorithm using bitsback coding (Hinton and Van Camp, 1993). Bitsback coding encodes a latent and data together, with the latent sampled from the approximate posterior using auxiliary random bits. The net coding cost of bitsback coding is given by subtracting the number of bits needed to sample the latent from the number of bits needed to encode the latent and data using the reverse process, so the negative VLB of our discrete time model is the theoretical expected coding cost for bitsback coding.
As a proof of concept for practical lossless compression using our model, Table 4 reports net codelengths on the CIFAR10 test set for various settings of using BBANS (Townsend et al., 2018), a practical implementation of bitsback coding based on asymmetric numeral systems (Duda, 2009). Details of our implementation are given in Appendix I. We achieve stateoftheart net codelengths, proving our model can be used as the basis of a practical lossless compression algorithm. However, for large a gap remains with the theoretically optimal codelength corresponding to the negative VLB, and compression becomes computationally expensive due to the large number of neural network forward passes required. Closing this gap with more efficient implementations of bitsback coding suitable for very deep models is an interesting avenue for future work.
We presented stateoftheart results on modeling the density of natural images using a new class of diffusion models that incorporates a learnable diffusion specification, Fourier features for finescale modeling, as well as other architectural innovations. In addition, we obtained new theoretical insight into likelihoodbased generative modeling with diffusion models, showing a surprising invariance of the VLB to the forward time diffusion process in continuous time, as well as an equivalence between various diffusion processes from the literature previously thought to be different.
We thank Yang Song, Kevin Murphy and Mohammad Norouzi for feedback on drafts of this paper.
2009 IEEE conference on computer vision and pattern recognition
, pages 248–255. Ieee, 2009.Proceedings of the Sixth Annual Conference on Computational Learning Theory
, pages 5–13, 1993.Stochastic backpropagation and approximate inference in deep generative models.
In International Conference on Machine Learning, pages 1278–1286, 2014.Deep unsupervised learning using nonequilibrium thermodynamics.
In International Conference on Machine Learning, pages 2256–2265, 2015.Pixel recurrent neural networks.
In International Conference on Machine Learning, pages 1747–1756, 2016.A connection between score matching and denoising autoencoders.
Neural computation, 23(7):1661–1674, 2011.In this section we provide details on the exact setup for each of our experiments. In Sections A.1 we describe the choices in common to each of our experiments. Hyperparameters specific to the individual experiments are given in Section A.2
. We are currently working towards open sourcing our code.
Our denoising models are parameterized in terms of , where , and where is the negative log signaltonoise ratio, i.e. .
Our models closely follow the architecture used by Ho et al. [2020], which is based on a UNet type neural net [Ronneberger et al., 2015] that maps from the input to output of the same dimension. As compared to their publically available code at https://github.com/hojonathanho/diffusion, our implementation differs in the following ways:
Our networks don’t perform any internal downsampling or upsampling: we process all the data at the original input resolution.
Instead of taking time as input to the denoising model, we use , which we rescale to have approximately the same range as of before using it to form ’time’ embeddings in the same way as Ho et al. [2020].
Our models calculate Fourier features on the input data as discussed in Section 6, which are then concatenated to before being fed to the UNet.
Apart from the middle attention block that connects the upward and downward branches of the UNet, we remove all other attention blocks from the model. We found that these attention blocks made it more likely for the model to overfit to the training set.
We use the Adam optimizer with a learning rate of and exponential decay rates of . We found that higher values for resulted in training instabilities.
For evaluation, we use an exponential moving average of our parameters, calculated with an exponential decay rate of .
We implemented our evaluation of the VLB following Equation 58, which we find the easiest to implement in a numerically stable way. This is equivalent to the VLB expressions in terms of that we find more intuitive to reason about and which therefore form the core of our presentation here. We use the variance preserving diffusion process with , with
the sigmoid function. Contrary to earlier works that used a fixed noise schedule, we learn the
function using the approach described in Sections 5 and D.We regularly evaluate the variational bound on the likelihood on the validation set and find that our models do not overfit during training, using the current settings. We therefore do not use early stopping and instead allow the network to be optimized for 10 million parameter updates for CIFAR10, and for 2 million updates for ImageNet, before obtaining the test set numbers reported in this paper. It looks like our models keep improving even after this number of updates, in terms of likelihood, but we did not explore this systematically due to resource constraints.
All of our models are trained on TPUv3 hardware (see https://cloud.google.com/tpu) using data parallelism. We also evaluated our trained models using CPU and GPU to check for robustness of our reported numbers to possible rounding errors. We found only very small differences when evaluating on these other hardware platforms.
Our model for CIFAR10 with no data augmentation uses a UNet of depth 32, consisting of 32 ResNet blocks in the forward direction and 32 ResNet blocks in the reverse direction, with a single attention layer and two additional ResNet blocks in the middle. We keep the number of channels constant throughout at 128. This model was trained on 8 TPUv3 chips, with a total batch size of 128 examples. Reaching a testset BPD of after 10 million updates takes 9 days, although our model already surpasses sparse transformers (the previous stateoftheart) of BPD after only hours of training.
For CIFAR10 with data augmentation we used random flips, 90degree rotations, and color channel swapping, which were previously shown to help for density estimation by Jun et al. [2020]. Each of the three augmentations independently were given a probability of being applied to each example, which means that 1 in 8 training examples was not augmented at all. For this experiment, we doubled the number of channels in our model to 256, and decreased the dropout rate from to . Since overfitting was less of a problem with data augmentation, we add back the attention blocks after each ResNet block, following Ho et al. [2020]. We also experimented with conditioning our model on an additional binary feature that indicates whether or not the example was augmented, which can be seen as a simplified version of the augmentation conditioning proposed by Jun et al. [2020]. Conditioning made almost no difference to our results, which may be explained by the relatively large fraction () of clean data fed to our model during training. We trained our model for slightly over a week on 128 TPUv3 chips to obtain the reported result.
Our model for 32x32 ImageNet looks similar to that for CIFAR10 without data augmentation, with a UNet depth of 32, but uses double the number of channels at 256. It is trained using data parallelism on 32 TPUv3 chips, with a total batch size of 512.
Our model for 64x64 ImageNet uses double the depth at 64 ResNet layers in both the forward and backward direction in the UNet. It also uses a constant number of channels of 256. This model is trained on 128 TPUv3 chips at a total batch size of 512 examples. The model passes the
Similar to [SohlDickstein et al., 2015], we decompose the negative variational lower bound (VLB) as:
(29)  
(30) 
where, and . The second and third righthand side terms of Equation 29 can be evaluated and optimized using standard techniques. We will now derive an estimator for , the remaining and more challenging term. We will first derive an expression of .
Recall that , and thus and , with
(31)  
(32)  
(33) 
Since and are Gaussians, their KL divergence is available in closed form as a function of their means and variances, which due to their with equal variances simplifies as:
(34)  
(35)  
(36)  
(37)  
(38)  
(39)  
(40) 
Reparameterizing as , where , our diffusion loss becomes:
(41)  
(42) 
To avoid having to compute all
terms when calculating the diffusion loss, we construct an unbiased estimator of
using