1 Introduction
Learning distributions in discrete domains is a fundamental problem in machine learning. This problem can be formulated in general as minimizing the following expected cost
(1) 
where is the cost function, is a discrete (latent) random variable whose distribution is parameterized by . Typically,
is obtained as the output of a Neural Network (NN), whose weights are learned by backpropagating the gradients through discrete random variables
.In practice, direct gradient computations through the discrete random variables,
suffers from the curse of dimensionality, since it requires traversing through all possible joint configurations of the latent variable, whose number is exponentially large w.r.t. the latent dimension. Due to this limitation, existing approaches resort to estimating the gradient
by approximating its expectation, where Monte Carlo sampling methods are typically employed.The StraightThrough (ST) estimator [14, 3] is a widely applied method due to its simplicity and effectiveness. The idea of ST is directly using the gradients of discrete samples as the gradients of the distribution parameters. Since discrete samples can be generated as the output of hard threshold functions with distribution parameters as input, Bengio et al [3] explain the ST estimator by set the gradients of hard threshold functions to . However, this explanation lacks theoretical justification for the gradients of hard threshold functions.
In this paper, we show that ST can be interpreted as simulating the projected Wasserstein gradient flow (pWGF) of a functional , where is a distribution in the target discrete distribution family with density parameterized by . Further, a more general optimizing scheme for (1) is introduced. Instead of directly updating in the discrete distribution family, is first updated to on a larger Wasserstein distribution space where gradients are easier to compute. Then, we project back to the discrete distribution family as the updated distribution. Moreover, the projection follows the descending direction of in , which justifies the effectiveness of ST. This pWGF based updating scheme also motivates another variant that achieves faster convergence when the desired family of distributions has infinite support, e.g., Poisson.
2 Proposed Algorithm
Denote as the dimensional discrete distributions family parameterized by . With , the task (1) can be rewritten as
(2) 
where is assumed to be differentiable. To solve (2), directly calculating the gradient is challenging, because the discrete distribution family is very restrictive on the gradients. Alternatively, if we relax the discrete constraint and perform updates in an appropriate larger space , the calculation of the gradient can be much easier. Therefore, as showed in Fig. 1, in th updating iteration, we consider first updating the current distribution to with stepsize in a larger 2Wasserstein space [24], then projecting back to as updated discrete distribution . Theorem C.2 in supplement guarantees that our updating scheme converges with a small enough step size .
With Wasserstein gradient flow (WGF) [24], we show (in Appendix) that, the gradient in larger space as , which means, if is represented by a group of its samples , then can be treated as a group of sample from . Therefore, we can update to along the WGF simply by updating its samples. To project back to as , we need to solve , which is equivalent to solve , where is the square of the 2Wasserstein distance [24]. Consequently, our pWGF algorithm proceeds in 3 steps shown in Fig. 2: (A) draw samples from current distribution ; (B) update to as samples from ; (C) project back to by minimizing Wasserstein distance.
Since distributions in are multidimensional, the exact Wasserstein distance is difficult to derive. We make a standard assumption [8] that and are factorized distributions. With the assumption, we prove in Theorem 2.1 that minimizing Wasserstein distance between factorized distributions is equivalent to minimizing the marginal distance on every dimension. Therefore, for simplicity, we describe our projection step using onedimensional distributions. As the updated distribution is implicit, we can not obtain the closed form of Wasserstein distance ,. Therefore, we consider two approximations of .
Theorem 2.1.
If dimensional distributions and are factorized, then , where and are the marginal distributions of and respectively.
2.1 ST estimator: Absolute Difference of Expectation
We find that the StraightThrough (ST) estimator [3] is a special case of pWGF, when the Wasserstein distance is approximated via its lower bound, absolute difference of expectations.
Theorem 2.2.
For two onedimensional distributions , the absolute difference between and is a lower bound of , i.e.
Remark.
If and are Bernoulli, then , which means minimizing the expectation difference is equivalent to minimizing the 2Wasserstein distance under Bernoulli cases.
For onedimensional Bernoulli distribution,
, noting that and , we approximate the parameter gradient by:To reduce the variance caused by the sample mean, we use the control variate method
[4] and write Thus, we have derived the pWGF estimator with expectation difference approximation, which has the same form as a multisample version ST estimator [3]. Parameter gradients for Poisson and Categorical distributions can be derived in a similar way.2.2 Proposed estimator: Maximum Mean Discrepancy
A more principled way to approximate the Wasserstein distance is to use Maximum Mean Discrepancy (MMD) [12]: , where is a selected kernel. In practice, instead of minimizing , we can minimize the empirical expectation . Details on parameter gradients are shown in the supplement.
3 Experiments
We demonstrate the advantage of pWGF on updating Poisson distributions, and show the benchmark performance with a binary latent model in the supplement. Since the only difference between our pWGF version ST and the original ST is the learning rate scalar, if not specifically mentioned, we call pWGFST or the original ST together as ST, and call our MMD version method as pWGF.
3.1 Poisson Parameter Estimation
We apply pWGF to infer the parameter of a onedimensional Poisson distribution. We use the true distribution to generate data samples , and use a Generative Adversarial learing framework to learn model parameters. A generator is constructed as . A discriminator
is a network used to distinguish true/fake samples, which outputs the probability that the data comes from the true distribution. During the adversarial training, the generator aims to increase
, while the discriminator tries to decrease and increase . We can rewrite the training process as a minmax game with objective function: Similar to the observation in [10], the training process should finally converges to . Therefore, for the generator, learning becomes optimizing . We compare our pWGF against ST, Reinforce and Muprop [13] and show the learning curves on estimation in Figure 3. pWGF converges faster than others and exhibits much smaller oscillation. In Table 3.1, We report the mean and the standard derivation of the inferred parameter aftertraining epochs, where our pWGF exhibits higher inference accuracy and lower variance.
Mean Std pWGF 5.0076 0.013 ST 5.1049 0.161 Muprop 5.0196 0.159 Reinforce 4.9452 0.1734 Conclusion
We presented a theoretical foundation to justify the superior empirical performance of StraightThrough (ST) estimator for backpropagating gradients through discrete latent variables. Specifically, we show that ST can be interpreted as the simulation of the projected gradient flow on Wasserstein space. Based upon this theoretical framework, we further propose another gradient estimator for learning discrete variables, which exhibits even better performance while applied to distributions with infinite support, e.g., Poisson.
References
 [1] (2008) Gradient flows: in metric spaces and in the space of probability measures. Springer Science & Business Media. Cited by: §A.2, Appendix B, Appendix B, Appendix B.
 [2] (2000) A computational fluid mechanics solution to the mongekantorovich mass transfer problem. Numerische Mathematik 84 (3), pp. 375–393. Cited by: Appendix B, Appendix B.

[3]
(2013)
Estimating or propagating gradients through stochastic neurons for conditional computation
. arXiv preprint arXiv:1308.3432. Cited by: §A.1, §1, §2.1, §2.1.  [4] (1977) Options: a monte carlo approach. Journal of financial economics 4 (3), pp. 323–338. Cited by: §2.1.
 [5] (2017) Continuoustime flows for deep generative models. arXiv preprint arXiv:1709.01179. Cited by: §A.2.
 [6] (2018) A unified particleoptimization framework for scalable bayesian sampling. UAI submission. Cited by: §A.2.
 [7] (2017) Particle optimization in stochastic gradient mcmc. arXiv preprint arXiv:1711.10927. Cited by: §A.2.
 [8] (2016) Infogan: interpretable representation learning by information maximizing generative adversarial nets. In Advances in neural information processing systems, pp. 2172–2180. Cited by: §2.

[9]
(2016)
Tutorial on variational autoencoders
. arXiv preprint arXiv:1606.05908. Cited by: Appendix A.  [10] (2014) Generative adversarial nets. In Advances in neural information processing systems, pp. 2672–2680. Cited by: §3.1.
 [11] (2017) Backpropagation through the void: optimizing control variates for blackbox gradient estimation. arXiv preprint arXiv:1711.00123. Cited by: Appendix E.
 [12] (2007) A kernel method for the twosampleproblem. In Advances in neural information processing systems, pp. 513–520. Cited by: §2.2.
 [13] (2015) MuProp: unbiased backpropagation for stochastic neural networks. arXiv preprint arXiv:1511.05176. Cited by: §3.1.
 [14] (2012) Neural networks for machine learning, video lectures. Coursera. Cited by: §A.1, §1.
 [15] (2017) Categorical reparametrization with gumblesoftmax. In International Conference on Learning Representations 2017, Cited by: §A.1, Appendix E.
 [16] (2014) Autoencoding variational bayes. stat 1050, pp. 1. Cited by: §A.1, Appendix E.

[17]
(2018)
Accelerated firstorder methods on the wasserstein space for bayesian inference
. arXiv preprint arXiv:1807.01750. Cited by: §A.2.  [18] (2016) Stein variational gradient descent: a general purpose bayesian inference algorithm. In Advances In Neural Information Processing Systems, pp. 2378–2386. Cited by: §A.2.
 [19] (2017) Stein variational gradient descent as gradient flow. In Advances in neural information processing systems, pp. 3118–3126. Cited by: §A.2.
 [20] (2010) Learning gradients on manifolds. Bernoulli 16 (1), pp. 181–207. Cited by: Theorem C.2, Appendix C.
 [21] (2001) The geometry of dissipative evolution equations: the porous medium equation. Cited by: Appendix B, Appendix B.
 [22] (2014) Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning, pp. 1278–1286. Cited by: §A.1.
 [23] (2017) Rebar: lowvariance, unbiased gradient estimates for discrete latent variable models. In Advances in Neural Information Processing Systems, pp. 2627–2636. Cited by: Appendix E.
 [24] (2008) Optimal transport: old and new. Vol. 338, Springer Science & Business Media. Cited by: §A.2, Appendix B, Appendix B, §2, §2.
 [25] (2018) ARM: augmentreinforcemerge gradient for discrete latent variable models. arXiv preprint arXiv:1807.11143. Cited by: Appendix E.
 [26] (2018) Policy optimization as wasserstein gradient flows. arXiv preprint arXiv:1808.03030. Cited by: §A.2.
Appendix A Background
To minimize the expected cost in (1), we assume that , if the cost function depends of . For instance, in the Variational Autoencoder (VAE) [9], we seek to maximize the Evidence Lower Bound (ELBO) as , where depends on parameter through the variational posterior approximation . Since , we have .
As described above, there are two types of updating methods for under (1), namely, estimation of the parameter gradient , and continuous relaxation of the discrete variable .
a.1 Continuous relaxation
Another approach used to obtain updates for in (1) is to approximate samples of from a deterministic function, , of and an independent random variable with simple distribution , e.g., uniform or normal, so
. Then we can use the chain rule to derive the gradient of (
1) asWe can take expectation of the gradients, which is very convenient because can be computed by chain rule, noting that does not directly depend of . This reparameterization trick works quiet well when
originates from a continuous distribution. For example, given a normal distribution,
, we can rewrite and directly obtain and . This reparameterization has been widely used in the training of variational autoencoder with latent Gaussian priors [16, 22].In the discrete case, it becomes very difficult to find a differentiable deterministic function to generate samples from . For the categorical distribution, [15] introduced the GumbelSoftmax distribution to relax the onehotvector encoding commonly used for categorical variables. For the multidimensional (factorized) Bernoulli distribution with parameter , the Straight Through (ST) estimator [14, 3], which considers the gradient of samples of directly, as the gradient of parameter , can be also explained by setting the derivative of the discrete function
(coordinatewise) directly to the identity matrix
[3].a.2 Wasserstein gradient flow
Wasserstein gradient flows (WGF) [6, 1, 24] have become popular in machine learning, due to its generality over parametric distribution families, and tractable computational efficiency. The Wasserstein space is a metric space of distributions. The WGF defines a family of steepest descending functions. It has been Bayesian inference, where the KL divergence of an approximating distribution to a target one is minimized by simulating its gradient flow. [6] developed a unfnied framework to simulate the WGF, including Stein Variational Gradient Descent (SVGD) [18, 19] and Stochastic Gradient MCMC as its special cases. [7] and [17] proposed an acceleration framework for these methods. WFGs have also been applied to deep generative models [5]
and policy optimization in reinforcement learning
[26]. However, all previous methods focus on simulating WGFs to approximate distributions in continuous domains. There has been little if any research reported for WGFs for discrete domains.Appendix B Updating via Wasserstein gradient flow
Gradient computation and Wasserstein Gradient Flow (WGF) simulation are made possible by the Riemannian structure of , which consists of a proper inner product in the tangent space that is consistent with the Wasserstein distance [2, 21]. The tangent space of at can be represented by a subspace of vector fields on ([24], Thm 13.8; [1], Thm 8.3.1, Prop 8.4.5):
where is the set of compactly supported smooth functions on , is a Hilbert space with inner product , and the overline represents taking the closure in .
With the inner product inherited from , being a Riemannian manifold is consistent with the Wasserstein distance due to the BenamouBrenier formula [2]. We can then express the gradient of a function on in the Riemannian sense. The explicit expression is intuitively proposed as Otto’s calculus ([21]; [24], Chapter 15) and rigorously verified by subsequent work, e.g., [24], Thm 23.18; [1], Lem 10.4.1. Specifically, they showed that given a functional with , its gradient is , a vector field on . This means that we can, in principle, compute the desired gradient using .
Another convenient property of based on the physical interpretation of tangent vectors on makes the gradient flow simulation possible. Consider a smooth curve of absolutely continuous measures, , with corresponding tangent vector , where , and for which the gradient flow is simulated (iteratively) at discrete values , to estimate (the target distribution). For any and , Proposition 8.4.6 of [1] guarantees that , where is a transformation on ( is the identity map and is a vector field on ), and is the pushedforward measure of that moves along the tangent vector by distance , see Figure 1. When is a gradient flow (steepest descending curve) of defined in the form above, , as described before, then for having a set of samples and the definition of pushedforward measure [1], is a set of samples of , which conform a firstorder approximation of . Since is a good approximation of (the optimal measure along the WGF) as discussed above, thus we can use to approximate . This is done by projecting onto . Then, per Theorem C.2, with small enough positive , we can always get a set of samples whose distribution improves , the functional of the cost in (1).
Appendix C Proofs
Theorem C.1.
Let be a differentiable function on a manifold and a submanifold of , , then at any ,
where is the projection of onto .
Proof of Theorem c.1.
By the definition of [20], for any vector ,
(3) 
By the definition of , for any vector ,
(4) 
Since is the subspace of , by definition of we have
(5) 
Theorem C.2.
Let and be the 2Wasserstein distance in . Update in along direction to (exponential map [20]), then project back to as . If is Lipschitz continuous, then there exists , such that for any , .
Proof of Theorem 2.1.
(1) First, we show that .
Arbitrarily selecting , , we define . Since , we have
(6) 
which means the marginal distribution of on is . Similarly, the marginal distribution of on is . Therofore, . Then
(7) 
(2) Then we show
Note that
(11) 
where is the marginal distribution of over .
By Fubini’s Theorem,
(12) 
Similarly, . Therefore, . Then
(13) 
(14) 
Take infimum over both sides,
Therefore, . ∎
Proof of Remark 2.2.
For and ,
(15) 
where , , and , .
Problem in (15
) is a linear programming. It can be shown easily that the minimum value of (
15) is . ∎Lemma C.3.
Let be an arbitrary distribution and be a Bernoulli distribution. Then
(16) 
where .
Appendix D Gradient For MMD Projection
We take the radial basis function kernel
for instance.For Bernoulli distribution, ,
Appendix E Binary Latent Models
As most of previous proposed algorithms are specifically designed for the discrete variables with finite support, we consider using a binary latent model as the benchmark. We use variational autoencoder (VAE) [16] with the Bernoulli latent variable (Bernoulli VAE). We compare pWGF with the baseline methods ST and GumbelSoftmax [15], as well as three stateoftheart algorithms: Rebar [23], Relax [11] and ARM [25]. Following the settings in [25], we build the model with different network architectures. We apply all methods and architectures to the MNIST dataset, and show the results in Table 2. From the results, pWGF is comparable with ST, and both pWGF/ST outperform other competing methods except ARM in all tested network architecture.
pWGF  ST  ARM  RELAX  REBAR  GumbelSoftmax  

Linear  119.8  119.1  110.3  122.1  123.2  129.2 
Two Layers  108.3  107.6  98.2  114  113.7  NA 
Nonlinear  104.6  104.2  101.3  110.9  111.6  112.5 