Log In Sign Up

Adaptive Perturbation-Based Gradient Estimation for Discrete Latent Variable Models

by   Pasquale Minervini, et al.

The integration of discrete algorithmic components in deep learning architectures has numerous applications. Recently, Implicit Maximum Likelihood Estimation (IMLE, Niepert, Minervini, and Franceschi 2021), a class of gradient estimators for discrete exponential family distributions, was proposed by combining implicit differentiation through perturbation with the path-wise gradient estimator. However, due to the finite difference approximation of the gradients, it is especially sensitive to the choice of the finite difference step size which needs to be specified by the user. In this work, we present Adaptive IMLE (AIMLE) the first adaptive gradient estimator for complex discrete distributions: it adaptively identifies the target distribution for IMLE by trading off the density of gradient information with the degree of bias in the gradient estimates. We empirically evaluate our estimator on synthetic examples, as well as on Learning to Explain, Discrete Variational Auto-Encoders, and Neural Relational Inference tasks. In our experiments, we show that our adaptive gradient estimator can produce faithful estimates while requiring orders of magnitude fewer samples than other gradient estimators.


page 1

page 2

page 3

page 4


Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions

Integrating discrete probability distributions and combinatorial optimiz...

Direct Optimization through for Discrete Variational Auto-Encoder

Reparameterization of variational auto-encoders with continuous latent s...

ARM: Augment-REINFORCE-Merge Gradient for Discrete Latent Variable Models

To backpropagate the gradients through discrete stochastic layers, we en...

Improved Gradient-Based Optimization Over Discrete Distributions

In many applications we seek to maximize an expectation with respect to ...

A Spectral Approach to Gradient Estimation for Implicit Distributions

Recently there have been increasing interests in learning and inference ...

Path-Gradient Estimators for Continuous Normalizing Flows

Recent work has established a path-gradient estimator for simple variati...

Gradient Estimation with Stochastic Softmax Tricks

The Gumbel-Max trick is the basis of many relaxed gradient estimators. T...

1 Introduction

Figure 1:

(Top) cosine similarity between the true and estimated gradients

, with and such that , where estimates are computed using IMLE (niepert21imle) with samples, the Straight-Through Estimator (STE, bengio2013estimating) with samples, and the Score Function Estimator (SFE, DBLP:journals/ml/Williams92) with samples, and (bottom) sparsity (% of zero elements) of the estimate IMLE gradient: for , the IMLE gradient estimate is , while increasing leads to increasingly more biased gradient estimates.

There is a growing interest in end-to-end learnable models incorporating discrete algorithms that allow, e.g. to sample from discrete latent distributions (jang2016categorical; paulus2020gradient) or solve combinatorial optimisation problems (poganvcic2019differentiation; Mandi_Guns:2020; niepert21imle)

. These discrete components are not continuously differentiable and an important problem is to efficiently estimate the gradients of their inputs to perform backpropagation. Reinforcement learning, discrete Energy-Based Models 

(EBMs, lecun2006tutorial), learning to explain chen2018learning, discrete Variational Auto-Encoders (VAEs, DBLP:journals/corr/KingmaW13), and discrete world models (hafner2020mastering)

are additional examples of neural network-based architectures that require the ability to back-propagate through expectations of discrete probability distributions.

The main challenge these approaches have in common is the problem of (approximately) computing gradients of an expectation of a continuously differentiable function :


where the expectation is taken over a complex discrete probability distribution with intractable marginals and normalisation constant.

In principle, one could use the Score Function Estimator (SFE, DBLP:journals/ml/Williams92)

. Unfortunately, it suffers from high variance which is typically exacerbated by the distribution

being intractable.

Implicit Maximum Likelihood Estimation (IMLE, niepert21imle), a recently proposed general-purpose gradient estimation technique has shown lower variance and outperformed other existing methods, including the sore function estimator and problem-specific continuous relaxations, in several settings niepert21imle; betz:2021; qian2022ordered. For instance, for the synthetic problem in Fig. 1, the gradient estimate produced by SFE based on samples is worse than the estimate based on — two orders of magnitude fewer — samples using IMLE due to the high variance of the SFE. IMLE

combines Perturb-and-MAP sampling with a finite difference method for implicit differentiation originally developed for loss functions defined over marginals 

domke:2010. In IMLE, gradients are approximated as:


where is a maximum-probability state of the distribution , is a perturbation drawn from a noise distribution, and . Computing MAP the states instead of sampling from is especially interesting since, in many cases, it has lower computational complexity than sampling the corresponding distribution (niepert21imle).

Crucially, the parameter determines the step-size of the finite difference approximation. When the input to is ’s continuously differentiable marginals, we have that smaller values of lead to less biased estimates. Hence, in this setting is typically set to a value that depends on the machine precision so as to void numerical instabilities domke:2010. In the setting, we consider, however, that the input to is discrete and discontinuous. Setting to a very small value, in this case, results in zero gradients. This is illustrated in Fig. 2 (right) for the forward difference method. Hence, the crucial insight is that trades off the bias and sparsity of the gradient approximation. In Fig. 1 (top) and (bottom) we plot, respectively, the bias and the sparsity of the gradient estimates for different values of on a toy optimisation problem. As we can see, larger values of result in a higher bias, and low values of result in gradient estimates almost always being zero. With this paper, we propose to make the parameter adaptive. We also provide empirical results showing that making adaptive reduces the bias and improves the results on several benchmark problems.

2 Problem Definition

We consider the problem of computing the gradients of an expectation over a discrete probability distribution of a continuously differentiable function , that is,



is a discrete probability distribution over binary vectors

and with parameters . More specifically, we are concerned with settings where is a discrete probability distribution with an intractable normalisation constant. Moreover, we assume that the function is a parameterized non-trivial continuously differentiable function which makes existing approaches such as direct loss minimisation and perturbed optimizers (berthet2020learning) not directly applicable.

More formally, let be a real-valued parameter vector. The probability mass function (PMF) of a discrete constrained exponential family r.v. is:


Here, is the inner product. is the log-partition function, defined as , and is an integral polytope of feasible configurations . We call the weight of the state . The marginals (expected value, mean) of the r.v.s are defined as . Finally, the most probable states also referred to as the Maximum A-Posteriori (MAP) states, are defined as .

3 Aimle: Adaptive Implicit Maximum-Likelihood Learning

We base our estimator on a finite difference method for implicit differentiation domke:2010; niepert21imle, which is generally applicable to any discrete distribution as defined in Eq. 4. For the initial derivation, we make the assumption that we can compute exact samples from the distribution using Perturb-and-MAP (Papandreou:2011) with noise distribution . We write for a perturbation of the parameters by a sample from a noise distribution . Since, in general, this is not possible for complex distributions, using approximate Perturb-and-MAP samples introduces a bias in the gradient estimates. For now, however, we make the assumption that these perturbations are exact, i.e., that . Under these assumptions, and by invoking the law of the unconscious statistician mohamed2019monte, we obtain:

We now approximate by for a small . That is, we replace a MAP state with a corresponding vector representing the marginal probabilities where we can make the probabilities increasingly spikier through lowering a temperature parameter . The approximation error between these two terms can be made arbitrarily small, since:

The above equality holds almost everywhere if the noise distribution is such that the probability of two or more components of being equal is zero. This is the case in the standard setting where  (Papandreou:2011). Therefore, we can write that, for some ,

Writing the expectation as an integral, we have:


Now we can exchange differentiation and integration since, for finite , and are continuously differentiable:

The last equality uses implicit differentiation by perturbation domke:2010, which is a finite difference method for computing the gradients of a function defined on marginals. Finally, we again approximate the expression with and obtain:


Again, the approximation error of the above expression is arbitrarily small (but not zero), because the derivation shown here is valid for any . A finite sample approximation of Eq. 6 results in the IMLE gradient estimator of Eq. 2 given in the introduction. While we could in principle use the marginals as input to the function as in relaxed gradient estimators maddison2016concrete; jang2016categorical; paulus2020gradient, computing marginals for the complex distributions we consider here is not tractable in general, and we have to use approximate Perturb-and-MAP samples.

function Init
         // Initial value of
         // Initial gradient norm estimate
       // Update step for
function ForwardPass()
   // Sample from the noise distribution
   // MAP states of perturbed
   , for
   save , , and
function BackwardPass()
   load , , and
    (see Eq. 7)
   for  do
   // Moving average of the gradient norm
   // Update to make closer to
Algorithm 1 Central Difference Perturbation-based Adaptive Implicit Maximum Likelihood Estimation (AIMLE).

3.1 An Adaptive Optimiser for Finite-Difference based Implicit Differentiation

An important observation that motivates the proposed adaptive version of IMLE

is that we need to choose a hyperparameter

for Eq. 2. Choosing a very small , leads to most gradients being zero. As a consequence, the gradients being back-propagated to the upstream model are zeros, which prevents the upstream model from being trained. If we choose too large, we obtain less sparse gradients, but the gradients are also more biased. Hence, we propose an optimiser that adapts during training so as to trade off non-zero but biased and sparse but unbiased gradients. Similar to adaptive first-order optimizers in deep learning, we replace a single hyperparameter with a set of new ones but show that we obtain consistently better results when using default hyperparameters for the adaptive method.

Normalisation of the perturbation strength.

Our first observation is that the magnitude of the perturbation in the direction of the negative downstream gradient in Eq. 6 highly depends on , the gradients of the downstream function . To mitigate the variations in the downstream gradients norm relative to the parameters , we propose to set a perturbation magnitude (the norm of the difference between and the perturbed ) to be a fraction of the norm of the parameter vector . In particular, let be such a fraction, then we seek such that:


This way we ensure that a global value for roughly translates to the same input-specific magnitude of the perturbation in the direction of the negative gradient.

Trading off Bias and Sparsity of the Gradient Estimates.

For computing as in Eq. 7, we track the sparsity of the gradient estimator with an exponential moving average of the gradient norm. Since the gradients – i.e. the difference between the two MAP states – in Eq. 2 for each are always in , we take the -norm which is here equivalent to the number of non-zero gradients.

Consider a batch of inputs during training. Let with be a single-sample gradient estimate from Eq. 2 for an input data point with input parameters and without the scaling factor . We compute, in every training iteration and using a discount factor , the exponential moving average of the number of non-zero gradients per training example:


Similarly to adaptive optimisation algorithms for neural networks, we introduce an update rule for . Let be the desired learning rate, expressed as the number of non-zero gradients per example. This is the target learning rate, that is, the desired number of gradients we obtain on average per example. A typical value is , meaning that we aim to adapt the value for such that we obtain, on average, at least one non-zero gradient per example. We now use the following update rule for some fixed :


Hence, by increasing or decreasing by a constant factor and based on the current exponential moving average of the gradient sparsity, we adapt through Eq. 7 which relates and . Algorithm 1 lists the gradient estimator as a layer in a neural network with a forward and backward pass.

Figure 2: Finite difference approximation of a downstream function on continuous marginals (left) and discrete samples (right). The step size trades off bias and sparsity of the gradient approximations for discrete samples and we propose to make the step size adaptive.

Forward and centred finite difference approximation.

Gradient estimation in IMLE, as outlined in Eq. 2, is analogous to gradient estimation with forward (one-sided) finite difference approximations, where . A better approximation can be obtained by the centred (two-sided) difference formula , which is a second-order approximation to the first derivative (Olver2013-sj). Following this intuition, we replace in Eq. 2 with , leading to the update equation in Algorithm 1.

4 Related Work

Continuous relaxations.

Several papers address the gradient estimation problem for discrete random variables, often resorting to continuous relaxations.

maddison2016concrete; jang2016categorical

propose the Gumbel-Softmax (or concrete) distribution to relax categorical random variables, which was extended by


to more complex probability distributions. The Gumbel-Softmax distribution is only directly applicable to categorical variables: for more complex distributions, one has to come up with tailor-made relaxations, or use the STE or SFE – e.g., see

kim2016exact and grover2019stochastic. REBAR (tucker2017rebar) and RELAX (grathwohl2017backpropagation) use parameterized control variates based on continuous relaxations for the SFE. In this work, we focus explicitly on problems where only discrete samples are used during training. Furthermore, REBAR is tailored to categorical distributions, while IMLE and AIMLE are intended for models with complex distributions and multiple constraints. Approaches that do not rely on relaxations are specific to certain distributions (bengio2013estimating; franceschi2019learning; liu2019rao) or assume knowledge of the constraints  (kool2020estimating). AIMLE and IMLE provide a general-purpose framework that does not require access to the linear constraints and the corresponding integer polytope . SparseMAP (Niculae2018SparseMAPDS)

is an approach to structured prediction and latent variables, replacing an exponential distribution with a sparser distribution; similarly to our work, it only presupposes the availability of a MAP oracle. LP-SparseMAP 

(Niculae2020LPSparseMAPDR) is an extension of SparseMAP that uses a relaxation of the underlying optimisation problem.

Differentiating through combinatorial solvers.

A series of works about differentiating through combinatorial optimisation problems (wilder2019melding; elmachtoub2020smart; ferber2020mipaal; DBLP:conf/nips/MandiG20) relax ILPs by adding a regularisation term, and differentiate through the KKT conditions deriving from the application of the cutting plane or the interior-point methods. These approaches are conceptually linked to techniques for differentiating through smooth programs (amos2017optnet; donti2017task; agrawal2019differentiable; chen2020understanding; domke2012generic; franceschi2018bilevel) that arise in modelling, hyperparameter optimisation, and meta-learning. Black-box Backprop (poganvcic2019differentiation; rolinek2020deep) and DPO (berthet2020learning)

are methods that are not tied to a specific ILP solver. Black-box Backprop, originally derived from a continuous interpolation argument, can be interpreted as special instantiations of

IMLE and AIMLE. DPO addresses the theory of perturbed optimizers and discusses perturb-and-MAP in the context of the Fenchel-Young losses. All the combinatorial optimisation-related works assume that either optimal costs or solutions are given as training data, while IMLE and AIMLE can also be applied in the absence of such supervision by making use of implicitly generated target distributions. Other authors focus on devising differentiable relaxations for specific combinatorial problems such as SAT (evans2018learning) or MaxSAT (wang2019satnet)

. Machine learning intersects with combinatorial optimisation also in other contexts, e.g. in learning heuristics to improve the performances of combinatorial solvers — we refer to

bengio2020machine for further details.

Direct Loss Minimisation (DLM, mcallester2010direct; song2016training) is also related to our work, but it relies on the assumption that examples of optimal states are given. lorberbom:2019 extend the DLM framework to discrete VAEs using coupled perturbations: their approach is tailored to VAEs, and is not general-purpose. Under a methodological viewpoint, IMLE inherits from classical MLE (wainwright2008graphical) and Perturb-and-MAP (Papandreou:2011). The theory of perturb-and-MAP was used to derive general-purpose upper bounds for log-partition functions (hazan2012partition; shpakova:2016).

5 Experiments

Similarly to niepert21imle, conducted three different types of experiments. First, we analyse and compare the behaviour of AIMLE with other gradient estimators (STE, SFE, IMLE) in a synthetic setting. Second, we consider a setting where the distribution parameters are produced by an upstream neural model, denoted by in Eq. 3, and the optimal discrete structure is not available during training. Finally, we consider the problem of differentiating through black-box combinatorial solvers, where we use the target distribution derived in Eq. 6. In all our experiments, we fix the AIMLE hyper-parameters and use the target gradient norm to , and the update step to , based on the AIMLE implementation described in Algorithm 1. More experimental details for available in the appendix.

Figure 3: Cosine similarity between the estimated gradient and the true gradient (-axis) using several estimators (IMLE, AIMLE, STE, and SFE) with samples, with (-axis) — the gradient estimates produced by AIMLE, both in its forward and central difference versions, are significantly more similar to the true gradient than the estimates produced by other methods.
Figure 4: Cosine similarity between the estimated gradient and the true gradient (-axis) using several estimators — namely IMLE with a varying (-axis), AIMLE, STE, and SFE — with samples.

Synthetic Experiments.

We conducted a series of experiments with a tractable categorical distribution where with and . We set the loss to , where .

In Fig. 3, we plot the cosine similarity between the gradient estimates produced by the Straight-Through Estimator (STE, bengio2013estimating) the Score Function Estimator (SFE, DBLP:journals/ml/Williams92), Implicit Maximum Likelihood Estimation (IMLE, niepert21imle), and AIMLE. For STE and IMLE, we use Perturb-and-MAP with Gumbel noise. For IMLE and AIMLE, we evaluated both their forward difference (Forward) and central difference (Central) versions. We evaluated all estimators using samples, and report how influences the cosine similarity between the gradient estimate and true gradient. Statistics are over runs. From the results, outlined in Fig. 3, we can see that AIMLE, both in its central and forward difference versions, produces significantly more accurate estimates of the true gradient compared to IMLE, STE, and SFE, with orders of magnitude fewer samples. Furthermore, we report the cosine similarity between the true and the estimated gradient for AIMLE, STE, SFE, and IMLE with a varying value of — all estimators use samples. Results are outlined in Fig. 4: we can see that AIMLE is able to produce gradient estimates that are comparable to the best estimates produced by IMLE, without the need of training a hyper-parameter.

Method Test MSE Subset Prec.
Mean SD Mean SD
Aspect: aroma,
SoftSub () 2.515 0.087 55.453 2.338
STE () 4.660 0.053 44.593 0.523
SST () 4.788 0.486 56.854 3.752
IMLE (Forward, , ) 2.413 0.055 53.744 5.635
IMLE (Central, , ) 2.266 0.050 50.888 5.453
AIMLE (Forward, ) 2.499 0.089 44.668 6.936
AIMLE (Central, ) 2.385 0.049 62.056 2.107
Aspect: aroma,
SoftSub () 2.543 0.044 44.513 2.958
STE () 4.310 0.039 39.635 0.281
SST () 5.213 0.295 24.328 12.463
IMLE (Forward, , ) 2.368 0.075 48.215 2.182
IMLE (Central, , ) 2.256 0.043 45.339 3.115
AIMLE (Forward, ) 2.402 0.042 48.397 1.967
AIMLE (Central, ) 2.419 0.061 53.260 2.271
Aspect: aroma,
SoftSub () 2.711 0.035 37.202 1.374
STE () 4.062 0.054 36.267 0.161
SST () 5.787 0.517 24.551 9.827
IMLE (Forward, , ) 2.411 0.087 41.850 1.477
IMLE (Central, , ) 2.508 0.396 40.057 7.172
AIMLE (Forward, ) 2.408 0.064 41.688 2.246
AIMLE (Central, ) 2.470 0.026 47.109 2.863
Table 1: Detailed results for the aspect aroma. Test MSE and subset precision, both , for .

Learning to Explain.

The BeerAdvocate dataset (McAuley2012LearningAA) consists of free-text reviews and ratings for different aspects of beer: appearance, aroma, palate, and taste. Each sentence in the test set has annotations providing the words that best describe the various aspects. Following the experimental setting in paulus2020gradient; niepert21imle, we address the problem of learning a distribution over -subsets of words that best explain a given aspect rating, introduced by chen2018learning. The complexity of the MAP problem for the -subset distribution is linear in .

The training set has 80,000 reviews for the aspect appearance and 70,000 reviews for all other aspects. Since the original dataset (McAuley2012LearningAA) did not provide separate validation and test sets, following niepert21imle

, we compute 10 different evenly sized validation and test splits of the 10,000 held out set and compute mean and standard deviation over 10 models, each trained on one split. Subset precision was computed using a subset of 993 annotated reviews. We use pre-trained word embeddings from 

Lei2016RationalizingNP. We extend the implementations provided by niepert21imle, which use a neural network following the architecture introduced by paulus2020gradient with four convolutional layers and one dense layer. This neural network outputs the parameters of the distribution over -hot binary latent masks with .

We compare AIMLE (both the forward and central difference versions) to relaxation-based baselines L2X (chen2018learning) and SoftSub (Xie2019ReparameterizableSS); to STE with Gumbel perturbations; and to IMLE (niepert21imle) with Gumbel perturbations. We used the standard hyperparameter settings of chen2018learning and choose the temperature parameter for all methods. For IMLE we choose based on the validation MSE. We trained separate models for each aspect using MSE as the training loss, using the Adam (DBLP:journals/corr/KingmaB14) optimiser with its default hyper-parameters.

Table 1 lists detailed results for the aspect aroma. We can see that AIMLE, in its central differences version, systematically produces the highest subset precision values while yielding test MSE values comparable to those produced by IMLE, while not requiring tuning the hyper-parameter. In the appendix, we report the results for the other aspects, where we notice that AIMLE produces significantly higher subset precision values in all other aspects as well.

Figure 5: Training dynamics for a DVAE using AIMLE and IMLE for (top) and (bottom).
Figure 6: Training dynamics for a DVAE using AIMLE and Gumbel-Softmax for .
Figure 7: Training dynamics for a DVAE using AIMLE and STE for (top) and (bottom).
Figure 8: Training dynamics for a DVAE using AIMLE with a different number of samples (i.e. ) for (top) and (bottom).
T=10 T=20
Edge Distribution ELBO Edge Prec. Edge Rec. ELBO Edge Prec. Edge Rec.
SST (Hard) -2301.47 85.86 33.75 9.44 60.40 23.23 -3407.89 221.53 57.40 17.87 70.42 8.22
IMLE (Forward) -2289.94 4.31 23.94 0.03 95.75 0.14 -3820.68 25.32 20.28 0.12 20.28 0.12
IMLE (Central) -2341.71 41.68 43.95 7.22 43.95 7.22 -3447.29 550.38 40.25 14.26 40.25 14.26
AIMLE (Forward) -2039.00 265.03 45.67 12.04 45.67 12.04 -1871.22 120.63 43.21 3.96 43.21 3.96
AIMLE (Central) -1980.46 316.15 46.48 14.25 46.48 14.25 -1571.98 30.59 82.51 1.30 82.51 1.30
Table 2: Latent Graph Structure Recovery – Stochastic Softmax Tricks (SST, paulus2020gradient) defining a spanning tree over undirected edges (with hard sampling), IMLE (niepert21imle), and AIMLE, where the MAP function is computed by Kruskal’s algorithm (Kruskal1956). AIMLE yields the lowest test ELBO values, both in the (shorter sequences) and the (longer sequences) settings.

Discrete Variational Auto-Encoder.

Following niepert21imle, we compare IMLE, STE, and the Gumbel-Softmax trick (maddison2016concrete; jang2016categorical) using a discrete -subset Variational Auto-Encoder (VAE). The latent variables model a probability distribution over -subsets of — or top- assignments — binary vectors of length 20; note that, for , this is equivalent to a categorical variable with 20 categories. We follow the implementation details in niepert21imle, where the encoder and the decoder of the VAE consist of three dense layers, where the encoder and decoder activations have sizes 512-256-20

20, and 256-512-784, respectively. The loss is the sum of the reconstruction losses — binary cross-entropy loss on output pixels — and the KL divergence between the marginals of the variables and the uniform distribution.

For AIMLE, IMLE, and STE, we use perturbations for , and Sum-of-Gamma (SoG, niepert21imle) perturbations for , with a temperature of . For IMLE, we select the hyper-parameter on a held-out validation set, based on the validation loss. For IMLE and AIMLE, we report the results for both the forward and central difference versions. We train the DVAE for epochs and report the loss on a held-out test set.

In Fig. 5 we show the test losses for IMLE and AIMLE (forward and central difference versions), for and . We can see that, for , AIMLE produces significantly lower test loss values than IMLE and, for , the central difference version of AIMLE produces test loss values comparable to IMLE. We also compare AIMLE to the Gumbel-Softmax trick (see Fig. 6) and STE (see Fig. 7): AIMLE produces significantly lower test losses than STE () while producing higher test losses than Gumbel-Softmax (). We also experimented with increasing the number of samples in AIMLE (see Fig. 8), finding that higher values of produce significantly lower test loss values, to the point that, for , AIMLE with produces lower test loss values than Gumbel-Softmax, but with a higher computational cost.

Neural Relational Inference for Recovering Latent Graph Structures.

In this experiment, we investigated the use of AIMLE for recovering latent graph structures and predicting the evolution of a dynamical system. In Neural Relational Inference (NRI, DBLP:conf/icml/KipfFWWZ18), a Graph Neural Network (GNN, DBLP:journals/tnn/Micheli09; DBLP:journals/tnn/ScarselliGTHM09) encoder is used to generate a latent interaction graph, which is then used to produce a distribution over an interacting particle system. NRI is trained as a variational auto-encoder to maximise a lower bound (ELBO) on the marginal log-likelihood of the time series. Based on the implementation provided by paulus2020gradient, we compared the Stochastic Softmax Tricks (SST, paulus2020gradient) encoder that induces a spanning tree over undirected edges, with an encoder producing a maximum spanning tree using Kruskal’s algorithm (Kruskal1956), and using either IMLE or AIMLE to back-propagate through it. In this setting, Kruskal’s algorithm represents the MAP estimator for a distribution over latent graph structures. Our dataset consisted of latent prior spanning trees over 10 vertices sampled from the prior. Given a tree, we embed the vertices in by applying iterations of a force-directed algorithm (DBLP:journals/spe/FruchtermanR91). The model saw particle locations at each iteration, not the underlying spanning tree. We found that AIMLE performed best, improving on both ELBO and the recovery of latent structure over the structured SST baseline proposed by paulus2020gradient.

6 Conclusions

We introduced Adaptive Implicit Maximum Likelihood Estimation (AIMLE), an efficient, simple-to-implement, and general-purpose framework for learning hybrid models. AIMLE is an extension of IMLE (niepert21imle) that, during training, can dynamically select the optimal target distribution by identifying the update step that yields the desired gradient norm. Furthermore, AIMLE incorporates insights from finite difference methods, improving its effectiveness in gradient estimation tasks. In our experiments, we show that AIMLE produces better results than relaxation-based approaches for discrete latent variable models, and approaches that back-propagate through black-box combinatorial solvers.

A limitation of this work, and a potential future extension, is that it relies on a warm-up period for selecting the optimal , whose duration varies depending on the update step — which we fix to . A potential solution to this problem is to use adaptive update steps, for instance by using momentum (DBLP:journals/nn/Qian99).


Appendix A Learning to Explain Experiments

Method Test MSE Subset Prec.
Mean SD Mean SD
Aspect: appearance,
SoftSub () 2.327 0.051 72.183 5.242
STE () 5.315 0.070 29.134 1.503
SST () 4.157 0.562 76.059 7.957
IMLE (Forward, , ) 2.420 0.066 54.457 7.165
IMLE (Central, , ) 2.263 0.086 69.679 12.590
AIMLE (Forward, ) 2.512 0.058 56.904 4.149
AIMLE (Central, ) 2.266 0.071 81.119 2.345
Aspect: appearance,
SoftSub () 2.349 0.064 63.667 5.215
STE () 4.967 0.049 33.629 2.268
SST () 4.973 0.443 60.363 12.225
IMLE (Forward, , ) 2.300 0.071 54.667 6.989
IMLE (Central, , ) 2.168 0.051 63.249 6.251
AIMLE (Forward, ) 2.305 0.065 59.840 4.859
AIMLE (Central, ) 2.191 0.037 72.167 2.744
Aspect: appearance,
SoftSub () 2.534 0.132 54.097 7.032
STE () 4.493 0.083 37.737 3.298
SST () 5.537 0.605 36.217 18.978
IMLE (Forward, , ) 2.245 0.059 52.654 3.913
IMLE (Central, , ) 2.192 0.047 51.246 9.288
AIMLE (Forward, ) 2.282 0.068 52.188 4.306
AIMLE (Central, ) 2.195 0.046 68.155 3.206
Table 3: Detailed results for the aspect appearance. Test MSE and subset precision, both , for .
Method Test MSE Subset Prec.
Mean SD Mean SD
Aspect: aroma,
SoftSub () 2.515 0.087 55.453 2.338
STE () 4.660 0.053 44.593 0.523
SST () 4.788 0.486 56.854 3.752
IMLE (Forward, , ) 2.413 0.055 53.744 5.635
IMLE (Central, , ) 2.266 0.050 50.888 5.453
AIMLE (Forward, ) 2.499 0.089 44.668 6.936
AIMLE (Central, ) 2.385 0.049 62.056 2.107
Aspect: aroma,
SoftSub () 2.543 0.044 44.513 2.958
STE () 4.310 0.039 39.635 0.281
SST () 5.213 0.295 24.328 12.463
IMLE (Forward, , ) 2.368 0.075 48.215 2.182
IMLE (Central, , ) 2.256 0.043 45.339 3.115
AIMLE (Forward, ) 2.402 0.042 48.397 1.967
AIMLE (Central, ) 2.419 0.061 53.260 2.271
Aspect: aroma,
SoftSub () 2.711 0.035 37.202 1.374
STE () 4.062 0.054 36.267 0.161
SST () 5.787 0.517 24.551 9.827
IMLE (Forward, , ) 2.411 0.087 41.850 1.477
IMLE (Central, , ) 2.508 0.396 40.057 7.172
AIMLE (Forward, ) 2.408 0.064 41.688 2.246
AIMLE (Central, ) 2.470 0.026 47.109 2.863
Table 4: Detailed results for the aspect aroma. Test MSE and subset precision, both , for .
Method Test MSE Subset Prec.
Mean SD Mean SD
Aspect: palate,
SoftSub () 2.857 0.049 51.435 2.258
STE () 4.456 0.038 30.650 0.449
SST () 4.357 0.596 50.634 4.587
IMLE (Forward, , ) 2.867 0.042 50.066 1.262
IMLE (Central, , ) 2.684 0.047 54.347 1.320
AIMLE (Forward, ) 2.856 0.050 47.818 4.231
AIMLE (Central, ) 2.683 0.037 56.043 1.540
Aspect: palate,
SoftSub () 2.957 0.048 35.604 1.818
STE () 4.065 0.039 32.096 0.501
SST () 4.754 0.260 37.302 7.304
IMLE (Forward, , ) 2.833 0.039 43.305 2.993
IMLE (Central, , ) 2.669 0.046 45.258 2.335
AIMLE (Forward, ) 2.837 0.062 40.777 3.054
AIMLE (Central, ) 2.666 0.020 49.895 3.750
Aspect: palate,
SoftSub () 3.138 0.051 27.927 1.651
STE () 3.849 0.047 30.727 0.947
SST () 5.171 0.645 24.095 9.535
IMLE (Forward, , ) 2.892 0.058 35.700 2.847
IMLE (Central, , ) 2.808 0.210 36.916 6.323
AIMLE (Forward, ) 2.885 0.027 36.070 2.279
AIMLE (Central, ) 2.708 0.031 44.860 1.549
Table 5: Detailed results for the aspect palate. Test MSE and subset precision, both , for .
Method Test MSE Subset Prec.
Mean SD Mean SD
Aspect: taste,
SoftSub () 2.196 0.045 42.762 1.785
STE () 4.591 0.047 39.360 0.402
SST () 3.942 0.354 35.376 4.146
IMLE (Forward, , ) 2.253 0.036 39.846 1.855
IMLE (Central, , ) 2.124 0.060 41.105 1.054
AIMLE (Forward, ) 2.196 0.054 38.344 2.007
AIMLE (Central, ) 2.132 0.043 43.539 1.175
Aspect: taste,
SoftSub () 2.137 0.048 42.716 1.085
STE () 4.283 0.053 38.089 0.089
SST () 4.422 0.176 30.581 4.077
IMLE (Forward, , ) 2.200 0.052 40.673 1.756
IMLE (Central, , ) 2.081 0.054 39.636 1.219
AIMLE (Forward, ) 2.190 0.030 41.098 1.525
AIMLE (Central, ) 2.111 0.034 45.593 1.222
Aspect: taste,
SoftSub () 2.173 0.043 40.189 1.318
STE () 4.009 0.035 37.364 0.058
SST () 4.868 0.264 32.507 3.649
IMLE (Forward, , ) 2.201 0.052 40.636 1.479
IMLE (Central, , ) 2.193 0.270 38.987 1.680
AIMLE (Forward, ) 2.196 0.046 40.993 1.627
AIMLE (Central, ) 2.138 0.041 43.089 1.780
Table 6: Detailed results for the aspect taste. Test MSE and subset precision, both , for .

Here we report the results on all aspects of the Learning to Explain task with the BeerAdvocate dataset (McAuley2012LearningAA), namely appearance (Table 3), aroma (Table 4), palate (Table 5) and taste (Table 6). In all cases, AIMLE produces the best subset precision results, while producing test MSE values that are comparable with those produced by IMLE.

Each experiment was re-run 10 times, for 20 training epochs; batch size was set to 40, with a kernel size of 3, hidden dimension of 250, and maximum sequence length of 350. All models were trained using the Adam (DBLP:journals/corr/KingmaB14) optimiser, with a learning rate of , by fitting an MSE loss between the predicted and true ratings. For all methods, the noise temperature was selected in . In IMLE, the hyper-parameter was selected in . In AIMLE, the hyper-parameters (the target gradient norm) and (the update step) were fixed to their default values, namely and .

Appendix B Discrete Variational Auto-Encoder Experiments

The data set can be loaded in PyTorch with

torchvision.datasets.MNIST. As in prior work, we use a batch size of and train for

epochs, plotting the test loss after each epoch. We use the standard Adam settings in PyTorch and no learning rate schedule. The MNIST dataset consists of black-and-white

pixels images of hand-written digits. The encoder network consists of an input layer with dimension , since we flatten the images; a dense layer with dimension

and ReLU 

(DBLP:conf/icml/NairH10) activations; a dense layer with dimension and ReLU activations; and a dense layer with dimension () which outputs the and no non-linearity. The layer implementing AIMLE receives as input and outputs a discrete latent code of shape . The decoder, which takes this discrete latent code as input, consists of a dense layer with dimension and ReLU activation; a dense layer with dimension and ReLU activations; and finally a dense layer with dimension

returning the logits for the output pixels. Sigmoid non-linearities are applied to these logits and used to compute the binary cross-entropy.

Appendix C Neural Relational Inference Experiments

We use the dataset of latent prior spanning trees over 10 vertices proposed by paulus2020gradient: latent spanning trees were sampled by applying Kruskal’s algorithm (Kruskal1956) to for a fully-connected graph. Initial vertex locations were sampled from in . Given the initial locations and the latent tree, the dynamical observations were obtained by applying a force-directed algorithm for graph layout for iterations. Then, the initial vertex positions were discarded, because the first iteration of the layout algorithm typically results in large relocations. Hence, the final dataset used for training consisted of 10 and 20 location observations in for each of the 10 vertices. Following this procedure, we generated a training set of size 50,000 and validation and test sets of size 10,000.


We follow the design and implementation provided by paulus2020gradient, where the NRI model consists of two neural modules, an encoder and a decoder. The encoder GNN passes messages over the fully connected directed graph with nodes. We took the final edge representations produced by the GNN, and use them as . The final edge representations are in , where over the 90 undirected edges (since we do not consider self-connections). Given the previous time-step data, the decoder GNN passes messages over the sampled graph adjacency matrix and predicts future node positions. As in DBLP:conf/icml/KipfFWWZ18 and paulus2020gradient, we used teacher-forcing every 10 time-steps. , in this case, was a directed adjacency matrix over the graph where are the nodes: is interpreted as there being an edge from to , and otherwise. represents the symmetric, directed adjacency matrix with edges in both directions for each undirected edge. The decoder passes messages between both connected and not-connected nodes. When considering a message from to , it uses one network for the edges such that , and another network for the edges such that .


We selected the hyper-parameters for SST (Spanning Tree, Hard), IMLE, and AIMLE, using a grid-search, and selected the hyper-parameter configuration which yields the highest ELBO on the validation set. For SST, we searched the learning rate in , and the temperature in . For IMLE, we searched the learning rate in , , and the noise temperature in . For AIMLE, we searched the learning rate in , and the noise temperature in .