1 Introduction
There is a growing interest in end-to-end learnable models incorporating discrete algorithms that allow, e.g. to sample from discrete latent distributions (jang2016categorical; paulus2020gradient) or solve combinatorial optimisation problems (poganvcic2019differentiation; Mandi_Guns:2020; niepert21imle)
. These discrete components are not continuously differentiable and an important problem is to efficiently estimate the gradients of their inputs to perform backpropagation. Reinforcement learning, discrete Energy-Based Models
(EBMs, lecun2006tutorial), learning to explain chen2018learning, discrete Variational Auto-Encoders (VAEs, DBLP:journals/corr/KingmaW13), and discrete world models (hafner2020mastering)are additional examples of neural network-based architectures that require the ability to back-propagate through expectations of discrete probability distributions.
The main challenge these approaches have in common is the problem of (approximately) computing gradients of an expectation of a continuously differentiable function :
(1) |
where the expectation is taken over a complex discrete probability distribution with intractable marginals and normalisation constant.
In principle, one could use the Score Function Estimator (SFE, DBLP:journals/ml/Williams92)
. Unfortunately, it suffers from high variance which is typically exacerbated by the distribution
being intractable.Implicit Maximum Likelihood Estimation (IMLE, niepert21imle), a recently proposed general-purpose gradient estimation technique has shown lower variance and outperformed other existing methods, including the sore function estimator and problem-specific continuous relaxations, in several settings niepert21imle; betz:2021; qian2022ordered. For instance, for the synthetic problem in Fig. 1, the gradient estimate produced by SFE based on samples is worse than the estimate based on — two orders of magnitude fewer — samples using IMLE due to the high variance of the SFE. IMLE
combines Perturb-and-MAP sampling with a finite difference method for implicit differentiation originally developed for loss functions defined over marginals
domke:2010. In IMLE, gradients are approximated as:(2) |
where is a maximum-probability state of the distribution , is a perturbation drawn from a noise distribution, and . Computing MAP the states instead of sampling from is especially interesting since, in many cases, it has lower computational complexity than sampling the corresponding distribution (niepert21imle).
Crucially, the parameter determines the step-size of the finite difference approximation. When the input to is ’s continuously differentiable marginals, we have that smaller values of lead to less biased estimates. Hence, in this setting is typically set to a value that depends on the machine precision so as to void numerical instabilities domke:2010. In the setting, we consider, however, that the input to is discrete and discontinuous. Setting to a very small value, in this case, results in zero gradients. This is illustrated in Fig. 2 (right) for the forward difference method. Hence, the crucial insight is that trades off the bias and sparsity of the gradient approximation. In Fig. 1 (top) and (bottom) we plot, respectively, the bias and the sparsity of the gradient estimates for different values of on a toy optimisation problem. As we can see, larger values of result in a higher bias, and low values of result in gradient estimates almost always being zero. With this paper, we propose to make the parameter adaptive. We also provide empirical results showing that making adaptive reduces the bias and improves the results on several benchmark problems.
2 Problem Definition
We consider the problem of computing the gradients of an expectation over a discrete probability distribution of a continuously differentiable function , that is,
(3) |
where
is a discrete probability distribution over binary vectors
and with parameters . More specifically, we are concerned with settings where is a discrete probability distribution with an intractable normalisation constant. Moreover, we assume that the function is a parameterized non-trivial continuously differentiable function which makes existing approaches such as direct loss minimisation and perturbed optimizers (berthet2020learning) not directly applicable.More formally, let be a real-valued parameter vector. The probability mass function (PMF) of a discrete constrained exponential family r.v. is:
(4) |
Here, is the inner product. is the log-partition function, defined as , and is an integral polytope of feasible configurations . We call the weight of the state . The marginals (expected value, mean) of the r.v.s are defined as . Finally, the most probable states also referred to as the Maximum A-Posteriori (MAP) states, are defined as .
3 Aimle: Adaptive Implicit Maximum-Likelihood Learning
We base our estimator on a finite difference method for implicit differentiation domke:2010; niepert21imle, which is generally applicable to any discrete distribution as defined in Eq. 4. For the initial derivation, we make the assumption that we can compute exact samples from the distribution using Perturb-and-MAP (Papandreou:2011) with noise distribution . We write for a perturbation of the parameters by a sample from a noise distribution . Since, in general, this is not possible for complex distributions, using approximate Perturb-and-MAP samples introduces a bias in the gradient estimates. For now, however, we make the assumption that these perturbations are exact, i.e., that . Under these assumptions, and by invoking the law of the unconscious statistician mohamed2019monte, we obtain:
We now approximate by for a small . That is, we replace a MAP state with a corresponding vector representing the marginal probabilities where we can make the probabilities increasingly spikier through lowering a temperature parameter . The approximation error between these two terms can be made arbitrarily small, since:
The above equality holds almost everywhere if the noise distribution is such that the probability of two or more components of being equal is zero. This is the case in the standard setting where (Papandreou:2011). Therefore, we can write that, for some ,
Writing the expectation as an integral, we have:
(5) | ||||
Now we can exchange differentiation and integration since, for finite , and are continuously differentiable:
The last equality uses implicit differentiation by perturbation domke:2010, which is a finite difference method for computing the gradients of a function defined on marginals. Finally, we again approximate the expression with and obtain:
(6) | |||
Again, the approximation error of the above expression is arbitrarily small (but not zero), because the derivation shown here is valid for any . A finite sample approximation of Eq. 6 results in the IMLE gradient estimator of Eq. 2 given in the introduction. While we could in principle use the marginals as input to the function as in relaxed gradient estimators maddison2016concrete; jang2016categorical; paulus2020gradient, computing marginals for the complex distributions we consider here is not tractable in general, and we have to use approximate Perturb-and-MAP samples.
3.1 An Adaptive Optimiser for Finite-Difference based Implicit Differentiation
An important observation that motivates the proposed adaptive version of IMLE
is that we need to choose a hyperparameter
for Eq. 2. Choosing a very small , leads to most gradients being zero. As a consequence, the gradients being back-propagated to the upstream model are zeros, which prevents the upstream model from being trained. If we choose too large, we obtain less sparse gradients, but the gradients are also more biased. Hence, we propose an optimiser that adapts during training so as to trade off non-zero but biased and sparse but unbiased gradients. Similar to adaptive first-order optimizers in deep learning, we replace a single hyperparameter with a set of new ones but show that we obtain consistently better results when using default hyperparameters for the adaptive method.Normalisation of the perturbation strength.
Our first observation is that the magnitude of the perturbation in the direction of the negative downstream gradient in Eq. 6 highly depends on , the gradients of the downstream function . To mitigate the variations in the downstream gradients norm relative to the parameters , we propose to set a perturbation magnitude (the norm of the difference between and the perturbed ) to be a fraction of the norm of the parameter vector . In particular, let be such a fraction, then we seek such that:
(7) | ||||
This way we ensure that a global value for roughly translates to the same input-specific magnitude of the perturbation in the direction of the negative gradient.
Trading off Bias and Sparsity of the Gradient Estimates.
For computing as in Eq. 7, we track the sparsity of the gradient estimator with an exponential moving average of the gradient norm. Since the gradients – i.e. the difference between the two MAP states – in Eq. 2 for each are always in , we take the -norm which is here equivalent to the number of non-zero gradients.
Consider a batch of inputs during training. Let with be a single-sample gradient estimate from Eq. 2 for an input data point with input parameters and without the scaling factor . We compute, in every training iteration and using a discount factor , the exponential moving average of the number of non-zero gradients per training example:
(8) |
Similarly to adaptive optimisation algorithms for neural networks, we introduce an update rule for . Let be the desired learning rate, expressed as the number of non-zero gradients per example. This is the target learning rate, that is, the desired number of gradients we obtain on average per example. A typical value is , meaning that we aim to adapt the value for such that we obtain, on average, at least one non-zero gradient per example. We now use the following update rule for some fixed :
(9) |
Hence, by increasing or decreasing by a constant factor and based on the current exponential moving average of the gradient sparsity, we adapt through Eq. 7 which relates and . Algorithm 1 lists the gradient estimator as a layer in a neural network with a forward and backward pass.
Forward and centred finite difference approximation.
Gradient estimation in IMLE, as outlined in Eq. 2, is analogous to gradient estimation with forward (one-sided) finite difference approximations, where . A better approximation can be obtained by the centred (two-sided) difference formula , which is a second-order approximation to the first derivative (Olver2013-sj). Following this intuition, we replace in Eq. 2 with , leading to the update equation in Algorithm 1.
4 Related Work
Continuous relaxations.
Several papers address the gradient estimation problem for discrete random variables, often resorting to continuous relaxations.
maddison2016concrete; jang2016categoricalpropose the Gumbel-Softmax (or concrete) distribution to relax categorical random variables, which was extended by
paulus2020gradientto more complex probability distributions. The Gumbel-Softmax distribution is only directly applicable to categorical variables: for more complex distributions, one has to come up with tailor-made relaxations, or use the STE or SFE – e.g., see
kim2016exact and grover2019stochastic. REBAR (tucker2017rebar) and RELAX (grathwohl2017backpropagation) use parameterized control variates based on continuous relaxations for the SFE. In this work, we focus explicitly on problems where only discrete samples are used during training. Furthermore, REBAR is tailored to categorical distributions, while IMLE and AIMLE are intended for models with complex distributions and multiple constraints. Approaches that do not rely on relaxations are specific to certain distributions (bengio2013estimating; franceschi2019learning; liu2019rao) or assume knowledge of the constraints (kool2020estimating). AIMLE and IMLE provide a general-purpose framework that does not require access to the linear constraints and the corresponding integer polytope . SparseMAP (Niculae2018SparseMAPDS)is an approach to structured prediction and latent variables, replacing an exponential distribution with a sparser distribution; similarly to our work, it only presupposes the availability of a MAP oracle. LP-SparseMAP
(Niculae2020LPSparseMAPDR) is an extension of SparseMAP that uses a relaxation of the underlying optimisation problem.Differentiating through combinatorial solvers.
A series of works about differentiating through combinatorial optimisation problems (wilder2019melding; elmachtoub2020smart; ferber2020mipaal; DBLP:conf/nips/MandiG20) relax ILPs by adding a regularisation term, and differentiate through the KKT conditions deriving from the application of the cutting plane or the interior-point methods. These approaches are conceptually linked to techniques for differentiating through smooth programs (amos2017optnet; donti2017task; agrawal2019differentiable; chen2020understanding; domke2012generic; franceschi2018bilevel) that arise in modelling, hyperparameter optimisation, and meta-learning. Black-box Backprop (poganvcic2019differentiation; rolinek2020deep) and DPO (berthet2020learning)
are methods that are not tied to a specific ILP solver. Black-box Backprop, originally derived from a continuous interpolation argument, can be interpreted as special instantiations of
IMLE and AIMLE. DPO addresses the theory of perturbed optimizers and discusses perturb-and-MAP in the context of the Fenchel-Young losses. All the combinatorial optimisation-related works assume that either optimal costs or solutions are given as training data, while IMLE and AIMLE can also be applied in the absence of such supervision by making use of implicitly generated target distributions. Other authors focus on devising differentiable relaxations for specific combinatorial problems such as SAT (evans2018learning) or MaxSAT (wang2019satnet). Machine learning intersects with combinatorial optimisation also in other contexts, e.g. in learning heuristics to improve the performances of combinatorial solvers — we refer to
bengio2020machine for further details.Direct Loss Minimisation (DLM, mcallester2010direct; song2016training) is also related to our work, but it relies on the assumption that examples of optimal states are given. lorberbom:2019 extend the DLM framework to discrete VAEs using coupled perturbations: their approach is tailored to VAEs, and is not general-purpose. Under a methodological viewpoint, IMLE inherits from classical MLE (wainwright2008graphical) and Perturb-and-MAP (Papandreou:2011). The theory of perturb-and-MAP was used to derive general-purpose upper bounds for log-partition functions (hazan2012partition; shpakova:2016).
5 Experiments
Similarly to niepert21imle, conducted three different types of experiments. First, we analyse and compare the behaviour of AIMLE with other gradient estimators (STE, SFE, IMLE) in a synthetic setting. Second, we consider a setting where the distribution parameters are produced by an upstream neural model, denoted by in Eq. 3, and the optimal discrete structure is not available during training. Finally, we consider the problem of differentiating through black-box combinatorial solvers, where we use the target distribution derived in Eq. 6. In all our experiments, we fix the AIMLE hyper-parameters and use the target gradient norm to , and the update step to , based on the AIMLE implementation described in Algorithm 1. More experimental details for available in the appendix.
Synthetic Experiments.
We conducted a series of experiments with a tractable categorical distribution where with and . We set the loss to , where .
In Fig. 3, we plot the cosine similarity between the gradient estimates produced by the Straight-Through Estimator (STE, bengio2013estimating) the Score Function Estimator (SFE, DBLP:journals/ml/Williams92), Implicit Maximum Likelihood Estimation (IMLE, niepert21imle), and AIMLE. For STE and IMLE, we use Perturb-and-MAP with Gumbel noise. For IMLE and AIMLE, we evaluated both their forward difference (Forward) and central difference (Central) versions. We evaluated all estimators using samples, and report how influences the cosine similarity between the gradient estimate and true gradient. Statistics are over runs. From the results, outlined in Fig. 3, we can see that AIMLE, both in its central and forward difference versions, produces significantly more accurate estimates of the true gradient compared to IMLE, STE, and SFE, with orders of magnitude fewer samples. Furthermore, we report the cosine similarity between the true and the estimated gradient for AIMLE, STE, SFE, and IMLE with a varying value of — all estimators use samples. Results are outlined in Fig. 4: we can see that AIMLE is able to produce gradient estimates that are comparable to the best estimates produced by IMLE, without the need of training a hyper-parameter.
Method | Test MSE | Subset Prec. | ||
Mean | SD | Mean | SD | |
Aspect: aroma, | ||||
SoftSub () | 2.515 | 0.087 | 55.453 | 2.338 |
STE () | 4.660 | 0.053 | 44.593 | 0.523 |
SST () | 4.788 | 0.486 | 56.854 | 3.752 |
IMLE (Forward, , ) | 2.413 | 0.055 | 53.744 | 5.635 |
IMLE (Central, , ) | 2.266 | 0.050 | 50.888 | 5.453 |
AIMLE (Forward, ) | 2.499 | 0.089 | 44.668 | 6.936 |
AIMLE (Central, ) | 2.385 | 0.049 | 62.056 | 2.107 |
Aspect: aroma, | ||||
SoftSub () | 2.543 | 0.044 | 44.513 | 2.958 |
STE () | 4.310 | 0.039 | 39.635 | 0.281 |
SST () | 5.213 | 0.295 | 24.328 | 12.463 |
IMLE (Forward, , ) | 2.368 | 0.075 | 48.215 | 2.182 |
IMLE (Central, , ) | 2.256 | 0.043 | 45.339 | 3.115 |
AIMLE (Forward, ) | 2.402 | 0.042 | 48.397 | 1.967 |
AIMLE (Central, ) | 2.419 | 0.061 | 53.260 | 2.271 |
Aspect: aroma, | ||||
SoftSub () | 2.711 | 0.035 | 37.202 | 1.374 |
STE () | 4.062 | 0.054 | 36.267 | 0.161 |
SST () | 5.787 | 0.517 | 24.551 | 9.827 |
IMLE (Forward, , ) | 2.411 | 0.087 | 41.850 | 1.477 |
IMLE (Central, , ) | 2.508 | 0.396 | 40.057 | 7.172 |
AIMLE (Forward, ) | 2.408 | 0.064 | 41.688 | 2.246 |
AIMLE (Central, ) | 2.470 | 0.026 | 47.109 | 2.863 |
Learning to Explain.
The BeerAdvocate dataset (McAuley2012LearningAA) consists of free-text reviews and ratings for different aspects of beer: appearance, aroma, palate, and taste. Each sentence in the test set has annotations providing the words that best describe the various aspects. Following the experimental setting in paulus2020gradient; niepert21imle, we address the problem of learning a distribution over -subsets of words that best explain a given aspect rating, introduced by chen2018learning. The complexity of the MAP problem for the -subset distribution is linear in .
The training set has 80,000 reviews for the aspect appearance and 70,000 reviews for all other aspects. Since the original dataset (McAuley2012LearningAA) did not provide separate validation and test sets, following niepert21imle
, we compute 10 different evenly sized validation and test splits of the 10,000 held out set and compute mean and standard deviation over 10 models, each trained on one split. Subset precision was computed using a subset of 993 annotated reviews. We use pre-trained word embeddings from
Lei2016RationalizingNP. We extend the implementations provided by niepert21imle, which use a neural network following the architecture introduced by paulus2020gradient with four convolutional layers and one dense layer. This neural network outputs the parameters of the distribution over -hot binary latent masks with .We compare AIMLE (both the forward and central difference versions) to relaxation-based baselines L2X (chen2018learning) and SoftSub (Xie2019ReparameterizableSS); to STE with Gumbel perturbations; and to IMLE (niepert21imle) with Gumbel perturbations. We used the standard hyperparameter settings of chen2018learning and choose the temperature parameter for all methods. For IMLE we choose based on the validation MSE. We trained separate models for each aspect using MSE as the training loss, using the Adam (DBLP:journals/corr/KingmaB14) optimiser with its default hyper-parameters.
Table 1 lists detailed results for the aspect aroma. We can see that AIMLE, in its central differences version, systematically produces the highest subset precision values while yielding test MSE values comparable to those produced by IMLE, while not requiring tuning the hyper-parameter. In the appendix, we report the results for the other aspects, where we notice that AIMLE produces significantly higher subset precision values in all other aspects as well.
T=10 | T=20 | |||||
Edge Distribution | ELBO | Edge Prec. | Edge Rec. | ELBO | Edge Prec. | Edge Rec. |
SST (Hard) | -2301.47 85.86 | 33.75 9.44 | 60.40 23.23 | -3407.89 221.53 | 57.40 17.87 | 70.42 8.22 |
IMLE (Forward) | -2289.94 4.31 | 23.94 0.03 | 95.75 0.14 | -3820.68 25.32 | 20.28 0.12 | 20.28 0.12 |
IMLE (Central) | -2341.71 41.68 | 43.95 7.22 | 43.95 7.22 | -3447.29 550.38 | 40.25 14.26 | 40.25 14.26 |
AIMLE (Forward) | -2039.00 265.03 | 45.67 12.04 | 45.67 12.04 | -1871.22 120.63 | 43.21 3.96 | 43.21 3.96 |
AIMLE (Central) | -1980.46 316.15 | 46.48 14.25 | 46.48 14.25 | -1571.98 30.59 | 82.51 1.30 | 82.51 1.30 |
Discrete Variational Auto-Encoder.
Following niepert21imle, we compare IMLE, STE, and the Gumbel-Softmax trick (maddison2016concrete; jang2016categorical) using a discrete -subset Variational Auto-Encoder (VAE). The latent variables model a probability distribution over -subsets of — or top- assignments — binary vectors of length 20; note that, for , this is equivalent to a categorical variable with 20 categories. We follow the implementation details in niepert21imle, where the encoder and the decoder of the VAE consist of three dense layers, where the encoder and decoder activations have sizes 512-256-20
20, and 256-512-784, respectively. The loss is the sum of the reconstruction losses — binary cross-entropy loss on output pixels — and the KL divergence between the marginals of the variables and the uniform distribution.
For AIMLE, IMLE, and STE, we use perturbations for , and Sum-of-Gamma (SoG, niepert21imle) perturbations for , with a temperature of . For IMLE, we select the hyper-parameter on a held-out validation set, based on the validation loss. For IMLE and AIMLE, we report the results for both the forward and central difference versions. We train the DVAE for epochs and report the loss on a held-out test set.
In Fig. 5 we show the test losses for IMLE and AIMLE (forward and central difference versions), for and . We can see that, for , AIMLE produces significantly lower test loss values than IMLE and, for , the central difference version of AIMLE produces test loss values comparable to IMLE. We also compare AIMLE to the Gumbel-Softmax trick (see Fig. 6) and STE (see Fig. 7): AIMLE produces significantly lower test losses than STE () while producing higher test losses than Gumbel-Softmax (). We also experimented with increasing the number of samples in AIMLE (see Fig. 8), finding that higher values of produce significantly lower test loss values, to the point that, for , AIMLE with produces lower test loss values than Gumbel-Softmax, but with a higher computational cost.
Neural Relational Inference for Recovering Latent Graph Structures.
In this experiment, we investigated the use of AIMLE for recovering latent graph structures and predicting the evolution of a dynamical system. In Neural Relational Inference (NRI, DBLP:conf/icml/KipfFWWZ18), a Graph Neural Network (GNN, DBLP:journals/tnn/Micheli09; DBLP:journals/tnn/ScarselliGTHM09) encoder is used to generate a latent interaction graph, which is then used to produce a distribution over an interacting particle system. NRI is trained as a variational auto-encoder to maximise a lower bound (ELBO) on the marginal log-likelihood of the time series. Based on the implementation provided by paulus2020gradient, we compared the Stochastic Softmax Tricks (SST, paulus2020gradient) encoder that induces a spanning tree over undirected edges, with an encoder producing a maximum spanning tree using Kruskal’s algorithm (Kruskal1956), and using either IMLE or AIMLE to back-propagate through it. In this setting, Kruskal’s algorithm represents the MAP estimator for a distribution over latent graph structures. Our dataset consisted of latent prior spanning trees over 10 vertices sampled from the prior. Given a tree, we embed the vertices in by applying iterations of a force-directed algorithm (DBLP:journals/spe/FruchtermanR91). The model saw particle locations at each iteration, not the underlying spanning tree. We found that AIMLE performed best, improving on both ELBO and the recovery of latent structure over the structured SST baseline proposed by paulus2020gradient.
6 Conclusions
We introduced Adaptive Implicit Maximum Likelihood Estimation (AIMLE), an efficient, simple-to-implement, and general-purpose framework for learning hybrid models. AIMLE is an extension of IMLE (niepert21imle) that, during training, can dynamically select the optimal target distribution by identifying the update step that yields the desired gradient norm. Furthermore, AIMLE incorporates insights from finite difference methods, improving its effectiveness in gradient estimation tasks. In our experiments, we show that AIMLE produces better results than relaxation-based approaches for discrete latent variable models, and approaches that back-propagate through black-box combinatorial solvers.
A limitation of this work, and a potential future extension, is that it relies on a warm-up period for selecting the optimal , whose duration varies depending on the update step — which we fix to . A potential solution to this problem is to use adaptive update steps, for instance by using momentum (DBLP:journals/nn/Qian99).
References
Appendix A Learning to Explain Experiments
Method | Test MSE | Subset Prec. | ||
Mean | SD | Mean | SD | |
Aspect: appearance, | ||||
SoftSub () | 2.327 | 0.051 | 72.183 | 5.242 |
STE () | 5.315 | 0.070 | 29.134 | 1.503 |
SST () | 4.157 | 0.562 | 76.059 | 7.957 |
IMLE (Forward, , ) | 2.420 | 0.066 | 54.457 | 7.165 |
IMLE (Central, , ) | 2.263 | 0.086 | 69.679 | 12.590 |
AIMLE (Forward, ) | 2.512 | 0.058 | 56.904 | 4.149 |
AIMLE (Central, ) | 2.266 | 0.071 | 81.119 | 2.345 |
Aspect: appearance, | ||||
SoftSub () | 2.349 | 0.064 | 63.667 | 5.215 |
STE () | 4.967 | 0.049 | 33.629 | 2.268 |
SST () | 4.973 | 0.443 | 60.363 | 12.225 |
IMLE (Forward, , ) | 2.300 | 0.071 | 54.667 | 6.989 |
IMLE (Central, , ) | 2.168 | 0.051 | 63.249 | 6.251 |
AIMLE (Forward, ) | 2.305 | 0.065 | 59.840 | 4.859 |
AIMLE (Central, ) | 2.191 | 0.037 | 72.167 | 2.744 |
Aspect: appearance, | ||||
SoftSub () | 2.534 | 0.132 | 54.097 | 7.032 |
STE () | 4.493 | 0.083 | 37.737 | 3.298 |
SST () | 5.537 | 0.605 | 36.217 | 18.978 |
IMLE (Forward, , ) | 2.245 | 0.059 | 52.654 | 3.913 |
IMLE (Central, , ) | 2.192 | 0.047 | 51.246 | 9.288 |
AIMLE (Forward, ) | 2.282 | 0.068 | 52.188 | 4.306 |
AIMLE (Central, ) | 2.195 | 0.046 | 68.155 | 3.206 |
Method | Test MSE | Subset Prec. | ||
Mean | SD | Mean | SD | |
Aspect: aroma, | ||||
SoftSub () | 2.515 | 0.087 | 55.453 | 2.338 |
STE () | 4.660 | 0.053 | 44.593 | 0.523 |
SST () | 4.788 | 0.486 | 56.854 | 3.752 |
IMLE (Forward, , ) | 2.413 | 0.055 | 53.744 | 5.635 |
IMLE (Central, , ) | 2.266 | 0.050 | 50.888 | 5.453 |
AIMLE (Forward, ) | 2.499 | 0.089 | 44.668 | 6.936 |
AIMLE (Central, ) | 2.385 | 0.049 | 62.056 | 2.107 |
Aspect: aroma, | ||||
SoftSub () | 2.543 | 0.044 | 44.513 | 2.958 |
STE () | 4.310 | 0.039 | 39.635 | 0.281 |
SST () | 5.213 | 0.295 | 24.328 | 12.463 |
IMLE (Forward, , ) | 2.368 | 0.075 | 48.215 | 2.182 |
IMLE (Central, , ) | 2.256 | 0.043 | 45.339 | 3.115 |
AIMLE (Forward, ) | 2.402 | 0.042 | 48.397 | 1.967 |
AIMLE (Central, ) | 2.419 | 0.061 | 53.260 | 2.271 |
Aspect: aroma, | ||||
SoftSub () | 2.711 | 0.035 | 37.202 | 1.374 |
STE () | 4.062 | 0.054 | 36.267 | 0.161 |
SST () | 5.787 | 0.517 | 24.551 | 9.827 |
IMLE (Forward, , ) | 2.411 | 0.087 | 41.850 | 1.477 |
IMLE (Central, , ) | 2.508 | 0.396 | 40.057 | 7.172 |
AIMLE (Forward, ) | 2.408 | 0.064 | 41.688 | 2.246 |
AIMLE (Central, ) | 2.470 | 0.026 | 47.109 | 2.863 |
Method | Test MSE | Subset Prec. | ||
Mean | SD | Mean | SD | |
Aspect: palate, | ||||
SoftSub () | 2.857 | 0.049 | 51.435 | 2.258 |
STE () | 4.456 | 0.038 | 30.650 | 0.449 |
SST () | 4.357 | 0.596 | 50.634 | 4.587 |
IMLE (Forward, , ) | 2.867 | 0.042 | 50.066 | 1.262 |
IMLE (Central, , ) | 2.684 | 0.047 | 54.347 | 1.320 |
AIMLE (Forward, ) | 2.856 | 0.050 | 47.818 | 4.231 |
AIMLE (Central, ) | 2.683 | 0.037 | 56.043 | 1.540 |
Aspect: palate, | ||||
SoftSub () | 2.957 | 0.048 | 35.604 | 1.818 |
STE () | 4.065 | 0.039 | 32.096 | 0.501 |
SST () | 4.754 | 0.260 | 37.302 | 7.304 |
IMLE (Forward, , ) | 2.833 | 0.039 | 43.305 | 2.993 |
IMLE (Central, , ) | 2.669 | 0.046 | 45.258 | 2.335 |
AIMLE (Forward, ) | 2.837 | 0.062 | 40.777 | 3.054 |
AIMLE (Central, ) | 2.666 | 0.020 | 49.895 | 3.750 |
Aspect: palate, | ||||
SoftSub () | 3.138 | 0.051 | 27.927 | 1.651 |
STE () | 3.849 | 0.047 | 30.727 | 0.947 |
SST () | 5.171 | 0.645 | 24.095 | 9.535 |
IMLE (Forward, , ) | 2.892 | 0.058 | 35.700 | 2.847 |
IMLE (Central, , ) | 2.808 | 0.210 | 36.916 | 6.323 |
AIMLE (Forward, ) | 2.885 | 0.027 | 36.070 | 2.279 |
AIMLE (Central, ) | 2.708 | 0.031 | 44.860 | 1.549 |
Method | Test MSE | Subset Prec. | ||
Mean | SD | Mean | SD | |
Aspect: taste, | ||||
SoftSub () | 2.196 | 0.045 | 42.762 | 1.785 |
STE () | 4.591 | 0.047 | 39.360 | 0.402 |
SST () | 3.942 | 0.354 | 35.376 | 4.146 |
IMLE (Forward, , ) | 2.253 | 0.036 | 39.846 | 1.855 |
IMLE (Central, , ) | 2.124 | 0.060 | 41.105 | 1.054 |
AIMLE (Forward, ) | 2.196 | 0.054 | 38.344 | 2.007 |
AIMLE (Central, ) | 2.132 | 0.043 | 43.539 | 1.175 |
Aspect: taste, | ||||
SoftSub () | 2.137 | 0.048 | 42.716 | 1.085 |
STE () | 4.283 | 0.053 | 38.089 | 0.089 |
SST () | 4.422 | 0.176 | 30.581 | 4.077 |
IMLE (Forward, , ) | 2.200 | 0.052 | 40.673 | 1.756 |
IMLE (Central, , ) | 2.081 | 0.054 | 39.636 | 1.219 |
AIMLE (Forward, ) | 2.190 | 0.030 | 41.098 | 1.525 |
AIMLE (Central, ) | 2.111 | 0.034 | 45.593 | 1.222 |
Aspect: taste, | ||||
SoftSub () | 2.173 | 0.043 | 40.189 | 1.318 |
STE () | 4.009 | 0.035 | 37.364 | 0.058 |
SST () | 4.868 | 0.264 | 32.507 | 3.649 |
IMLE (Forward, , ) | 2.201 | 0.052 | 40.636 | 1.479 |
IMLE (Central, , ) | 2.193 | 0.270 | 38.987 | 1.680 |
AIMLE (Forward, ) | 2.196 | 0.046 | 40.993 | 1.627 |
AIMLE (Central, ) | 2.138 | 0.041 | 43.089 | 1.780 |
Here we report the results on all aspects of the Learning to Explain task with the BeerAdvocate dataset (McAuley2012LearningAA), namely appearance (Table 3), aroma (Table 4), palate (Table 5) and taste (Table 6). In all cases, AIMLE produces the best subset precision results, while producing test MSE values that are comparable with those produced by IMLE.
Each experiment was re-run 10 times, for 20 training epochs; batch size was set to 40, with a kernel size of 3, hidden dimension of 250, and maximum sequence length of 350. All models were trained using the Adam (DBLP:journals/corr/KingmaB14) optimiser, with a learning rate of , by fitting an MSE loss between the predicted and true ratings. For all methods, the noise temperature was selected in . In IMLE, the hyper-parameter was selected in . In AIMLE, the hyper-parameters (the target gradient norm) and (the update step) were fixed to their default values, namely and .
Appendix B Discrete Variational Auto-Encoder Experiments
The data set can be loaded in PyTorch with
torchvision.datasets.MNIST. As in prior work, we use a batch size of and train forepochs, plotting the test loss after each epoch. We use the standard Adam settings in PyTorch and no learning rate schedule. The MNIST dataset consists of black-and-white
pixels images of hand-written digits. The encoder network consists of an input layer with dimension , since we flatten the images; a dense layer with dimensionand ReLU
(DBLP:conf/icml/NairH10) activations; a dense layer with dimension and ReLU activations; and a dense layer with dimension () which outputs the and no non-linearity. The layer implementing AIMLE receives as input and outputs a discrete latent code of shape . The decoder, which takes this discrete latent code as input, consists of a dense layer with dimension and ReLU activation; a dense layer with dimension and ReLU activations; and finally a dense layer with dimensionreturning the logits for the output pixels. Sigmoid non-linearities are applied to these logits and used to compute the binary cross-entropy.
Appendix C Neural Relational Inference Experiments
We use the dataset of latent prior spanning trees over 10 vertices proposed by paulus2020gradient: latent spanning trees were sampled by applying Kruskal’s algorithm (Kruskal1956) to for a fully-connected graph. Initial vertex locations were sampled from in . Given the initial locations and the latent tree, the dynamical observations were obtained by applying a force-directed algorithm for graph layout for iterations. Then, the initial vertex positions were discarded, because the first iteration of the layout algorithm typically results in large relocations. Hence, the final dataset used for training consisted of 10 and 20 location observations in for each of the 10 vertices. Following this procedure, we generated a training set of size 50,000 and validation and test sets of size 10,000.
Model.
We follow the design and implementation provided by paulus2020gradient, where the NRI model consists of two neural modules, an encoder and a decoder. The encoder GNN passes messages over the fully connected directed graph with nodes. We took the final edge representations produced by the GNN, and use them as . The final edge representations are in , where over the 90 undirected edges (since we do not consider self-connections). Given the previous time-step data, the decoder GNN passes messages over the sampled graph adjacency matrix and predicts future node positions. As in DBLP:conf/icml/KipfFWWZ18 and paulus2020gradient, we used teacher-forcing every 10 time-steps. , in this case, was a directed adjacency matrix over the graph where are the nodes: is interpreted as there being an edge from to , and otherwise. represents the symmetric, directed adjacency matrix with edges in both directions for each undirected edge. The decoder passes messages between both connected and not-connected nodes. When considering a message from to , it uses one network for the edges such that , and another network for the edges such that .
Hyper-parameters.
We selected the hyper-parameters for SST (Spanning Tree, Hard), IMLE, and AIMLE, using a grid-search, and selected the hyper-parameter configuration which yields the highest ELBO on the validation set. For SST, we searched the learning rate in , and the temperature in . For IMLE, we searched the learning rate in , , and the noise temperature in . For AIMLE, we searched the learning rate in , and the noise temperature in .