Learning to learn generative programs with Memoised Wake-Sleep

07/06/2020 ∙ by Luke B. Hewitt, et al. ∙ MIT 0

We study a class of neuro-symbolic generative models in which neural networks are used both for inference and as priors over symbolic, data-generating programs. As generative models, these programs capture compositional structures in a naturally explainable form. To tackle the challenge of performing program induction as an 'inner-loop' to learning, we propose the Memoised Wake-Sleep (MWS) algorithm, which extends Wake Sleep by explicitly storing and reusing the best programs discovered by the inference network throughout training. We use MWS to learn accurate, explainable models in three challenging domains: stroke-based character modelling, cellular automata, and few-shot learning in a novel dataset of real-world string concepts.



There are no comments yet.


page 9

page 11

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

From the phonemes that make up a word to the nested goals and subgoals that make up a plan, many of our models of the world rely on symbolic structures such as categories, objects, and composition. Such explicit representations are desirable not only for interpretability, but also because models that use them are often highly flexible and robust. For example, in spreadsheet editing, FlashFill uses program inference to allow specification of batch operations by example (Gulwani et al., 2015). In character recognition, the stroke-based model of Lake et al. (2015) remains state-of-the-art at both few-shot classification and generation, despite competition from a variety of neural models (Lake et al., 2019).

In this work we focus on structured generative modelling: we aim to find symbolic generative programs to describe a set of observations, while also learning a prior over these programs and fitting any continuous model parameters. For example, we model handwritten characters by composing a sequence of strokes, drawn from a finite bank of stroke types which is itself learned.

Unfortunately, learning such models from scratch is a substantial challenge. A major barrier is the difficulty of search: discovering a latent program for any given observation is challenging due to the size of the space and sparsity of solutions. Furthermore, this inference must be revised at every iteration of learning.

We build on the Helmholtz Machine (Dayan et al., 1995), a longstanding approach to learning in which two models are trained together: a generative model

learns a joint distribution over latent and observed variables, while a recognition network

performs fast inference. This approach, including more recent variants such as VAEs (Kingma & Welling, 2014), is well-suited to learning neural generative models because, as noted by Hinton et al. (1995) “the algorithm adapts the generative weights so as to make close to . That is, when is a neural network, the semantics of the latent space are highly unconstrained, and so can be learned to aid fast recognition.

Unlike such purely neural generative models, the models we consider have a more constrained and interpretable latent space. We take to be a sequence of discrete tokens representing a data-generating program. Our goal is to learn a prior over programs (which may be a neural network such as an LSTM), alongside parameters of a symbolic program evaluator and a program recognition network . Figure 2 describes such a model for string concepts, in which is a regular expression and is a symbolic regex parser.

Figure 2: A. In Memoised Wake-Sleep, we train using samples from a finite set , containing the best programs found for . We use a recognition network to propose updates to this set. B. For our regex model, each is a set of strings generated by a latent regex . As a recognition network, we use a program synthesis LSTM (RobustFill, Devlin et al. (2017)) to propose regexes for each set .

In this setting, training a network for fast and accurate inference is ambitious: even state-of-the-art program induction networks often must guess-and-check thousands of candidates. This is impractical for Helmholtz machine algorithms, which require inference to be repeated at each gradient step. We propose a novel algorithm, Memoised Wake-Sleep (MWS), which is better suited for the structured domains we study. Rather than restart inference afresh at each iteration, MWS maintains a finite set containing the best programs discovered by the recognition network for instance , remembering and reusing these to train the generative model in future iterations.

Our contributions in this paper are as follows. We first outline the MWS algorithm, and prove that it optimises a variational bound on the data likelihood, where is the support of a finite variational distribution

. We then illustrate MWS with a simple nonparametric Gaussian mixture model and evaluate on three structure-learning domains (Fig. 

1), finding that it greatly outperforms more computationally intensive approaches of Reweighted Wake Sleep (Bornschein & Bengio, 2015) and VIMCO (Mnih & Rezende, 2016) while often providing an significant speedup. We also develop a novel String-Concepts dataset, collected from publicly available spreadsheets for our evaluation. This dataset contains 1500 few-shot learning problems, each involving a real-world string concept (such as date or email) to be inferred from a small set of examples.

2 Background

The Helmholtz Machine (Dayan et al., 1995) is a framework for learning generative models, in which a recognition network is used to provide fast inference during training. Formally, suppose we wish to learn a generative model , which is a joint distribution over latents and observations , and a recognition network , which approximates the posterior over latents given observations. The marginal likelihood of each observation is bounded by:


where is the KL divergence from the true posterior to the approximate posterior . Learning is typically framed as maximisation of this evidence lower bound (ELBO) by training the recognition network and generative model together.

Gradient-based maximisation of this objective with respect to and

is straightforward: an unbiased gradient estimate for Eq. 

2 can be created by taking a single sample each gradient step. However, maximising Eq. 2 with respect to is more challenging and two main approaches exist:

VAE. We may update

also using an unbiased estimate of Eq. 

2, sampling . However, if is a discrete symbolic expression, then estimating the gradient requires the REINFORCE estimator (Williams, 1992; Mnih & Gregor, 2014)

. Despite advances in control-variate techniques, this estimator often suffers from high variance, which may lead to impractically slow training.

111Many approaches for training discrete VAEs are inapplicable to the models we study here. Rolfe (2017) constructs architectures for which discrete variables can be marginalised out, while relaxation techniques (Jang et al., 2017) approximate discrete variables continuously to produce low-variance gradient estimators or control variates (Tucker et al., 2017; Grathwohl et al., 2018). These methods are intractable for the compositional symbolic models we consider, as they require exponentially many path evaluations (Le et al. (2019)). The EC algorithm (Ellis et al., 2018) enables inference in such compositional models, but does not learn model parameters.

Wake-Sleep. Instead of using the KL term that appears in Eq. 1, we may update approximately by minimising the reversed KL divergence . In practice, this means updating at each iteration using data sampled from the model’s prior . This yields an algorithm which is not in general convergent, yet still often performs competitively if is discrete.

Beyond the optimisation difficulties that come with discrete latent variable modelling, a further challenge arises when the recognition model is simply incapable of matching the true posterior accurately. This is common even in deep generative models, which can flexibly adapt their latent representation. To address this, the above approaches may be extended by taking multiple samples from the recognition model at each training iteration, then using importance weighting to estimate the true posterior. For VAEs, this yields the

Importance Weighted Autoencoder

(IWAEs, Burda et al. (2016)), and is often applied to discrete variables using multiple samples for variance reduction (VIMCO, Mnih & Rezende (2016)). For Wake-sleep, it yields Reweighted Wake-Sleep (RWS, Bornschein & Bengio (2015)), in which the recognition model may be trained either by the generative model (RWS-sleep) or the importance-weighted posterior (RWS-wake).

3 Memoised Wake-Sleep

Our goal is learning and inference in rich neurosymbolic models such as that shown in Figure 2, for which all parameters are continuous, and the latent variables are symbolic programs. These models pose a challenge for Helmholtz machines: given the strong constraints on , it is common that only a small set of latent programs can well-explain any given observation , and these may be difficult for to recognise quickly and reliably. The importance-weighted methods described above (RWS, VIMCO) may therefore require evaluating very many samples per iteration to train . This is computationally wasteful, as it amounts to re-solving the same hard search problems repeatedly.

We propose an alternative approach which actively utilises the sparsity of good solutions in to its advantage. In the Memoised Wake-Sleep algorithm we do not discard the result of inference after each training step. Instead, for each observation we introduce a memory containing a set of the best distinct historical samples from the recognition model. Formally, we take to be a variational distribution over , which has finite support

and probabilities

. In the box below, we prove two statements which suggest a simple algorithm for updating , which maximises the ELBO (Eq. 1) by minimising:

Claim 1 Fixing the support of to the optimal weights are given by . Proof. At optimality, there can be no pair for which is reduced by moving probability mass from to . We therefore solve by setting . Hence,
Claim 2 Fixing the weights , we decrease if we replace any with a new value , such that and . Proof. Rewriting the loss as
we see that the only dependence on is through ]. The update therefore satisfies:

Repeated application of claims 2 and 1 yields an intuitive algorithm for optimising . Every iteration, we sample a set of new programs from a recognition network, which we call , and compare those to the programs already in memory (). We then update the memory to contain the best unique elements from either the sampled programs or the existing memory elements, ranked by . We then resample a program from memory to train .


To train the recognition network, we propose two variants of our algorithm. In MWS-fantasy, we train on pairs sampled directly from the generative model , as in the sleep phase of the wake-sleep algorithm. In MWS-replay, we train on the same pair that was sampled from memory to train (analogous to RWS-wake, Bornschein & Bengio (2015)). In practice we find that the latter performs well, and is significantly faster as it requires no additional sampling. In this paper, we therefore refer to MWS-replay and RWS-wake as simply MWS and RWS, but include additional results for MWS-fantasy and RWS-sleep in the appendix.

The three phases of the algorithm (wake, sleep:replay and sleep:fantasy) are summarised in Figure 3, and the full algorithm is provided above. Unlike RWS and VIMCO, the memory usage of Memoised Wake-Sleep contains a term linear in the dataset size , due to maintaining a separate set of programs for each instance . However, in practice this is typically negligible compared the reduction of memory required for training the recognition network: MWS can achieve strong performance with many fewer recognition samples () per iteration.

The memory size, , may be chosen to trade off accuracy and efficiency, with corresponding only to MAP inference, and

approaching full Bayesian inference over

. For modest values of M, achieving a small variational gap relies on sparsity in the true posterior, as will converge to the best -support posterior approximation of .222See appendix for more discussion of MWS limiting behaviour, including an empirical study of sparsity in .

In this paper, we use default values of , and define when presenting results. This means MWS is matched to baseline algorithms on the number of the evaluations per iteration (), but requires half as many recognition model evaluations ().

Figure 3: MWS extends Wake-Sleep with a separate memory for each observation . This memory is a discrete distribution defined over a finite set . Each phase of MWS uses a sample from one model to update another: during wake, the recognition network samples a program and, if is large, is stored in memory . During sleep:replay, is sampled from and used to train and . Alternatively, may be trained by sampling from (in a sleep:fantasy phase).
Figure 4: Quantitatively (top) MWS outperforms RWS and VIMCO in terms of both speed of convergence and quality of learned model and inference. MWS learns near-perfect model and inference parameters with only particles. Qualitatively (bottom), the neural amortized inference of RWS and VIMCO fails to cluster accurately. The model compensates by increasing the within-cluster variance

, as seen by the more spread out posterior predictive (gray).

4 Experiments

4.1 Gaussian Mixture Model

We first validate the MWS algorithm for learning and inference in a simple nonparametric Gaussian mixture model, for which we can evaluate model performance exactly. In this model, the latent variable corresponds to a clustering of datapoints. MWS is therefore well-suited because the latent space is discrete and exponentially large in the number of data points, while the true posterior is highly peaked on few clusterings of the data.

We generated a synthetic dataset of 100 mini-datasets, with each comprising two-dimensional data points, as illustrated in the bottom-left of Fig. 4. The latent variable for each dataset is a sequence of cluster assignments

and a mean vector

for each cluster . The only learnable model parameter is which parameterizes the cluster covariance.

We use a Chinese restaurant process (CRP) prior for in order to break the permutation invariance of clustering and to avoid fixing the number of clusters. The full generative model is therefore given by:

In this model, the cluster means may be marginalized out analytically, allowing us to exactly calculate during learning. For the recognition model we use a feedforward neural network with one hidden layer and a

activation whose output logits are masked to enforce valid sequences under the CRP prior.

We train the model using Adam with default hyperparameters for

k iterations for , and evaluate model quality using the average negative log marginal likelihood. For inference quality, we evaluate where is the posterior under the true data-generating model and is the memory-induced posterior approximation for MWS and the importance-weighted recognition-based approximation for RWS and VIMCO. In Figure 4

(top), we show medians and inter-quartile ranges of the marginal likelihood and the

KL from runs of training.

With a moderate number of particles , both RWS and VIMCO algorithms fail to cluster the data accurately using the recognition network. By contrast MWS can maintain a persistent, high-quality approximation to the true posterior for each mini-dataset. This discrepancy of inference quality is shown by the sample clustering in Figure 4 (bottom). In turn, the use of inaccurate inference during training causes RWS and VIMCO converge to a model with poorer marginal likelihood.

Figure 5: A. Schematic for our stroke-based model of handwritten characters. is a prior over tokens sequences , where each token points to a particular stroke from a finite set. The available strokes are learned parameters of the renderer , varying in length, direction, thickness and sharpness, curvature. Strokes chosen in are placed end to end on a canvas (optionally marked as OFF for pen movement without drawing). B. Reconstructions of Omniglot characters. In each row we sample a program , and visualise the canvas of the renderer .

4.2 Drawing Handwritten Characters

Next, we build a generative model of handwritten characters using an explicit stroke-based renderer. Drawing inspiration from Lake et al. (2015), our model contains a finite bank of learnable stroke types, varying in parameters such as length, thickness, direction and curvature. Each latent variable is a sequence of integers which index into this bank. For generation, the renderer places the corresponding strokes sequentially onto a canvas, which is differentiable with respect to the stroke parameters (Fig. 5A). To calculate , we use this canvas to provide Bernoulli pixel probabilities, marginalising across a set of affine transforms in order to allow programs to be position invariant. The prior and the recognition network are LSTMs and recognition network additionally takes as input an image embedding given by a convolutional network.

In Fig. 5B, we visualise the stroke sequences inferred by our model after training on a random subset of characters from the Omniglot dataset (approximately 10 characters per alphabet across 50 alphabets). For each character, we sample a program from the memory of the MWS algorithm, and visualise the render canvas at each step of . We find that our model is able to accurately reconstruct a wide variety of characters, and does so using a natural sequencing of pen strokes. In Table 1, we compare the performance of baseline algorithms applied to the same model. We find that MWS is able to learn effectively with very few particles, yet continues to outperform alternative algorithms at learning even up to .

RWS 0.363 0.348 0.333 0.324 0.322
VIMCO 0.361 0.333 0.326 0.318 0.319
MWS 0.311 0.305 0.321 0.310 0.316
Table 1: Marginal NLL (bits/pixel, avg. of 3 runs).

Our approach combines the strengths of previous work on modelling handwritten characters. Like Ganin et al. (2018), we learn only from raw images, aided by a neural recognition model. However, like Lake et al. (2015), we use a symbolic representation of characters: our model uses a limited symbolic vocabulary of 64 strokes rather than allowing the model to produce a free-form stroke at each time step, and we restrict to a maximum of only 10 strokes. This provides an inductive bias that should encourage reuse of strokes across characters, potentially allowing our model to make richer generalisations.

To illustrate this, we extend our model by conditioning the prior and the recognition model on the alphabet label which we provide during training. Given the ten characters from an alphabet (red), this conditioned-model can generate novel samples which somewhat capture its high-level style by reusing common patterns (Fig. 6C).

Figure 6: A. Samples from unconditional model. B. and C Samples from the alphabet-conditional model, for both instance reconstruction and novel character generation.
Character classes . any character () \w alphanumeric character () \d digit () \u uppercase character () \l lowercase character () \s whitespace character () contains specific probabilities for each allowed character Operators Optional subexpression E? E () () Repetition E* E+ () () E+ EE* Either/or E|E E () E () contains production probabilities Figure 7: Character classes (left) and operators (center) included in our probabilistic regex model. Parameters determine the probability of a regex producing any given string , which can be calculated exactly by dynamic programming. Right: Given five example strings, the model finds a plausible regex explanation which can be used to generates novel instances. The inferred repeating subexpression (, \u)* is highlighted in green.

4.3 Structured Text Concepts

We next apply MWS at learning short text concepts, such as date or email address, from a few example strings. This task is of interest because such concepts often have a highly compositional and interpretable structure, while each is itself a fairly rich generative model which can be applied to generate new strings.

For this task we created a new dataset comprising 1500 concepts, each with 5 training strings and 5 test strings, collected by random sampling of spreadsheet columns crawled from public GitHub repositories. The data was filtered to remove columns that contain only numbers, English words longer than 5 characters, or common names (see Figure 1 and Appendix for dataset samples).

We aim to model this dataset by inferring a regular expression (regex) for each concept. This is a convenient choice because regexes can naturally express compositional relationships, and can be evaluated efficiently on any given string. Specifically, we consider probabilistic regexes: programs which generate strings according to a distribution, and for which the probability of any given string can be calculated exactly and efficiently.

The full model we develop for this domain is shown in Fig. 2. We use an LSTM prior over regexes , a program-synthesis LSTM network to infer regexes from strings (RobustFill, Devlin et al. (2017)), and a symbolic regex evaluator . The prior and recognition networks output a sequence of regex tokens, including characters (4, 7, etc.), character classes ( \d for digit, \u for uppercase, etc.), operators (* for repetition, etc.) and brackets. In the regex evaluator , learnable parameters determine the assignment of probability to strings: for example, when *

appears, the number of repeats is geometrically distributed with parameter

. The full set of parameters is shown in Figure. 7.

Figure 8: Inferred regexes and posterior predictive samples from models trained on the String-Concepts dataset. Posterior samples are taken from in MWS, and from with K-importance sampling in RWS and VIMCO.

We first present the results from training this model using the MWS algorithm, with . Table 2 shows prior samples generated by the learned model. For each row, we draw a new concept from the LSTM prior , and then generate several instances using the symbolic regex evaluator, sampling . This demonstrates qualitatively that our model generalises meaningfully from the training concepts: the invented concepts are highly plausible, containing motifs such as # preceding a string of digits, or % as a final character. The model’s prior has learned high level structures that are common to many concepts in the training data, but can compose these structures in novel ways.

In Table 3 we quantitatively evaluate MWS, RWS and VIMCO algorithms for the same neurosymbolic model architecture (additional results with MWS-fantasy and RWS-sleep  variants  are  provided  in  the  appendix).   We

Prior Generated
c \s \d . \d+ c 0.6,   c 4.4,   c 6.0
\w \d \d \d \d- \d \d 56144-73,   60140-63
$ \d00 $600,   $300,   $000
\l \l \d hc8,   ft5,   vs9
# \d \d \u \u #57EP,   #11UW,   #26KR
\u0 \d \d \d \d \d \d B0522234,   M0142810
\u \u \u \d . \s \d0% TAP0. 70%,   THR6. 50%
R0< \d+ R0<3,   R0<9,   R0<80
\u+. EA.,   SD.,   CSB.
Table 2: Novel concepts sampled from the MWS model. In each row we sample a regex from the learned prior, then generate examples from this regex.

estimate the true marginal likelihood of all models on held out test data using importance sampling. MWS not only exceeds the performance of baselines algorithms for large values of , but also achieves strong performance with only particles. This allows a model to be fit accurately at a very significant reduction in computational cost, requiring fewer evaluations of per iteration. Such efficiency is particularly valuable for domains in which likelihood evaluation is costly, such as those requiring parsing or enumeration when scoring observations under programs.

Figure 8 shows qualitatively the inferences made by our model. Across a diverse set of concepts, the model learns programs with significant compositional structure. In many cases, this allows it to make strong generalisations from very few examples.

Furthermore, comparison across algorithms demonstrates that these more complex expressions are challenging for the recognition network to infer reliably. For example, while RWS and VIMCO are typically able to discover template-based programs (e.g. Figure 8 concept 4), MWS is the only algorithm to utilise operators such as alternation (|) or bracketed subexpressions for any concepts in dataset (e.g. concepts 7 and 1).

RWS 87.1 86.5 85.2 85.0
VIMCO 97.5 89.1 84.8 83.5
MWS 84.1 83.1 82.8 82.5
Table 3: Marginal NLL (nats, avg. of 3 runs)
RWS 0.74 0.51 0.17 0.09 VIMCO 1.31 1.05 0.66 0.37 MWS 0.01 0.01 0.02 0.02 Distance from to true model (easy task) RWS 10.27 5.55 3.44 2.43 VIMCO 5.98 3.39 2.15 1.26 MWS 1.24 1.15 0.90 0.75 Distance from to true model (hard task)
Figure 9: A. In our model, describes a CA in canonical binary form (depicted is Rule 30, Wolfram (2002)). Images are generated from left to right, with each pixel stochastically conditioned on its three left-neighbours. B. MWS is able to infer the CA rule for each image, and learns a global noise parameter , which we then use to extrapolate the images. The model accurately matches the true generative noise, as is most visually apparent in row 1.

4.4 Noisy Cellular Automata

Finally, to demonstrate the use of MWS in estimating meaningful parameters, we consider the domain of cellular automata (CA). These processes have been studied by Wolfram (2002) and are often cited as a model of seashell pigmentation. We consider noisy, elementary automata: binary image-generating processes in which each pixel (cell) is sampled stochastically given its immediate neighbours in the previous column (left to right).

For this domain, we build a dataset of binary images generated by cellular automata. For each image we sample a “rule”, , which determines the value of each pixel given the configuration of its immediate 3 neighbours in the previous column. Each rule is represented canonically as a binary vector of length (see Figure 9). To generate an image , the leftmost column is sampled uniformly at random, and then subsequent columns are determined by applying the rule to each pixel , with corruption probability (fixed to ):

We then build a generative model which matches this process, but where the rules are latent variables and the noise is replaced by a learnable parameter . We aim to perform joint learning and inference in this model, and use to estimate the true . This estimation a significant challenge, because any inaccurate inferences of will cause the model will overestimate the noise to compensate. As a stringent test of the algorithm, we also evaluate on a harder domain of nonelementary automata, with cells depending on five neighbours (so has length ).

In Figure 9B, we visualise the automata inferred by MWS by extrapolating images from the dataset. It is visually apparent that the model accurately captures the generative process for each image , including both the rule  and noise . This is confirmed quantitatively: for both the easy and hard task variants, MWS is able to estimate the true generative process parameter with significantly greater accuracy than alternative algorithms.

5 Discussion

This work sits at the intersection of program learning and neural generative modelling. From a dataset, we aim to infer a latent generative program to describe each observation, while simultaneously learning a neural prior over programs and any additional model parameters.

To tackle the challenge of performing program induction repeatedly during learning, we train our models by a novel algorithm called Memoised Wake-Sleep (MWS), and find that this improves the quality of learning across all domains we study. MWS builds upon existing Helmholtz machine algorithms by maintaining a memory of programs seen so far during training, which reduces the need for effective amortized inference. We optimise a variational bound previously proposed by Saeedi et al. (2017)

, extending their algorithm to include parameter learning for a recognition network and for the generative model. Algorithmically, our approach is also similar to memory-based reinforcement learning methods 

(Abolafia et al., 2018; Liang et al., 2018) which maintain a queue of best action-traces found during training.

In general, MWS can be applied to any models for which all parameters are continuous, and all latent variables are discrete. However, we particularly advocate its use for the class of ‘programmatic’ generative models we study, due to the difficulties of sparse inference that they often present. If learning can be made tractable in such models, they have the potential to greatly improve generalisation in many domains, discovering compositional structure that is rich and understandable.


This work was supported by AFOSR award FA9550-18-S-0003 and the MIT-IBM Watson AI Lab.


  • Abolafia et al. (2018) Daniel A Abolafia, Mohammad Norouzi, Jonathan Shen, Rui Zhao, and Quoc V Le. Neural program synthesis with priority queue training. arXiv:1801.03526, 2018.
  • Bornschein & Bengio (2015) Jörg Bornschein and Yoshua Bengio. Reweighted wake-sleep. In International Conference on Learning Representations, 2015.
  • Burda et al. (2016) Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance weighted autoencoders. In International Conference on Learning Representations, 2016.
  • Dallenbach (1951) Karl M Dallenbach. A puzzle-picture with a new principle of concealment. The American Journal of Psychology, 1951.
  • Dayan et al. (1995) Peter Dayan, Geoffrey E Hinton, Radford M Neal, and Richard S Zemel. The Helmholtz machine. Neural computation, 7(5):889–904, 1995.
  • Devlin et al. (2017) Jacob Devlin, Jonathan Uesato, Surya Bhupatiraju, Rishabh Singh, Abdel-rahman Mohamed, and Pushmeet Kohli. Robustfill: Neural program learning under noisy I/O. In

    International Conference on Machine Learning

    , 2017.
  • Duvenaud et al. (2013) David Duvenaud, James Lloyd, Roger Grosse, Joshua Tenenbaum, and Ghahramani Zoubin. Structure discovery in nonparametric regression through compositional kernel search. In International Conference on Machine Learning, 2013.
  • Ellis et al. (2018) Kevin Ellis, Lucas Morales, Mathias Sablé-Meyer, Armando Solar-Lezama, and Josh Tenenbaum.

    Learning libraries of subroutines for neurally–guided Bayesian program induction.

    In Advances in Neural Information Processing Systems, 2018.
  • Ganin et al. (2018) Yaroslav Ganin, Tejas Kulkarni, Igor Babuschkin, SM Ali Eslami, and Oriol Vinyals. Synthesizing programs for images using reinforced adversarial learning. In International Conference on Machine Learning, 2018.
  • Grathwohl et al. (2018) Will Grathwohl, Dami Choi, Yuhuai Wu, Geoff Roeder, and David Duvenaud. Backpropagation through the void: Optimizing control variates for black-box gradient estimation. In International Conference on Learning Representations, 2018.
  • Gregory (1970) R.L. Gregory. The intelligent eye. McGraw-Hill paperbacks. McGraw-Hill, 1970.
  • Gulwani et al. (2015) Sumit Gulwani, José Hernández-Orallo, Emanuel Kitzelmann, Stephen H Muggleton, Ute Schmid, and Benjamin Zorn. Inductive programming meets the real world. Comms. of the ACM, 58(11):90–99, 2015.
  • Hinton et al. (1995) Geoffrey E Hinton, Peter Dayan, Brendan J Frey, and Radford M Neal. The ”wake-sleep” algorithm for unsupervised neural networks. Science, 1995.
  • Jang et al. (2017) Eric Jang, Shixiang Gu, and Ben Poole. Categorical reparameterization with Gumbel-softmax. In International Conference on Learning Representations, 2017.
  • Kingma & Welling (2014) Diederik P Kingma and Max Welling. Auto-encoding variational Bayes. In International Conference on Learning Representations, 2014.
  • Lake et al. (2015) Brenden M Lake, Ruslan Salakhutdinov, and Joshua B Tenenbaum. Human-level concept learning through probabilistic program induction. Science, 350(6266):1332–1338, 2015.
  • Lake et al. (2019) Brenden M Lake, Ruslan Salakhutdinov, and Joshua B Tenenbaum. The Omniglot challenge: a 3-year progress report. Current Opinion in Behavioral Sciences, 29:97–104, 2019.
  • Le et al. (2019) Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, and Frank Wood. Revisiting reweighted wake-sleep for models with stochastic control flow. In Uncertainty in Artificial Intelligence, 2019. Le and Kosiorek contributed equally.
  • Liang et al. (2018) Chen Liang, Mohammad Norouzi, Jonathan Berant, Quoc V Le, and Ni Lao. Memory augmented policy optimization for program synthesis and semantic parsing. In Advances in Neural Information Processing Systems, 2018.
  • Mnih & Gregor (2014) Andriy Mnih and Karol Gregor. Neural variational inference and learning in belief networks. In International Conference on Machine Learning, 2014.
  • Mnih & Rezende (2016) Andriy Mnih and Danilo Rezende. Variational inference for Monte Carlo objectives. In International Conference on Machine Learning, 2016.
  • Ozair & Bengio (2014) Sherjil Ozair and Yoshua Bengio. Deep directed generative autoencoders. arXiv:1410.0630, 2014.
  • Rolfe (2017) Jason Tyler Rolfe. Discrete variational autoencoders. In International Conference on Learning Representations, 2017.
  • Saeedi et al. (2017) Ardavan Saeedi, Tejas D Kulkarni, Vikash K Mansinghka, and Samuel J Gershman. Variational particle approximations. The Journal of Machine Learning Research, 18(1):2328–2356, 2017.
  • Tucker et al. (2017) George Tucker, Andriy Mnih, Chris J Maddison, John Lawson, and Jascha Sohl-Dickstein. Rebar: Low-variance, unbiased gradient estimates for discrete latent variable models. In Advances in Neural Information Processing Systems, 2017.
  • Williams (1992) Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229–256, 1992.
  • Wolfram (2002) Stephen Wolfram. A new kind of science, volume 5. Wolfram media Champaign, IL, 2002.

Appendix A Memory in Perceptual Recognition

Figure 10: What is depicted in these images? Viewers typically find it difficult to see the subject of these images on first presentation, suggesting that it involves solving a very difficult search problem. However, after initial recognition, this structure is immediately apparent on all future viewings, as though the result of previous inference has been stored for future reuse. Answer in footnote
22footnotetext: Left: a dalmatian sniffs the ground, (Gregory, 1970). Right: a cow looks towards the camera, (Dallenbach, 1951).

Appendix B Mws-Fantasy and Rws-Sleep

We present additional results for both MWS and RWS, when the recognition model is trained using prior samples.

b.1 String Concepts Experiment

RWS 87.1 86.5 85.2 85.0
RWS-sleep 88.9 87.5 86.7 85.5
MWS 84.1 83.1 82.8 82.5
MWS-fantasy 82.5 82.7 82.8 82.8
Table 4: Marginal NLL

b.2 Cellular Automata Experiment

Easy () Hard ()
RWS 0.74 0.51 0.17 0.09 10.27 5.55 3.44 2.43
RWS-sleep 3.39 3.64 2.84 3.12 23.07 20.54 18.05 15.53
MWS 0.01 0.01 0.02 0.02 1.24 1.15 0.90 0.75
MWS-fantasy 0.01 0.01 0.01 0.01 5.98 5.12 4.80 4.24
Table 5: Distance from true model

Appendix C String Concepts Dataset

Our String-Concepts dataset was collected by crawling public GitHub repositories for files with the .csv datatype. The data was then automatically processed to remove columns that contain only numbers, English words longer than 5 characters, or common names, and so that each column contained at least 10 elements. We then drew one random column from each file, while ensuring that no more than three columns were included with the same column header. This allows us to reduce homogeneity (e.g. a large proportion of columns had the header ’state’) while preserving some of the true variation (e.g. different formats of ‘date’). The final dataset contains 5 training examples and 5 test examples for each of 1500 concepts. A sample is included below.

‘#574c1d’, ‘#603a0d’, ‘#926627’, ‘#9e662d’, ‘#9e8952’
‘(206) 221-2205’, ‘(206) 221-2252’, ‘(206) 393-1513’, ‘(206) 393-1973’, ‘(206) 882-1281’
‘ca’, ‘ja’, ‘rp’, ‘tw’, ‘vm’
‘EH15 2AT’, ‘EH17 8HP’, ‘EH20 9DU’, ‘EH48 4HH’, ‘EH54 6P’
‘Exi0000027’, ‘Exi0000217’, ‘Exi0000242’, ‘Exi0000244’, ‘Exi0000250’
‘GDPA178’, ‘GDPA223’, ‘GDPA289’, ‘GDPA428’, ‘GDPA632’
‘YE Dec 11’, ‘YE Dec 14’, ‘YE Mar 11’, ‘YE Sep 08’, ‘YE Sep 12’
‘$0’, ‘$330,000’, ‘$35,720,000’, ‘$4,505,000’, ‘$42,095,000’
‘ODI # 2672’, ‘ODI # 2750’, ‘ODI # 3294’, ‘ODI # 3372’, ‘ODI # 3439’
‘3000-3999’, ‘4000-4999’, ‘5000-5999’, ‘50000-59999’, ‘NA’
‘soc135’, ‘soc138’, ‘soc144’, ‘soc67’, ‘soc72’
‘9EFLFN31’, ‘FE87SA-9’, ‘LD7B0U27A1’, ‘PPB178’, ‘TL88008’
‘+1 212 255 7065’, ‘+1 212 431 9303’, ‘+1 212 477 1023’, ‘+1 212 693 0588’, ‘+1 212 693 1400’
‘CH2A’, ‘CH64’, ‘CH72’, ‘CH76’, ‘CH79A’
‘20140602-2346-00’, ‘20140603-1148-01’, ‘20140603-1929-04’, ‘O0601014802’, ‘O0603155802’
‘BUS M 277’, ‘BUS M 440’, ‘BUS M 490R F’, ‘BUS M 490R TTh’, ‘BUS M 498’
‘-2.9065552’, ‘-3.193863’, ‘-3.356659’, ‘-4.304764’, ‘-4.5729218’
‘#101’, ‘#4/2/95/2’, ‘#79’, ‘#8/110/3-2’, ‘#94/2’
"’1322563’", "’151792’", "’2853979’", "’5273420’", "’7470563’"
‘F150009124’, ‘F150009169’, ‘F150009180’, ‘F150009181’, ‘F150009346’
‘BA-CMNPR2’, ‘BCOUNS’, ‘JBGD’, ‘JDAE’, ‘OBSB51413’
‘b_1’, ‘e_1’, ‘g_2’, ‘k_1’, ‘o_1’
‘P.AC.010.999’, ‘P.IH.040.999’, ‘P.IH.240.029’, ‘P.PC.030.000’, ‘P.PC.290.999’
‘-00:16:05.9’, ‘-00:19:52.9’, ‘-00:24:25.0’, ‘-00:33:24.7’, ‘-00:44:02.3’
‘APT’, ‘FUN’, ‘JAK’, ‘KEX’, ‘NAP’
‘SC_L1_O3’, ‘SC_L3_O2’, ‘SC_L5_O3’, ‘SC_L6_O1’, ‘SC_L6_O2’
‘onsen_20’, ‘onsen_44’, ‘onsen_79’, ‘onsen_80’, ‘onsen_86’
‘SDP-00-005’, ‘SDP-02-106’, ‘SDP-04-079’, ‘SDP-06-067’, ‘SDP-08-045’
‘FM0001’, ‘FM0225’, ‘FM2500’, ‘SL0304’, ‘SS0339’
‘BEL’, ‘KOR’, ‘PAR’, ‘POL’, ‘RUS’
‘-0.5423’, ‘-0.702’, ‘0.2023’, ‘0.6124’, ‘0.6757’
‘R0353226’, ‘R0356653’, ‘R0397240’, ‘R0474859’, ‘R0488595’
‘CB’, ‘NA’, ‘SC’, ‘SE’, ‘WE’
‘|S127’, ‘|S23’, ‘|S3’, ‘|S4’, ‘|S5’
‘GO:0008238’, ‘GO:0009259’, ‘GO:0009896’, ‘GO:0034332’, ‘GO:0043270’
‘MN’, ‘MO’, ‘NE’, ‘SD’, ‘WY’
‘F1-D0N343656’, ‘F1-D0N343666’, ‘F1-D0N343669’, ‘F1-D0N343677’, ‘F1-D0N343680’
‘ABF’, ‘AF’, ‘CBA’, ‘CC’, ‘EAJ’
‘E’, ‘NNE’, ‘NNW’, ‘W’, ‘WSW’
‘A’, ‘F’, ‘G’, ‘Q’, ‘R’
‘bio11’, ‘bio14’, ‘bio16’, ‘bio19’, ‘tmin4’
‘ACT-B06’, ‘MS-ACT-C15’, ‘MS3-08’, ‘MS960931’, ‘MS960961’

Appendix D Convergence

In this section, we describe the limiting behaviour of the MWS algorithm, assuming a sufficiently expressive recognition network .

The MWS(-replay) algorithm advocated in the paper permits a strong convergence result as long as the encoding probability of each program is bounded ( ) for the top- programs in the posterior . This ensures the optimal s are eventually proposed. While it is simple to enforce such a constraint in the network structure, we have found this unnecessary in practice.

We consider convergence of and in three cases, providing brief proofs below:

  1. Fixed ; large memory ()

    The encoder converges to the true posterior

  2. Fixed ; finite memory

    and converge to the best -support posterior approximation of .

    will therefore be accurate if if sufficiently sparse, with elements of covering much of the probability mass. We believe this condition is often realistic for program induction (including models explored in our paper) and it is assumed by previous work (e.g. Ellis et al. 2018).

  3. Learned

    p and Q converge to a local optimum of the ELBO, and matches . This is a stronger result than exists for (R)WS: in (R)WS, optima of the ELBO are fixed points, but convergence is not guaranteed due to differing objectives for and (although this is rarely a problem in practice).

Note that, for the MWS-fantasy variant of the algorithm, is trained directly on samples from , and so necessarily converges to the true posterior


  1. If is large enough to contain the full domain of , then is equivalent to enumerating the full posterior. is trained on samples from .

  2. is updated (to minimize ) in discrete steps proposed by . Assuming that eventually proposes each of the top- values in , these proposals will be accepted and converges to the optimal set of programs (with ). Then, learns to match .

  3. and are optimized to the same objective (ELBO), using SGD for and in monotonic steps for . The algorithm therefore converges to a local optimum of (, ), and converges to .

Appendix E Sparsity

In order to provide empirical support to the theoretically founded suggestion that MWS relies on posterior sparsity, we ran an additional experiment building on the synthetic GMM model described in section 4.1. In this experiment, we vary the cluster variance in the dataset, as a way to produce different levels of posterior sparsity. That is, a large produces more overlap between clusters, which leads to less certainty over which cluster an observation came from (less posterior sparsity).

Below, we show the improvement in posterior accuracy gained by using MWS (compared to VIMCO/RWS) for a range of cluster variances and particles K. We find significant improvement from MWS for low (=0.03, as in Figure 4) which diminishes as increases.

0.03 0.1 0.3 1.0
K Improvement in by MWS
2 10.26 3.35 0.99 0.07
5 10.21 3.57 0.73 -0.06
10 6.62 3.13 1.35 0.77
20 5.42 3.25 2.24 2.13
Table 6: The improvement is defined as:

e.1 Map Inference

Note, at the lower limit of (which corresponds to in the table above), MWS may be seen as simply an algorithm for learning with amortized MAP inference.

This task has also been approached with Deep Directed Generative Autoencoders (Ozair & Bengio, 2014), which are closely related to sparse autoencoders, and may also be seen as a special case of VAEs in which the recognition network is deterministic.

Appendix F Handwritten Characters Model

We use an LSTM for both the prior and the recognition model , where the output is a sequence of (strokeid, on/off) tuples. In the case of alphabet generalisation, both the prior and recognition model are conditioned on the alphabet index.

Each stroke in the stroke bank is parametrised by its total displacement , and parameters and to control the stroke width and sharpness. For the likelihood we use a differentiable renderer to draw strokes on a canvas: each pixel in the image is set to a value of

where is the shortest distance from that pixel to the stroke. During rendering, all strokes are placed end-to-end onto the canvas and we then take the output to be the logits of a Bernoulli likelihood.

We include an additional parameter to each stroke, which corresponds to the arc angle, where still correspond to the straight-line distance from the start point to the end point. This angle ranges from to with no discontinuity, corresponding to the two possible orientations of the circle (clockwise or anticlockwise).

We learn our model using ADAM, using 10 characters per alphabet as training data.

To evaluate the log marginal likelihood , we fix the generative model, and refine the recognition model by training for additional k iterations (batch size = ) with . We use this refined recognition model to estimate using the IWAE bound with k samples. We run this procedure with RWS, VIMCO and Sleep-fantasy training and report the best .

Figure 11: Learning to model the UCR time series dataset with Gaussian Processes, by inferring a latent kernel for each observed timeseries. Blue (left) is a 256-timepoint observation , and orange (right) is a sampled extrapolation using the inferred kernel (symbolic representation above, natural language representation below). During learning, we use as a memory for the discrete structure of kernels, but use a Variational Bayes inner loop to marginalise out a kernel’s continuous variables when evaluation of is required.

Appendix G Timeseries Data

As a preliminary experiment, we applied MWS to the task of finding explainable models for time-series data. We draw inspiration from Duvenaud et al. (2013), who frame this problem as Gaussian process (GP) kernel learning. They describe a grammar for building kernels compositionally, and demonstrate that inference in this grammar can produce highly interpretable and generalisable descriptions of the structure in a time series.

We follow a similar approach, but depart by learning a set of GP kernels jointly for each in timeseries in a dataset, rather than individually. We start with time series data provided by the UCR Time Series Classification Archive. This dataset contains 1-dimensional times series data from a variety of sources (such as household electricity usage, apparent brightness of stars, and seismometer readings). In this work, we use 1000 time series randomly drawn from this archive, and normalise each to zero mean and unit variance.

For our model, we define the following simple grammar over kernels:

  • WN is the White Noise kernel,

  • SE is the Squared Exponential kernel,

  • Per is a Periodic kernel,

  • C is a Constant,

We wish to learn a prior distribution over both the symbolic structure of a kernel and its continuous variables (, , etc.). To represent the structure of the kernel as , we use a symbolic kernel ‘expression’: a string over the characters

We define an LSTM prior over these kernel expressions, alongside parametric prior distributions over continuous latent variables (). As in previous work, exact evaluation of the marginal likelihood of a kernel expression is intractable and so requires an approximation. For this we use a simple variational inference scheme which cycles through coordinate updates to each continuous latent variable (up to 100 steps), and estimates a lowerbound on using 10 samples from the variational distribution.

Examples of latent programs discovered by our model are displayed in Figure 11. These programs describe meaningful compositional structure in the time series data, and can also be used to make highly plausible extrapolations.