Gradient Estimation with Stochastic Softmax Tricks

by   Max B. Paulus, et al.
ETH Zurich

The Gumbel-Max trick is the basis of many relaxed gradient estimators. These estimators are easy to implement and low variance, but the goal of scaling them comprehensively to large combinatorial distributions is still outstanding. Working within the perturbation model framework, we introduce stochastic softmax tricks, which generalize the Gumbel-Softmax trick to combinatorial spaces. Our framework is a unified perspective on existing relaxed estimators for perturbation models, and it contains many novel relaxations. We design structured relaxations for subset selection, spanning trees, arborescences, and others. When compared to less structured baselines, we find that stochastic softmax tricks can be used to train latent variable models that perform better and discover more latent structure.



There are no comments yet.


page 1

page 2

page 3

page 4


Rao-Blackwellizing the Straight-Through Gumbel-Softmax Gradient Estimator

Gradient estimation in models with discrete latent variables is a challe...

Efficient Marginalization of Discrete and Structured Latent Variables via Sparsity

Training neural network models with discrete (categorical or structured)...

Categorical Reparameterization with Gumbel-Softmax

Categorical variables are a natural choice for representing discrete str...

Learning Latent Permutations with Gumbel-Sinkhorn Networks

Permutations and matchings are core building blocks in a variety of late...

Leveraging Recursive Gumbel-Max Trick for Approximate Inference in Combinatorial Spaces

Structured latent variables allow incorporating meaningful prior knowled...

Generalized Doubly Reparameterized Gradient Estimators

Efficient low-variance gradient estimation enabled by the reparameteriza...

Understanding the Mechanics of SPIGOT: Surrogate Gradients for Latent Structure Learning

Latent structure models are a powerful tool for modeling language data: ...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Gradient computation is the methodological backbone of deep learning, but computing gradients is not always easy. Gradients with respect to parameters of the density of an integral are generally intractable, and one must resort to gradient estimators

(asmussen2007stochastic, ; mohamed2019gradientest, )

. Typical examples of objectives over densities are returns in reinforcement learning

(sutton2018reinforcement, ) or variational objectives for latent variable models (e.g., kingma2014auto, ; rezende2014stochastic, ). In this paper, we address gradient estimation for discrete distributions with an emphasis on latent variable models. We introduce a relaxed gradient estimation framework for combinatorial discrete distributions that generalizes the Gumbel-Softmax and related estimators (maddison2016concrete, ; jang2016categorical, ).

Relaxed gradient estimators incorporate bias in order to reduce variance. Most relaxed estimators are based on the Gumbel-Max trick (luce1959individual, ; maddison2014astarsamp, )

, which reparameterizes distributions over one-hot binary vectors. The Gumbel-Softmax estimator is the simplest; it continuously approximates the Gumbel-Max trick to admit a reparameterization gradient

(kingma2014auto, ; rezende2014stochastic, ; ruiz2016generalized, ). This is used to optimize the “soft” approximation of the loss as a surrogate for the “hard” discrete objective.

Adding structured latent variables to deep learning models is a promising direction for addressing a number of challenges: improving interpretability (e.g., via latent variables for subset selection (chen2018learning, ) or parse trees corro2018differentiable ), incorporating problem-specific constraints (e.g., via enforcing alignments mena2018learning ), and improving generalization (e.g., by modeling known algorithmic structure graves2014neural ). Unfortunately, the vanilla Gumbel-Softmax cannot scale to distributions over large state spaces, and the development of structured relaxations has been piecemeal.

We introduce stochastic softmax tricks (SSTs), which are a unified framework for designing structured relaxations of combinatorial distributions. They include relaxations for the above applications, as well as many novel ones. To use an SST, a modeler chooses from a class of models that we call stochastic argmax tricks (SMT). These are instances of perturbation models (e.g., papandreou2011perturb, ; hazan2012partition, ; tarlow2012randoms, ; gane2014learning, ), and they induce a distribution over a finite set by optimizing a linear objective (defined by random utility ) over . An SST relaxes this SMT by combining a strongly convex regularizer with the random linear objective. The regularizer makes the solution a continuous, a.e. differentiable function of and appropriate for estimating gradients with respect to ’s parameters. The Gumbel-Softmax is a special case. Fig. 1 provides a summary.

We test our relaxations in the Neural Relational Inference (NRI) (kipf2018neural, ) and L2X chen2018learning frameworks. Both NRI and L2X use variational losses over latent combinatorial distributions. When the latent structure in the model matches the true latent structure, we find that our relaxations encourage the unsupervised discovery of this combinatorial structure. This leads to models that are more interpretable and achieve stronger performance than less structured baselines. All proofs are in the Appendix.

Finite set
Random utility
Stoch. Argmax Trick
Stoch. Softmax Trick
Figure 1:

Stochastic softmax tricks relax discrete distributions that can be reparameterized as random linear programs.

is the solution of a random linear program defined by a finite set and a random utility with parameters . To design relaxed gradient estimators with respect to , is the solution of a random convex program that continuously approximates from within the convex hull of . The Gumbel-Softmax (maddison2016concrete, ; jang2016categorical, ) is an example of a stochastic softmax trick.

2 Problem Statement

Let be a non-empty, finite set of combinatorial objects, e.g. the spanning trees of a graph. To represent , define the embeddings of to be the image of some embedding function .111This equivalent to the notion of sufficient statistics wainwright2008graphical . We draw a distinction only to avoid confusion, because the distributions that we ultimately consider are not necessarily from the exponential family. For example, if is the set of spanning trees of a graph with edges , then we could enumerate in and let be the one-hot binary vector of length , with iff . This requires a very large ambient dimension . Alternatively, in this case we could use a more efficient, structured representation: could be a binary indicator vector of length , with iff edge is in the tree . We assume that is convex independent.222Convex independence is the analog of linear independence for convex combinations.

Given a probability mass function

that is differentiable in

, a loss function

, and , our goal is to estimate,


We are interested in this, because our ultimate goal is gradient-based optimization of .

3 Background on Gradient Estimation

Relaxed gradient estimators assume that is differentiable and use a change of variables to remove the dependence of on , known as the reparameterization trick (kingma2014auto, ; rezende2014stochastic, ). The Gumbel-Softmax trick (GST) (maddison2016concrete, ; jang2016categorical, ) is a simple relaxed gradient estimator for one-hot embeddings, which is based on the Gumbel-Max trick (GMT) (luce1959individual, ; maddison2014astarsamp, ). Let be the one-hot embeddings of and . The GMT is the following identity: for and indep.,


Ideally, one would have a reparameterization estimator, ,333For a function , is the partial derivative (e.g., a gradient vector) of in the first variable evaluated at . is the total derivative of in evaluated at . For example, if , then . using the right-hand expression in (2

). Unfortunately, this fails. The problem is not the lack of differentiability, as normally reported. In fact, the argmax is differentiable almost everywhere. Instead it is the jump discontinuities in the argmax that invalidate this particular exchange of expectation and differentiation

(lee2018reparameterization, ; asmussen2007stochastic, , Chap. 7.2). The GST estimator (maddison2016concrete, ; jang2016categorical, ) overcomes this by using the tempered softmax, for , to continuously approximate ,


The relaxed estimator is . While this is a biased estimator of (1

), it is an unbiased estimator of

and a.s. as . Thus, is used for optimizing as a surrogate for , on which the final model is evaluated.

The score function estimator (glynn1990likelihood, ; williams1992simple, ), , is the classical alternative. It is a simple, unbiased estimator, but without highly engineered control variates, it suffers from high variance (mnih2014neural, ). Building on the score function estimator are a variety of estimators that require multiple evaluations of to reduce variance (DBLP:journals/corr/GuLSM15, ; tucker2017rebar, ; grathwohl2018backpropagation, ; yin2018arm, ; Kool2020Estimating, ; aueb2015local, ). The advantages of relaxed estimators are the following: they only require a single evaluation of , they are easy to implement using modern software packages (abadi2016tensorflow, ; paszke2017automatic, ; jax2018github, ), and, as reparameterization gradients, they tend to have low variance (gal2016uncertainty, ).

4 Stochastic Argmax Tricks

Simulating a GST requires enumerating random variables, so it cannot scale. We overcome this by identifying generalizations of the GMT that can be relaxed and that scale to large s by exploiting structured embeddings . We call these stochastic argmax tricks (SMTs), because they are perturbation models (tarlow2012randoms, ; gane2014learning, ), which can be relaxed into stochastic softmax tricks (Section 5). Given a non-empty, convex independent, finite set and a random utility whose distribution is parameterized by , a stochastic argmax trick for is the linear program,


The GMT is recovered with one-hot and . We assume that (4) is a.s. unique, which is guaranteed if a.s. never lands in any particular lower dimensional subspace (Prop. A, App. A). Because efficient linear solvers are known for many structured , SMTs are capable of scaling to very large (schrijver2003combinatorial, ; kolmogorov2006convergent, ; koller2009probabilistic, ). For example, if are the edge indicator vectors of spanning trees , then (4) is the maximum spanning tree problem, which is solved by Kruskal’s algorithm (kruskal1956shortest, ).

The role of the SMT in our framework is to reparameterize in (1). Ideally, given , there would be an efficient (e.g., ) method for simulating some such that the marginal of in (4) is . The GMT shows that this is possible for one-hot , but the situation is not so simple for structured . Characterizing the marginal of in general is difficult tarlow2012randoms ; hazan2013perturb , but that are efficient to sample from typically induce conditional independencies in (gane2014learning, ). Therefore, we are not be able to reparameterize an arbitrary on structured . Instead, for structured we assume that is reparameterized by (4), and treat as a modeling choice. Thus, we caution against the standard approach of taking or without further analysis. Practically, in experiments we show that the difference in noise distribution can have a large impact on quantitative results. Theoretically, we show in App. B that an SMT over directed spanning trees with negative exponential utilities has a more interpretable structure than the same SMT with Gumbel utilities.

5 Stochastic Softmax Tricks

If we assume that is reparameterized as an SMT, then a stochastic softmax trick (SST) is a random convex program with a solution that relaxes . An SST has a valid reparameterization gradient estimator. Thus, we propose using SSTs as surrogates for estimating gradients of (1), a generalization of the Gumbel-Softmax approach. Because we want gradients with respect to , we assume that is also reparameterizable.

Given an SMT, an SST incorporates a strongly convex regularizer to the linear objective, and expands the state space to the convex hull of the embeddings ,


Expanding the state space to a convex polytope makes it path-connected, and the strongly convex regularizer ensures that the solutions are continuous over the polytope. Given a stochastic argmax trick where and a proper, closed, strongly convex function whose domain contains the relative interior of , a stochastic softmax trick for at temperature is the convex program,


For one-hot , the Gumbel-Softmax is a special case of an SST where is the probability simplex, , and . Objectives like (6) have a long history in convex analysis (e.g., rockafellar1970convex, , Chap. 12)

and machine learning

(e.g., wainwright2008graphical, , Chap. 3). In general, the difficulty of computing the SST will depend on the interaction between and .

is suitable as an approximation of . At positive temperatures , is a function of that ranges over the faces and relative interior of . The degree of approximation is controlled by the temperature parameter, and as , is driven to a.s. propositionapproximation If in Def. 4 is a.s. unique, then for in Def. 5, a.s. If additionally is bounded and continuous, then

. It is common to consider temperature parameters that interpolate between marginal inference and a deterministic, most probable state. While superficially similar, our relaxation framework is different; as

, an SST approaches a sample from the SMT model as opposed to a deterministic state.

also admits a reparameterization trick. The SST reparameterization gradient estimator given by,


If is differentiable on , then this is an unbiased estimator444Technically, one needs an additional local Lipschitz condition for in (asmussen2007stochastic, , Prop. 2.3, Chap. 7). of the gradient , because is continuous and a.e. differentiable: propositionrelaxation in Def. 5 exists, is unique, and is a.e. differentiable and continuous in . In general, the Jacobian will need to be derived separately given a choice of and . However, as pointed out by (domke2010impdiff, ), because the Jacobian of symmetric (rockafellar1999second, , Cor. 2.9), local finite difference approximations can be used to approximate (App. D). These finite difference approximations only require two additional calls to a solver for (6) and do not require additional evaluations of . We found them to be helpful in a few experiments (c.f., Section 8).

There are many, well-studied for which (6) is efficiently solvable. If , then is the Euclidean projection of onto . Efficient projection algorithms exist for some convex sets (see wolfe1976finding, ; duchi2008efficient, ; liu2009efficient, ; blondel2019structured, , and references therein), and more generic algorithms exist that only call linear solvers as subroutines (niculae2018sparsemap, ). In some of the settings we consider, generic negative-entropy-based relaxations are also applicable. We refer to relaxations with as categorical entropy relaxations (e.g., blondel2019structured, ; blondel2020learning, ). We refer to relaxations with as binary entropy relaxations (e.g., amos2019limited, ).

Marginal inference in exponential families is a rich source of SST relaxations. Consider an exponential family over the finite set with natural parameters such that the probability of is proportional to . The marginals of this family are solutions of a convex program in exactly the form (6) (wainwright2008graphical, ), i.e., there exists such that,


The definition of , which generates in (8), can be found in (wainwright2008graphical, , Thm. 3.4). is a kind of negative entropy and in our case it satisfies the assumptions in Def. 5. Computing amounts to marginal inference in the exponential family, and efficient algorithms are known in many cases (see wainwright2008graphical, ; koller2009probabilistic, ), including those we consider. We call the exponential family entropy relaxation.

Taken together, Prop. 5 and 4 suggest our proposed use for SSTs: optimize at a positive temperature, where unbiased gradient estimation is available, but evaluate

. We find that this works well in practice if the temperature used during optimization is treated as a hyperparameter and selected over a validation set. It is worth emphasizing that the choice of relaxation is unrelated to the distribution

of in the corresponding SMT. is not only a modeling choice; it is a computational choice that will affect the cost of computing (6) and the quality of the gradient estimator.

6 Examples of Stochastic Softmax Tricks

Figure 2: An example realization of a spanning tree SST for an undirected graph. Middle: Random undirected edge utilities. Left: The random soft spanning tree , represented as a weighted adjacency matrix, can be computed via Kirchoff’s Matrix-Tree theorem. Right: The random spanning tree , represented as an adjacency matrix, can be computed with Kruskal’s algorithm.

The Gumbel-Softmax (maddison2016concrete, ; jang2016categorical, ) introduced neither the Gumbel-Max trick nor the softmax. The novelty of this work is neither the pertubation model framework nor the relaxation framework in isolation, but their combined use for gradient estimation. Here we layout some example SSTs, organized by the set with a choice of embeddings . Bold italics indicates previously described relaxations, most of which are bespoke and not describable in our framework. Italics indicates our novel SSTs used in our experiments; some of these are also novel perturbation models. A complete discussion is in App. B.

Subset selection. is the set of binary vectors indicating membership in the subsets of a finite set . Indep. uses and a binary entropy relaxation. and are computed with a dimension-wise step function or sigmoid, resp.

-Subset selection. is the set of binary vectors with a -hot binary vectors indicating membership in a -subset of a finite set . All of the following SMTs use . Our SSTs use the following relaxations: euclidean (amos2017optnet, ) and categorical (martins2017learning, ), binary (amos2019limited, ), and exponential family (swersky2012cardinality, ) entropies. is computed by sorting and setting the top elements to 1 (blondel2019structured, ). Top refers to our SST with relaxation . L2X (chen2018learning, ) and SoftSub (xie2019reparameterizable, ) are bespoke relaxations.

Correlated -subset selection. is the set of -dimensional binary vectors with a -hot cardinality constraint on the first dimensions and a constraint that the dimensions indicate correlations between adjacent dimensions in the first , i.e. the vertices of the correlation polytope of a chain (wainwright2008graphical, , Ex. 3.8) with an added cardinality constraint (mezuman2013tighter, ). Corr. Top uses , , and the exponential family entropy relaxation. and can be computed with dynamic programs (tarlow2012fast, ), see App. B.

Perfect Bipartite Matchings. is the set of permutation matrices representing the perfect matchings of the complete bipartite graph . The Gumbel-Sinkhorn (mena2018learning, ) uses and a Shannon entropy relaxation. can be computed with the Hungarian method (kuhn1955hungarian, ) and with the Sinkhorn algorithm (sinkhorn1967concerning, ). Stochastic NeuralSort (grover2018stochastic, ) uses correlated Gumbel-based utilities that induce a Plackett-Luce model and a bespoke relaxation.

Undirected spanning trees. Given a graph , is the set of binary indicator vectors of the edge sets of undirected spanning trees. Spanning Tree uses and the exponential family entropy relaxation. can be computed with Kruskal’s algorithm (kruskal1956shortest, ), with Kirchoff’s matrix-tree theorem (koo2007matrixtree, , Sec. 3.3), and both are represented as adjacency matrices, Fig. 2.

Rooted directed spanning trees. Given a graph , is the set of binary indicator vectors of the edge sets of -rooted, directed spanning trees. Arborescence uses or or and an exponential family entropy relaxation. can be computed with the Chu-Liu-Edmonds algorithm (chu1965shortest, ; edmonds1967optimum, ), with a directed version of Kirchoff’s matrix-tree theorem (koo2007matrixtree, , Sec. 3.3), and both are represented as adjacency matrices. Perturb & Parse (corro2018differentiable, ) further restricts to be projective trees, uses , and uses a bespoke relaxation.

7 Related Work

Here we review perturbation models (PMs) and methods for relaxation more generally. SSTs are a subclass of PMs, which draw samples by optimizing a random objective. Perhaps the earliest example comes from Thurstonian ranking models thurstone1927law , where a distribution over rankings is formed by sorting a vector of noisy scores. Perturb & MAP models papandreou2011perturb ; hazan2012partition were designed to approximate the Gibbs distribution over a combinatorial output space using low-order, additive Gumbel noise. Randomized Optimum models tarlow2012randoms ; gane2014learning are the most general class, which include non-additive noise distributions and non-linear objectives. Recent work (lorberbom2019direct, ) uses PMs to construct finite difference approximations of the expected loss’ gradient. It requires optimizing a non-linear objective over , and making this applicable to our settings would require significant innovation.

Using SSTs for gradient estimation requires differentiating through a convex program. This idea is not ours and is enjoying renewed interest in cvxpylayers2019 ; agrawal2019differentiating ; amos2019differentiable . In addition, specialized solutions have been proposed for quadratic programs amos2017optnet ; martins2016softmax ; blondel2020fast and linear programs with entropic regularizers over various domains martins2017learning ; amos2019limited ; adams2011ranking ; mena2018learning ; blondel2020fast . In graphical modeling, several works have explored differentiating through marginal inference domke2010impdiff ; ross-cvpr-11 ; poon2011sum ; domke2013learning ; swersky2012cardinality ; djolonga2017differentiable and our exponential family entropy relaxation builds on this work. The most superficially similar work is (2020arXiv200208676B, ), which uses noisy utilities to smooth the solutions of linear programs. In (2020arXiv200208676B, ), the noise is a tool for approximately relaxing a deterministic linear program. Our framework uses relaxations to approximate stochastic linear programs.

8 Experiments

Our goal in these experiments was to evaluate the use of SSTs for learning distributions over structured latent spaces in deep structured models. We chose frameworks (NRI (kipf2018neural, ), L2X (chen2018learning, ), and a latent parse tree task) in which relaxed gradient estimators are the methods of choice, and investigated the effects of , , and on the task objective and on the unsupervised structure discovery. For NRI, we also implemented the standard single-loss-evaluation score function estimators (REINFORCE (williams1992simple, ) and NVIL (mnih2014neural, )), but struggled to achieve competitive results, see App. C. All SST models were trained with the “soft” SST and evaluated with the “hard” SMT. We optimized hyperparameters (including fixed training temperature

) using random search over multiple independent runs. We selected models on a validation set according to the best objective value obtained during training. All reported values are measured on a test set. Error bars are bootstrap standard errors over the model selection process. We refer to SSTs defined in Section  

6 with italics. Details are in App. D.

8.1 Neural Relational Inference (NRI) for Graph Layout

max width= Edge Distribution ELBO Edge Prec. Edge Rec. ELBO Edge Prec. Edge Rec. Indep. Directed Edges (kipf2018neural, ) E.F. Ent. Top Spanning Tree

Ground Truth
Indep. Directed Edges
E.F. Ent. Top
Spanning Tree
Table 1: Spanning Tree performs best on structure recovery, despite being trained on the ELBO. Test ELBO and structure recovery metrics are shown from models selected on valid. ELBO. Below: Test set example where Spanning Tree recovers the ground truth latent graph perfectly.

With NRI we investigated the use of SSTs for latent structure recovery and final performance. NRI is a graph neural network (GNN) model that samples a latent interaction graph

and runs messages over the adjacency matrix to produce a distribution over an interacting particle system. NRI is trained as a variational autoencoder to maximize a lower bound (ELBO) on the marginal log-likelihood of the time series. We experimented with three SSTs for the encoder distribution:

Indep. Binary over directed edges, which is the baseline NRI encoder (kipf2018neural, ), E.F. Ent. Top over undirected edges, and Spanning Tree over undirected edges. Our dataset consisted of latent prior spanning trees over 10 vertices sampled from the prior. Given a tree, we embed the vertices in by applying iterations of a force-directed algorithm (fruchterman1991graph, ). The model saw particle locations at each iteration, not the underlying spanning tree.

We found that Spanning Tree performed best, improving on both ELBO and the recovery of latent structure over the baseline (kipf2018neural, )

. For structure recovery, we measured edge precision and recall against the ground truth adjacency matrix. It recovered the edge structure well even when given only a short series (

, Fig. 1). Less structured baselines were only competitive on longer time series.

8.2 Unsupervised Parsing on ListOps

We investigated the effect of ’s structure and of the utility distribution in a latent parse tree task. We used a simplified variant of the ListOps dataset nangia2018listops , which contains sequences of prefix arithmetic expressions, e.g., max[ 3 min[ 8 2 ]], that evaluate to an integer in . The arithmetic syntax induces a directed spanning tree rooted at its first token with directed edges from operators to operands. We modified the data by removing the summod operator, capping the maximum depth of the ground truth dependency parse, and capping the maximum length of a sequence. This simplifies the task considerably, but it makes the problem accessible to GNN models of fixed depth. Our models used a bi-LSTM encoder to produce a distribution over edges (directed or undirected) between all pairs of tokens, which induced a latent (di)graph. Predictions were made from the final embedding of the first token after passing messages in a GNN architecture over the latent graph. For undirected graphs, messages were passed in both directions. We experimented with the following SSTs for the edge distribution: Indep. Undirected Edges, Spanning Tree, Indep. Directed Edges, and Arborescence (with three separate utility distributions). Arborescence was rooted at the first token. For baselines we used an unstructured LSTM and the GNN over the ground truth parse. All models were trained with cross-entropy to predict the integer evaluation of the sequence.

The best performing models were structured models whose structure better matched the true latent structure (Table 2). For each model, we measured the accuracy of its prediction (task accuracy). We measured both precision and recall with respect to the ground truth parse’s adjacency matrix. 555We exclude edges to and from the closing symbol "". Its edge assignments cannot be learnt from the task objective, because the correct evaluation of an operation does not depend on the closing symbol. Both tree-structured SSTs outperformed their independent edge counterparts on all metrics. Overall, Arborescence achieved the best performance in terms of task accuracy and structure recovery. We found that the utility distribution significantly affected performance (Table 2). For example, while negative exponential utilities induce an interpretable distribution over arborescences, App. B, we found that the multiplicative parameterization of exponentials made it difficult to train competitive models. Despite the LSTM baseline performing well on task accuracy, Arborescence additionally learns to recover much of the latent parse tree.

Model Edge Distribution Task Acc. Edge Precision Edge Recall
GNN on latent graph Indep. Undirected Edges
Spanning Tree
GNN on latent digraph Indep. Directed Edges
   - Neg. Exp.
   - Gaussian
   - Gumbel
Ground Truth Edges 100 100
Table 2: Matching ground truth structure (non-tree tree) improves performance on ListOps. The utility distribution impacts performance. Test task accuracy and structure recovery metrics are shown from models selected on valid. task accuracy.

8.3 Learning To Explain (L2X) Aspect Ratings

With L2X we investigated the effect of the choice of relaxation. We used the BeerAdvocate dataset (mcauley2012learning, ), which contains reviews comprised of free-text feedback and ratings for multiple aspects (appearance, aroma, palate, and taste; Fig. 3). Each sentence in the test set is annotated with the aspects that it describes, allowing us to define structure recovery metrics. We considered the L2X task of learning a distribution over -subsets of words that best explain a given aspect rating.666While originally proposed for model interpretability, we used the original aspect ratings. This allowed us to use the sentence-level annotations for each aspect to facilitate comparisons between subset distributions. Our model used word embeddings from (lei2016rationalizing, )

and convolutional neural networks with one (simple) and three (complex) layers to produce a distribution over

-hot binary latent masks. Given the latent masks, our model used a convolutional net to make predictions from masked embeddings. We used in and the following SSTs for the subset distribution: {Euclid., Cat. Ent., Bin. Ent., E.F. Ent.} Top and Corr. Top . For baselines, we used bespoke relaxations designed for this task: L2X (chen2018learning, ) and SoftSub (xie2019reparameterizable, ). We trained separate models for each aspect using mean squared error (MSE).

We found that SSTs improve over bespoke relaxations (Table 3 for aspect aroma, others in App. C). For unsupervised discovery, we used the sentence-level annotations for each aspect to define ground truth subsets against which precision of the -subsets was measured. SSTs tended to select subsets with higher precision across different architectures and cardinalities and achieve modest improvements in MSE. We did not find significant differences arising from the choice of regularizer . Overall, the most structured SST, Corr. Top , achieved the lowest MSE, highest precision and improved interpretability: The correlations in the model allowed it to select contiguous words, while subsets from less structured distributions were scattered (Fig. 3).

max width= Model Relaxation MSE Subs. Prec. MSE Subs. Prec. MSE Subs. Prec. Simple L2X (chen2018learning, ) SoftSub (xie2019reparameterizable, ) Euclid. Top Cat. Ent. Top Bin. Ent. Top E.F. Ent. Top Corr. Top Complex L2X (chen2018learning, ) SoftSub (xie2019reparameterizable, ) Euclid. Top Cat. Ent. Top Bin. Ent. Top E.F. Ent. Top Corr. Top

Pours a slight tangerine orange and straw yellow. The head is nice and bubbly but fades very quickly with a little lacing. Smells like Wheat and European hops, a little yeast in there too. There is some fruit in there too, but you have to take a good whiff to get it. The taste is of wheat, a bit of malt, and a little fruit flavour in there too. Almost feels like drinking Champagne, medium mouthful otherwise. Easy to drink, but not something I’d be trying every night. Appearance: 3.5 Aroma: 4.0 Palate: 4.5 Taste: 4.0 Overall: 4.0

Table 3: For -subset selection on aroma aspect, SSTs tend to outperform baseline relaxations. Test set MSE and subset precision is shown for models selected on valid. MSE. Bottom: Corr. Top (red) selects contiguous words while Top (blue) picks scattered words.

9 Conclusion

We introduced stochastic softmax tricks, which are random convex programs that capture a large class of relaxed distributions over structured, combinatorial spaces. We designed stochastic softmax tricks for subset selection and a variety of spanning tree distributions. We tested their use in deep latent variable models, and found that they can be used to improve performance and to encourage the unsupervised discovery of true latent structure. There are future directions in this line of work. The relaxation framework can be generalized by modifying the constraint set or the utility distribution at positive temperatures. Some combinatorial objects might benefit from a more careful design of the utility distribution, while others, e.g., matchings, are still waiting to have their tricks designed.

Broader Impact

This work introduces methods and theory that have the potential for improving the interpretability of latent variable models. While unfavorable consequences cannot be excluded, increased interpretability is generally considered a desirable property of machine learning models. Given that this is foundational, methodologically-driven research, we refrain from speculating further.


We thank Daniel Johnson and Francisco Ruiz for their time and insightful feedback. We also thank Tamir Hazan, Yoon Kim, Andriy Mnih, and Rich Zemel for their valuable comments. MBP gratefully acknowledges support from the Max Planck ETH Center for Learning Systems. CJM is grateful for the support of the James D. Wolfensohn Fund at the Institute of Advanced Studies in Princeton, NJ.


  • (1) Martín Abadi, Ashish Agarwal, Paul Barham, Eugene Brevdo, Zhifeng Chen, Craig Citro, Greg S. Corrado, Andy Davis, Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Ian Goodfellow, Andrew Harp, Geoffrey Irving, Michael Isard, Yangqing Jia, Rafal Jozefowicz, Lukasz Kaiser, Manjunath Kudlur, Josh Levenberg, Dan Mane, Rajat Monga, Sherry Moore, Derek Murray, Chris Olah, Mike Schuster, Jonathon Shlens, Benoit Steiner, Ilya Sutskever, Kunal Talwar, Paul Tucker, Vincent Vanhoucke, Vijay Vasudevan, Fernanda Viegas, Oriol Vinyals, Pete Warden, Martin Wattenberg, Martin Wicke, Yuan Yu, and Xiaoqiang Zheng. TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems. arXiv e-prints, page arXiv:1603.04467, March 2016.
  • (2) Ryan Prescott Adams and Richard S Zemel. Ranking via sinkhorn propagation. arXiv preprint arXiv:1106.1925, 2011.
  • (3) A. Agrawal, B. Amos, S. Barratt, S. Boyd, S. Diamond, and Z. Kolter. Differentiable convex optimization layers. In Advances in Neural Information Processing Systems, 2019.
  • (4) Akshay Agrawal, Shane Barratt, Stephen Boyd, Enzo Busseti, and Walaa M Moursi. Differentiating through a conic program. arXiv preprint arXiv:1904.09043, 2019.
  • (5) Brandon Amos. Differentiable optimization-based modeling for machine learning. PhD thesis, PhD thesis. Carnegie Mellon University, 2019.
  • (6) Brandon Amos and J Zico Kolter. Optnet: Differentiable optimization as a layer in neural networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 136–145. JMLR. org, 2017.
  • (7) Brandon Amos, Vladlen Koltun, and J. Zico Kolter. The Limited Multi-Label Projection Layer. arXiv e-prints, page arXiv:1906.08707, June 2019.
  • (8) Søren Asmussen and Peter W Glynn. Stochastic simulation: algorithms and analysis, volume 57. Springer Science & Business Media, 2007.
  • (9) Michalis Titsias RC AUEB and Miguel Lázaro-Gredilla. Local expectation gradients for black box variational inference. In Advances in neural information processing systems, pages 2638–2646, 2015.
  • (10) Amir Beck. First-Order Methods in Optimization. SIAM, 2017.
  • (11) Quentin Berthet, Mathieu Blondel, Olivier Teboul, Marco Cuturi, Jean-Philippe Vert, and Francis Bach. Learning with Differentiable Perturbed Optimizers. arXiv e-prints, page arXiv:2002.08676, February 2020.
  • (12) Dimitris Bertsimas and John N Tsitsiklis. Introduction to linear optimization, volume 6. Athena Scientific Belmont, MA, 1997.
  • (13) Mathieu Blondel. Structured prediction with projection oracles. In Advances in Neural Information Processing Systems, pages 12145–12156, 2019.
  • (14) Mathieu Blondel, André FT Martins, and Vlad Niculae. Learning with fenchel-young losses. Journal of Machine Learning Research, 21(35):1–69, 2020.
  • (15) Mathieu Blondel, Olivier Teboul, Quentin Berthet, and Josip Djolonga. Fast differentiable sorting and ranking. arXiv preprint arXiv:2002.08871, 2020.
  • (16) James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, and Skye Wanderman-Milne. JAX: composable transformations of Python+NumPy programs, 2018.
  • (17) Jianbo Chen, Le Song, Martin Wainwright, and Michael Jordan. Learning to explain: An information-theoretic perspective on model interpretation. In International Conference on Machine Learning, 2018.
  • (18) Y.J. Chu and T. H. Liu. On the shortest arborescence of a directed graph. Scientia Sinica, 14:1396–1400, 1965.
  • (19) Caio Corro and Ivan Titov. Differentiable perturb-and-parse: Semi-supervised parsing with a structured variational autoencoder. In International Conference on Learning Representations, 2019.
  • (20) Josip Djolonga and Andreas Krause. Differentiable learning of submodular models. In Advances in Neural Information Processing Systems, pages 1013–1023, 2017.
  • (21) Justin Domke. Implicit differentiation by perturbation. In J. D. Lafferty, C. K. I. Williams, J. Shawe-Taylor, R. S. Zemel, and A. Culotta, editors, Advances in Neural Information Processing Systems 23, pages 523–531. Curran Associates, Inc., 2010.
  • (22) Justin Domke. Learning graphical model parameters with approximate marginal inference. IEEE transactions on pattern analysis and machine intelligence, 35(10):2454–2467, 2013.
  • (23) John Duchi, Shai Shalev-Shwartz, Yoram Singer, and Tushar Chandra. Efficient projections onto the l 1-ball for learning in high dimensions. In Proceedings of the 25th international conference on Machine learning, pages 272–279, 2008.
  • (24) Jack Edmonds. Optimum branchings”. Journal of Research of the National Bureau of Standards: Mathematics and mathematical physics. B, 71:233, 1967.
  • (25) Thomas MJ Fruchterman and Edward M Reingold. Graph drawing by force-directed placement. Software: Practice and experience, 21(11):1129–1164, 1991.
  • (26) Yarin Gal. Uncertainty in deep learning. University of Cambridge, 1:3, 2016.
  • (27) Andreea Gane, Tamir Hazan, and Tommi Jaakkola. Learning with maximum a-posteriori perturbation models. In Artificial Intelligence and Statistics, pages 247–256, 2014.
  • (28) Peter W Glynn. Likelihood ratio gradient estimation for stochastic systems. Communications of the ACM, 33(10):75–84, 1990.
  • (29) Will Grathwohl, Dami Choi, Yuhuai Wu, Geoff Roeder, and David Duvenaud. Backpropagation through the void: Optimizing control variates for black-box gradient estimation. In International Conference on Learning Representations, 2018.
  • (30) Alex Graves, Greg Wayne, and Ivo Danihelka. Neural turing machines. arXiv preprint arXiv:1410.5401, 2014.
  • (31) Aditya Grover, Eric Wang, Aaron Zweig, and Stefano Ermon. Stochastic optimization of sorting networks via continuous relaxations. In International Conference on Learning Representations, 2019.
  • (32) Shixiang Gu, Sergey Levine, Ilya Sutskever, and Andriy Mnih. Muprop: Unbiased backpropagation for stochastic neural networks. In ICLR, 2016.
  • (33) Tamir Hazan and Tommi Jaakkola. On the partition function and random maximum a-posteriori perturbations. In International Conference on Machine Learning, 2012.
  • (34) Tamir Hazan, Subhransu Maji, and Tommi Jaakkola. On Sampling from the Gibbs Distribution with Random Maximum A-Posteriori Perturbations. In Advances in Neural Information Processing Systems, 2013.
  • (35) Eric Jang, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax. In International Conference on Learning Representations, 2016.
  • (36) Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. International Conference on Learning Representations, 2015.
  • (37) Diederik P Kingma and Max Welling. Auto-encoding variational bayes. In International Conference on Learning Representations, 2014.
  • (38) Thomas Kipf, Ethan Fetaya, Kuan-Chieh Wang, Max Welling, and Richard Zemel. Neural relational inference for interacting systems. In International Conference on Machine Learning, 2018.
  • (39) Jon Kleinberg and Éva Tardos. Algorithm Design. Pearson Education, 2006.
  • (40) Daphne Koller and Nir Friedman. Probabilistic graphical models: principles and techniques. 2009.
  • (41) Vladimir Kolmogorov. Convergent tree-reweighted message passing for energy minimization. IEEE transactions on pattern analysis and machine intelligence, 28(10):1568–1583, 2006.
  • (42) Terry Koo, Amir Globerson, Xavier Carreras, and Michael Collins. Structured prediction models via the matrix-tree theorem. In

    Proceedings of the 2007 Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning (EMNLP-CoNLL)

    , pages 141–150, Prague, Czech Republic, June 2007. Association for Computational Linguistics.
  • (43) Wouter Kool, Herke van Hoof, and Max Welling. Ancestral gumbel-top-k sampling for sampling without replacement. Journal of Machine Learning Research, 21(47):1–36, 2020.
  • (44) Wouter Kool, Herke van Hoof, and Max Welling.

    Estimating gradients for discrete random variables by sampling without replacement.

    In International Conference on Learning Representations, 2020.
  • (45) Joseph B Kruskal. On the shortest spanning subtree of a graph and the traveling salesman problem. Proceedings of the American Mathematical society, 7(1):48–50, 1956.
  • (46) Harold W Kuhn. The hungarian method for the assignment problem. Naval research logistics quarterly, 2(1-2):83–97, 1955.
  • (47) Wonyeol Lee, Hangyeol Yu, and Hongseok Yang. Reparameterization gradient for non-differentiable models. In Advances in Neural Information Processing Systems, pages 5553–5563, 2018.
  • (48) Tao Lei, Regina Barzilay, and Tommi Jaakkola. Rationalizing neural predictions. arXiv preprint arXiv:1606.04155, 2016.
  • (49) Jun Liu and Jieping Ye. Efficient euclidean projections in linear time. In Proceedings of the 26th Annual International Conference on Machine Learning, pages 657–664, 2009.
  • (50) Guy Lorberbom, Andreea Gane, Tommi Jaakkola, and Tamir Hazan. Direct optimization through argmax for discrete variational auto-encoder. In Advances in Neural Information Processing Systems, pages 6200–6211, 2019.
  • (51) R Duncan Luce. Individual Choice Behavior: A Theoretical Analysis. New York: Wiley, 1959.
  • (52) Chris J Maddison, Andriy Mnih, and Yee Whye Teh. The concrete distribution: A continuous relaxation of discrete random variables. In International Conference on Learning Representations, 2017.
  • (53) Chris J Maddison, Daniel Tarlow, and Tom Minka. A Sampling. In Advances in Neural Information Processing Systems, 2014.
  • (54) Andre Martins and Ramon Astudillo. From softmax to sparsemax: A sparse model of attention and multi-label classification. In International Conference on Machine Learning, pages 1614–1623, 2016.
  • (55) André FT Martins and Julia Kreutzer. Learning what’s easy: Fully differentiable neural easy-first taggers. In Proceedings of the 2017 conference on empirical methods in natural language processing, pages 349–362, 2017.
  • (56) Julian McAuley, Jure Leskovec, and Dan Jurafsky. Learning attitudes and attributes from multi-aspect reviews. In 2012 IEEE 12th International Conference on Data Mining, pages 1020–1025. IEEE, 2012.
  • (57) Gonzalo Mena, David Belanger, Scott Linderman, and Jasper Snoek. Learning latent permutations with gumbel-sinkhorn networks. In International Conference on Learning Representations, 2018.
  • (58) Elad Mezuman, Daniel Tarlow, Amir Globerson, and Yair Weiss. Tighter linear program relaxations for high order graphical models. In Proceedings of the Twenty-Ninth Conference on Uncertainty in Artificial Intelligence, pages 421–430, 2013.
  • (59) Andriy Mnih and Karol Gregor. Neural variational inference and learning in belief networks. In International Conference on Machine Learning, 2014.
  • (60) Shakir Mohamed, Mihaela Rosca, Michael Figurnov, and Andriy Mnih. Monte Carlo Gradient Estimation in Machine Learning. arXiv e-prints, page arXiv:1906.10652, June 2019.
  • (61) Nikita Nangia and Samuel R Bowman. Listops: A diagnostic dataset for latent tree learning. arXiv preprint arXiv:1804.06028, 2018.
  • (62) Vlad Niculae, André FT Martins, Mathieu Blondel, and Claire Cardie. Sparsemap: Differentiable sparse structured inference. arXiv preprint arXiv:1802.04223, 2018.
  • (63) G. Papandreou and A. Yuille. Perturb-and-MAP Random Fields: Using Discrete Optimization to Learn and Sample from Energy Models. In

    International Conference on Computer Vision

    , 2011.
  • (64) Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer.

    Automatic differentiation in pytorch.

  • (65) Robin L Plackett. The analysis of permutations. Journal of the Royal Statistical Society: Series C (Applied Statistics), 24(2):193–202, 1975.
  • (66) Hoifung Poon and Pedro Domingos. Sum-product networks: A new deep architecture. In 2011 IEEE International Conference on Computer Vision Workshops (ICCV Workshops), pages 689–690. IEEE, 2011.
  • (67) Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning, 2014.
  • (68) R. Tyrrell Rockafellar. Convex Analysis. Princeton University Press, 1970.
  • (69) R Tyrrell Rockafellar. Second-order convex analysis. J. Nonlinear Convex Anal, 1(1-16):84, 1999.
  • (70) Stephane Ross, Daniel Munoz, Martial Hebert, and J. Andrew Bagnell. Learning message-passing inference machines for structured prediction. In

    IEEE Conference on Computer Vision and Pattern Recognition (CVPR)

    , 2011.
  • (71) Francisco JR Ruiz, Michalis K Titsias, and David M Blei. The generalized reparameterization gradient. In Advances in Neural Information Processing Systems, 2016.
  • (72) Alexander Schrijver. Combinatorial optimization: polyhedra and efficiency, volume 24. Springer Science & Business Media, 2003.
  • (73) Richard Sinkhorn and Paul Knopp. Concerning nonnegative matrices and doubly stochastic matrices. Pacific Journal of Mathematics, 21(2):343–348, 1967.
  • (74) Richard S Sutton and Andrew G Barto. Reinforcement learning: An introduction. MIT press, 2018.
  • (75) Kevin Swersky, Ilya Sutskever, Daniel Tarlow, Richard S Zemel, Russ R Salakhutdinov, and Ryan P Adams.

    Cardinality restricted boltzmann machines.

    In Advances in neural information processing systems, pages 3293–3301, 2012.
  • (76) Daniel Tarlow, Ryan Adams, and Richard Zemel. Randomized optimum models for structured prediction. In Neil D. Lawrence and Mark Girolami, editors, Proceedings of the Fifteenth International Conference on Artificial Intelligence and Statistics, volume 22 of Proceedings of Machine Learning Research, pages 1221–1229, La Palma, Canary Islands, 21–23 Apr 2012. PMLR.
  • (77) Daniel Tarlow, Kevin Swersky, Richard S Zemel, Ryan P Adams, and Brendan J Frey. Fast exact inference for recursive cardinality models. In 28th Conference on Uncertainty in Artificial Intelligence, UAI 2012, pages 825–834, 2012.
  • (78) Louis L Thurstone. A law of comparative judgment. Psychological review, 34(4):273, 1927.
  • (79) George Tucker, Andriy Mnih, Chris J Maddison, John Lawson, and Jascha Sohl-Dickstein. Rebar: Low-variance, unbiased gradient estimates for discrete latent variable models. In Advances in Neural Information Processing Systems, pages 2627–2636, 2017.
  • (80) William T. Tutte. Graph Theory. Addison‐Wesley, 1984.
  • (81) Martin J Wainwright and Michael I Jordan. Graphical models, exponential families, and variational inference. Foundations and Trends in Machine Learning, 1(1–2):1–305, 2008.
  • (82) Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229–256, 1992.
  • (83) Philip Wolfe. Finding the nearest point in a polytope. Mathematical Programming, 11(1):128–149, 1976.
  • (84) Sang Michael Xie and Stefano Ermon. Reparameterizable subset sampling via continuous relaxations. In International Joint Conference on Artificial Intelligence, 2019.
  • (85) Mingzhang Yin and Mingyuan Zhou. ARM: Augment-REINFORCE-merge gradient for stochastic binary networks. In International Conference on Learning Representations, 2019.

Appendix A Proofs for Stochastic Softmax Tricks

Let be a finite, non-empty set, , and . We have,


If has a unique solution , then is also the unique solution of .


Assume w.l.o.g. that . Let .

First, let us consider the linear program over vs. . Clearly, . In the other direction, for any , we can write for such that , and


Hence . Thus .

Second, let us consider the linear program over vs. . The cases or are trivial, so assume otherwise. Since for , it suffices to show that for all there exists such that . To that end, take and , and define


by [68, Thm 6.1]. Thus, we get


Finally, suppose that is unique, but contains more than just . We will show this implies a contradiction. Let be the index such that . Let be such that . Then we may write for such that . But this leads to a contradiction,


Let be a non-empty convex polytope and an extreme point of . Define the set,


This is the set of utility vectors of a linear program over whose argmax is the minimal face . Then, for all , there exists an open set containing .


Let . Let be the set of extreme points (there are finitely many), and assume w.l.o.g. that . For each there exists such that . Thus, for all in the open ball of radius centered at , we have


Define . Note, for all . Now, let . Because is the convex hull of the [12, Thm. 2.9], we must have


for , with at least one for . Thus, for all


This implies that , which concludes the proof, as is open, convex, and contains . ∎

Given a non-empty, finite set and a proper, closed, strongly convex function whose domain contains the relative interior of , let , and be the indicator function of the polytope ,


For , define


The following are true for ,

  1. (20) has a unique solution, is continuously differentiable, twice differentiable a.e., and

  2. If has a unique solution, then


Note, .

  1. Since is strongly convex [10, Lem. 5.20], (20) has a unique maximum [10, Thm. 5.25]. Moreover, is differentiable everywhere in and its gradient is Lipschitz continuous [10, Thm. 5.26]. By [68, Thm 25.5] is a continuous function on . By Rademacher’s theorem, is a.e. differentiable. (21) follows by standard properties of the convex conjugate [68, Thm. 23.5, Thm. 25.1].

  2. First, by Lemma A,


    Since is such that is uniquely maximized over , is differentiable at by [68, Thm. 23.5, Thm. 25.1]. Again by Lemma A we have


    Hence, our aim is to show . This is equivalent to showing that for any such that . Let be such a sequence.

    We will first show that . For any ,


    Since for all , we also have

    Thus .

    By Lemma A, there exists an open convex set containing such that for all , . Again, is differentiable on [68, Thm. 23.5, Thm. 25.1]. Using this and the fact that , we get [68, Thm. 25.7].



For defined in (20), we have by Lemma A,


If is a.s. unique, then again by Lemma A

The last bit of the proof follows from the dominated convergence theorem, since the loss in bounded on by assumption, so is surely bounded. ∎



For defined in (20), we have by Lemma A,


Our result follows by the other results of Lemma A. ∎

If for all such that , then in Def. 4 is a.s. unique.


It suffices to show that for all subsets with , the event has zero measure. If , then we can pick two distinct points with . Now,


where . ∎

Let be a non-empty finite set. If is convex independent, i.e., for all , , then is the set of extreme points of . In particular, any non-empty set of binary vectors is convex independent and thus the set of extreme points of .


Let . The fact that the extreme points of are in is trivial. In the other direction, it is enough to show that is an extreme point. Assume is not an extreme point of . Then by definition, we can write for , with and . Then, we have that


for some sequences such that and . This is clearly a contradiction of our assumption that , since the weights in the summation (28) sum to unity. This implies that are the extreme points of .

Let . It is enough to show that . Assume this is not the case. Let , and note that for all when are distinct binary vectors. But, this leads to a contradiction. By assumption we can express as a convex combination of . Thus, there exists such that , and


Appendix B An Abbreviated Field Guide to Stochastic Softmax Tricks

b.1 Introduction


This is a short field guide to some stochastic softmax tricks (SSTs) and their associated stochastic argmax tricks (SMTs). There are many potential SSTs not discussed here. We assume throughout this Appendix that readers are completely familiar with main text and its notation; we do not review it. In particular, we follow the problem definition and notation of Section 2, the definition and notation of SMTs in Section 4, and the definition and notation of SSTs in Section 5.

This field guide is organized by the abstract set . For each , we identify an appropriate set of structured embeddings. We discuss utility distributions used in the experiments. In some cases, we can provide a simple, “closed-form”, categorical sampling process for , i.e., a generalization of the Gumbel-Max trick. We also cover potential relaxations used in the experiments. In the remainder of this introduction, we introduce basic concepts that recur throughout the field guide.


Given a finite set , the indicator vector of a subset