Meta-Amortized Variational Inference and Learning

02/05/2019 ∙ by Kristy Choi, et al. ∙ 26

How can we learn to do probabilistic inference in a way that generalizes between models? Amortized variational inference learns for a single model, sharing statistical strength across observations. This benefits scalability and model learning, but does not help with generalization to new models. We propose meta-amortized variational inference, a framework that amortizes the cost of inference over a family of generative models. We apply this approach to deep generative models by introducing the MetaVAE: a variational autoencoder that learns to generalize to new distributions and rapidly solve new unsupervised learning problems using only a small number of target examples. Empirically, we validate the approach by showing that the MetaVAE can: (1) capture relevant sufficient statistics for inference, (2) learn useful representations of data for downstream tasks such as clustering, and (3) perform meta-density estimation on unseen synthetic distributions and out-of-sample Omniglot alphabets.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 5

Code Repositories

meta-inference-public

A PyTorch implementation of "Meta-Amortized Variational Inference and Learning" (https://arxiv.org/abs/1902.01950)


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

A wide variety of problems in modern AI can be posed as probabilistic inference in generative models. While traditional inference techniques solve each inference independently, amortized inference (Gershman and Goodman, 2014) aims to solve multiple inferences for a given model together—learning to do inference for that model. This approach has been particularly fruitful when applied to variational inference (Jordan et al., 1999; Wainwright et al., 2008; Blei et al., 2017) where amortization across observations solves a serious problem with scaling to large data sets (Rezende et al., 2014; Kingma and Welling, 2013). In this paper we explore amortizing not just over the observations for a single model, but further amortizing the cost of inference over different generative models.

More precisely, suppose we have a family of generative models where for each family member, we would like to perform scalable inference. Then, we would ideally design an efficient, amortized inference model that takes as input: (1) a suitable representation of the target probabilistic model, (2) an inference query, and (3) observed data, and outputs an approximation of the desired posterior distribution. We note that this inference model is not intended to be universal, but rather tailored to a specific family where each model is similar in structure. Inspired by meta-learning, we denote this “doubly-amortized” inference problem as meta-inference and let a meta-distribution

refer to the probability distribution over the family of probabilistic models.

The challenge is generalization: we wish to draw correct inferences efficiently on unseen distributions that are either sampled from the meta-distribution or “close” to it. This challenge is especially pertinent for latent variable models such as the variational autoencoder (VAE), where the amortized inference network is used to map data points to latent representations. In this work, we introduce the MetaVAE, a VAE that meta-amortizes the inference procedure across a family of generative models. We use the MetaVAE to perform: (1) meta-unsupervised learning, where we leverage the underlying meta-distribution to find good representations on previously unseen distributions for downstream tasks; and (2) meta-density estimation, where we can properly estimate the marginal distribution with very few data points from an unseen target distribution.

2 Preliminaries

2.1 Exact and Approximate Inference

Let be an (empirical) data distribution over the observed variables . In practice, this is often uniform over a training set of examples from . We then define

to be a joint distribution over a set of latent variables

and observed variables .

A typical inference query involves computing our posterior beliefs after incorporating evidence into the prior: . This quantity is often intractable to compute, as the marginal likelihood requires integrating/summing over a potentially exponential number of configurations for .

Instead, we leverage approximate inference techniques such as Markov Chain Monte Carlo (MCMC) sampling (

Hastings (1970), Gelfand and Smith (1990)) and variational inference (VI) (Jordan et al. (1999), Wainwright et al. (2008), Blei et al. (2017)) to estimate . In VI, we posit a family of tractable distributions parameterized by over the latent variables and find the member (called the approximate posterior) that minimizes the Kullback-Leibler (KL) divergence between itself and the exact posterior:

This then serves as a proxy for the true underlying posterior distribution. We note that the solution will depend on the specific value of the observed (evidence) variables we are conditioning on. For notational clarity, we rewrite the variational parameters as to make explicit their dependence on . As noted earlier, one often needs to solve multiple inference queries of the same kind, conditioning on different values of the observed (evidence) variables . The average quality of the variational approximations obtained can be quantified as follows:

(1)

2.2 Amortized Variational Inference

Massively large training sets require yet another layer of efficiency, as the computational cost of VI in Eq. 1 scales linearly with the number of data points . We thus leverage a technique known as amortization, in which we amortize the computational cost of the inference procedure by casting the per-sample optimization process in Eq. 1 as a supervised regression task. Specifically, rather than solving for an optimal for every data point , we learn one deterministic mapping to predict as a function of . Often, we choose to concisely represent as a conditional distribution, denoted by = .

This procedure introduces an amortization gap, in which the less flexible parameterization of the inference network results in replacing the original objective as shown in Eqn. 1 with the following lower bound:

(2)

This gap refers to the suboptimality caused by amortizing the variational parameters over the entire training set, as opposed to optimizing for each training example individually (pulling the out of the expectation in Eq. 2). This tradeoff in expressiveness, however, enables significant computational speedups and generalization to new values of the observed variables.

2.3 Latent Variable Models

Of particular importance to latent variable modeling is the variational autoencoder (VAE), a generative model trained to maximize the log marginal likelihood of the data:

(3)

as a function of a set of trainable parameters .

As an optimization objective, Eqn. 3 is intractable. Instead, we can derive the Evidence Lower Bound (ELBO) to Eqn. 3 using as a tractable amortized inference model:

(4)

With Eqn. 4 as an objective, we can train the VAE by jointly optimizing . Post-optimization, the latent variables are learned features inferred by that can be used in generic unsupervised learning tasks (e.g. clustering).

We may also derive an alternative formulation of the ELBO where denoting we get:

(5)
(6)

Eqn. 6 comprises a maximum likelihood term with a regularization penalty that encourages the learned model to have posteriors that can be approximated by the amortized inference model (Shu et al., 2018).

3 Meta-Amortized Variational Inference

Recall a (singly)-amortized inference model for

which attempts to approximate for various choices of . This is the original setting where we consider repeated inference queries from the same model but evaluated on many values of the observed variables .

Now imagine that we are interested in not one but a set of models,

. We assume that the random variables in these models have the same domains (

), but the relationships between the random variables may be different. Further, we make the key simplifying assumption that for each model, we care about the same query . Finally, we assume to have some knowledge of typical values of the observed variables for each model in . Formally, we assume to have a set of marginal distributions over the observed variables, e.g., a set of data distributions. Here denotes the set of all possible marginal distributions over . Let denote a distribution over . For example, may be uniform over a finite number of training datasets. As is a distribution over distributions, we refer to it as a meta-distribution.

The standard approach to amortize over a set of models is:

(7)

where we separately fit an amortized inference model for each . However, we propose to doubly-amortize the inference procedure as follows:

(8)

where the original mapping is replaced by an amortized mapping that takes the marginal distribution and an observation x to return a posterior. Formally, we call such a mapping, , a meta-inference model. Given a single inference query, this doubly-amortized inference component must be robust across varying marginals and evidence. The hope is that will generalize over , and possibly to a larger set of sufficiently similar, but previously unseen models.

3.1 Meta-Amortized Variational Bayes

Assume we are given a set of fixed generative models, with pre-defined parameters, for all . Each model is built to capture a marginal distribution, . We propose the following procedure to do meta-amortized inference for such a set of models:

For a given meta distribution , the optimal meta-inference model is defined as:

(9)

In practice, we represent each as a finite number of samples drawn from the distribution. We call this set of samples a dataset, written as , where and . Then, Eqn. 9 yields,

(10)

where is an empirical analogue to that maps a dataset and an observation to a posterior.

3.2 Meta-Amortized Variational Learning

Obtaining such a set of similarly related generative models is difficult. However, just as amortized variational inference works particularly well when learning the parameters of the generative model jointly with those of the amortized inference model, we can “meta-learn” a set of generative models jointly with a single doubly-amortized inference model.

To meta-learn a VAE, we can jointly optimize the parameters of the meta-inference network and the parameters of each generative network , according to this objective:

(11)

where

(12)

and denotes the distribution defined implicitly by and . We denote this lower bound as the MetaELBO, and refer to the VAE with meta-inference as the MetaVAE.

We can rewrite the MetaELBO to a more interpretable form, as in Eqn. 6. Similar to , our doubly-amortized mapping can be represented as a conditional distribution, denoted . Then,

As in (Shu et al., 2018), this MetaELBO has a maximum likelihood term and a regularization term but for each distribution , thereby encouraging the meta-amortized inference model to perform well across distributions sampled from the meta-distribution .

Finally, we state a property of the MetaELBO: if and , then the MetaELBO decomposes to the standard ELBO and .

3.3 Representing the Meta-Inference Model

In Eqn. 12, if we parameterize

as a neural network, it is not clear how to represent a distribution,

as input. One of the main insights from this work is to “discretize” the marginal distribution as a finite set of samples, , or a dataset. We can use a dataset, as a surrogate for and define an “empirical” analogue to , denoted as , which maps a dataset with samples and an observation to a posterior. Then, there is an equivalent analogue of Eqn. 12 where a marginal, is replaced by a dataset, .

In practice, for some dataset and , we set where ,

is a recurrent neural network (RNN) over an arbitrary ordering of the elements in

, and

is a two layer multilayer perceptron (MLP). Each generative model

, is also parameterized by a MLP with identical architecture as . We refer to as the summary network and to as the aggregation network.

3.4 Fully Bayesian VAE

The proposed MetaVAE has an interesting relationship to a fully Bayesian VAE where one would explicitly model a posterior distribution over parameters. More precisely, this involves the factorization of the joint:

(13)

where is a prior distribution over the parameters. Then, the appropriate inference network would be i.e. an inference model amortized over a family of generative models . If is a discrete set, then the fully Bayesian VAE is analogous to a MetaVAE.

In practice, the fully Bayesian VAE is difficult to train because Bayesian neural networks are extremely sensitive to hyperparameter choices and initializations. By discretizing

to a finite set, we make the optimization problem easier.

3.5 Instantiations of the MetaVAE

The meta-amortized inference procedure is flexible, meaning that it can be instantiated in a variety of ways depending on the probabilistic task. Here we describe two particular instantiations that are used in our experiments.

Figure 1: Plate diagrams for (a) VAE, (b) MetaVAE where an observation is a data point, (c) MetaVAE where an observation is a dataset, . Let be the number of observations.

The first setup (shown in Fig. 1b) is as described in Sec. 3: there exists a meta-inference model that takes as input an observation and a dataset. Unless otherwise stated, we default to this instantiation.

An alternative setup (Fig. 1c) imposes an additional layer of abstraction: a single observation is now a dataset . The meta-inference model is where is a sequence of datasets whose elements from drawn from . This

requires a second RNN that ingests a dataset, and returns a single hidden vector. After applying the RNNs, the resulting hidden vector

and sequence of hidden vectors are analogous to Fig. 1b, where we treat a hidden vector as an observed variable.

4 Related work

There exists a rich body of work on meta-learning, particularly in the supervised learning setting with the goal of rapid adaptation to unseen classification tasks (

Ravi and Larochelle (2016), Santoro et al. (2016), Vinyals et al. (2016), Snell et al. (2017)). A popular line of work formulates proper initialization as the workhorse of successful meta-learning, such as (finn2017model, Grant et al. (2018), Yoon et al. (2018)). In many ways, our meta-amortized inference procedure can be thought of as learning a good initialization of for an inference model on a new target distribution.

Meta-learning for unsupervised tasks has also been explored by (Metz et al. (2018)), who learn the weight updates for good representation learning. Several lines of work have tackled the problem of few-shot density estimation, with approaches ranging from attention mechanisms (Rezende et al. (2016)), memory-augmented models (Bornschein et al. (2017)), weight-updates for conditional generative models (Reed et al. (2017)), and hierarchical models (Edwards and Storkey (2016), Hewitt et al. (2018)). Our architecture shares similarities to both the Neural Statistician (Edwards and Storkey (2016)) and the Variational Homoencoder (Hewitt et al. (2018)): we also derive salient features of each dataset with a summary network. Our model’s distinguishing factor, then, is on doubly amortizing the inference procedure over a family generative models. To the best of our knowledge, this is a novel contribution.

5 Experimental Results

First, we probe the characteristics and generalization ability of the meta-inference model in two synthetic settings, and then we demonstrate its applicability to meta-density learning using OMNIGLOT.

Figure 2: Let be distributions used in meta-learning. Weak generalization refers to , ; strong generalization refers to , , .

5.1 2D Gaussian Datasets

In this experiment, the set of marginals,

, is composed of two-dimensional distributions (e.g., Gaussian) with parameters (e.g., mean and variance) that vary within a fixed range

. We amortize over 30 sampled marginals and consequently, optimize 30 generative models. The meta-distribution is uniform over . Critically, each generative model, is parameter-free (), thereby encouraging the latent variable to capture the sufficient statistics of the true distribution, . Each generative model is also given the correct distribution family that belongs to. However, the meta-inference model is not given any prior knowledge; it is tasked with matching marginals with the correct families. As sufficient statistics only make sense across a set of observations, we use the second setup of the MetaVAE (Fig. 1c). The measure of success is then how close we can infer the sufficient statistics with no additional training (zero shot) for (1) unseen distributions from and (2) unseen distributions outside of . We refer to (1) as weak generalization and (2) as strong generalization. We first explore a members of the exponential family one-at-a-time, then proceed to multiple members of the exponential family at the same time.

Gaussian Marginals In this setting, each distribution is Gaussian with a fixed spherical covariance of 0.1 and a mean uniformly sampled from i.e. , . The summary network and aggregation network have 64 hidden dimensions for all layers. To measure weak generalization, we sample new means from that are previously unseen. For strong generalization, we sample means from .

(a) Gaussian (Weak)
(b) Gaussian (Strong)
(c) Generalizability Spectrum
Figure 3: Colored circles represent 30 different ; black dots represent the inferred means from the meta-inference model. (a) New distributions sampled from ; (b) New distributions sampled from outside of . (c) plots the mean squared error between the true mean and the inferred mean as the true mean of tiles . The green region shows the span of the meta-distribution. The orange line shows a singly-amortized VAE trained on a single with mean (randomly chosen).

We find that the MetaVAE is successfully able to learn the means (the only sufficient statistic) of the underlying Gaussians. Interestingly, in Fig. 3a, as you move closer to the boundary of the meta-distribution, the inference quality decreases (see purple Gaussian near ). In Fig. 3, we can convincingly see that the meta-inference model is almost bounded within the [5,5] square centered at the origin. Finally, from Fig. 3c, we see that doubly-amortizing increases the inference quality dramatically over a singly-amortized model, even for distributions far from .

Log Normal and Exponential Marginals

Similar to the above setting, we sample 30 log Normal distributions with a fixed spherical covariance of 0.1 and means from

. For strong generalization, we sample from

. We also study the exponential distribution by choosing 30

with a rate sampled from i.e. , . To measure strong generalization, we sample from .

(a) Log Normal (Weak)
(b) Exponential (Weak)
(c) Log Normal Spectrum
(d) Exponential Spectrum
Figure 4: (a) Comparison of samples from the an unseen distribution (red) and samples the log normal distribution defined by the inferred sufficient statistic (blue). (b) Similar visualization for exponential distributions as in (a). (c) and (d) show the mean squared error between the true sufficient statistic and the inferred one (mean for log Normal, rate for exponential). The orange line is a non-amortized VAE trained on a single randomly chosen distribution ( for log Normal; for exponential).

Many Exponential Families The natural next step is to amortize over many types of exponential families. We sample 30 Gaussian, 30 log Normal, and 30 exponential (with same meta-distributions as above) and train a single meta-inference model. We measure weak and strong generalization as done previously. But, we also measure an even stronger notion of generalization: can we do inference for unseen members of the exponential family?

Fig. 5 compares the performance of our 90 distribution amortized MetaVAE to three different MetaVAEs, each of which is amortized over 30 distributions from a single exponential family. Fig. 5(a-c) show examples of weak generalization. As expected, the best performing model is the MetaVAE amortized on distributions only from that family. However, the 90-amortized MetaVAE only performs slightly worse, beating the remaining two models dramatically. Fig. 5

(d-f) show results for 2D distributions over (1) Weibull distributions with a fixed scale of 1, (2) Laplace distributions with a fixed location of 0, and (3) Beta distributions with equal shape parameters. Critically, none of these distributions lie in

. We find that the 90-amortized MetaVAE consistently outperforms any of the legioned baselines. This suggests that doing inference over the exponential families together enables the model to learn more robust representations.

(a) Gaussian
(b) Log Normal
(c) Exponential
(d) Beta(, )
(e) Weibull(shape, scale=1)
(f) Laplace (loc=0, scale)
Figure 5:

Comparison of the generalization capabilities of a MetaVAE amortized over three members of the exponential family versus MetaVAE amortized over only a single member. In each subplot, the blue line represents a meta-inference model trained on 30 Gaussian distributions, 30 Exponential distributions, and 30 Log Normal distributions. The other dotted lines are each associated with a meta-inference model trained on 30 of a single type of distribution. Each subplot shows a new unseen distribution drawn either from the meta-distribution (a,b,c) or from another exponential family completely (d,e,f). No additional training was done on the unseen distribution.

5.2 2D Mixtures of Gaussians

Next, we test the MetaVAE’s ability to perform clustering and density estimation. A distribution is a mixture of Gaussians, where each component is a Gaussian with fixed isotropic covariance , and the means are drawn from . The two Gaussians are mixed equally: . We assign each mixture component a label of 0 or 1. We then amortize over {10, 30, 50, 100} of such generated mixtures and evaluate whether our meta-inference model can successfully: (1) cluster each mixture component; and (2) estimate the 3 unseen mixture densities, each with means drawn from , , respectively.

Figure 6: 30 mixtures drawn from the meta-distribution. We plot (in color) 3 unseen distributions whose parameters are drawn from (left) ; (middle) ; (right) .

We model as a binary latent variable that denotes mixture component membership. We note that the true clustering is exchangeable, only recoverable up to a permutation. We fix the number of hidden units to 10 for all neural networks, and optimize using exact enumeration of the ELBO.

Clustering We first test zero-shot clustering performance on an unseen dataset distributed according to the same underlying meta-distribution. We report average error rates for the zero-shot clustering setup in Table 1.

samples 10 20 30 50
10 GMMs 0.167 0.212 0.161 0.169
30 GMM 0.167 0.158 0.161 0.169
50 GMM 0.083 0.099 0.161 0.169
100 GMM 0.084 0.114 0.161 0.169
Table 1: Zero-shot clustering performance on an unseen dataset, . We report clustering accuracies over 1000 different sampled datasets, while varying the number of samples used by the RNN to generate a summary statistic at test time.

We find that the MetaVAE successfully learns to cluster in the zero-shot case (8%). Amortizing over more meta-datasets seem to help improve the model’s clustering performance, although the model trained with 100 mixtures seem to indicate signs of overfitting to the meta-training set.

Next, we extract the pre-trained meta-inference models and train a new generative network on each of 3 unseen data distributions, evaluating the final clustering performance. We only use {5, 10, 15, 20}% of the test distribution for training. As shown in Figure 4a, the model is able weakly generalize across all levels of meta-training, outperforming the VAE baseline with the exception of the 100 GMM meta-encoder – a phenomena consistent with the results shown in Table 1, i.e., overfitting to the meta-training set. However, Fig. 4(b,c) shows that meta-training does not seem to provide significant gains in generalization performance on marginals far from .

(a) Weak generalization:
(b) Generalization:
(c) Strong generalization: )
Figure 7: Each sub-figure shows the final clustering performance after training on {5,10,15,20}% of the unseen data distribution. In (a), meta-training on 10, 30, and 50 datasets allows for perfect clustering, on par with the RNNVAE baseline and outperforming the VAE. The 100 GMM meta-trained model overfits. In (b), only the 50 GMM meta-trained model has successfully learned to cluster. In (c), the meta-clustering algorithm fails to generalize to an extremely out-of-sample distribution.

Density Estimation We repeat the same experimental procedure as before, but evaluate final test ELBOS after training the MetaVAEs on small proportions of the unseen data distribution. As expected, Fig. 5a shows that a meta-amortized model is able to perform density estimation well on an instantiation from the same underlying meta-distribution; the performance improvements are very slight for the generalization case to , and fails on completely out-of-distribution samples as evidenced by Fig. 5(b,c).

Intuitively, these results indicate that the meta-encoder successfully learned to take as input a dataset (representative of ), identify two clusters, and associate a query datapoint to the closest cluster. However, this clustering “algorithm” learned in an unsupervised way by the meta-inference network is imperfect (e.g., inferior to -means), as it shows signs of overfitting to the training meta-distribution.

(a) Weak generalization:
(b) Generalization:
(c) Strong generalization: )
Figure 8: Each sub-figure shows the final test ELBOs after training on {5,10,15,20}% of the unseen data distribution. In (a), meta-training on 10, 30, and 50 datasets allows for good generalization performance on an unseen dataset drawn from the same underlying meta-distribution as compared to baselines trained from scratch. In (b), we are able to perform reasonably well in a slightly out-of-sample dataset. In (c), we cannot generalize to datasets that are extremely out-of-distribution.

5.3 MNIST Clustering and Density Estimation

To further test the MetaVAE, we construct a setup analogous to the mixtures of Gaussians experiment with MNIST (LeCun, 1998). Specifically, we hold out two digit classes for evaluation, and generate datasets comprised of pairs of the remaining digits for training. We select a subset of {5, 10, 20} combinations out of a total of 28 (8 choose 2) possibilities to pre-train the meta-amortized model. We then ask the model to: (1) cluster the two digit classes; and (2) perform density estimation on an unseen target dataset. We switch to a continuous variant of the MetaVAE with 40-dimensional latent variables to better model the complexity of the data.

We consider two scenarios for evaluation. For weak generalization, the new dataset is still drawn from . Concretely, this involves evaluating the MetaVAE on one of the eight remaining combinations (out of 28) that were unseen during training time across all amortized models. For strong generalization, we test clustering and density estimation performance on the pair of digits that were held-out for the entirety of the meta-training phase.

Zero-Shot Clustering: We extract the MetaVAE’s latent representations (

) of the unseen training data, without additional gradient updates, and train a simple logistic regression model with the true labels (0/1 for each digit class). Intuitively, logistic regression finds the best linear split between two clusters in the latent space; note that for it to perform well, such a linear split must already exist. In this sense, the clustering is ”zero-shot”. To measure performance, we obtain the corresponding latent vectors for the test set and predict the true labels.

For weak generalization, Fig. 9a & b shows the clustering results for two levels of difficulty: digits 1/6 (easy) and 4/9 (hard). For the former, an amortized MetaVAE outperforms the VAE trained on the full dataset of 1’s and 6’s; however, there does not seem to be much benefit in additional amortization (i.e. amortizing over 5 pairs performs as well as 20 pairs). For the more difficult task, adding more combinations improves clustering performance, and the MetaVAE outperforms a VAE trained on half of the target data. For strong generalization, Fig. 9c shows that the meta-inference network is able to obtain less than 2% clustering error without adapting its encoder parameters to the unseen data distribution. Further, it outperforms a VAE which has been trained on 100% of the target dataset of 3’s and 7’s. We note that the meta-inference model has inferred useful representations that allows for good zero-shot clustering performance on a new, unseen dataset.

Density Estimation: For strong generalization, we extract the pre-trained meta-inference network and train a new generative model using {1, 5, 10, 20}% of the target dataset. Fig. 9d shows that we reach a much better test log-likelihood across the board compared to a VAE trained from scratch, and Fig. 9 (e-g) show faster rates of learning for a MetaVAE amortized over 20 combinations of digit pairs as compared to those of a vanilla VAE.

(a) (Weak) Clustering: 1/6
(b) (Weak) Clustering: 4/9
(c) (Strong) Clustering: 3/7
(d) (Strong) Density estimation
(e) 1% of data
(f) 5% of data
(g) 10% of data
Figure 9: Clustering and density estimation on MNIST. We train a MetaVAE amortized over {5, 10, 20} pairs of digit classes and evaluate their performance on unseen pairs from (weak) and outside of (strong). The first row shows that the MetaVAE achieves higher zero-shot clustering performance as compared to a VAE trained on 100% and 50% of the target distribution, still within the meta-distribution. The second row shows that the MetaVAE outperforms a VAE trained on 100% of the out-of-sample target distribution (not in ), and outperforms a VAE across the board for few-shot density estimation. The last row compares test log-likelihoods of the MetaVAE amortized over 20 combinations against those of a vanilla VAE. Colored lines denote MetaVAE models; black lines denote vanilla VAE models trained on the target data.

5.4 OMNIGLOT Transfer Density Estimation

To showcase meta-amortized inference beyond synthetic settings, we consider the challenge of transfer density estimation in OMNIGLOT111We use the pre-processed version from (Burda et al., 2015): https://github.com/yburda/iwae/blob/master/ datasets/OMNIGLOT/chardata.mat. We reserve the first 25 OMNIGLOT alphabets for a pre-training phase where we optimize a MetaVAE amortized over linear combinations of the 25 alphabets. Precisely, we sample combination vectors that each sum to one from a Dirichlet distribution with 25 categories. Each combination vector specifies the probability that we sample uniformly form that particular alphabet. Intuitively, we can think of the meta-distribution as constructing new alphabets. For each new alphabet, we sample 5000 times to create a dataset . The purpose of generating new alphabets in this manner is to expand the training dataset size as some of the original OMNIGLOT alphabets are extremely small and cannot possibly yield good density estimation. The MetaVAE is then amortized across distributions. In practice, .

For transfer learning, we consider (case 1) an unseen distribution by sampling a new combination vector over the 25 alphabets used in pre-training, and (case 2) an unseen distribution by sampling a new combination vector over the remaining 35 alphabets

not used in pre-training. For an unseen distribution, , we initialize the inference network using the pre-trained meta-inference model and initialize its generative network from scratch. We allow ourselves to train to completion but limit the number of examples from to a small number (making this few-shot). In our experiments, we limit to 100 or 10 examples.

Fig. 10 shows nine experiments, the first three being in case 1, the second three in case 2 and the last three in case 2 but limited to 10 training examples. As expected, meta-amortizing reaches a higher log marginal (and faster) on new distributions in . Moreover, we observe worse performance in Fig. 10d as there are little guarantees for distributions outside of . What is surprising however, is that meta-amortizing reaches a higher log marginal for the other two distributions composed of entirely new alphabets. This implies that the initialization must be useful in other alphabets, as most hieroglyphs share similarities in design.

Finally, if we restrict how many examples the model can use to transfer to 10, we find that meta-amortizing is still effective (although less so). A MetaVAE and a vanilla VAE (trained on only the 10 examples) converge to the same final log likelihood but the amortized model reaches the optimum faster. Interestingly though, , which did not benefit from meta-amortization under 100 examples, now does just so, suggesting a relationship between relying on prior (amortized) knowledge and observed data.

(a)
(b)
(c)
(d)
(e)
(f)
(g) (10 examples)
(h) (10 examples)
(i) (10 examples)
Figure 10: Transfer density estimation on OMNIGLOT. We train a MetaVAE amortized over 50 mixtures of 25 OMNIGLOT alphabets. This figure shows test log likelihoods from training independently on three unseen distributions and three unseen mixtures of held-out OMNIGLOT alphabets . For each distribution

, we only let the model train on 100 examples (and only 10 examples for the last row in the figure). We train for 100 epochs and 50 epochs for experiments with 100 and 10 examples respectively. Subfigures (a-f) are plotted from epoch 10. Pink and red lines denote MetaVAE models; black lines denote vanilla VAE models trained on only the unseen distribution.

6 Discussion

Confronted with the theoretical hardness results for exact and approximate inference, the idea of learning approximate inference strategies that are not fully general but tailored to “typical” models of the world is appealing. Our meta-amortization method is particularly useful when we have a set of models with shared structure, and we want to leverage that structure for good few-shot generalization performance on a related target task. We now mention a few observations:

The interesting generalization behavior is in-between weak and strong generalization. With a large enough , we find that meta-amortization leads to weak generalization for distributions in . This is sensible as the training distributions build a convex hull that can span . In most cases, we cannot hope to expect strong generalization as distributions outside of can be wildly different. It is the distributions near but outside the edge of (e.g. in Fig. 2) at which our method demonstrates promise.

Meta-amortization is subject to meta-overfitting. We observe a form of overfitting unique to doubly-amortizing inference, which we call meta-overfitting. Meta-overfitting can occur in two ways. In the first scenario, we fail to sufficiently cover the space of by amortizing over too few datasets. When this happens, the meta-inference model essentially fails to learn the algorithm of interest (e.g. density estimation) and cannot generalize to even other distributions in . In the second scenario, the meta-inference model is trained with too many marginal distributions such that with its limited number of parameters, it fails to capture the correct marginal for any single generative model. Intuitively, it has overfit completely to the underlying meta-distribution. This is exemplified by the 100 GMM-pretrained MetaVAE’s inability to cluster in the 2D Gaussian mixture experiments.

7 Conclusion

We introduce meta-amortized variational inference and learning. As far as we know, this is the first instance of amortizing over families of generative models. We find appealing results on density estimation and representation learning, where meta-training leads to improved sample complexity, and hope to explore meta-inference for zero-shot compilation of probabilistic programs.

References

  • D. M. Blei, A. Kucukelbir, and J. D. McAuliffe (2017) Variational inference: a review for statisticians. Journal of the American Statistical Association 112 (518), pp. 859–877. Cited by: §1, §2.1.
  • J. Bornschein, A. Mnih, D. Zoran, and D. J. Rezende (2017) Variational memory addressing in generative models. In Advances in Neural Information Processing Systems, pp. 3920–3929. Cited by: §4.
  • Y. Burda, R. Grosse, and R. Salakhutdinov (2015) Importance weighted autoencoders. arXiv preprint arXiv:1509.00519. Cited by: footnote 1.
  • H. Edwards and A. Storkey (2016) Towards a neural statistician. arXiv preprint arXiv:1606.02185. Cited by: §4.
  • A. E. Gelfand and A. F. Smith (1990) Sampling-based approaches to calculating marginal densities. Journal of the American statistical association 85 (410), pp. 398–409. Cited by: §2.1.
  • S. Gershman and N. Goodman (2014) Amortized inference in probabilistic reasoning. In Proceedings of the Annual Meeting of the Cognitive Science Society, Vol. 36. Cited by: §1.
  • E. Grant, C. Finn, S. Levine, T. Darrell, and T. Griffiths (2018) Recasting gradient-based meta-learning as hierarchical bayes. arXiv preprint arXiv:1801.08930. Cited by: §4.
  • W. K. Hastings (1970) Monte carlo sampling methods using markov chains and their applications. Cited by: §2.1.
  • L. B. Hewitt, M. I. Nye, A. Gane, T. Jaakkola, and J. B. Tenenbaum (2018) The variational homoencoder: learning to learn high capacity generative models from few examples. arXiv preprint arXiv:1807.08919. Cited by: §4.
  • M. I. Jordan, Z. Ghahramani, T. S. Jaakkola, and L. K. Saul (1999) An introduction to variational methods for graphical models. Machine learning 37 (2), pp. 183–233. Cited by: §1, §2.1.
  • D. P. Kingma and M. Welling (2013) Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114. Cited by: §1.
  • Y. LeCun (1998)

    The mnist database of handwritten digits

    .
    http://yann. lecun. com/exdb/mnist/. Cited by: §5.3.
  • L. Metz, N. Maheswaranathan, B. Cheung, and J. Sohl-Dickstein (2018) Learning unsupervised learning rules. arXiv preprint arXiv:1804.00222. Cited by: §4.
  • S. Ravi and H. Larochelle (2016) Optimization as a model for few-shot learning. Cited by: §4.
  • S. Reed, Y. Chen, T. Paine, A. v. d. Oord, S. Eslami, D. Rezende, O. Vinyals, and N. de Freitas (2017) Few-shot autoregressive density estimation: towards learning to learn distributions. arXiv preprint arXiv:1710.10304. Cited by: §4.
  • D. J. Rezende, S. Mohamed, I. Danihelka, K. Gregor, and D. Wierstra (2016) One-shot generalization in deep generative models. arXiv preprint arXiv:1603.05106. Cited by: §4.
  • D. J. Rezende, S. Mohamed, and D. Wierstra (2014)

    Stochastic backpropagation and approximate inference in deep generative models

    .
    arXiv preprint arXiv:1401.4082. Cited by: §1.
  • A. Santoro, S. Bartunov, M. Botvinick, D. Wierstra, and T. Lillicrap (2016) One-shot learning with memory-augmented neural networks. arXiv preprint arXiv:1605.06065. Cited by: §4.
  • R. Shu, H. H. Bui, S. Zhao, M. J. Kochenderfer, and S. Ermon (2018) Amortized inference regularization. arXiv preprint arXiv:1805.08913. Cited by: §2.3, §3.2.
  • J. Snell, K. Swersky, and R. Zemel (2017) Prototypical networks for few-shot learning. In Advances in Neural Information Processing Systems, pp. 4077–4087. Cited by: §4.
  • O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra, et al. (2016) Matching networks for one shot learning. In Advances in Neural Information Processing Systems, pp. 3630–3638. Cited by: §4.
  • M. J. Wainwright, M. I. Jordan, et al. (2008) Graphical models, exponential families, and variational inference. Foundations and Trends® in Machine Learning 1 (1–2), pp. 1–305. Cited by: §1, §2.1.
  • J. Yoon, T. Kim, O. Dia, S. Kim, Y. Bengio, and S. Ahn (2018) Bayesian model-agnostic meta-learning. In Advances in Neural Information Processing Systems, pp. 7343–7353. Cited by: §4.