Augment and Reduce: Stochastic Inference for Large Categorical Distributions

Categorical distributions are ubiquitous in machine learning, e.g., in classification, language models, and recommendation systems. They are also at the core of discrete choice models. However, when the number of possible outcomes is very large, using categorical distributions becomes computationally expensive, as the complexity scales linearly with the number of outcomes. To address this problem, we propose augment and reduce (A&R), a method to alleviate the computational complexity. A&R uses two ideas: latent variable augmentation and stochastic variational inference. It maximizes a lower bound on the marginal likelihood of the data. Unlike existing methods which are specific to softmax, A&R is more general and is amenable to other categorical models, such as multinomial probit. On several large-scale classification problems, we show that A&R provides a tighter bound on the marginal likelihood and has better predictive performance than existing approaches.



There are no comments yet.


page 1

page 2

page 3

page 4


One-vs-Each Approximation to Softmax for Scalable Estimation of Probabilities

The softmax representation of probabilities for categorical variables pl...

The Thermodynamic Variational Objective

We introduce the thermodynamic variational objective (TVO) for learning ...

Variational Rejection Particle Filtering

We present a variational inference (VI) framework that unifies and lever...

Categorical Reparameterization with Gumbel-Softmax

Categorical variables are a natural choice for representing discrete str...

Modeling Text Complexity using a Multi-Scale Probit

We present a novel model for text complexity analysis which can be fitte...

Unsupervised and interpretable scene discovery with Discrete-Attend-Infer-Repeat

In this work we present Discrete Attend Infer Repeat (Discrete-AIR), a R...

Field-wise Learning for Multi-field Categorical Data

We propose a new method for learning with multi-field categorical data. ...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Categorical distributions are fundamental to many areas of machine learning. Examples include classification (Gupta et al., 2014), language models (Bengio et al., 2006), recommendation systems (Marlin & Zemel, 2004)

, reinforcement learning

(Sutton & Barto, 1998)

, and neural attention models

(Bahdanau et al., 2015). They also play an important role in discrete choice models (McFadden, 1978).

A categorical is a die with

sides, a discrete random variable that takes on one of

unordered outcomes; a categorical distribution gives the probability of each possible outcome. Categorical variables are challenging to use when there are many possible outcomes. Such large categoricals appear in common applications such as image classification with many classes, recommendation systems with many items, and language models over large vocabularies. In this paper, we develop a new method for fitting and using large categorical distributions.

The most common way to form a categorical is through the softmax transformation, which maps a

-vector of reals to a distribution of

outcomes. Let be a real-valued -vector. The softmax transformation is


Note the softmax is not the only way to map real vectors to categorical distributions; for example, the multinomial probit (Albert & Chib, 1993) is an alternative. Also note that in many applications, such as in multiclass classification, the parameter is a function of per-sample features

. For example, a linear classifier forms a categorical over classes through a linear combination,


We usually fit a categorical with maximum likelihood estimation or any other closely related strategy. Given a dataset

of categorical data—each is one of values—we aim to maximize the log likelihood,


Fitting this objective requires evaluating both the log probability and its gradient.

Eqs. 2 and 1 reveal the challenge to using large categoricals. Evaluating the log probability and evaluating its gradient are both operations. But this is not OK: most algorithms for fitting categoricals—for example, stochastic gradient ascent—require repeated evaluations of both gradients and probabilities. When is large, these algorithms are prohibitively expensive.

Here we develop a method for fitting large categorical distributions, including the softmax but also more generally. It is called AR. AR rewrites the categorical distribution with an auxiliary variable ,


AR then replaces the expensive log probability with a variational bound on the integral in Eq. 3. Using stochastic variational methods (Hoffman et al., 2013), the cost to evaluate the bound (or its gradient) is far below .

Because it relies on variational methods, AR provides a lower bound on the marginal likelihood of the data. With this bound, we can embed AR in a larger algorithm for fitting a categorical, e.g., a (stochastic) VEM algorithm (Beal, 2003). Though we focus on maximum likelihood, we can also use AR in other algorithms that require or its gradient, e.g., fully Bayesian approaches (Gelman et al., 2003) or the reinforce algorithm (Williams, 1992).

We study AR on linear classification tasks with up to classes. On simulated and real data, we find that it provides accurate estimates of the categorical probabilities and gives better performance than existing approaches.

Related work.  There are many methods to reduce the cost of large categorical distributions, particularly under the softmax transformation. These include methods that approximate the exact computations (Gopal & Yang, 2013; Vijayanarasimhan et al., 2014), those that rely on sampling (Bengio & Sénécal, 2003; Mikolov et al., 2013; Devlin et al., 2014; Ji et al., 2016; Botev et al., 2017), those that use approximations and distributed computing (Grave et al., 2017), double-sum formulations (Raman et al., 2017; Fagan & Iyengar, 2018)

, and those that avail themselves of other techniques such as noise contrastive estimation

(Smith & Jason, 2005; Gutmann & Hyvärinen, 2010) or random nearest neighbor search (Mussmann et al., 2017).

Other methods change the model. They might replace the softmax transformation with a hierarchical or stick-breaking model (Kurzynski, 1988; Morin & Bengio, 2005; Tsoumakas et al., 2008; Beygelzimer et al., 2009; Dembczyński et al., 2010; Khan et al., 2012). These approaches can be successful, but the structure of the hierarchy may influence the learned probabilities. Other methods replace the softmax with a scalable spherical family of losses (Vincent et al., 2015; de Brébisson & Vincent, 2016).

AR is different from all of these techniques. Unlike many of them, it provides a lower bound on the log probability rather than an approximation. The bound is useful because it can naturally be embedded in algorithms like stochastic VEM. Further, the AR methodology applies to transformations beyond the softmax. In this paper, we study large categoricals via softmax, multinomial probit, and multinomial logistic. AR is the first scalable approach for the two latter models. It accelerates any transformation that can be recast as an additive noise model (e.g., Gumbel, 1954; Albert & Chib, 1993).

The approach that most closely relates to AR is the OVE bound of Titsias (2016), which is a lower bound of the softmax. Like the other related methods, it is narrower than AR in that it does not apply to transformations beyond the softmax. We also empirically compare AR to OVE in Section 4. AR provides a tighter lower bound and yields better predictive performance.

2 Augment and Reduce

We develop AR, a method for computing with large categorical random variables.

The utility perspective.  AR uses the additive noise model perspective on the categorical, which we refer to as the utility perspective. Define a mean utility for each possible outcome . To draw a variable from a categorical, we draw a zero-mean noise term for each possible outcome and then choose the value that maximizes the realized utility . This corresponds to the following process,


Note the errors are drawn fresh each time we draw a variable . We assume that the errors are independent of each other, independent of the mean utility , and identically distributed according to some distribution .

Now consider the model where we marginalize the errors from Eq. 4. This results in a distribution , a categorical that transforms to the simplex. Depending on the distribution of the errors, this induces different transformations. For example, a standard Gumbel distribution recovers the softmax transformation; a standard Gaussian recovers the multinomial probit transformation; a standard logistic recovers the multinomial logistic transformation.

Typically, the mean utility is a function of observed features , e.g., in linear models or in non-linear settings. In both cases, are model parameters, relating the features to mean utilities.

Let us focus momentarily on a linear classification problem under the softmax model. For each observation , the mean utilities are and the random errors are Gumbel distributed. After marginalizing out the errors, the probability that observation is in class is given by Eq. 1, . Fitting the classifier involves learning the weights that parameterize . For example, maximum likelihood uses gradient ascent to maximize with respect to .

Large categoricals.  When the number of outcomes is large, the normalizing constant of the softmax is a computational burden; it is . Consequently, it is burdensome to calculate useful quantities like and its gradient . As an ultimate consequence, maximum likelihood estimation is slow—it needs to evaluate the gradient for each at each iteration.

Its difficulty scaling is not unique to the softmax. Similar issues arise for the multinomial probit and multinomial logistic. With these transformations as well, evaluating likelihoods and related quantities is .

2.1 Augment and reduce

We introduce AR to relieve this burden. AR accelerates training in models with categorical distributions and a large number of outcomes.

Rather than operating directly on the marginal , AR augments the model with one of the error terms and forms a joint . (We drop the subscript to avoid cluttered notation.) This augmented model has a desirable property: its log-joint is a sum over all the possible outcomes. AR then reduces—it subsamples a subset of outcomes to construct estimates of the log-joint and its gradient. As a result, its complexity relates to the size of the subsample, not the total number of outcomes .

The augmented model.  Let be the distribution over the error terms, and the corresponding CDF. The marginal probability of outcome is the probability that its realized utility () is greater than all others,

We write this probability as an integral over the th error using the CDF of the other errors,


(We renamed the dummy variable

as to avoid clutter.) Eq. 5 is the same as found by Girolami & Rogers (2006) for the multinomial probit model, although we do not assume a Gaussian density . Rather, we only assume that we can evaluate both and .

We derived Eq. 5 from the utility perspective, which encompasses many common models. We obtain the softmax by choosing a standard Gumbel distribution for , in which case Eqs. 1 and 5

are equivalent. We obtain the multinomial probit by choosing a standard Gaussian distribution over the errors, and in this case the integral in

Eq. 5 does not have a closed form. Similarly, we obtain the multinomial logistic by choosing a standard logistic distribution . What is important is that regardless of the model, the cost to compute the marginal probability is .

We now augment the model with the auxiliary latent variable

to form the joint distribution



This is a model that includes the th error term from Eq. 4 but marginalizes out all the other errors. By construction, marginalizing from Eq. 6 recovers the original model in Eq. 5. Figure 1 illustrates this idea.

Riihimäki et al. (2013) used Eq. 6 in the nested expectation propagation for Gaussian process classification. We use it to scale learning with categorical distributions.

Figure 1: (a) Illustration of the parameterization of a categorical model in terms of the utilities , where is the mean utility and is an error term. The observed outcome is . (b) In this model, the error terms have been marginalized out. This is the most common model for categorical distributions; it includes the softmax and multinomial probit. (c) The augmented model that we consider for AR. All error terms have been integrated out, except one. In this model, the log-joint involves a summation over the possible outcomes

, enabling fast unbiased estimates of the log probability and its gradient.

The variational bound.  The augmented model in Eq. 6 involves one latent variable . But our goal is to calculate the marginal and its gradient. AR derives a variational lower bound on using the joint in Eq. 6. Define to be a variational distribution on the auxiliary variable. The bound is , where


In Eq. 7, is the ELBO; it is tight when is equal to the posterior of given , (Jordan et al., 1999; Blei et al., 2017).

The ELBO contains a summation over the outcomes . AR exploits this property to reduce complexity, as we describe below. Next we show how to use the bound in a VEM procedure and we describe the reduce step of AR.

Variational expectation maximization.  Consider again a linear classification task, where we have a dataset of features and labels for . The mean utility for each observation is , and the goal is to learn the weights by maximizing the log likelihood .

AR replaces each term in the data log likelihood with its bound using Eq. 7. The objective becomes . Maximizing this objective requires an iterative process with two steps. In one step, AR optimizes the objective with respect to . In the other step, AR optimizes each with respect to the variational distribution. The resulting procedure takes the form of a VEM algorithm (Beal, 2003).

The VEM algorithm requires optimizing the ELBO with respect to and the variational distributions.111Note that maximizing the ELBO in Eq. 7 with respect to the distribution

is equivalent to minimizing the Kullback-Leibler divergence from

to the posterior . This is challenging for two reasons. First, the expectations in Eq. 7 might not be tractable. Second, the cost to compute the gradients of Eq. 7 is still .

Section 3 addresses these issues. To sidestep the intractable expectations, AR forms unbiased Monte Carlo estimates of the gradient of the ELBO. To alleviate the computational complexity, AR uses stochastic optimization, subsampling a set of outcomes .

Reduce by subsampling.  The subsampling step in the VEM procedure is one of the key ideas behind AR. Since Eq. 7 contains a summation over the outcomes , we can apply stochastic optimization techniques to obtain unbiased estimates of the ELBO and its gradient.

More specifically, consider the gradient of the ELBO in Eq. 7 with respect to (the parameters of ). It is

AR estimates this by first randomly sampling a subset of outcomes of size . AR then uses the outcomes in to approximate the gradient,

This is an unbiased estimator222This is not the only way to construct an unbiased estimator. Alternatively, we can draw the outcomes using importance sampling, taking into account the frequency of each class. We leave this for future work. of the gradient . Crucially, AR only needs to iterate over outcomes to obtain it, reducing the complexity to .

The reduce step is also applicable to optimize the ELBO with respect to . Section 3 gives further details about the stochastic VEM procedure in different settings.

3 Algorithm Description

Here we provide the details to run the VEM algorithm for the softmax model (Section 3.1) and for more general models including the multinomial probit and multinomial logistic (Section 3.2). These models only differ in the prior over the errors .

AR is not limited to point-mass estimation of the parameters . It is straightforward to extend the algorithm to perform posterior inference on via stochastic variational inference, but for simplicity we describe maximum likelihood estimation.

3.1 Augment and Reduce for Softmax

In the softmax model, the distribution over the error terms is a standard Gumbel (Gumbel, 1954),

In this model, the optimal distribution , which achieves equality in the bound, has closed-form expression:

with . However, even though has an analytic form, its parameter is computationally expensive to obtain because it involves a summation over classes. Instead, we set

Substituting this choice for into Eq. 7 gives the following ELBO:


Eq. 8 coincides with the log-concavity bound (Bouchard, 2007; Blei & Lafferty, 2007), although we have derived it from a completely different perspective. This derivation allows us to optimize efficiently, as we describe next.

The is an exponential family distribution whose natural parameter is . This allows us to use natural gradients in the stochastic inference procedure. AR iterates between a local step, in which we update , and a global step, in which we update the parameters .

In the local step (e step), we optimize by taking a step in the direction of the noisy natural gradient, yielding . Here, is an estimate of the optimal natural parameter, which we obtain using a random set of outcomes, i.e., , where . The parameter is the step size; it must satisfy the Robbins-Monro conditions (Robbins & Monro, 1951; Hoffman et al., 2013).

In the global step (m step), we take a gradient step with respect to (the parameters of ), holding fixed. Similarly, we can estimate the gradient of Eq. 8 with complexity by leveraging stochastic optimization.

Algorithm 1 summarizes the procedure for a classification task. In this example, the dataset consists of datapoints , where is a feature vector and is the class label. Each observation is associated with its parameters ; e.g., . We posit a softmax likelihood, and we wish to infer the weights via maximum likelihood using . Thus, the objective function is . (It is straightforward to obtain the maximum a posteriori solution by adding a regularizer.) At each iteration, we process a random subset of observations as well as a random subset of classes for each one.

Finally, note that we can perform posterior inference on the parameters (instead of maximum likelihood) using AR. One way is to consider a variational distribution and take gradient steps with respect to the variational parameters of in the global step, using the reparameterization trick (Rezende et al., 2014; Titsias & Lázaro-Gredilla, 2014; Kingma & Welling, 2014)

to approximate that gradient. In the local step, we only need to evaluate the moment generating function, estimating the optimal natural parameter as


  Input: data , minibatch sizes and
  Output: weights
  Initialize all weights and natural parameters
  for iteration  do
     # Sample minibatches:
     Sample a minibatch of data,
     for  do
        Sample a set of labels,
     end for
     # Local step (E step):
     for  do
        Update natural param.,
     end for
     # Global step (M step):
     Gradient step on the weights,
  end for
Algorithm 1 Softmax AR for classification

3.2 Augment and Reduce for Other Models

For most models, the expectations of the ELBO in Eq. 7 are intractable, and there is no closed-form solution for the optimal variational distribution . Fortunately, we can apply AR, using the reparameterization trick to build Monte Carlo estimates of the gradient of the ELBO with respect to the variational parameters (Rezende et al., 2014; Titsias & Lázaro-Gredilla, 2014; Kingma & Welling, 2014).

More in detail, consider the variational distribution , parameterized by some variational parameters . We assume that this distribution is reparameterizable, i.e., we can sample from by first sampling an auxiliary variable and then setting .

In the local step, we fit by taking a gradient step of the ELBO with respect to the variational parameters . Since the expectations in Eq. 7 are not tractable, we obtain Monte Carlo estimates by sampling from the variational distribution. To sample , we sample and set . To alleviate the computational complexity, we apply the reduce step, sampling a random subset of outcomes. We thus form a one-sample gradient estimator as


where is the entropy of the variational distribution,333

We can estimate the gradient of the entropy when it is not available analytically. Even when it is, the Monte Carlo estimator may have lower variance

(Roeder et al., 2017). and is a log joint estimate,

In the global step, we estimate the gradient of the ELBO with respect to . Following a similar approach, we obtain an unbiased one-sample gradient estimator as .

  Input: data , minibatch sizes and
  Output: weights
  Initialize all weights and local variational parameters
  for iteration  do
     # Sample minibatches:
     Sample a minibatch of data,
     for  do
        Sample a set of labels,
     end for
     # Local step (E step):
     for  do
        Sample auxiliary variable
        Transform auxiliary variable,
        Estimate the gradient (Eq. 9)
        Update variational param.,
     end for
     # Global step (M step):
     Sample for all
     Gradient step on the weights,
  end for
Algorithm 2 General AR for classification

Algorithm 2 summarizes the procedure to efficiently run maximum likelihood on a classification problem. We subsample observations and classes at each iteration.

Finally, note that we can perform posterior inference on the parameters by positing a variational distribution and taking gradient steps with respect to the variational parameters of in the global step. In this case, the reparameterization trick is needed in both the local and global step to obtain Monte Carlo estimates of the gradient.

We now particularize AR for the multinomial probit and multinomial logistic models.

AR for multinomial probit.  Consider a standard Gaussian distribution over the error terms,

AR chooses a Gaussian variational distribution and fits the variational parameters . The Gaussian is reparameterizable in terms of a standard Gaussian, i.e., . The transformation is . Thus, the gradients in Eq. 9 are and .

AR for multinomial logistic.  Consider now a standard logistic distribution over the errors,


is the sigmoid function. (The logistic distribution has heavier tails than the Gaussian.) Under this model, the ELBO in

Eq. 7 takes the form

Note the close resemblance between this expression and the OVE bound of Titsias (2016),


However, while the former is a bound on the multinomial logistic model, the OVE is a bound on the softmax.

AR sets , a logistic distribution. The variational parameters are . The logistic distribution is reparameterizable, with and transformation . The gradient of the entropy in Eq. 9 is .

4 Experiments

We showcase AR on a linear classification task. Our goal is to assess the predictive performance of AR in this classification task, to assess the quality of the marginal bound of the data, and to compare its complexity444We focus on runtime cost. AR requires memory storage capacity due to the local variational parameters. with existing approaches.

We run AR for three different models of categorical distributions (softmax, multinomial probit, and multinomial logistic).555Code for AR is available at For the softmax model, we compare AR against the OVE bound (Titsias, 2016). Just like AR, OVE is a rigorous lower bound on the marginal likelihood. It can also run on a single machine,666AR is amenable to an embarrassingly parallel algorithm, but we focus on single-core procedures. and it has been shown to outperform other approaches.

For softmax, AR runs nearly as fast as OVE but has better predictive performance and provides a tighter bound on the marginal likelihood than OVE. On two small datasets, the AR bound closely reaches the marginal likelihood of exact softmax maximum likelihood estimation.

We now describe the experimental settings. In Section 4.1, we analyze synthetic data and classes. In Section 4.2, we analyze five real datasets.

Experimental setup.  We consider linear classification, where the mean utilities are

. We fit the model parameters (weights and biases) via maximum likelihood estimation, using stochastic gradient ascent. We initialize the weights and biases randomly, drawing from a Gaussian distribution with zero mean and standard deviation

( for the biases). For each experiment, we use the same initialization across all methods.

Algorithms 2 and 1 require setting a step size schedule for . We use the adaptive step size sequence proposed by Kucukelbir et al. (2017), which combines rmsprop (Tieleman & Hinton, 2012) and Adagrad (Duchi et al., 2011). We set the step size using the default parameters, i.e.,

We set and we additionally decrease by a factor of every iterations. We use the same step size sequence for OVE.

We set the step size in Algorithm 1 as , the default values suggested by Hoffman et al. (2013). For the step size in Algorithm 2, we set

. For the multinomial logit and multinomial probit AR, we parameterize the variational distributions in terms of their means

and their unconstrained scale parameter , such that the scale parameter is .

4.1 Synthetic Dataset

We mimic the toy experiment of Titsias (2016) to assess how well AR estimates the categorical probabilities. We generate a dataset with classes and observations, each assigned label with probability , where each

is randomly generated from a uniform distribution in

. After generating the data, we have effective classes (thus we use this value for ). In this simple setting, there are no observed covariates .

We estimate the probabilities via maximum likelihood on the biases . We posit a softmax model, and we apply both the VEM in Section 3.1 and the OVE bound. For both approaches, we choose a minibatch size of observations and classes, and we run iterations.

covariates classes
minibatch (obs.) minibatch (classes) iterations
Table 1: Statistics and experimental settings of the considered datasets. and are the number of training and test data points. The number of classes is the resulting value after the preprocessing step (see text). The minibatch sizes correspond to and , respectively.
Figure 2: Evolution of the ELBO as a function of wall-clock time. The softmax AR method provides a tighter bound than OVE (Titsias, 2016) for almost all the considered datasets.

We run each approach on one cpu

core. On average, the wall-clock time per epoch (one epoch takes

iterations) is minutes for softmax AR and minutes for OVE. AR is slightly slower because of the local step that OVE does not require; however, the bound on the marginal log likelihood is tighter (by orders of magnitude) for AR than for OVE ( and , respectively). The estimated probabilities are similar for both methods: the average absolute error is for AR and for OVE; the difference is not statistically significant.

4.2 Real Datasets

We now turn to real datasets. We consider MNIST and Bibtex (Katakis et al., 2008; Prabhu & Varma, 2014), where we can compare against the exact softmax. We also analyze Omniglot (Lake et al., 2015), EURLex-4K (Mencia & Furnkranz, 2008; Bhatia et al., 2015), and AmazonCat-13K (McAuley & Leskovec, 2013).777MNIST is available at Omniglot can be found at Bibtex, EURLex-4K, and AmazonCat-13K are available at Table 1 gives information about the structure of these datasets.

We run each method for a fixed number of iterations. We set the minibatch sizes and beforehand. The specific values for each dataset are also in Table 1.

OVE (Titsias, 2016)
AR [this paper]
softmax multi. probit multi. logistic
s s s
s s s
s s s
s s s
h h h
Table 2: Average time per epoch for each method and dataset. Softmax AR (Section 3.1) is almost as fast as OVE. The AR approaches in Section 3.2 take longer because they require some additional computations, but they are still competitive.
softmax model
exact OVE (Titsias, 2016) AR [this paper]
log lik acc log lik acc log lik acc
multi. probit
AR [this paper]
log lik acc
multi. logistic
AR [this paper]
log lik acc
Table 3: Test log likelihood and accuracy for each method and dataset. The table on the left compares the approaches based on the softmax. Softmax AR outperforms OVE in four out of the five datasets. The two tables on the right show the performance of other models (multinomial probit and multinomial logistic), for which AR is also competitive.

Data preprocessing.  For MNIST, we divide the pixel values by so that the maximum value is one. For Omniglot, following other works in the literature (e.g., Burda et al., 2016), we resize the images to pixels. For EURLex-4K and AmazonCat-13K, we normalize the covariates dividing by their maximum value.

Bibtex, EURLex-4K, and AmazonCat-13K are multi-class datasets, i.e., each observation may be assigned more than one label. Following Titsias (2016), we keep only the first non-zero label for each data point. See Table 1 for the resulting number of classes in each case.

Evaluation.  For the softmax, we compare AR against the OVE bound.888We also implemented the approach of Botev et al. (2017), but we do not report the results because it did not outperform OVE in terms of test log-likelihood on four out of the five considered datasets. On the fifth dataset, softmax AR was still superior. We also compare against the exact softmax on MNIST and Bibtex, where the number of classes is small. For the multinomial probit and multinomial logistic models, we also report the predictive performance of AR.

We evaluate performance with test log likelihood and accuracy. The accuracy is the fraction of correctly classified instances, assuming that we assign the most likely label (i.e., the one with the highest mean utility). To compute the test log likelihood, we use Eq. 1 for the softmax and Eq. 5 for the multinomial probit and multinomial logistic models. We approximate the integral in Eq. 5 with samples using importance sampling (we use a Gaussian distribution with mean and standard deviation as a proposal).

Results.  Table 2 shows the wall-clock time per epoch for each method and dataset. In general, softmax AR is almost as fast as OVE because the extra local step can be performed efficiently without additional expensive operations. It requires to evaluate exponential functions that can be reused in the global step. Multinomial probit AR and multinomial logistic AR are slightly slower because of the local step, but they are still competitive.

For the five datasets, Figure 2 shows the evolution of the ELBO as a function of wall-clock time for the softmax AR (Eq. 8), compared to the OVE (Eq. 10). For easier visualization, we plot a smoothed version of the bounds after applying a moving average window of size . (For AmazonCat-13K, we only compute the ELBO every iterations and we use a window of size .) Softmax AR provides a significantly tighter bound for most datasets (except for Bibtex, where the ELBO of AR is close to the OVE bound). For MNIST and Bibtex, we also plot the marginal likelihood obtained after running maximum likelihood estimation on the exact softmax model. The ELBO of AR nearly achieves this value.

Finally, Table 3 shows the predictive performance for all methods across all datasets. We report test log likelihood and accuracy. Softmax AR outperforms OVE in both metrics on all but one dataset (except EURLex-4K). Although our goal is not to compare performance across different models, for completeness Table 3 also shows the predictive performance of multinomial probit AR and multinomial logistic AR. In general, softmax AR provides the highest test log likelihood, but multinomial probit AR outperforms all other methods in EURLex-4K and AmazonCat-13K. Additionally, multinomial logistic AR presents better predictive performance than OVE on Omniglot and Bibtex.

5 Conclusion

We have introduced AR, a scalable method to fit models involving categorical distributions. AR is general and applicable to many models, including the softmax and the multinomial probit. On classification tasks, we found that AR outperforms state-of-the art algorithms with little extra computational cost.


This work was supported by ONR N00014-15-1-2209, ONR 133691-5102004, NIH 5100481-5500001084, NSF CCF-1740833, the Alfred P. Sloan Foundation, the John Simon Guggenheim Foundation, Facebook, Amazon, and IBM. Francisco J. R. Ruiz is supported by the EU Horizon 2020 programme (Marie Skłodowska-Curie Individual Fellowship, grant agreement 706760). We also thank Victor Elvira and Pablo Moreno for their comments and help.


  • Albert & Chib (1993) Albert, J. H. and Chib, S. Bayesian analysis of binary and polychotomous response data. Journal of the American Statistical Association, 88(422):669–679, 1993.
  • Bahdanau et al. (2015) Bahdanau, D., Cho, K., and Bengio, Y. Neural machine translation by jointly learning to align and translate. In International Conference on Learning Representations, 2015.
  • Beal (2003) Beal, M. J.

    Variational algorithms for approximate Bayesian inference

    PhD thesis, Gatsby Computational Neuroscience Unit, University College London, 2003.
  • Bengio & Sénécal (2003) Bengio, Y. and Sénécal, J.-S. Quick training of probabilistic neural nets by importance sampling. In Artificial Intelligence and Statistics, 2003.
  • Bengio et al. (2006) Bengio, Y., Schwenk, H., Senécal, J.-S., Morin, F., and Gauvain, J.-L. Neural probabilistic language models. In Innovations in Machine Learning. Springer, 2006.
  • Beygelzimer et al. (2009) Beygelzimer, A., Langford, J., Lifshits, Y., Sorkin, G. B., and Strehl, L. Conditional probability tree estimation analysis and algorithms. In Uncertainty in Artificial Intelligence, 2009.
  • Bhatia et al. (2015) Bhatia, K., Jain, H., Kar, P., Varma, M., and Jain, P. Sparse local embeddings for extreme multi-label classification. In Advances in Neural Information Processing Systems, 2015.
  • Blei et al. (2017) Blei, D., Kucukelbir, A., and McAuliffe, J. Variational inference: A review for statisticians. Journal of American Statistical Association, 2017.
  • Blei & Lafferty (2007) Blei, D. M. and Lafferty, J. D. A correlated topic model of Science. The Annals of Applied Statistics, 1(1):17–35, 2007.
  • Botev et al. (2017) Botev, A., Zheng, B., and Barber, D. Complementary sum sampling for likelihood approximation in large scale classification. In Artificial Intelligence and Statistics, 2017.
  • Bouchard (2007) Bouchard, G. Efficient bounds for the softmax and applications to approximate inference in hybrid models. In Advances in Neural Information Processing Systems, Workshop on Approximate Inference in Hybrid Models, 2007.
  • Burda et al. (2016) Burda, Y., Grosse, R., and Salakhutdinov, R.

    Importance weighted autoencoders.

    In International Conference on Learning Representations, 2016.
  • de Brébisson & Vincent (2016) de Brébisson, A. and Vincent, P. An exploration of softmax alternatives belonging to the spherical loss family. In International Conference on Learning Representations, 2016.
  • Dembczyński et al. (2010) Dembczyński, K., Cheng, W., and Hüllermeier, E. Bayes optimal multilabel classification via probabilistic classifier chains. In International Conference on Machine Learning, 2010.
  • Devlin et al. (2014) Devlin, J., Zbib, R., Huang, Z., Lamar, T., Schwartz, R., and Makhoul, J.

    Fast and robust neural network joint models for statistical machine translation.

    In Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), 2014.
  • Duchi et al. (2011) Duchi, J., Hazan, E., and Singer, Y. Adaptive subgradient methods for online learning and stochastic optimization. Journal of Machine Learning Research, 12:2121–2159, Jul 2011.
  • Fagan & Iyengar (2018) Fagan, F. and Iyengar, G. Unbiased scalable softmax optimization. In arXiv:1803.08577, 2018.
  • Gelman et al. (2003) Gelman, A., Carlin, J. B., Stern, H. S., and Rubin, D. B. Bayesian data analysis. Chapman and Hall/CRC, 2003.
  • Girolami & Rogers (2006) Girolami, M. and Rogers, S. Variational Bayesian multinomial probit regression with Gaussian process priors. Neural Computation, 18(8):1790–1817, 2006.
  • Gopal & Yang (2013) Gopal, S. and Yang, Y. Distributed training of large-scale logistic models. In International Conference on Machine Learning, 2013.
  • Grave et al. (2017) Grave, E., Joulin, A., Cissé, M., Grangier, D., and Jégrou, H. Efficient softmax approximation for GPUs. In arXiv:1609.04309, 2017.
  • Gumbel (1954) Gumbel, E. J. Statistical theory of extreme values and some practical applications: A series of lectures. U. S. Govt. Print. Office, 1954.
  • Gupta et al. (2014) Gupta, M. R., Bengio, S., and Jason, W. Training highly multiclass classifiers. Journal of Machine Learning Research, 15(1):1461–1492, 2014.
  • Gutmann & Hyvärinen (2010) Gutmann, M. and Hyvärinen, A. Noise-contrastive estimation: A new estimation principle for unnormalized statistical models. In Artificial Intelligence and Statistics, 2010.
  • Hoffman et al. (2013) Hoffman, M. D., Blei, D. M., Wang, C., and Paisley, J. Stochastic variational inference. Journal of Machine Learning Research, 14:1303–1347, May 2013.
  • Ji et al. (2016) Ji, S., Vishwanathan, S. V. N., Satish, N., Anderson, M. J., and Dubey, P.

    Blackout: Speeding up recurrent neural network language models with very large vocabularies.

    In International Conference on Learning Representations, 2016.
  • Jordan et al. (1999) Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., and Saul, L. K. An introduction to variational methods for graphical models. Machine Learning, 37(2):183–233, November 1999.
  • Katakis et al. (2008) Katakis, I., Tsoumakas, G., and Vlahavas, I. Multilabel text classification for automated tag suggestion. In ECML/PKDD Discovery Challenge, 2008.
  • Khan et al. (2012) Khan, M. E., Mohamed, S., Marlin, B. M., and Murphy, K. P. A stick-breaking likelihood for categorical data analysis with latent Gaussian models. In Artificial Intelligence and Statistics, 2012.
  • Kingma & Welling (2014) Kingma, D. P. and Welling, M. Auto-encoding variational Bayes. In International Conference on Learning Representations, 2014.
  • Kucukelbir et al. (2017) Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., and Blei, D. M. Automatic differentiation variational inference. Journal of Machine Learning Research, 18(14):1–45, 2017.
  • Kurzynski (1988) Kurzynski, M. On the multistage Bayes classifier. Pattern Recognition, 21(4):355–365, 1988.
  • Lake et al. (2015) Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. Human-level concept learning through probabilistic program induction. Science, 350(6266):1332–1338, 2015.
  • Marlin & Zemel (2004) Marlin, B. M. and Zemel, R. S. The multiple multiplicative factor model for collaborative filtering. In International Conference on Machine Learning, 2004.
  • McAuley & Leskovec (2013) McAuley, J. and Leskovec, J. Hidden factors and hidden topics: understanding rating dimensions with review text. In ACM Conference on Recommender Systems, 2013.
  • McFadden (1978) McFadden, D. Modeling the choice of residential location. In Spatial Interaction Theory and Residential Location, 1978.
  • Mencia & Furnkranz (2008) Mencia, E. L. and Furnkranz, J. Efficient pairwise multilabel classification for large-scale problems in the legal domain. In ECML/PKDD, 2008.
  • Mikolov et al. (2013) Mikolov, T., Sutskever, I., Chen, K., Corrado, G. S., and Dean, J. Distributed representations of words and phrases and their compositionality. In Advances in Neural Information Processing Systems, 2013.
  • Morin & Bengio (2005) Morin, F. and Bengio, Y. Hierarchical probabilistic neural network language model. In Artificial Intelligence and Statistics, 2005.
  • Mussmann et al. (2017) Mussmann, S., Levy, D., and Ermon, S. Fast amortized inference and learning in log-linear models with randomly perturbed nearest neighbor search. In Uncertainty in Artificial Intelligence, 2017.
  • Prabhu & Varma (2014) Prabhu, Y. and Varma, M. FastXML: Fast, accurate and stable tree-classifier for eXtreme multi-label learning. In KDD, 2014.
  • Raman et al. (2017) Raman, P., Matsushima, S., Zhang, X., Yun, H., and Vishwanathan, S. V. N. DS-MLR: exploiting double separability for scaling up distributed multinomial logistic regression. arXiv:1604.04706v2, 2017.
  • Rezende et al. (2014) Rezende, D. J., Mohamed, S., and Wierstra, D.

    Stochastic backpropagation and approximate inference in deep generative models.

    In International Conference on Machine Learning, 2014.
  • Riihimäki et al. (2013) Riihimäki, J., Jylänki, P., and Vehtari, A. Nested expectation propagation for Gaussian process classification with a multinomial probit likelihood. Journal of Machine Learning Research, 14:75–109, 2013.
  • Robbins & Monro (1951) Robbins, H. and Monro, S. A stochastic approximation method. The Annals of Mathematical Statistics, 22(3):400–407, 1951.
  • Roeder et al. (2017) Roeder, G., Wu, Y., and Duvenaud, D. Sticking the landing: Simple, lower-variance gradient estimators for variational inference. In Advances in Neural Information Processing Systems, 2017.
  • Smith & Jason (2005) Smith, N. A. and Jason, E. Contrastive estimation: Training log-linear models on unlabeled data. In Association for Computational Linguistics, 2005.
  • Sutton & Barto (1998) Sutton, R. S. and Barto, A. G. Reinforcement Learning: An Introduction. The MIT Press, Cambridge, MA, 1998.
  • Tieleman & Hinton (2012) Tieleman, T. and Hinton, G. Lecture 6.5-RMSPROP: Divide the gradient by a running average of its recent magnitude. Coursera: Neural Networks for Machine Learning, 2012.
  • Titsias (2016) Titsias, M. K. One-vs-each approximation to softmax for scalable estimation of probabilities. In Advances in Neural Information Processing Systems, 2016.
  • Titsias & Lázaro-Gredilla (2014) Titsias, M. K. and Lázaro-Gredilla, M. Doubly stochastic variational Bayes for non-conjugate inference. In International Conference on Machine Learning, 2014.
  • Tsoumakas et al. (2008) Tsoumakas, G., Katakis, I., and Vlahavas, I. Effective and efficient multilabel classification in domains with large number of labels. In ECML/PKDD Workshop on Mining Multidimensional Data, 2008.
  • Vijayanarasimhan et al. (2014) Vijayanarasimhan, S., Shlens, J., Monga, R., and Yagnik, J. Deep networks with large output spaces. In arXiv:1412.7479, 2014.
  • Vincent et al. (2015) Vincent, P., de Brébisson, A., and Bouthillier, X. Efficient exact gradient update for training deep networks with very large sparse targets. In Advances in Neural Information Processing Systems, 2015.
  • Williams (1992) Williams, R. J. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine Learning, 8(3–4):229–256, 1992.