Your Classifier is Secretly an Energy Based Model and You Should Treat it Like One

12/06/2019 ∙ by Will Grathwohl, et al. ∙ Google UNIVERSITY OF TORONTO 19

We propose to reinterpret a standard discriminative classifier of p(y|x) as an energy based model for the joint distribution p(x,y). In this setting, the standard class probabilities can be easily computed as well as unnormalized values of p(x) and p(x|y). Within this framework, standard discriminative architectures may beused and the model can also be trained on unlabeled data. We demonstrate that energy based training of the joint distribution improves calibration, robustness, andout-of-distribution detection while also enabling our models to generate samplesrivaling the quality of recent GAN approaches. We improve upon recently proposed techniques for scaling up the training of energy based models and presentan approach which adds little overhead compared to standard classification training. Our approach is the first to achieve performance rivaling the state-of-the-artin both generative and discriminative learning within one hybrid model.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 4

page 8

page 14

page 15

page 16

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Figure 1: Visualization of our method, JEM, which defines a joint EBM from classifier architectures.

For decades, research on generative models has been motivated by the promise that generative models can benefit downstream problems such as semi-supervised learning, imputation of missing data, and calibration of uncertainty (

e.g., Chapelle et al. (2006); Dempster et al. (1977)). Yet, most recent research on deep generative models ignores these problems, and instead focuses on qualitative sample quality and log-likelihood on heldout validation sets.

Currently, there is a large performance gap between the strongest generative modeling approach to downstream tasks of interest and hand-tailored solutions for each specific problem. One potential explanation is that most downstream tasks are discriminative in nature and state-of-the-art generative models have diverged quite heavily from state-of-the-art discriminative architectures. Thus, even when trained solely as classifiers, the performance of generative models is far below the performance of the best discriminative models. Hence, the potential benefit from the generative component of the model is far outweighed by the decrease in discriminative performance. Recent work (Behrmann et al., 2018; Chen et al., 2019) attempts to improve the discriminative performance of generative models by leveraging invertible architectures, but these methods still underperform their purely discriminative counterparts jointly trained as generative models.

This paper advocates the use of energy based models (EBMs) to help realize the potential of generative models on downstream discriminative problems. While EBMs are currently challenging to work with, they fit more naturally within a discriminative framework than other generative models and facilitate the use of modern classifier architectures. Figure 1

illustrates an overview of the architecture, where the logits of a classifier are re-interpreted to define the joint density of data points and labels and the density of data points alone.

The contributions of this paper can be summarized as: 1) We present a novel and intuitive framework for joint modeling of labels and data. 2) Our models considerably outperform previous state-of-the-art hybrid models at both generative and discriminative modeling. 3) We show that the incorporation of generative modeling gives our models improved calibration, out-of-distribution detection, and adversarial robustness, performing on par with or better than hand-tailored methods for each task.

2 Energy Based Models

Energy based models (LeCun et al., 2006) hinge on the observation that any probability density for can be expressed as

(1)

where , known as the energy function, maps each point to a scalar, and is the normalizing constant (with respect to ) also known as the partition function. Thus, one can parameterize an EBM using any function that takes as the input and returns a scalar.

For most choices of

, one cannot compute or even reliably estimate

, which means estimating the normalized densities is intractable and standard maximum likelihood estimation of the parameters, , is not straightforward. Thus, we must rely on other methods to train EBMs. We note that the derivative of the log-likelihood for a single example with respect to can be expressed as

(2)

where the expectation is over the model distribution. Unfortunately, we cannot easily draw samples from

, so we must resort to MCMC to use this gradient estimator. This approach was used to train some of the earliest EBMs. For example, Restricted Boltzmann Machines 

(Hinton, 2002) were trained using a block Gibbs sampler to approximate the expectation in Eq. (2).

Despite a long period of little development, there has been recent work using this method to train large-scale EBMs on high-dimensional data, parameterized by deep neural networks 

(Nijkamp et al., 2019b, a; Du and Mordatch, 2019; Xie et al., 2016). These recent successes have approximated the expectation in Eq. (2) using a sampler based on Stochastic Gradient Langevin Dynamics (SGLD) (Welling and Teh, 2011) which draws samples following

(3)

where

is typically a Uniform distribution over the input domain and the step-size

should be decayed following a polynomial schedule. In practice the step-size,

, and the standard deviation of

is often chosen separately leading to a biased sampler which allows for faster training. See Appendix H.1 for further discussion of samplers for EBM training.

3 What your classifier is hiding

In modern machine learning, a classification problem with

classes is typically addressed using a parametric function, , which maps each data point to real-valued numbers known as logits. These logits are used to parameterize a categorical distribution using the so-called Softmax transfer function:

(4)

where indicates the index of , i.e., the logit corresponding the the class label.

Our key observation in this work is that one can slightly re-interpret the logits obtained from to define and as well. Without changing , one can re-use the logits to define an energy based model of the joint distribution of data point and labels via:

(5)

where is the unknown normalizing constant and .

By marginalizing out , we obtain an unnormalized density model for as well,

(6)

Notice now that the of the logits of any classifier can be re-used to define the energy function at a data point as

(7)

Unlike typical classifiers, where shifting the logits by an arbitrary scalar does not affect the model at all, in our framework, shifting the logits for a data point will affect

. Thus, we are making use of the extra degree of freedom hidden within the logits to define the density function over input examples as well as the joint density among examples and labels. Finally, when we compute

via by dividing Eq. (5) to Eq. (6), the normalizing constant cancels out, yielding the standard Softmax parameterization in Eq. (4). Thus, we have found a generative model hidden within every standard discriminative model! Since our approach proposes to reinterpret a classifier as a Joint Energy based Model we refer to it throughout this work as JEM.

4 Optimization

We now wish to take advantage of our new interpretation of classifier architectures to gain the benefits of generative models while retaining strong discriminative performance. Since our model’s parameterization of is normalized over , it is simple to maximize its likelihood as in standard classifier training. Since our models for and are unnormalized, maximizing their likelihood is not as easy. There are many ways we could train to maximize the likelihood of the data under this model. We could apply the gradient estimator of Equation 2 to the likelihood under the joint distribution of Equation 5. Using Equations 6 and 4, we can also factor the likelihood as

(8)

The estimator of Equation 2 is biased when using a MCMC sampler with a finite number of steps. Given that the goal of our work is to incorporate EBM training into the standard classification setting, the distribution of interest is . For this reason we propose to train using the factorization of Equation 8 to ensure this distribution is being optimized with an unbiased objective. We optimize using standard cross-entropy and optimize using Equation 2 with SGLD where gradients are taken with respect to . We find alternative factorings of the likelihood lead to considerably worse performance as can be seen in Section 5.1.

Following Du and Mordatch (2019)

we use persistent contrastive divergence 

(Tieleman, 2008) to estimate the expectation in the right-hand-side of Equation 2 since it gives an order of magnitude savings in computation compared to seeding new chains at each iteration as in Nijkamp et al. (2019b). This comes at the cost of decreased training stability. These trade-offs are discussed in Appendix H.2.

5 Applications

We completed a thorough empirical investigation to demonstrate the benefits of JEM over standard classifiers. First, we achieved performance rivaling the state of the art in both discriminative and generative modeling. Even more interesting, we observed a number of benefits related to the practical application of discriminative models including improved uncertainty quantification, out-of-distribution detection, and robustness to adversarial examples. Generative models have been long-expected to provide these benefits but have never been demonstrated to do so at this scale.

All architectures used are based on Wide Residual Networks (Zagoruyko and Komodakis, 2016)

where we have removed batch-normalization

111This was done to remove sources of stochasticity in early experiments. Since then we have been able to successfully train Joint-EBMs with Batch Normalization and other forms of stochastic regularization (such as dropout) without issue. We leave the incorporation of these methods to further work. to ensure that our models’ outputs are deterministic functions of the input. This slightly increases classification error of a WRN-28-10 from to on CIFAR10 and from to on SVHN.

All models were trained in the same way with the same hyper-parameters which were tuned on CIFAR10. Intriguingly, the SGLD sampler parameters found here generalized well across datasets and model architectures. All models are trained on a single GPU in approximately 36 hours. Full experimental details can be found in Appendix A.

Class Model Accuracy% IS FID Hybrid Residual Flow 70.3 3.6 46.4 Glow 67.6 3.92 48.9 IGEBM 49.1 8.3 JEM factored 30.1 6.36 61.8 JEM (Ours) 38.4 Disc. Wide-Resnet 95.8 N/A N/A Gen. SNGAN N/A 8.59 25.5 NCSN N/A 8.91 25.32 Table 1: CIFAR10 Hybrid modeling Results. Residual Flow (Chen et al., 2019), Glow (Kingma and Dhariwal, 2018), IGEBM (Du and Mordatch, 2019), SNGAN (Miyato et al., 2018), NCSN (Song and Ermon, 2019) Figure 2: CIFAR10 class-conditional samples.

5.1 Hybrid modeling

SVHN
CIFAR100
Figure 3: Class-conditional samples.

First, we show that a given classifier architecture can be trained as an EBM to achieve competitive performance as both a classifier and a generative model. We train JEM on CIFAR10, SVHN, and CIFAR100 and compare against other hybrid models as well as standalone generative and discriminative models. We find JEM performs near the state of the art in both tasks simultaneously, outperforming other hybrid models (Table 5).

Given that we cannot compute normalized likelihoods, we present inception scores (IS) (Salimans et al., 2016) and Frechet Inception Distance (FID) (Heusel et al., 2017) as a proxy for this quantity. We find that JEM is competitive with SOTA generative models at these metrics. These metrics are not commonly reported on CIFAR100 and SVHN so we present accuracy and qualitative samples on these datasets. Our models achieve 96.7% and 72.2% accuracy on SVHN and CIFAR100, respectively. Samples from JEM can be seen in Figures 2, 3 and in Appendix C.

JEM is trained to maximize the likelihood factorization shown in Eq. 8. This was to ensure that no bias is added into our estimate of which can be computed exactly in our setup. Prior work (Du and Mordatch, 2019; Xie et al., 2016) proposes to factorize the objective as . In these works, each is a separate EBM with a distinct, unknown normalizing constant, meaning that their model cannot be used to compute or . This explains why the model of Du and Mordatch (2019) (we will refer to this model as IGEBM) is not a competitive classifier. As an ablation, we trained JEM to maximize this objective and found a considerable decrease in discriminative performance (see Table 5, row 4).

5.2 Calibration

Accuracy
Confidence Confidence
Figure 4: CIFAR100 calbration results. ECE = Expected Calibration Error (Guo et al., 2017), see Appendix E.1.

A classifier is considered calibrated if its predictive confidence, , aligns with its misclassification rate. Thus, when a calibrated classifier predicts label with confidence it should have a chance of being correct. This is an important feature for a model to have when deployed in real-world scenarios where outputting an incorrect decision can have catastrophic consequences. The classifier’s confidence can be used to decide when to output a prediction or deffer to a human, for example. Here, a well-calibrated, but less accurate classifier can be considerably more useful than a more accurate, but less-calibrated model.

While classifiers have grown more accurate in recent years, they have also grown considerably less calibrated (Guo et al., 2017). Contrary to this behavior, we find that JEM notably improves classification while retaining high accuracy.

We focus on CIFAR100 since SOTA classifiers achieve approximately accuracy. We train JEM on this dataset and compare to a baseline of the same architecure without EBM training. Our baseline model achieves accuracy and JEM achieves (for reference, a ResNet-110 achieves accuracy (Zagoruyko and Komodakis, 2016)). We find the baseline model is very poorly calibrated outputting highly over-confident predictions. Conversely, we find JEM produces a nearly perfectly calibrated classifier when measured with Expected Calibration Error (see Appendix E.1). Compared to other calibration methods such as Platt scaling (Guo et al., 2017), JEM requires no additional training data. Results can be seen in Figure 4 and additional results can be found in Appendix E.2.

5.3 Out-Of-Distribution Detection

In general, out-of-distribution (OOD) detection is a binary classification problem, where the model is required to produce a score

where is the query, and is the set of learnable parameters. We desire that the scores for in-distribution examples are higher than that out-of-distribution examples. Typically for evaluation, threshold-free metrics are used, such as the area under the receiver-operating curve (AUROC) (Hendrycks and Gimpel, 2016). There exist a number of distinct OOD detection approaches to which JEM can be applied. We expand on them below. Further results and experimental details can be found in Appendix F.2.

5.3.1 Input Density

A natural approach to OOD detection is to fit a density model on the data and consider examples with low likelihood to be OOD. While intuitive, this approach is currently not competitive on high-dimensional data. Nalisnick et al. (2018) showed that tractable deep generative models such as Kingma and Dhariwal (2018) and Salimans et al. (2017) can assign higher densities to OOD examples than in-distribution examples. Further work (Nalisnick et al., 2019) shows examples where the densities of an OOD dataset are completely indistinguishable from the in-distribution set, e.g., see Table 2, column 1. Conversely, Du and Mordatch (2019) have shown that the likelihoods from EBMs can be reliably used as a predictor for OOD inputs. As can be seen in Table 2 column 2, JEM consistently assigns higher likelihoods to in-distribution data than OOD data. One possible explanation for JEM’s further improvement over IGEBM is its ability to incorporate labeled information during training while also being able to derive a principled model of . Intriguingly, Glow does not appear to benefit in the same way from this supervision as is demonstrated by the little difference between our unconditional and class-conditional Glow results. Quantitative results can be found in Table 3 (top).

Glow JEM Approx. Mass JEM
SVHN
CIFAR100
CelebA
Table 2: Histograms for OOD detection. All models trained on CIFAR10. Green corresponds to the score on (in-distribution) CIFAR10, and red corresponds to the score on the OOD dataset.
CIFAR10
Model SVHN Interp CIFAR100 CelebA
Unconditional Glow .05 .51 .55 .57
Class-Conditional Glow .07 .45 .51 .53
IGEBM .63 .50 .70
JEM (Ours) .65
Wide-ResNet .85 .62
Class-Conditional Glow .64 .61 .65 .54
IGEBM .43 .69 .54 .69
JEM (Ours) .89 .75
Unconditional Glow .27 .46 .29
Class-Conditional Glow .01 .52 .59
IGEBM .84 .65 .55 .66
JEM (Ours)
Table 3: OOD Detection Results. Models trained on CIFAR10. Values are AUROC.

5.3.2 Predictive Distribution

Many successful approaches have utilized a classifier’s predictive distribution for OOD detection (Gal and Ghahramani, 2016; Wang et al., 2018; Liang et al., 2017). A useful OOD score that can be derived from this distribution is the maximum prediction probability: (Hendrycks and Gimpel, 2016). It has been demonstrated that OOD performance using this score is highly correlated with a model’s classification accuracy. Since JEM is a competitive classifier, we find it performs on par (or beyond) the performance of a strong baseline classifier and considerably outperforms other generative models. Results can be seen in Table 3 (middle).

5.3.3 A new score: Approximate Mass

It has been recently proposed that likelihood may not be enough for OOD detection in high dimensions (Nalisnick et al., 2019). It is possible for a point to have high likelihood under a distribution yet be nearly impossible to be sampled. Real samples from a distribution lie in what is known as the “typical” set. This is the area of high probability mass. A single point may have high density but if the surrounding areas have very low density, then that point is likely not in the typical set and therefore likely not a sample from the data distribution. For a high-likelihood datapoint outside of the typical set, we expect the density to change rapidly around it, thus the norm of the gradient of the log-density will be large compared to examples in the typical set (otherwise it would be in an area of high mass). We propose an alternative OOD score based on this quantity:

(9)

For EBMs (JEM and IGEBM), we find this predictor greatly outperforms our own and other generative model’s likelihoods – see Table 2 column 3. For tractable likelihood methods we find this predictor is anti-correlated with the model’s likelihood and neither is reliable for OOD detection. Results can be seen in Table 3 (bottom).

5.4 Robustness

Recent work (Athalye et al., 2017) has demonstrated that classifiers trained to be adversarially robust can be re-purposed to generate convincing images, do in-painting, and translate examples from one class to another. This is done through an iterative refinement procedure, quite similar to the SGLD used to sample from EBMs. We also note that adversarial training (Goodfellow et al., 2014) bears many similarities to SGLD training of EBMs. In both settings, we use a gradient-based optimization procedure to generate examples which activate a specific high-level network activation, then optimize the weights of the network to minimize the generated example’s effect on that activation. Further connections have been drawn between adversarial training and regularizing the gradients of the network’s activations around the data (Simon-Gabriel et al., 2018). This is similar to the objective of Score Matching (Hyvärinen, 2005) which can also be used to train EBMs (Kingma and Lecun, 2010; Song and Ermon, 2019).

Given these connections one may wonder if a classifier derived from an EBM would be more robust to adversarial examples than a standard model. This behavior has been demonstrated in prior work on EBMs (Du and Mordatch, 2019) but their work did not produce a competitive discriminative model and is therefore of limited practical application for this purporse. Similarly, we find JEM achieves considerable robustness without sacrificing discriminative performance.

5.4.1 Improved Robustness Through EBM Training

(a) Robustness
(b) Robustness
Figure 5: Adversarial Robustness Results with PGD attacks. JEM adds considerable robustness.

A common threat model for adversarial robustness is that of perturbation-based adversarial examples with an -norm constraint (Goodfellow et al., 2014). They are defined as perturbed inputs , which change a model’s prediction subject to . These examples exploit semantically meaningless perturbations to which the model is overly sensitive. However, closeness to real inputs in terms of a given metric does not imply that adversarial examples reside within areas of high density according to the model distribution, hence it is not surprising that the model makes mistakes when asked to classify inputs it has rarely or never encountered during training.

This insight has been used to detect and robustly classify adversarial examples with generative models (Song et al., 2017; Li et al., 2018; Fetaya et al., 2019). The state-of-the-art method for adversarial robustness on MNIST classifies by comparing an input to samples generated from a class-conditional generative model (Schott et al., 2018). This can be thought of as classifying an example similar to the input but from an area of higher density under the model’s learned distribution. This refined input resides in areas where the model has already “seen” sufficient data and is thus able to accurately classify. Albeit promising, this family of methods has not been able to scale beyond MNIST due to a lack of sufficiently powerful conditional generative models. We believe JEM can help close this gap. We propose to run a few iterations of our model’s sampling procedure seeded at a given input. This should be able to transform low-probability inputs to a nearby point of high probability, “undoing” any adversarial attack and enabling the model to classify robustly.

Perturbation Robustness   We run a number of powerful adversarial attacks on our CIFAR10 models. We run a white-box PGD attack, giving the attacker access to the gradients through our sampling procedure222In Du and Mordatch (2019) the attacker was not given access to the gradients of the refinement procedure. We re-run these stronger attacks on their model as well and provide a comparison in Appendix G.. Because our sampling procedure is stochastic, we compute the “expectation over transformations” Athalye et al. (2018), the expected gradient over multiple runs of the sampling procedure. We also run gradient-free black-box attacks; the boundary attack (Brendel et al., 2017) and the brute-force pointwise attack (Rauber et al., 2017). All attacks are run with respect to the and norms and we test JEM with 0, 1, and 10 steps of sampling seeded at the input.

Results from the PGD experiments can be seen in Figure 5. Experimental details and remaining results, including gradient-free attacks, can be found in Appendix G. Our model is considerably more robust than a baseline with standard classifier training. With respect to both norms, JEM performs competitive with, but slightly below state-of-the-art adversarial training (Madry et al., 2017; Santurkar et al., 2019) and the state-of-the-art certified robutness method of Salman et al. (2019) (“RandAdvSmooth” in Figure 5). We note that each of these baseline methods is trained to be robust to the norm through which it is being attacked and it has been shown that attacking an adversarially trained model with an adversary decreases robustness considerably (Madry et al., 2017). However, we attack the same JEM model with both norms and observe competitive robustness in both cases.

JEM with 0 steps refinement is noticeably more robust than the baseline model trained as a standard classifier, thus simply adding EBM training can produce more robust models. We also find that increasing the number of refinement steps further increases robustness to levels at robustness-specific approaches. We expect that increasing the number of refinement steps will lead to more robust models but due to computational constraints we could not run attacks in this setting.

Figure 6: Distal Adversarials. Confidently classified images generated from noise, such that: .

Distal Adversarials   Another common failure mode of non-robust models is their tendency to classify nonsensical inputs with high confidence. To analyze this property, we follow Schott et al. (2018). Starting from noise we generate images to maximize . Results are shown in figure 6. The baseline confidently classifies unstructured noise images. The adversarially trained ResNet with (Santurkar et al., 2019) confidently classifies somewhat structured, but unrealistic images. JEM does not confidently classify nonsensical images, so instead, car attributes and natural image properties visibly emerge.

6 Limitations

Energy based models can be very challenging to work with. Since normalized likelihoods cannot be computed, it can be hard to verify that learning is taking place at all. When working in domains such as images, samples can be drawn and checked to assess learning, but this is far from a generalizable strategy. Even so, these samples are only samples from an approximation to the model so they can only be so useful. Furthermore, the gradient estimators we use to train JEM are quite unstable and are prone to diverging if the sampling and optimization parameters are not tuned correctly. Regularizers may be added (Du and Mordatch, 2019) to increase stability but it is not clear what effect they have on the final model. The models used to generate the results in this work regularly diverged throughout training, requiring them to be restarted with lower learning rates or with increased regularization. See Appendix H.3 for a detailed description of how these difficulties were handled.

While this may seem prohibitive, we believe the results presented in this work are sufficient to motivate the community to find solutions to these issues as any improvement in the training of energy based models will further improve the results we have presented in this work.

7 Related Work

Prior work (Xie et al., 2016) made a similar observation to ours about classifiers and EBMs but define the model differently. They reinterpret the logits to define a class-conditional EBM , similar to Du and Mordatch (2019). This setting requires additional parameters to be learned to derive a classifier and an unconditional model. We believe this subtle distinction is responsible for our model’s success. The model of (Song and Ou, 2018) is similar as well but is trained using a GAN-like generator and is applied to different applications.

Our work builds heavily on Nijkamp et al. (2019b, a); Du and Mordatch (2019) which scales the training of EBMs to high-dimensional data using Contrastive Divergence and SGLD. While these works have pushed the boundaries of the types of data to which we can apply EBMs, many issues still exist. These methods require many steps of SGLD to take place at each training iteration. Each step requires approximately the same amount of computation as one iteration of standard discriminitive model training, therefore training EBMs at this scale is orders of magnitude slower than training a classifier – limiting the size of problems we can attack with these methods. There exist orthogonal approaches to training EBMs which we believe have promise to scale more gracefully.

Score matching (Hyvärinen, 2005) attempts to match the derivative of the model’s density with the derivative of the data density. This approach saw some development towards high-dimensional data (Kingma and Lecun, 2010) and recently has been successfully applied to large natural images (Song and Ermon, 2019)

. This approach required a model to output the derivatives of the density function, not the density function itself, so it is unclear what utility this model can provide to the applications we have discussed in this work. Regardless, we believe this is a promising avenue for further research. Noise Contrastive Estimation 

(Gutmann and Hyvärinen, 2010) rephrases the density estimation problem as a classification problem, attempting to distinguish data from a known noise distribution. If the classifier is properly structured, then once the classification problem is solved, an unnormalized density estimator can be derived from the classifier and noise distribution. While this method has been recently extended (Ceylan and Gutmann, 2018), these methods are challenging to extend to high-dimensional data.

8 Conclusion and Further Work

In this work we have presented JEM, a novel reinterpretation of standard classifier architectures which retains the strong performance of SOTA discriminative models while adding the benefits of generative modeling approaches. Our work is enabled by recent work scaling techniques for training EBMs to high dimensional data. We have demonstrated the utility of incorporating this type of training into discriminative models. While there exist many issues in training EBMs we hope the results presented here will encourage the community to improve upon current approaches.

9 Acknowledgements

We would like to thank Ying Nian Wu and Mitch Hill for providing some EBM training tips and tricks which were crucial in getting this project off the ground. We would also like to thank Jeremy Cohen for his useful feedback which greatly strengthened our adversarial robustness results. We would like to thank Lukas Schott for feedback on the robustness evaluation, Alexander Meinke and Francesco Croce for spotting some typos and suggesting the transfer attack.

References

  • A. Athalye, N. Carlini, and D. Wagner (2018) Obfuscated gradients give a false sense of security: circumventing defenses to adversarial examples. arXiv preprint arXiv:1802.00420. Cited by: §G.1, §5.4.1.
  • A. Athalye, L. Engstrom, A. Ilyas, and K. Kwok (2017) Synthesizing robust adversarial examples. arXiv preprint arXiv:1707.07397. Cited by: §5.4.
  • S. Barratt and R. Sharma (2018) A note on the inception score. arXiv preprint arXiv:1801.01973. Cited by: Table 6, Appendix B.
  • J. Behrmann, W. Grathwohl, R. T. Chen, D. Duvenaud, and J. Jacobsen (2018) Invertible residual networks. arXiv preprint arXiv:1811.00995. Cited by: §1.
  • W. Brendel, J. Rauber, and M. Bethge (2017) Decision-based adversarial attacks: reliable attacks against black-box machine learning models. arXiv preprint arXiv:1712.04248. Cited by: §5.4.1.
  • C. Ceylan and M. U. Gutmann (2018) Conditional noise-contrastive estimation of unnormalised models. arXiv preprint arXiv:1806.03664. Cited by: §7.
  • O. Chapelle, B. Scholkopf, and A. Zien (2006) Semi-supervised learning. MIT Press. Cited by: §1.
  • R. T. Chen, J. Behrmann, D. Duvenaud, and J. Jacobsen (2019) Residual flows for invertible generative modeling. arXiv preprint arXiv:1906.02735. Cited by: Table 6, Appendix B, §1, Table 1.
  • A. P. Dempster, N. M. Laird, and D. B. Rubin (1977) Maximum likelihood from incomplete data via the em algorithm. Journal of the Royal Statistical Society: Series B (Methodological) 39 (1), pp. 1–22. Cited by: §1.
  • Y. Du and I. Mordatch (2019) Implicit generation and generalization in energy-based models. arXiv preprint arXiv:1903.08689. Cited by: Table 6, Appendix B, §F.1, Figure 15, Appendix G, §H.1, §2, §4, §5.1, §5.3.1, §5.4, Table 1, §6, §7, §7, footnote 2.
  • E. Fetaya, J. Jacobsen, and R. Zemel (2019) Conditional generative models are not robust. arXiv preprint arXiv:1906.01171. Cited by: §5.4.1.
  • Y. Gal and Z. Ghahramani (2016)

    Dropout as a Bayesian approximation: Representing model uncertainty in deep learning

    .
    In International Conference on Machine Learning (ICML), pp. 1050–1059. Cited by: §5.3.2.
  • I. J. Goodfellow, J. Shlens, and C. Szegedy (2014) Explaining and harnessing adversarial examples. arXiv preprint arXiv:1412.6572. Cited by: §5.4.1, §5.4.
  • C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger (2017) On calibration of modern neural networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 1321–1330. Cited by: Figure 4, §5.2, §5.2.
  • M. Gutmann and A. Hyvärinen (2010) Noise-contrastive estimation: a new estimation principle for unnormalized statistical models. In

    Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics

    ,
    pp. 297–304. Cited by: §7.
  • D. Hendrycks and K. Gimpel (2016) A baseline for detecting misclassified and out-of-distribution examples in neural networks. arXiv preprint arXiv:1610.02136. Cited by: §5.3.2, §5.3.
  • M. Heusel, H. Ramsauer, T. Unterthiner, B. Nessler, and S. Hochreiter (2017) Gans trained by a two time-scale update rule converge to a local nash equilibrium. In Advances in Neural Information Processing Systems, pp. 6626–6637. Cited by: Table 6, Appendix B, §5.1.
  • G. E. Hinton (2002) Training products of experts by minimizing contrastive divergence. Neural computation 14 (8), pp. 1771–1800. Cited by: §2.
  • A. Hyvärinen (2005) Estimation of non-normalized statistical models by score matching. Journal of Machine Learning Research 6 (Apr), pp. 695–709. Cited by: §5.4, §7.
  • D. P. Kingma and J. Ba (2014) Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: Appendix A.
  • D. P. Kingma and P. Dhariwal (2018) Glow: generative flow with invertible 1x1 convolutions. In Advances in Neural Information Processing Systems, pp. 10215–10224. Cited by: §5.3.1, Table 1.
  • D. P. Kingma and Y. Lecun (2010) Regularized estimation of image statistics by score matching. In Advances in neural information processing systems, pp. 1126–1134. Cited by: §5.4, §7.
  • Y. LeCun, S. Chopra, R. Hadsell, M. Ranzato, and F. Huang (2006) A tutorial on energy-based learning. Predicting structured data 1 (0). Cited by: §2.
  • Y. Li, J. Bradshaw, and Y. Sharma (2018) Are generative classifiers more robust to adversarial attacks?. arXiv preprint arXiv:1802.06552. Cited by: §5.4.1.
  • S. Liang, Y. Li, and R. Srikant (2017) Enhancing the reliability of out-of-distribution image detection in neural networks. arXiv preprint arXiv:1706.02690. Cited by: §5.3.2.
  • A. Madry, A. Makelov, L. Schmidt, D. Tsipras, and A. Vladu (2017) Towards deep learning models resistant to adversarial attacks. arXiv preprint arXiv:1706.06083. Cited by: §5.4.1.
  • T. Miyato, T. Kataoka, M. Koyama, and Y. Yoshida (2018) Spectral normalization for generative adversarial networks. arXiv preprint arXiv:1802.05957. Cited by: Table 1.
  • E. Nalisnick, A. Matsukawa, Y. W. Teh, D. Gorur, and B. Lakshminarayanan (2018) Do deep generative models know what they don’t know?. arXiv preprint arXiv:1810.09136. Cited by: §5.3.1.
  • E. Nalisnick, A. Matsukawa, Y. W. Teh, and B. Lakshminarayanan (2019) Detecting out-of-distribution inputs to deep generative models using a test for typicality. arXiv preprint arXiv:1906.02994. Cited by: §5.3.1, §5.3.3.
  • E. Nijkamp, M. Hill, T. Han, S. Zhu, and Y. N. Wu (2019a) On the anatomy of mcmc-based maximum likelihood learning of energy-based models. arXiv preprint arXiv:1903.12370. Cited by: §H.1, §H.2, §H.3, §2, §7.
  • E. Nijkamp, S. Zhu, and Y. N. Wu (2019b) On learning non-convergent short-run mcmc toward energy-based model. arXiv preprint arXiv:1904.09770. Cited by: §H.1, §H.3, §2, §4, §7.
  • J. Rauber, W. Brendel, and M. Bethge (2017) Foolbox: a python toolbox to benchmark the robustness of machine learning models. arXiv preprint arXiv:1707.04131. Cited by: Appendix G, §5.4.1.
  • T. Salimans, I. Goodfellow, W. Zaremba, V. Cheung, A. Radford, and X. Chen (2016) Improved techniques for training gans. In Advances in neural information processing systems, pp. 2234–2242. Cited by: §5.1.
  • T. Salimans, A. Karpathy, X. Chen, and D. P. Kingma (2017) Pixelcnn++: improving the pixelcnn with discretized logistic mixture likelihood and other modifications. arXiv preprint arXiv:1701.05517. Cited by: §5.3.1.
  • H. Salman, G. Yang, J. Li, P. Zhang, H. Zhang, I. Razenshteyn, and S. Bubeck (2019) Provably robust deep learning via adversarially trained smoothed classifiers. arXiv preprint arXiv:1906.04584. Cited by: §5.4.1.
  • S. Santurkar, D. Tsipras, B. Tran, A. Ilyas, L. Engstrom, and A. Madry (2019) Computer vision with a single (robust) classifier. CoRR abs/1906.09453. External Links: Link, 1906.09453 Cited by: §5.4.1, §5.4.1.
  • L. Schott, J. Rauber, M. Bethge, and W. Brendel (2018) Towards the first adversarially robust neural network model on mnist. arXiv preprint arXiv:1805.09190. Cited by: §5.4.1, §5.4.1.
  • C. Simon-Gabriel, Y. Ollivier, L. Bottou, B. Schölkopf, and D. Lopez-Paz (2018) Adversarial vulnerability of neural networks increases with input dimension. arXiv preprint arXiv:1802.01421. Cited by: §5.4.
  • Y. Song and S. Ermon (2019) Generative modeling by estimating gradients of the data distribution. arXiv preprint arXiv:1907.05600. Cited by: §H.1, §5.4, Table 1, §7.
  • Y. Song, T. Kim, S. Nowozin, S. Ermon, and N. Kushman (2017) Pixeldefend: leveraging generative models to understand and defend against adversarial examples. arXiv preprint arXiv:1710.10766. Cited by: §5.4.1.
  • Y. Song and Z. Ou (2018) Learning neural random fields with inclusive auxiliary generators. arXiv preprint arXiv:1806.00271. Cited by: §7.
  • T. Tieleman (2008) Training restricted boltzmann machines using approximations to the likelihood gradient. In Proceedings of the 25th international conference on Machine learning, pp. 1064–1071. Cited by: §4.
  • K. Wang, P. Vicol, J. Lucas, L. Gu, R. Grosse, and R. Zemel (2018) Adversarial distillation of bayesian neural network posteriors. In International Conference on Machine Learning (ICML), Cited by: §5.3.2.
  • M. Welling and Y. W. Teh (2011) Bayesian learning via stochastic gradient langevin dynamics. In Proceedings of the 28th international conference on machine learning (ICML-11), pp. 681–688. Cited by: §H.1, §2.
  • J. Xie, Y. Lu, S. Zhu, and Y. Wu (2016) A theory of generative convnet. In International Conference on Machine Learning, pp. 2635–2644. Cited by: §2, §5.1, §7.
  • S. Zagoruyko and N. Komodakis (2016) Wide residual networks. arXiv preprint arXiv:1605.07146. Cited by: §5.2, §5.

Appendix A Training Details

We train all models with the Adam optimizer (Kingma and Ba, 2014)

for 150 epochs through the dataset using a staircase decay schedule. All network architecutres are based on WideResNet-28-10 with no batch normalization. We generate samples using PCD with hyperparameters in Table 

4. We evolve the chains with 20-steps of SGLD per iteration and with probability we reiniatilize the chains with uniform random noise. For preprocessing, we scale images to the range and add Gaussian noise of stddev = . Pseudo-code for our training procedure is in Algorithm 1.

When training via contrastive divergence there are a few different ways one could potentially draw samples from . We could:

  1. Sample then sample via SGLD with energy then throw away .

  2. Sample via SGLD with energy .

We experimented with both methods during training and found that while method 1 produced more visually appealing samples (to a human’s perspective), method 2 produced slightly stronger discirminative performance – 92.9% vs. 91.2% accuracy on CIFAR10. For this reason we use method 2 in all results presented.

1:while not converged do
2:     Sample and from dataset
3:     
4:     Sample with probability , else Initialize SGLD
5:     for  do SGLD
6:         
7:     end for
8:      Surrogate for Eq 2
9:     
10:     Obtain gradients for training
11:     Add to
12:end while
Algorithm 1 JEM training: Given network , SGLD step-size , SGLD noise , replay buffer , SGLD steps , reinitialization frequency
Variable Values
initial learning rate .0001
learning epochs 150
learning rate decay .3
learning rate decay epochs 50, 100
SGLD steps 20
Buffer-size 10000
reinitialization frequency .05
SGLD step-size 1
SGLD noise .01
Table 4: Hyperparameters

Appendix B Sample Quality Evalution

In this section we describe the details for reproducing the Inception Score (IS) and FID results reported in the paper. First we note that both IS and FID are scores computed based on a pretrained classifier network, and thus can be very dependent on the exact model/code repository used. For a more detailed discussion on the variability of IS, please refer to Barratt and Sharma (2018). To gauge our model against the other papers, we document our attempt to fairly compare the scores across papers in Table 4. As a direct comparison of IS, we got 8.76 using the code provided by Du and Mordatch (2019), and is better than their best reported score of 8.3. For FID, we used the official implementation from Heusel et al. (2017). Note that FID computed from this repository assigned much worse FID than reported in  Chen et al. (2019).

Conditional vs unconditional samples.

Since we are interested in training a Hybrid model, our model, by definition, is a conditional generative model as it has access to label information. In Table 5, unconditional samples mean samples directly obtained from running SGLD using . Conditional samples are obtained by taking the max of our model. The reported scores are obtained by keeping the top 10 percentile samples with the highest values. Scores obtained on a “single” model are computed directly on the training replay buffer of the last checkpoint. “Ensemble” here are obtained by lumping together 5 buffers over the last few epochs of training. As we initialize SGLD with uniform noise, using the training buffer is exactly the same as re-sampling from the model.

Conditional Unconditional
Method single ensemble single ensemble
JEM (Ours) - 8.76 7.82 7.79
EBM (D&M) 8.3 X 6.02 6.78
Table 5: Conditional vs. unconditional Inception Scores.
Inception Score FID
Method from paper B&S D&M from paper H D&M
Residual Flow X 3.6 - 46.4 - -
Glow X - 3.9 48.9* 107 -
JEM (Ours) X 7.13 8.76 X 38.4 -
JEM factored X - 6.36 X 61.8 -
EBM (D&M) 8.3 - 8.3 37.9 - 37.9
SNGAN 8.59 - - 25.5 - -
NCSN 8.91 - - 25.3 - -
Table 6: The headings: B&S, D&M, and H denotes scores computed using code provided by Barratt and Sharma (2018), Du and Mordatch (2019),Heusel et al. (2017). *denotes numbers copied from Chen et al. (2019), but not the original papers. As unfortunate as the case is with Inception Score and FID (i.e., taking different code repository yields vastly different results), from this table we can still see that our model performs well. Using D&M Inception Score we beat their own model, and using the official repository for FID we beat the Glow 444Code taken from https://github.com/y0ast/Glow-PyTorchmodel by a big margin.

Appendix C Further Hybrid Model Samples

Additional samples from CIFAR10 and SVHN can be seen in Figure 7 and samples from CIFAR100 can be seen in Figure 8

Figure 7: Class-conditional Samples. Left to right: CIFAR10, SVHN.
Figure 8: CIFAR100 Class-conditional Samples.

Appendix D Qualitative Analysis of Samples

Visual quality is difficult to quantify. Of the known metrics like IS and FID, using samples that have higher values results in higher scores, but not necessary if we use samples with higher

. However, this is likely because of the downfalls of the evaluation metrics themselves rather than reflecting true sample quality.

Based on our analysis (below), we find

  1. Our model assigns values that cluster around different means for different classes. The class automobiles has the highest . Of all generated samples, all top 100 samples are of this class.

  2. Given the class, the samples that have higher values all have white background and centered object, and lower samples have colorful (e.g., forest-like) background.

  3. Of all samples, higher values means clearly centered objects, and lower otherwise.

Figure 9: Each row corresponds to 1 class, subfigures corresponds to different values of . left: highest, mid: random, right: lowest.
Figure 10: Histograms (oriented horizontally for easier visual alignment) of arranged by class.
Figure 11: left: samples with highest , right: left: samples with lowest
Figure 12: left: samples with highest , right: left: samples with lowest

Appendix E Calibration

e.1 Expected Calibration Error

Expected Calibration Error (ECE) is a metric to measure the calibration of a classifier. It works by first computing the confidence, , for each in some dataset. We then group the items into equally spaced buckets based on the classifier’s output confidence. For example, if , then would represent all examples for which the classifier’s confidence was between and .

We then define:

(10)

where is the number of examples in the dataset, acc is the averaged accuracy of the classifier of all examples in and conf is the averaged confidence over all examples in .

For a perfectly calibrated classifier, this value will be 0 for any choice of . In our analysis, we choose throughout.

e.2 Further results

We find that JEM also improves calibration on CIFAR10 as can be seen in Table 13. There we see an improvement in calibration, but both classifiers are well calibrated because their accuracy is so high. In a more interesting experiment, we limit the size of the training set to 4,000 labeled examples. In this setting the accuracy drops to 78.0% and 74.9% in the baseline and JEM, respectively. Given the JEM can be trained on unlabeled data, we treat the remainder of the training set as unlabeled and train in a semi-supervised manner. We find this gives a noticeable boost in the classifier’s calibration as seen in Figure 13. Surprisingly this did not improve generalization. We leave exploring this phenomenon for future work.

(a) CIFAR10 Baseline (b) CIFAR10 JEM
(c) CIFAR100 Baseline (4k labels) (d) CIFAR100 JEM (4k labels)
Figure 13: CIFAR10 Calibration results

Appendix F Ouf-Of-Distribution Detection

f.1 Experimental details

To obtain OOD results for unconditional Glow, we used the pre-trained model and implementation of https://github.com/y0ast/Glow-PyTorch. We trained a Class-Conditional model as well using this codebase which was used to generate the class-conditional OOD results.

We obtained the IGEBM of Du and Mordatch (2019)

from their open-source implementation at

https://github.com/openai/ebm_code_release. For likelihood and likelihood-gradient OOD scores we used their pre-trained cifar10_large_model_uncond model. We were able to replicate the likelihood based OOD results presented in their work. We implemented our likelihood-gradient approximate-mass score on top of their codebase. For predictive distribution based OOD scores we used their cifar_cond model which was the model used in their work to generate their robustness results.

f.2 Further results

Figure 7 contains results on two datasets, Constant and Uniform, which were omitted for space. Most models perform very well at the Uniform dataset. On the Constant dataset (all examples = 0) generative models mainly fail – with JEM being the only one whose likelihoods can be used to derive a predictive score function for OOD detection. Intrestinly, we could not obtain approximate mass scores on this dataset from the Glow models due to numerical stability issues.

CIFAR10
Score Model SVHN Uniform Constant Interp CIFAR100 CelebA
Unconditional Glow .05 1.0 0.0 .51 .55 .57
Glow Supervised .07 1.0 0.0 .45 .51 .53
IGEBM .63 1.0 .30 .50 .70
JEM (Ours) 1.0 .65
WRN-baseline .85 .62
Class-Conditional Glow .64 0.0 .82 .61 .65 .54
IGEBM .43 .05 .60 .69 .54 .69
JEM (Ours) .89 .41 .84 .75
Unconditional Glow .99 NaN .27 .46 .29
Class-Conditional Glow .99 NaN .01 .52 .59
IGEBM .84 .99 0.0 .65 .55 .66
JEM (Ours)
Table 7: OOD Detection Results. Values are AUROC.

Appendix G Attack Details and Further Robustness Results

We use foolbox (Rauber et al., 2017) for our experiments. PGD uses binary search to determine minimal epsilons for every input and we plot the resulting robustness-distortion curves. PGD runs with 20 random restarts and 40 iterations. For the boundary attack, we run default foolbox settings with one important difference. The random initialization often fails for JEM and thus we initialize the attack with a correclty classified input of another class. This other class is chosen based on the top-2 prediction for the image to be attacked. As all our attacks are expensive to run, we only attacked 300 randomly chosen inputs. The same randomly chosen inputs were used to attack each model.

In Figure 14 we see the results of the boundary attack and pointwise attack on JEM and a baseline. The main point to running these attacks was to demonstrate that our model was not able to “cheat” by having vanishing gradients through our gradient-based sampling procedure. Since PGD was more successful than these gradient-free methods, this is clearly not the case and the attacker was able to use the gradients of the sampling procedure to attack our model. Further, we observe the same behavior across all attacks; the EBM with 0 steps sampling is more robust than the baseline and the robustness increases as we add more steps of sampling.

We also compare JEM to the IGEBM of Du and Mordatch (2019) with 10 steps of sampling refinement, see Figure 15. We run the same gradient-based attacks on their model and find that despite not having competitive clean accuracy, it is quite robust to large attacks – especially with respect to the norm. After their model is more robust than ours and after it is more robust than the adversarial training baseline. With respect to the norm their model is more robust than the adversarial training baseline above but remains less robust than JEM until .

We believe these results demonstrate that EBMs are a compelling class of models to explore for further work on building robust models.

(a) Boundary (b) Boundary
(c) Pointwise (d) Pointwise
Figure 14: Gradient-free adversarial attacks.
(a) PGD (b) PGD
Figure 15: PGD attacks comparing JEM to the IG EBM of Du and Mordatch (2019).

g.1 Expectation Over Transformations

Our SGLD-based refinement procedure is stochastic in nature and it has been shown that stochastic defenses to adversarial attacks can provide a false sense of security (Athalye et al., 2018). To deal with this, when we attack our stochastically refined classifiers, we average the classifier’s predictions over multiple samples of this refinement procedure. This makes the defense more deterministic and easier to attack. We redefine the logits of our classifier as:

(11)

where we have defined SGLD as an SGLD chain run for steps seeded at . Intuitively, we draw different samples from our model seeded at input , then compute for each of these samples, then average the results. We then attack these averaged logits with PGD to generate the results in Figure 5. We experimented with different numbers of samples and found that 10 samples yields very similar results to 5 samples on JEM with one refinement step (see Figure 16). Because 10 samples took very long to run on the JEM model with ten refinement steps, we settled on using 5 samples in the results reported in the main body of the paper.

(a) PGD (b) PGD
Figure 16: Comparing the effect of the number of samples in the EOT attack. We find negligible difference between 5 and 10 for JEM-1 (red and green curves).

g.2 Transfer Attacks

We would like to see if JEM’s refinement procedure can correct adversarial perturbed inputs – inputs which cause the model to fail. To do this, we generate a series of adversarial examples for JEM-0, with respect to the norm, and test the accuracy of JEM-{1,10} on these examples. Ideally, with further refinement the accuracy will increase. The results of this experiment can be seen in Figure 17. We see here that JEM’s refinement procedure can correct for adversarial perturbations.

(a) PGD
Figure 17: PGD transfer attack . We attack JEM-0 and evaluate success of the same adversarial examples under JEM-1 and JEM-10. Whenever an adversarial example is refined back to its correct class, we set the distance to infinity. Note that the adversarial examples do not transfer well from JEM-0 to JEM-1/-10.

Appendix H A Discussion on Samplers

h.1 Improper SGLD

Recall the transition kernel of SGLD:

In the proper formulation of this sampler (Welling and Teh, 2011)

, the step-size and the variance of the Gaussian noise are related

. If the stepsize is decayed with a polynomial schedule, then samples from SGLD converge to samples from our unnomralized density as the number of steps goes to .

In practice, we approximate these samples with a sampler that runs for a finite number of steps. When using the proper step-size to noise ratio, the signal from the gradient is overtaken by the noise when step-sizes are large enough to be informative. In practice the sampler is typically “relaxed” in that different values are used for the step-size and the amount of Guassian noise added – typically the amount of noise is significantly reduced.

While we are no longer working with a valid MCMC sampler, this approximation has been successfully applied in practice in most recent work scaling EBM training to high dimensional data (Nijkamp et al., 2019b, a; Du and Mordatch, 2019) with the exception of Song and Ermon (2019) (which develops a clever work-around). The model they train is actually an ensemble of models trained on data with different amounts of noise added. They use a proper SGLD sampler decaying the step size as they sample, moving from their high-noise models to their low-noise models. This provides one possible explanation for the compelling results of their model.

In our work we have set the step-size and draw . We have found these parameters to work well across a variety of datasets, domains, architectures, and sampling procedures (persistent vs. short-run). We believe they are a decent “starting place” for energy-functions parameterized by deep neural networks.

h.2 Persistent or Short-run Chains?

Both persistent and short-run markov chains have been able to succesfully train EBMs.

Nijkamp et al. (2019a) presents a careful study of various samplers which can be used and the tradeoffs one makes when choosing one sampler over another. In our work we have found that if computation allows, short-run MCMC chains are preferable in terms of training stability. Given that each step of SGLD requires approximately the computation of 1 training iteration of a standard classifier we are incentivized to find a sampler which can stably train EBMs requiring as few steps as possible per training iteration.

In our experiments we found the smallest number of SGLD steps we could take to stably train an EBM at the scale of this work was 80 steps. Even so, these models eventually would diverge late into training. At 80 steps, we found the cost of training to be prohibitively high compared to a standard classifier.

We found that by using persistent markov chains, we could further reduce the number of steps per iteration to 20 and still allow for relatively stable training. This gave a 4x speedup over our fastest short-run MCMC sampler. Still, this PCD sampler was noticebly less stable than the fastest short-run sampler we could use but we found the multiple factor increase in speed to be a worth-while trade-off.

If time allows, we recommend using a short-run MCMC sampler with a large enough number of steps to be stable. Given that is not always possible on problems of scale, PCD can be made to work more efficiently, but at the cost of a greater number of stability-related hyper-parameters. These additional parameters include the buffer size and the re-initialization frequency of the Markov chains. We found both to be important for training stability and found no general recipe for which to set them. We ran most of our experiments with re-initialization frequency at .

A particualrly interesting observation we discovered while using PCD is that the model would use the length of the markov chains to encode semantic information. We found that when training models on CIFAR10, when chains were young they almost always could be identified as frogs. When chains were old they could almost always be identified as cars. This behavior is likely some degeneracy of PCD which would not be possible with a short-run MCMC since all chains have the same length.

h.3 Dealing with Instability

Training a model with the gradient estimator of Eq. (2) can be quite unstable – especially when combined with other objective as was the case with all models presented in this work. There exists a “stable region” of sorts when training these models where the energy values of the true data are in the same range as the energy values of the generated samples. Intuitively, if the generated samples create energies that are not trivially separated from the training data, then real learning has to take place. Nijkamp et al. (2019b, a) provide a careful analysis of this and we refer the reader there for a more in-depth analysis.

We find that when using PCD occasionally throughout training a sample will be drawn from the replay buffer that has a considerably higher-than average energy (higher than the energy of a random initialization). This causes the gradients w.r.t this example to be orders of magnitude larger than gradients w.r.t the rest of the examples and causes the model to diverge. We tried a number of heuristic approaches such as gradient clipping, energy clipping, ignoring examples with atypical energy values, and many others but could not find an approach that stabilized training and did not hurt generative and discriminative performance.

The only two approaches we found to consistently work to increase stability of a model which has diverged is to 1) decrease the learning rate and 2) increase the number of SGLD steps in each PCD iteration. Unfortunately, both of these approaches slow down learning. We also had some success simply restarting models from a saved checkpoint with a different random seed. This was the main approach taken unless the model was late into training. In this case, random restarts were less effective and we increased the number of SGLD steps from 20 to 40 which stabilized training.

While we are very optimistic about the future of large-scale EBMs we believe these are the most important issues that must be addressed in order for these models to be succeful.