Generalized Gumbel-Softmax Gradient Estimator for Various Discrete Random Variables

03/04/2020 ∙ by Weonyoung Joo, et al. ∙ 0

Estimating the gradients of stochastic nodes is one of the crucial research questions in the deep generative modeling community. This estimation problem becomes further complex when we regard the stochastic nodes to be discrete because pathwise derivative techniques can not be applied. Hence, the gradient estimation requires the score function methods or the continuous relaxation of the discrete random variables. This paper proposes a general version of the Gumbel-Softmax estimator with continuous relaxation, and this estimator is able to relax the discreteness of probability distributions, including broader types than the current practice. In detail, we utilize the truncation of discrete random variables and the Gumbel-Softmax trick with a linear transformation for the relaxation. The proposed approach enables the relaxed discrete random variable to be reparameterized and to backpropagate through a large scale stochastic neural network. Our experiments consist of synthetic data analyses, which show the efficacy of our methods, and topic model analyses, which demonstrates the value of the proposed estimation in practices.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Stochastic neural networks, including deep generative models, are widely used for representation learning. Optimizing the network parameters through gradient methods requires an estimation of the gradient values. There are two ways of obtaining the gradients, which are score function based methods and reparameterization trick methods. Each gradient estimation method has its own characteristics. For example, (1) the score function based estimators tend to result in unbiased gradients with high variances; and (2) the reparameterization estimators seem to be leading biased gradients with low variances

(Xu et al., 2019). Hence, to limit the negative aspect, the core technique of the score function based estimators becomes reducing the variances of gradients to achieve stable and fast optimizations. Similarly, utilizing the reparameterization estimators requires the differentiable non-centered parameterization (Kingma and Welling, 2014b) of random variables.

If we focus on the reparameterization estimators, one of the most popular examples is the reparameterization in the Gaussian Variational Autoencoder (VAE)

(Kingma and Welling, 2014a), which has an exact reparameterization form. Other VAEs with explicit priors suggest their reparameterization tricks with approximations. For example, Stick-Breaking VAE (Nalisnick and Smyth, 2017) assumes a Griffiths-Engen-McCloskey (GEM) prior (Pitman, 2002), and the authors utilized the Kumaraswamy distribution (Kumaraswamy, 1980)

to approximate the Beta distribution used by the GEM distribution. Dirichlet VAE

(Joo et al., 2019) assumes a Dirichlet prior, and the authors utilized the approximation by the inverse Gamma cumulative density function (Knowles, 2015) and the composition of Gamma random variables to form the Dirichlet distribution.

For continuous random variables, it is feasible to estimate gradients with recent methods: such as the optimal mass transport gradient

(Jankowiak and Obermeyer, 2018), which utilizes a transport equation; or the implicit reparameterization gradients (Figurnov et al., 2018), which are based on the automatic differentiation. However, these methods are not applicable to discrete random variables, due to their non-differentiable characteristics.

To overcome this difficulty, some discrete random variables, such as Bernoulli or categorical random variables, are well-explored recently.111In this paper, discrete random variable implies a random variable with a discrete outcome which may follow diverse distributions other than the Bernoulli or the categorical case. The authors of Jang,E. et al. (2017) and Maddison et al. (2017)

developed a continuous relaxation of the Bernoulli and the categorical random variables through the Gumbel-Softmax and the Concrete distributions, respectively. Meanwhile, other discrete distributions, such as the Poisson, the binomial, the geometric, the negative binomial distributions, and etc, are not explored enough from the learning perspective.

This paper proposes a reparameterization trick for generic discrete random variables through continuous relaxation, which is a generalized version of the Gumbel-Softmax (GenGS) estimator. The key idea of relaxing the discrete random variables is (1) utilizing truncated discrete random variables as an approximation to the discrete random variables; and (2) transforming a Gumbel-Softmax trick with a special form of a linear transformation. We present three theorems to theoretically substantiate that the proposed GenGS is applicable to a discrete random variable set that is broader than both Bernoulli and categorical random variables. Since we present a reparameterization trick for discrete random variables, we present two cases of practical usages through experiments. First, we show that the proposed GenGS is well applicable to the variants of VAEs by diversifying the priors. Second, we illustrate the potential gains in the topic model from the neural topic model with GenGS.

2 Preliminary

This section introduces the following: (1) the review of reparameterization tricks; and (2) the Gumbel-Softmax reparameterization on categorical random variables.

2.1 Backpropagation through Stochastic Nodes

Let’s suppose that we have a stochastic node, or a latent variable, , where the distribution depends on

, and we want to optimize the loss function,

, where is a continuous and differentiable function with respect to , i.e. neural networks. To optimize the loss function in terms of through the gradient methods, we need to find which can not be directly computed with its original form.

2.1.1 Reparameterization Trick

To compute , the reparameterization trick introduces an auxiliary variable, , which takes over all randomness of the latent variable , so the sampled value can be re-written as where is a deterministic and differentiable function in terms of . Here, the gradient of the loss function with respect to is derived as

(1)

where the term inside of the expectation in Equation 1 is now achievable. A condition on enabling the reparameterization trick is the assumption on the continuity of the random variable , so the distribution of is limited to a class of continuous distributions. To utilize the reparameterization trick on discrete random variables, the continuous relaxation can be applied: for example, a relaxation from the categorical distribution to the Gumbel-Softmax distribution.

2.2 Reparameterization Trick on Categorical Random Variable

A Gumbel-Max trick (Gumbel, 1948) is a procedure for sampling a categorical one-hot value from the Gumbel distribution, instead of direct samplings from a categorical distribution. This implies that the categorical random variable , where lies on the -dimensional simplex , can be reparameterized by the Gumbel-Max trick: (1) sample i.i.d. to generate a gumbel sample for each ; and (2) compute where is a categorical parameter. This procedure generates a one-hot sample , such that for and with . We denote to be the distribution whose samples are generated by the Gumbel-Max trick.

A Gumbel-Softmax trick (Jang,E. et al., 2017; Maddison et al., 2017) is a variant of the Gumbel-Max trick that relaxes a categorical random variable into a continuous one. The key of the Gumbel-Softmax is using the softmax activation with a temperature , instead of using argmax in the sampling process, which enables (1) relaxing the discreteness of the categorical random variable in the one-hot form to have a continuous value ; and (2) approximating the Gumbel-Max by taking small enough. Recently, the Gumbel-Softmax estimator is widely used to reparameterize categorical random variables, for example, RelaxedOneHotCategorical

in TensorFlow

(Abadi et al., 2016). We denote the distribution generated by the Gumbel-Softmax trick as .

3 Approximation on Discrete Random Variables with Truncations

This section provides the theoretical background of our reparameterization methodology. Our theoretic framework consists of two steps, and its first step relaxes a class of discrete random variables, which we defined as a truncatable discrete random variable, into a continuous random variable. The second step is applying the reparameterization trick that is more generalized than the reviewed reparameterization practice in Section 2.2.

To support our reparameterization methodology, this section provides three theorems on reparameterizations. The first theorem approximates an original discrete distribution with its truncated version. Next, the second theorem enables the truncated distribution to be reparameterized by the Gumbel-Max trick. Finally, the third theorem shows the Gumbel-SoftMax function converges to the Gumbel-Max function under an assumption of the suggested linear transformation. Figure 1 illustrates the transformation relation for the reparameterization of the truncatable discrete random variables.

We would like to note that our proposed reparameterization trick boundaries the applicable discrete random variables by the truncation, and we generalize the reparameterization with the Gumbel-Softmax function by the introduction of the linear transformation. The combination of these two contributions provides the reparameterization trick that is expanded and grounded, theoretically.

Figure 1: In -axis, as truncation level , the distribution by Theorem 3. can be reparameterized by the Gumbel-Max trick and a linear transformation as in Theorem 4. In -axis, as temperature , where

is a computed PMF value vector of

, by Theorem 5.

3.1 Truncatable Discrete Random Variables

We first define the class of discrete random variables that boundaries the feasible distributions of our reparameterization trick. Definition 1 specifies a truncated discrete random variable.

Definition 1.

A truncated discrete random variable of a non-negative discrete random variable is a discrete random variable with the point probability mass of for and . The random variable is said to follow a truncated discrete distribution with a parameter and a truncation level .

Note that Definition 1 can be easily extended to truncate the left-hand side or the both sides of distributions. However, we focus on the non-negative distribution since most of the popularly used discrete random variables have the support of .

Recall a popular inequalitiy in the probability theory, namely

Markov’s inequality, stated in Proposition 2.

Proposition 2.

(Markov’s Inequality)  If is a non-negative random variable, then for any , .

There are discrete distributions with a finite mean and a finite variance: for example, binomial, Poisson, geometric, and negative binomial distributions. Since we are focusing on the discrete distributions of a finite mean and a finite variance, the inequalities guarantee that there exists only a small amount of mass at the tail of the distributions. In other words, if we take the truncation level far enough from zero, we can cover the most of the possible outcomes which can be a sample from the distribution. This idea leads to Theorem 3.

Theorem 3.

For a non-negative discrete random variable D with the parameter , which has a finite mean and a finite variance, let’s define its truncated distribution TD with a truncation level as Definition 1. Then, converges to in probability as . We say that the distribution D is truncatable if the theorem holds for the truncated distribution TD.

Theorem 3 supports the theoretical basis of truncating a discrete random variable into a truncated one , and Appendix A.1 shows the detailed proof.222A similar statement can be proven for truncating both sides distributions using Chebyshev’s inequality, and Appendix A.2 shows the detailed definition and proof for the two-tail case.

One example of a truncatable discrete random variable is a discrete random variable following the Poisson distribution. Since the Poisson distribution has a finite mean and a finite variance as

, given a parameter , the Poisson distribution is a truncatable distribution by Theorem 3. Note that the Poisson distribution draws samples around the rate parameter , and the probability mass function (PMF) value of goes to zero as grows. Moreover, the summation of the PMF values beyond the truncation level converges to zero as . This property is crucial in our reparameterization method because it allows finitizing the number of possible outcomes by ignoring the samples of extremely small probabilities.333When the rate parameter has a large value, the distribution hardly samples the small values. In this case, truncating both sides with the center , which is the mean and approximated mode, might be a better choice. In some distributions, such as the binomial distribution, the summation of the PMF values beyond the truncation level becomes zero by setting the truncation level to be an arbitrarily chosen large level. Hence, in this case, we can simply take the truncation level as the number of trials, which is a parameter of the binomial distribution, and we can utilize the full spectrum of possible outcomes of the distribution.

The converge in probability property of Theorem 3 ensures that our truncation method is probabilistically stable. By injecting the near-zero remaining probability to the last category right before the truncation level, the sum-to-one property remains satisfied. Through the truncation method, the discrete distribution is ready to be approximated by the Gumbel-Softmax trick.

3.2 Reparameterization by Generalized Gumbel-Softmax Trick

Since widely utilized discrete distributions have the explicit forms of PMF, we can directly compute the PMF values for the truncated support with a pre-defined truncation level . Let be the computed PMF of a truncated distribution , where , of a truncatable distribution .444Here, we assume that we have an explicit form of a PMF on the distirubion . We now suggest a transformation: , such that where is a constant outcome vector.555By diversifying the construction of , it is able to reparameterize any kind of truncatable discrete random variables with a finite mean and a finite variance. Also, we denote the distributions, generated by applying on GM and GS, as and , respectively. Afterwards, we can reparameterize with as stated in Theorem 4, proved in Appendix A.3.

Theorem 4.

For any truncated discrete random variable of truncatable distribution with a transformation , can be reparameterized by if we set .

Theorem 4 indicates that we can generate a sampled value of by a linear transformation of a Gumbel-Max sample. Now, the randomness of with respect to the parameter moves into the Gumbel-Max with respect to the Gumbel sample in the Gumbel-Max trick, since the linear transformation is a continuous and deterministic function.

Then, we can apply the Gumbel-Softmax trick to the Gumbel-Max in as stated in Theorem 5, proved in Appendix A.4. The theorem implies that we can relax the truncated discrete random variable by the Gumbel-Softmax and the linear transformation, .

Theorem 5.

For a transformation , and a given categorical parameter , the convergence property of Gumbel-Softmax to Gumbel-Max still holds under the linear transformation , i.e., as implies as .

The assumption of Theorem 5 that as has not proven mathematically in the literatures where it was originally suggested (Jang,E. et al., 2017; Maddison et al., 2017). Instead, the authors empirically show that eventually becomes as goes near zero.

Figure 2: An illustration of various choices of the truncation level and the temperature in the case of . As sub-figures go from left to right, the truncation level grows, and hence the sticks, implying remaining probability at the right side, disappears if the truncation level is large enough. As sub-figures goes from top to bottom, the temperature decreases, and the PMF of truncated distributions become similar to the original distribution.

Figure 2 illustrates the approximated Poisson distribution by the Gumbel-Softmax estimator with the truncation level and the Gumbel-Softmax temperature . We can observe that the approximation becomes closer to the Poisson distribution as we increase . However, the increment of is technically limited due to the finite neural network output for the inference. Additionally, the decrement of results in the closer Poisson distribution. When we recall that the relaxed one-hot vector , the initially small leads to high variance of gradients, which becomes problematic at the learning stage on . Therefore, the annealing of from a large value to a small one is necessary to provide a learning chance of . Having said that, the annealing process will take the learning time, so the decrement of will be limited by a time budget.

4 Algorithm and Characteristics of GenGS

So far, we discussed the theoretical approach of our reparameterization method for truncatable discrete random variables. In summary, the concept of our work is the following: (1) approximate a discrete distribution by truncating the distribution; (2) reparameterize the truncated distribution with the Gumbel-Max trick and the linear transformation ; and (3) relax the Gumbel-Max trick with the Gumbel-Softmax trick. For simplicity, we name our estimator as Generalized Gumbel-Softmax (GenGS) estimator. As other reparameterization tricks do, GenGS introduces an auxiliary random variable, which is a uniform random variable composing the Gumbel sample in our case. This Gumbel sample enables the reparameterized variable to be deterministic with respect to the parameter of the target distribution. This section describes how GenGS can be used in the practice, how to sample discretized values, and be applied to the general form of discrete random variables.

  Input: Loss , PMF , Temperature , Truncation level , Gumbel-Softmax reparameterizer GS, Transformation , Constant vector .
  repeat
     for  to  do
        Compute .
     end for
     Compute .
     Sample -dimensional one-hot-like .
     Compute transformation .
     Compute loss .
     Update via stochastic gradient method.
  until convergence
Algorithm 1 GenGS: Inference step on the parameter of distribution

4.1 Algorithm

We propose two ways of applying GenGS: (1) explicitly inferring the distribution with a parameter ; and (2) inferring utilizing the distribution with its PMF values .

Explicit Inference.

This is the usual case of inferring the behavior of latent variables with a distribution parameter by assuming the same distribution form. For example, in VAEs, we choose a prior distribution, and we set an approximate posterior distribution with the same type of prior distribution. This probabilistic model encodes the input to the approximate posterior parameter with the encoder network in the VAEs. In this general framework of the explicit mass modeling, Algorithm 1 shows how to apply the proposed reparameterization trick with a truncatable discrete distribution. Specifically, becomes the explicit point-mass model, and will be adjusted by the choice of , see Section 4.3 for an example of . Finally, given and , Algorithm 1 optimizes the inferenced parameters of . Additionally, we provide Figure 3 to illustrate the process of the reparameterization by GenGS.

Implicit Inference.

Unlike continuous random variables, the discrete random variables have discretized outcomes. If the number of possible outcomes is finite, by utilizing the truncation for examples, we can directly infer the PMF values of possible outcomes with a categorical parameter , which becomes the input of the Gumbel tricks. In other words, we are now interested in inferring the shape of the distribution, not the distribution parameter , by loosening the assumption on the approximate posterior. However, this inference approach needs a regularizer, such as the KL divergence term in VAEs, which ensures the distribution shape to be alike a prior distribution with a preset parameter. We found that loosening the approximate posterior assumption leads to a significant performance gain in our VAE experiments.

Figure 3: Visualization of GenGS reparameterization. Shaded nodes denote auxiliary random nodes which enables the reparameterized variable to be deterministic with respect to the distribution parameter through the transformation of .

4.2 Discretization

With GenGS, the reparameterized sample has a continuous value since we are relaxing the discrete random variable into a continuous form. Utilizing the Straight-Through (ST) Gumbel-Softmax estimator (Bengio et al., 2013; Jang,E. et al., 2017), instead of the naive Gumbel-Softmax, we can obtain the discrete sample, as well. Since ST Gumbel-Softmax discretizes the relaxed Gumbel-Softmax output with argmax, ST Gumbel-Softmax uses the gradients obtained from the relaxed ones, which could result in significant performance degradation.

4.3 Extension

Adjusting Truncation Level.

As we discussed in Section 3.1, the truncation does not necessarily start from zero. For example, let’s consider a Poisson distribution with the rate parameter . Though has support starting from zero, the PMF values can be disregarded probabilistically up to a certain point. Therefore, we can truncate for both left and right sides, not from zero. For example, we can choose the left and the right truncation level as and , which have near zero PMF values around 1e-8 and 6e-7, respectively.

Manipulating Transformation Constant.

Transformation includes a constant vector . When we interpret , each entry of is a possible outcome of the truncated distribution . If the truncated discrete distribution has a possible outcome set with a corrensponding PMF value set , modifying that would still give the same result. Example 6 shows that Gumbel-Softmax estimator can be reduced down to GenGS.

Example 6.

The Gumbel-Softmax estimator of categorical random variables is a trivial case of GenGS. Let’s assume the number of dimensions and a categorical parameter in this example. Then, the poissble outcomes of are , , and . Afterwards, we draw a sample from for some temperature , and the value will become a relaxed one-hot form, for example, . If we construct , then . Hence, the Gumbel-Softmax trick can be written in the form of GenGS

with the identity matrix in the linear transformation.

Figure 4:

Curves of losses (top) and variances of gradients (bottom) with confidence intervals for synthetic example in

log scale: (Left) , (Center) , (Right) .

5 Related Work

Note that GenGS is basically a single-sample gradient estimator like other reparameterization gradient estimators. Though GenGS could use multiple samples to compute the stable gradients, we compare GenGS with the other estimators using a single sample to test the fundamental performance of gradient estimators. NVIL (Mnih and Gregor, 2014) and MuProp (Gu et al., 2016) are single-sample gradient estimators. NVIL utilizes a neural network to introduce the optimal control variate, and MuProp utilizes the first-order Taylor expansion on the loss term as a control variate. Deterministic RaoBlack estimator (DetRB) (Liu et al., 2019)

uses the weighted value of the exact gradients from selected categories and the estimated gradients from the remaining categories with respect to their odds to reduce the variance. The idea of Stochastic RaoBlack estimator (

StoRB) (Kool et al., 2020) is fundamentally same as DetRB, but the difference between the two gradient estimators is the utilization of fixed categories by DetRB while StoRB randomly chooses the categories at each step. The authors of Kool et al. (2020) also suggest an unordered set gradient estimator (UnOrd), which also uses the multiple sampled gradients, utilizing the sampling without replacements. For DetRB, StoRB, and UnOrd, we use only one category that utilizing the exact gradient for the fair comparison. Also, we add notation to model names as a variation that using a built-in control variate introduced in Kool et al. (2020).

6 Experiment

This section presents the experiments on GenGS, comparing to other gradient estimators. First, we compare the loss and its variance by synthetic examples. Second, we move to the practical applications of our reparameterization trick by VAEs with various prior settings and by topic models in Section 6.3. Appendix B explains the details of the experiments.

6.1 Synthetic Example

Experimental Setting.

In this experiment, we first sample independently and identically from a discete distribution for a chosen , and we optimize the loss function with respect to where is . We use , , and in this experiment. For fair comparisons, we use exact gradients for the RBs and UnOrd. Whereas it is able to use more than one gradient in the synthetic example, if there is more than one latent dimension, , it requires to compute gradient combinations, which has high complexity. We also adapt the Rao-Blackwellization idea in GenGS, which is utilizing true gradient and utilizing GenGS for the remainings, namely GenGS-RB.

Experimental Result.

We compare the log-loss and the log-variance of estimated gradients from various estimators in this experiment. Figure 4 shows the two properties of the estimated gradients. Here, the log-loss needs to be minimized to correctly estimate the back-propagated gradient value in the learning process. Also, the log-variance requires to be minimized to maintain the consistency of the gradients, so the gradient descent can be efficient. GenGS shows the best log-loss and the best log-variance if GenGS keeps the continuous relaxation of the modeled discrete random variable. The results also show that adapting the Rao-Blackwellization idea, GenGS-RBs, further reduces the log-variance of gradients of other GenGSs. UnOrd shows the minimum gradient among baselines in the binomial and negative binomial cases, but it fails to converge to the optimal parameter because of its approximation accuracy problem.

MNIST RF RF NVIL MuProp StoRB UnOrd GenGS ST (Ex.) GenGS (Ex.) GenGS ST (Im.) GenGS (Im.)
Poisson(2)
Poisson(3)
Geometric(0.25) ——
Geometric(0.5) ——
NegativeBinomial(3,0.5) —— ——
NegativeBinomial(5,0.3) —— ——
OMNIGLOT RF RF NVIL MuProp StoRB UnOrd GenGS ST (Ex.) GenGS (Ex.) GenGS ST (Im.) GenGS (Im.)
Poisson(2)
Poisson(3)
Geometric(0.25) ——
Geometric(0.5) ——
NegativeBinomial(3,0.5) —— ——
NegativeBinomial(5,0.3) —— ——
Table 1: Test negative ELBO on MNIST and OMNIGLOT datasets. The lower is better for the negative ELBO. Symbol “——” in the table implies that the estimator does not lead to the optimal point.

6.2 Variational Autoencoders

Experimental Setting.

We choose VAE to be the application to test the performance of the gradient estimators. While previous categorical VAEs flatten sampled one-hot categorical outputs on a latent variable into a single vector, we assume every single dimension of the latent variable to follow the prior distribution. We utilize the (truncated) Poisson, the geometric, and the negative binomial distributions in this experiment. The objective function in Equation 2, which consists of the reconstruction part and the KL divergence part, is minimized during the training period.

(2)

To optimize the evidence lower bound (ELBO) of VAE, it requires to compute the KL divergence between the approximate posterior and the prior distributions. In GenGS, by truncating the original distribution, the KL divergence reduces down to the derivation with categorical distributions as the below theorem states, proved in Appendix A.5.

Theorem 7.

Assume two truncated distributions and where , . Then, the KL divergence between and can be represented in the KL divergence between the categorical distributions where .

We construct two fully-connected hidden layers of dimension for the encoder and the decoder, and we set the latent dimension for both MNIST and OMNIGLOT datasets. This task is a more challenging task than the synthetic example, since this task requires to compute (1) the gradients of the encoder neural network parameters through the latent distribution parameter ; and (2) a single combination of gradients over the latent dimensions takes extremely small portion of the entire gradient combination, if there are possible outcomes for each latent dimensions. Moreover, since we are utilizing the fully-connected layers, each gradient at the latent layer affects all encoder parameters. Hence, a deviating gradient from the true gradient with respect to the latent distribution parameter could lead the encoder parameters away from the optimal point. This task utilizes the implicit inference discussed in Section 4.1, so the KL divergence term becomes a regularizer of the shape from the approximated posterior distribution.

Figure 5: Reconstructed images by VAEs with various gradient estimators. GenGS shows the clearest images among other gradient estimators with better reconstruction.
Experimental Result.

Table 1 shows the negative ELBO results on the VAE experiments. We found that the baselines such as DetRB, DetRB, and StoRB fails to reach the optimal point, so we exclude those estimators from the table. The variants of GenGS shows the lowest negative ELBO in general. Figure 5 shows the reconstructed images by VAEs with various gradient estimators on MNIST and OMNIGLOT. GenGS draws the clearest images and better reconstructions, which aligns with the quantitative result of the gradient estimators.

Figure 6: A graphical notation (left) and a neural network diagram (right) of NVPDEF. (Left) Solid lines with parameter denote the generative process, while a dashed line with parameter indicates the inference network. The multi-stacked latent layers have as a prior distribution parameter. (Right) Diamond nodes indicate the auxiliary random variable for the reparameterization trick.

6.3 Topic Model Application

Experimental Setting.

This experiment shows the usefulness of GenGS by a topic model application. Among all discrete distributions that we discussed above, the Poisson distribution is one of the most important distributions over the discrete distributions, which is used for counting the number of outcomes. The authors of Deep Exponential Families (DEFs) (Ranganath et al., 2015) utilize the exponential family, including the Poisson distribution, on the stacked latent layers. The architecture of DEF consists of multiple top-to-bottom latent layers and an output layer. For instance, from the perspective of topic modeling, as the layer goes to the top, the layer captures a super-topic; and as the layer goes to the bottom, the layer captures a sub-topic. Utilizing multiple Poisson distributions at the latent layers enables capturing the number of topic appearances in each layer.

Our experimental setting focuses on the PoissonDEF, which assumes the Poisson latent layers, and we convert PoissonDEF into a neural network form, which resembles the neural variational document model (NVDM) (Miao et al., 2016), see Figure 6 depicting the neural network and the probabilistic modeling structure. We utilize GenGS on PoissonDEF to sample the values in the latent variable, and we name such variant as neural variational Poisson DEF (NVPDEF). We use stacked layers for 20Newsgroups, which has the vocabulary size of ; and we use stacked layers for RCV1-V2 dataset, which has the vocabulary size of . For the multi-stacked version of NVPDEF, we utilized multi-sampling on the latent layers for the stable optimization from the consecutive latent samplings.

Experimental Result.

We enumerate the baselines and the variants of NVPDEFs in Table 2, and we confirmed that NVPDEF shows the lowest perplexity in overall with 20Newsgroups and RCV1-V2. Since NVPDEF and the original DEFs have different training and testing regimes, we compare NVPDEF against representative neural variational topic (document) models, which are listed in Table 2. Additionally, Appendix B.3 shows the qualitative result from topic models.

Model 20Newsgroups (Dim.) RCV1-V2 (Dim.)
LDA (Blei et al., 2003) (50) (200)
NVDM (Miao et al., 2016) (50) (200)
GSM (Miao et al., 2017) (50) (200)
NVLDA (Srivastava and Sutton, 2017) (50) (200)
ProdLDA (Srivastava and Sutton, 2017) (50) (200)
NVPDEF (50) (200)
Multi-Stacked NVPDEF (20-50) (50-200)
Table 2: Test perplexity on 20Newsgroups and RCV1-V2 dataset.

7 Conclusion

This paper suggests a new gradient estimator of discrete random variables, GenGS, which is a generalzied version of the Gumbel-Softmax estimator. To strengthen the practical usage of reparameterization tricks with the Gumbel-Softmax function, we provide a theoretic background and boundary to our reparameterization trick. Our finding claims that a truncatable discrete random variable can always be reparameterized via the proposed GenGS algorithm. The limitation of GenGS is the setting of the truncation level and the Gumbel-Softmax temperature, which becomes the trade-off between the gradient estimation accuracy and the time budget. Subsequently, we show the synthetic analysis as well as two applications, VAE and topic model, of GenGS. We expect that GenGS clearly boundaries and generalizes the reparameterization trick on the discrete random variable.

References

  • M. Abadi, P. Barham, J. Chen, Z. Chen, A. Davis, J. Dean, M. Devin, S. Ghemawat, Ge. Irving, M. Isard, and et al. (2016) Tensorflow: a system for large-scale machine learning. USENIX Symposium on Operating Systems Design and Implementation. Cited by: §2.2.
  • Y. Bengio, N. Leonard, and A. Courville (2013)

    Estimating or propagating gradients through stochastic neurons for conditional computation

    .
    arXiv preprint arXiv:1308.3432,. Cited by: §4.2.
  • D. M. Blei, A. Y. Ng, and M. I. Jordan (2003) Latent dirichlet allocation. Journal of Machine Learning Research. Cited by: Table 2.
  • M. Figurnov, S. Mohamed, and A. Mnih (2018) Implicit reparameterization gradients. Advances in Neural Information Processing Systems. Cited by: §1.
  • S. Gu, S. Levine, I. Sutskever, and A. Mnih (2016) Muprop: unbiased backpropagation for stochastic neural networks. International Conference on Learning Representations. Cited by: §5.
  • E. J. Gumbel (1948) Statistical theory of extreme values and some practical applications: a series of lectures (vol. 33). US Government Printing Office. Cited by: §2.2.
  • Jang,E., S. Gu, and B. Poole (2017) Categorical reparameterization with gumbel-softmax. International Conference on Learning Representations. Cited by: §1, §2.2, §3.2, §4.2, footnote 7.
  • M. Jankowiak and F. Obermeyer (2018) Pathwise derivatives beyond the reparameterization trick. International Conference on Machine Learning. Cited by: §1.
  • W. Joo, W. Lee, S. Park, and I. C. Moon (2019) Dirichlet variational autoencoder. arXiv preprint arXiv:1901.02739. Cited by: §1.
  • D. P. Kingma and M. Welling (2014a) Auto-encoding variational bayes. International Conference on Learning Representations. Cited by: §1.
  • D. P. Kingma and M. Welling (2014b) Efficient gradient-based inference through transformations between bayes nets and neural nets. International Conference on Machine Learning. Cited by: §1.
  • D. A. Knowles (2015) Stochastic gradient variational bayes for gamma approximating distributions. arXiv preprint arXiv:1509.01631. Cited by: §1.
  • W. Kool, H. van Hoof, and M. Welling (2020) Estimating gradients for discrete random variables by sampling without replacement. International Conference on Learning Representations. Cited by: §5.
  • P. Kumaraswamy (1980)

    A generalized probability density function for double-bounded random processes

    .
    Journal of Hydrology. Cited by: §1.
  • R. Liu, J. Regier, N. Tripuraneni, M. I. Jordan, and J. McAuliffe (2019) Rao-blackwellized stochastic gradients for discrete distributions. International Conference on Machine Learning. Cited by: §5.
  • C. J. Maddison, A. Mnih, and Y. W. Teh (2017) The concrete distribution: a continuous relaxation of discrete random variables. International Conference on Learning Representations. Cited by: §1, §2.2, §3.2.
  • Y. Miao, E. Grefenstette, and P. Blunsom (2017) Discovering discrete latent topics with neural variational inference. International Conference on Machine Learning. Cited by: Table 2.
  • Y. Miao, L. Yu, and P. Blunsom (2016) Neural variational inference for text processing. International Conference on Machine Learning. Cited by: §B.3, §6.3, Table 2.
  • A. Mnih and K. Gregor (2014) Neural variational inference and learning in belief networks. International Conference on Machine Learning. Cited by: §5.
  • E. Nalisnick and P. Smyth (2017) Stick-breaking variational autoencoders. International Conference on Learning Representations. Cited by: §1.
  • J. Pitman (2002) Combinatorial stochastic processes. Technical report, UC Berkeley. Cited by: §1.
  • R. Ranganath, L. Tang, L. Charlin, and D. Blei (2015) Deep exponential families. Artificial Intelligence and Statistics. Cited by: §B.3, §6.3.
  • A. Srivastava and C. Sutton (2017) Autoencoding variational inference for topic models. International Conference on Learning Representations. Cited by: Table 2.
  • M. Xu, M. Quiroz, R. Kohn, and S. A. Sisson (2019) Variance reduction properties of the reparameterization trick. International Conference on Artificial Intelligence and Statistics. Cited by: §1.

Appendix A Appendix 1: Theorems and Proofs

This section provides the detailed proofs for theorems suggested in the main paper.

a.1 Proof of Theorem 3

Theorem 8.

For a non-negative discrete random variable D with the parameter , which has a finite mean and a finite variance, let’s define its truncated distribution TD with a truncation level as Definition . Then, converges to in probability as . We say that the distribution is truncatable if the theorem holds for the truncated distribution.

Proof.

Let and so that if , and if . Since has a finite mean, by Markov’s inequality in Proposition ,

Then, for small ,

where

denotes for the cumulative distribution function of

. ∎

a.2 Theorem 3 for Truncating Both Sides

The below theorem is another version of Theorem 3 in the main paper, which we mentioned as a Footnote 2. Before we state and proof the theorem, we first suggest the other definition on truncated discrete distribution for truncating both left-hand and right-hand side of the distribution.

Definition 9.

A truncated discrete random variable with center of a discrete random variable is a discrete random variable with point probability mass for and , . The random variable is said to follow a truncated discrete distribution with center , parameter and truncation level .

Theorem 10.

For any random variable D with the parameter , which has a finite mean and a finite variance, let’s define its truncated distribution TD with center and truncation level as the definition above. Then, converges to in probability as . We say that the distribution is truncatable with center if the theorem holds for the truncated distribution.

Proof.

Chebyshev’s inequality in the probability theory states the following: for any random variable , and for any , . Since has a finite variance , for small , by Chebyshev’s inequality,

Hence, for small ,

a.3 Proof of Theorem 4

Theorem 11.

For any truncatable discrete random variable with a transformation , can be reparameterized by if we set .

Proof.

Note that has two parameters and . By pre-defining the trunacation level as a hyper-parameter, the randomness of is fully dependent on the distribution parameter . Now we introduce the Gumbel random variable where as an auxiliary random variable. Then, given a categorical parameter , any -dimensional one-hot vector , which has in entry and in all other entries, can be reparameterized by Gumbel-Max trick.

Suppose we have a sample from , and note that we have known as the PMF values of . Then, with the transformation , the following holds:

Since the transformation is also a deterministic function, by introcuding the Gumbel random variable as an auxiliary random variable, we can replace the randomness of from (or in the implicit inference case) with the uniform random variable composing the Gumbel random variable. Hence, the truncatable discrete random variable can be reparameterized by the Gumbel-Max trick and the transformation . ∎

a.4 Proof of Theorem 5

Theorem 12.

For a transformation , and a given categorical parameter , as implies as .

Proof.

Suppose we have given a categorical parameter . Define be a Gumbel-Max trick function, and be a Gumbel-Softmax trick function with a temperature that the both functions take the categorical parameter and a Gumbel sample as inputs. Note that returns a one-hot vector which has in the argmax entry after the Gumbel perturbation, and returns a one-hot-like softmax activation value with the temperature with the Gumbel perturbation.

Draw a random sample which defines the Gumbel sample for the Gumbel perturbation. Assume that for all . Therefore, for the Gumbel-Max trick, it is clear that . Then, the statement that as implies , i.e., the following:

Then, the components of converges to zero or , since the constant multiplication gives no harm to the approximation. Hence, by taking the summation, which also gives no harm to the approximation, . ∎

a.5 Proof of Theorem 7

Theorem 13.

Assume two truncated distributions and where , . Then, the KL divergence between and can be represented in the KL divergence between the categorical distributions where .

Proof.

Appendix B Appendix 2: Experiment Details

For all experiments, we use Intel Core i7-6700K CPU, 32GB RAM, and Titan X. For the dependency, we use TensorFlow version 1.15.0 and TensorFlow Probability version 0.8.0.

Super Topic 1 Super Topic 2 Super Topic 3 Super Topic 4 Super Topic 5 Super Topic 6
Topic 1 Topic 2 Topic 3 Topic 4 Topic 5 Topic 6 Topic 7 Topic 8 Topic 9 Topic 10 Topic 11 Topic 12
lebanese hitler knife tire water probe brand honda flyers pitcher holy bible
lebanon jewish gun helemet air spacecraft outlet dealer braves montreal resurrection biblical
palestinian nazi police rider heat plane sale offer hitter score passage faith
arabs religion weapon bike cold shuttle shipping sell philadelphia season jesus prayer
israeli territory firearm gear oil nasa insurance condition detroit game worship doctrine
islamic sentence handgun motorcycle gas launch price purchase rangers player christ verse
regulation moral officer wheel noise fuel supply market minnesota tie sin god
Table 3: Super-topic, sub-topic, and word relationship obtained from two-layer-stacked NVPDEF in 20Newsgroups dataset.

b.1 Synthetic Example

In this experiment, we first sample independently and identically from a discete distribution for a chosen , and we optimize the loss function with respect to where is . We use , , and in this experiment. For GenGS, we use truncation level and for the Poisson and the negative binomial, respectively. Note that the binomial case does not require a truncation of the distribution. We use sampled targets for the Poisson and the binomial cases, and for the negative binomial case. To compute the variance of gradients, we sampled gradients for the Poisson and binomial, and gradients for the negative binomial. For the fair comparison, we use categories, which utilizes the exact gradient samples, for the RBs and UnOrd estimators. We run the experiments over times for each case.

b.2 Variational Autoencoders

We utilize the (truncated) Poisson, the (truncated) geometric, and the (truncated) negative binomial distributions in this experiment. Both MNIST and OMNIGLOT666https://github.com/yburda/iwae/tree/
master/datasets/OMNIGLOT
are hand-written gray-scale datasets of size . We split MNIST dataset into {train : validation : test} = { : : }, and OMNIGLOT dataset into { : : }.

We construct two fully-connected hidden layers of dimension for the encoder and the decoder, and we set the latent dimension for both MNIST and OMNIGLOT datasets. we use tanhactivation function, learning rate 5e-4, and batch size , for MNIST and OMNIGLOT, repectively. For GenGS, we use temperature annealing777For GenGS ST, the temperature annealing is unnecessary as the ST Gumbel-Softmax estimator does (Jang,E. et al., 2017). from to , and truncation levels (1) for ; (2) for ; (3) for ; (4) for ; (5) for ; and (6) for . We run the experiments over times for each case.

b.3 Topic Model Application

20Newsgroups and RCV1-V2 datasets from Miao et al. (2016) are used in this experiment. 20Newsgroups dataset has {train : test} = { : } split with the vocabulary size of , and RCV1-V2 has {train : test} = { : } split with the vocabulary size of . We use stacked layers for 20Newsgroups, stacked layers for RCV1-V2 dataset. As the authors of Ranganath et al. (2015) used the gamma weight samples, we utilize softplus activation function to ensure the positiveness of decoder weights. For the multi-stacked version of NVPDEF, we utilized multi-sampling on the latent layers for the stable optimization of consecutive sampling. For all neural network models, we utilize two -dimensional hidden layers for the encoders. We set for the single-stacked case, and for the multi-stacked case. As a performance measure, we utilize perplexity where is the number of words in document , and is the total number of documents. We run the experiments over times for each case. Table 3 shows the relationship among super-topics, sub-topics, and words obtained from two-layer-stacked NVPDEF in 20Newsgroups dataset. We place the table in Appendix due to the lack of the margin in the original paper.