Log In Sign Up

ARM: Augment-REINFORCE-Merge Gradient for Discrete Latent Variable Models

by   Mingzhang Yin, et al.

To backpropagate the gradients through discrete stochastic layers, we encode the true gradients into a multiplication between random noises and the difference of the same function of two different sets of discrete latent variables, which are correlated with these random noises. The expectations of that multiplication over iterations are zeros combined with spikes from time to time. To modulate the frequencies, amplitudes, and signs of the spikes to capture the temporal evolution of the true gradients, we propose the augment-REINFORCE-merge (ARM) estimator that combines data augmentation, the score-function estimator, permutation of the indices of latent variables, and variance reduction for Monte Carlo integration using common random numbers. The ARM estimator provides low-variance and unbiased gradient estimates for the parameters of discrete distributions, leading to state-of-the-art performance in both auto-encoding variational Bayes and maximum likelihood inference, for discrete latent variable models with one or multiple discrete stochastic layers.


DisARM: An Antithetic Gradient Estimator for Binary Latent Variables

Training models with discrete latent variables is challenging due to the...

Augment-Reinforce-Merge Policy Gradient for Binary Stochastic Policy

Due to the high variance of policy gradients, on-policy optimization alg...

Latent Transformations for Discrete-Data Normalising Flows

Normalising flows (NFs) for discrete data are challenging because parame...

ARSM: Augment-REINFORCE-Swap-Merge Estimator for Gradient Backpropagation Through Categorical Variables

To address the challenge of backpropagating the gradient through categor...

Adaptive Perturbation-Based Gradient Estimation for Discrete Latent Variable Models

The integration of discrete algorithmic components in deep learning arch...

Efficient Marginalization of Discrete and Structured Latent Variables via Sparsity

Training neural network models with discrete (categorical or structured)...

Straight-Through Estimator as Projected Wasserstein Gradient Flow

The Straight-Through (ST) estimator is a widely used technique for back-...

Code Repositories


Low variance, unbiased gradient for discrete latent variable models

view repo