Meta-Learning Deep Energy-Based Memory Models

10/07/2019 ∙ by Sergey Bartunov, et al. ∙ Google 0

We study the problem of learning associative memory – a system which is able to retrieve a remembered pattern based on its distorted or incomplete version. Attractor networks provide a sound model of associative memory: patterns are stored as attractors of the network dynamics and associative retrieval is performed by running the dynamics starting from a query pattern until it converges to an attractor. In such models the dynamics are often implemented as an optimization procedure that minimizes an energy function, such as in the classical Hopfield network. In general it is difficult to derive a writing rule for a given dynamics and energy that is both compressive and fast. Thus, most research in energy-based memory has been limited either to tractable energy models not expressive enough to handle complex high-dimensional objects such as natural images, or to models that do not offer fast writing. We present a novel meta-learning approach to energy-based memory models (EBMM) that allows one to use an arbitrary neural architecture as an energy model and quickly store patterns in its weights. We demonstrate experimentally that our EBMM approach can build compressed memories for synthetic and natural data, and is capable of associative retrieval that outperforms existing memory systems in terms of the reconstruction error and compression rate.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 6

page 7

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Associative memory has long been of interest to neuroscience and machine learning communities

(Willshaw et al., 1969; Hopfield, 1982; Kanerva, 1988). This interest has generated many proposals for associative memory models, both biological and synthetic. These models address the problem of storing a set of patterns in such a way that a stored pattern can be retrieved based on a partially known or distorted version. This kind of retrieval from memory is known as auto-association.

Due to the generality of associative retrieval, successful implementations of associative memory models have the potential to impact many applications. Attractor networks provide one well-grounded foundation for associative memory models (Amit and Amit, 1992). Patterns are stored in such a way that they become attractors of the update dynamics defined by the network. Then, if a query pattern that preserves sufficient information for association lies in the basin of attraction for the original stored pattern, a trajectory initialized by the query will converge to the stored pattern.

A variety of implementations of the general attractor principle have been proposed. The classical Hopfield network (Hopfield, 1982), for example, defines a simple quadratic energy function whose parameters serve as a memory. The update dynamics in Hopfield networks iteratively minimize the energy by changing elements of the pattern until it converges to a minimum, typically corresponding to one of the stored patterns. The goal of the writing process is to find parameter values such that the stored patterns become attractors for the optimization process and such that, ideally, no spurious attractors are created.

Many different learning rules have been proposed for Hopfield energy models, and the simplicity of the model affords compelling closed-form analysis (Storkey and Valabregue, 1999). At the same time, Hopfield memory models have fundamental limitations: (1) It is not possible to add capacity for more stored patterns by increasing the number of parameters since the number of parameters in a Hopfield network is quadratic in the dimensionality of the patterns. (2) The model lacks a means of modelling the higher-order dependencies that exist in real-world data.

Figure 1:

A schematic illustration of EBMM. The energy function is modelled by a neural network. The writing rule is then implemented as a weight update, producing parameters

from the initialization , such that the stored patterns , , become local minima of the energy (see section 3). Local minima are attractors for gradient descent which implements associative retrieval starting from a query , in this case a distorted version of .

In domains such as natural images, the potentially large dimensionality of an input makes it both ineffective and often unnecessary to model global dependencies among raw input measurements. In fact, many auto-correlations that exist in real-world perceptual data can be efficiently compressed without significant sacrifice of fidelity using either algorithmic (Wallace, 1992; Candes and Tao, 2004) or machine learning tools (Gregor et al., 2016; Toderici et al., 2017)

. The success of existing deep learning techniques suggests a more efficient recipe for processing high-dimensional inputs by modelling a hierarchy of signals with restricted or local dependencies 

(LeCun et al., 1995). In this paper we use a similar idea for building an associative memory: use a deep network’s weights to store and retrieve data.

Fast writing rules

A variety of energy-based memory models have been proposed since the original Hopfield network to mitigate its limitations (Hinton et al., 2006a; Du and Mordatch, 2019)

. Restricted Boltzmann Machines (RBMs) 

(Hinton, 2012) add capacity to the model by introducing latent variables, and deep variants of RBMs (Hinton et al., 2006a; Salakhutdinov and Larochelle, 2010) afford more expressive energy functions. Unfortunately, training Boltzmann machines remains challenging, and while recent probabilistic models such as variational auto-encoders (Kingma and Welling, 2013; Rezende et al., 2014) are easier to train, they nevertheless pay the price for expressivity in the form of slow writing. While Hopfield networks memorize patterns quickly using a simple Hebbian rule, deep probabilistic models are slow in that they rely on gradient training that requires many updates (typically thousands or more) to settle new inputs into the weights of a network. Hence, writing memories via parametric gradient based optimization is not straightforwardly applicable to memory problems where fast adaptation is a crucial requirement. In contrast, and by explicit design, our proposed method enjoys fast writing, requiring few parameter updates (we employ just 5 steps) to write new inputs into the weights of the net once meta-learning is complete. It also enjoys fast reading, requiring few gradient descent steps (again just 5 in our experiments) to retrieve a pattern. Furthermore, our writing rules are also fast in the sense that they use operations to store patterns in the memory – this scaling is the best one can hope for without additional assumptions.

We propose a novel approach that leverages meta-learning to enable fast storage of patterns into the weights of arbitrarily structured neural networks, as well as fast associative retrieval. Our networks output a single scalar value which we treat as an energy function whose parameters implement a distributed storage scheme. We use gradient-based reading dynamics and meta-learn a writing rule in the form of truncated gradient descent over the parameters defining the energy function. We show that the proposed approach enables compression via efficient utilization of network weights, as well as fast-converging attractor dynamics.

2 Retrieval in energy-based models

We focus on attractor networks as a basis for associative memory. Attractor networks define update dynamics for iterative evolution of the input pattern: .

For simplicity, we will assume that this process is discrete in time and deterministic, however there are examples of both continuous-time (Yoon et al., 2013) and stochastic dynamics (Aarts and Korst, 1988). A fixed-point attractor of deterministic dynamics can be defined as a point for which it converges, i.e. . Learning the associative memory in the attractor network is then equivalent to learning the dynamics such that its fixed-point attractors are the stored patterns and the corresponding basins of attraction are sufficiently wide for retrieval.

An energy-based attractor network is defined by the energy function mapping an input object

to a real scalar value. A particular model may then impose additional requirements on the energy function. For example if the model has a probabilistic interpretation, the energy function is usually a negative unnormalized logarithm of the object probability

, implying that the energy has to be well-behaved for the normalizing constant to exist. In our case no such constraints are put on the energy.

The attractor dynamics in energy-based models is often implemented either by iterative energy optimization 

(Hopfield, 1982) or sampling (Aarts and Korst, 1988). In the optimization case considered further in the paper, attractors are conveniently defined as local minimizers of the energy function.

While a particular energy function may suggest a number of different optimization schemes for retrieval, convergence to a local minimum of an arbitrary function is NP-hard. Thus, we consider a class of energy functions that are differentiable on , bounded from below and define the update dynamics over steps via gradient descent:

(1)

With appropriately set step sizes this procedure asymptotically converges to a local minimum of energy  (Nesterov, 2013). Since asymptotic convergence may be not enough for practical applications, we truncate the optimization procedure (1) at steps and treat as a result of the retrieval. While vanilla gradient descent (1

) is sufficient to implement retrieval, in our experiments we employ a number of extensions, such as the use of Nesterov momentum and projected gradients, which are thoroughly described in Appendix 

B.

Relying on the generic optimization procedure allows us to translate the problem of designing update dynamics with desirable properties to constructing an appropriate energy function, which in general is equally difficult. In the next section we discuss how to tackle this difficulty.

Figure 2: Visualization of gradient descent iterations during retrieval of Omniglot characters (largest model). 4 random images are shown from the batch of 64.

3 Meta-learning gradient-based writing rules

As discussed in previous sections, our ambition is to be able to use any scalar-output neural network as an energy function for associate retrieval. We assume a parametric model

differentiable in both and , and bounded from below as a function of

. These are mild assumptions that are often met in the existing neural architectures with an appropriate choice of activation functions, e.g. tanh.

The writing rule then compresses input patterns into parameters such that each of the stored patterns becomes a local minimum of or, equivalently, creates a basin of attraction for gradient descent in the pattern space.

This property can be practically quantified by the reconstruction error, e.g. mean squared error, between the stored pattern and the pattern retrieved from a distorted version of :

(2)

Here we assume a known, potentially stochastic distortion model such as randomly erasing certain number of dimensions, or salt-and-pepper noise. While one can consider loss (2) as a function of network parameters and call minimization of this loss with a conventional optimization method a writing rule — it will require many optimization steps to obtain a satisfactory solution and thus does not fall into our definition of fast writing rules (Santoro et al., 2016).

Hence, we explore a different approach to designing a fast writing rule inspired by recently proposed gradient-based meta-learning techniques (Finn et al., 2017)

which we call meta-learning energy-based memory models (EBMM). Namely we perform many write and read optimization procedures with a small number of iterations for several sets of write and read observations, and backpropagate into the initial parameters of

— to learn a good starting location for fast optimization. As usual, we assume that we have access to the underlying data distribution over datasets of interest from which we can sample sufficiently many training datasets, even if the actual dataset our memory model will be used to store (at test time) is not available at the training time (Santoro et al., 2016).

The straightforward application of gradient-based meta-learning to the loss (2) is problematic, because we generally cannot evaluate or differentiate through the expectation over stochasticity of the distortion model in a way that is reliable enough for adaptation, because as the dimensionality of the pattern space grows the number of possible (and representative) distortions grows exponentially.

Instead, we define a different writing loss , minimizing which serves as a proxy for ensuring that input patters are local minima for the energy , but does not require costly retrieval of exponential number of distorted queries.

(3)

As one can see, the writing loss (3) consists of three terms. The first term is simply the energy value which we would like to be small for stored patterns relative to non-stored patterns. The condition for to be a local minimum of is two-fold: first, the gradient at is zero, which is captured by the second term of the writing loss, and, second, the hessian is positive-definite. The latter condition is difficult to express in a form that admits efficient optimization and we found that meta-learning using just first two terms in the writing loss is sufficient. Finally, the third term limits deviation from initial or prior parameters which we found helpful from optimization perspective.

We use truncated gradient descent on the writing loss (3) to implement the writing rule:

(4)

To ensure that gradient updates (4) are useful for minimization of the reconstruction error (2) we train the combination of retrieval and writing rules end-to-end, meta-learning initial parameters , learning rate schedules and meta-parameters to perform well on random sets of patterns from the data distribution :

(5)

Crucially, the proposed EBMM implements both and operations via truncated gradient descent which can be itself differentiated through in order to set up a tractable meta-learning problem. While truncated gradient descent is not guaranteed to converge, reading and writing rules are trained jointly to minimize the reconstruction error (2) and thus ensure that they converge sufficiently fast. This property turns this potential drawback of the method to its advantage over provably convergent, but slow models. It also relaxes the necessity of stored patterns to create too well-behaved basins of attraction because if, for example, a stored pattern creates a nuisance attractor in the dangerous proximity of the main one, the gradient descent (1) might successfully pass it with appropriately learned step sizes .

4 Experiments

(a) Omniglot.
(b) CIFAR.
Figure 3: Distortion (reconstruction error) vs rate (memory size) analysis on batches of 64 images.

In this section we experimentally evaluate EBMM on a number of real-world image datasets. The performance of EBMM is compared to a set of relevant baselines: Long-Short Term Memory (LSTM) 

(Hochreiter and Schmidhuber, 1997), the classical Hopfield network (Hopfield, 1982), Memory-Augmented Neural Networks (MANN) (Santoro et al., 2016) (which are a variant of the Differentiable Neural Computer (Graves et al., 2016)), Memory Networks (Weston et al., 2014), Differentiable Plasticity model of Miconi et al. (2018) (a generalization of the Fast-weights RNN (Ba et al., 2016)) and Dynamic Kanerva Machine (Wu et al., 2018). Some of these baselines failed to learn at all for real-world images. In the Appendix A.2 we provide additional experiments with random binary strings with a larger set of representative models.

The experimental procedure is the following: we write a fixed-sized batch of images into a memory model, then corrupt a random block of the written image to form a query and let the model retrieve the originally stored image. By varying the memory size and repeating this procedure, we perform distortion/rate analysis, i.e. we measure how well a memory model can retrieve a remembered pattern for a given memory size.

We define memory size as a number of float32 numbers used to represent a modifiable part of the model. In the case of EBMM it is the subset of all network weights that are modified by the gradient descent (4), for other models it is size of the state, e.g. the number of slots the slot size for a Memory Network. To ensure fair comparison, all models use the same encoder (and decoder, when applicable) networks, which architectures are described in Appendix C. In all experiments EBMM used read iterations and write iterations.

4.1 Omniglot characters

We begin with experiments on the Omniglot dataset (Lake et al., 2015) which is now a standard evaluation of fast adaptation models. For simplicity of comparison with other models, we downscaled the images to

size and binarized them using a

threshold. We use Hamming distance as the evaluation metric. For training and evaluation we apply a

randomly positioned binary distortions (see figure 2 for example).

We explored two versions of EBMM for this experiment that use parts of fully-connected (FC, see Appendix C.2) and convolutional (conv, Appendix C.3) layers in a 3-block ResNet (He et al., 2016) as writable memory.

Figure 2(a)

contains the distortion-rate analysis of different models which in this case is the Hamming distance as a function of memory size. We can see that there are two modes in the model behaviour. For small memory sizes, learning a lossless storage becomes a hard problem and all models have to find an efficient compression strategy, where most of the difference between models can be observed. However, after a certain critical memory size it becomes possible to rely just on the autoencoding which in the case of a relatively simple dataset such as Omniglot can be efficiently tackled by the ResNet architecture we are using. Hence, even Memory Networks that do not employ any compression mechanisms beyond using distributed representations can retrieve original images almost perfectly. In this experiment MANN has been able to learn the most efficient compression strategy, but could not make use of larger memory. EBMM performed well both in the high and low compression regimes with convolutional memory being more efficient over the fully-connected memory. Further, in CIFAR and ImageNet experiments we only use the convolutional version of EBMM.

We visualize the process of associative retrieval in figure 2. The model successfully detected distorted parts of images and clearly managed to retrieve the original pixel intensities. We also show energy levels of the distorted query image, the recalled images through 5 read iterations, and the original image. In most cases we found the energy of the retrieved images to match to energy of the originals, however, an error would occur when they sometimes do not match (see the green example).

4.2 Real images from CIFAR-10

Figure 4: Visualization of gradient descent iterations during retrieval of CIFAR images. The last column contains reconstructions from Memory networks (both models use 10k memory).

We conducted a similiar study on the CIFAR dataset. Here we used the same network architecture as in the Omniglot experiment. The only difference in the experimental setup is that we used squared error as an evaluation metric since the data is continuous RGB images.

Figure 2(b) contains the corresponding distortion-rate analysis. EBMM clearly dominates in the comparison. One important reason for that is the ability of the model to detect the distorted part of the image so it can avoid paying the reconstruction loss for the rest of the image. Moreover, unlike Omniglot where images can be almost perfectly reconstructed by an autoencoder with a large enough code, CIFAR images have much more variety and larger channel depth. This makes an efficient joint storage of a batch as important as an ability to provide a good decoding of the stored original.

Gradient descent iterations shown in figure 4 demonstrate the successful application of the model to natural images. Due to the higher complexity of the dataset, the reconstructions are imperfect, however the original patterns are clearly recognizable. Interestingly, the learned optimization schedule starts with one big gradient step providing a coarse guess that is then gradually refined.

4.3 ImageNet 64x64

We further investigate the ability of EBMM to handle complex visual datasets by applying the model to ImageNet. Similar to the CIFAR experiment, we construct queries by corrupting a quarter of the image with

random masks. The model is based on a 4-block version of the CIFAR network. While the network itself is rather modest compared to existing ImageNet classifiers, the sequential training regime resembling large-state recurrent networks prevents us from using anything significantly bigger than a CIFAR model. Due to prohibitively expensive computations required by experimenting at this scale, we also had to decrease the batch size to 32.

(a) Distortion-rate analysis on ImageNet.
(b) Retrieval of 64x64 ImageNet images (all models have 18K memory).
Figure 5: ImageNet results.

The distortion-rate analysis (Figure 4(a)) shows the behaviour similar to the CIFAR experiment. EBMM pays less reconstruction error than other models and MANN demonstrates better performance than Memory Networks for smaller memory sizes; however, the asymptotic behaviour of these two models will likely match.

The qualitative results are shown on the Figure 4(b). Despite the arguably more difficult images, EBMM is able to capture shape and color information, although not in high detail. We believe this could likely be mitigated by using larger models. Additionally, using techniques such as perceptual losses (Johnson et al., 2016) instead of naive pixel-wise reconstruction errors can improve visual quality with the existing architectures, but we leave these ideas for future work.

4.4 Analysis of energy levels

Figure 6: Energy distributions of different classes of patterns under an Omniglot model. Memories are the patterns written into memory, non-memories are other randomly sampled images and distorted memories are the written patterns distorted as during the retrieval. CIFAR images were produced by binarizing the original RGB images and serve as out-of-distribution samples.

We were also interested in whether energy values provided by EBMM are interpretable and whether they can be used for associative retrieval. We took an Omniglot model and inspected energy levels of different types of patterns. It appears that, despite not explicitly trained to, EBMM in many cases could discriminate between in-memory and out-of-memory patterns, see Figure 6

. Moreover, distorted patterns had even higher energy than simply unknown patterns. Out-of-distribution patterns, here modelled as binarized CIFAR images, can be seen as clear outliers.

5 Related work

Deep neural networks are capable of both compression (Parkhi et al., 2015; Kraska et al., 2018)

, and memorizing training patterns 

(Zhang et al., 2016). Taken together, these properties make deep networks an attractive candidate for memory models, with both exact recall and compressive capabilities. However, there exists a natural trade-off between the speed of writing and the realizable capacity of a model (Ba et al., 2016). Approaches similar to ours in their use of gradient descent dynamics, but lacking fast writing, have been proposed by  Hinton et al. (2006b) and recently revisited by Du and Mordatch (2019) together with another stochastic deep energy model (Krotov and Hopfield, 2016). In general it is difficult to derive a writing rule for a given dynamics equation or an energy model which we attempt to address in this work.

The idea of meta-learning (Thrun and Pratt, 2012; Hochreiter et al., 2001) has found many successful applications in few-shot supervised (Santoro et al., 2016; Vinyals et al., 2016)

and unsupervised learning 

(Bartunov and Vetrov, 2016; Reed et al., 2017). Our model is particularly influenced by works of Andrychowicz et al. (2016) and Finn et al. (2017), which experiment with meta-learning efficient optimization schedules and, perhaps, can be seen as an ultimate instance of this principle since we implement both learning and inference procedures as optimization. Perhaps the most prominent existing application of meta-learning for associative retrieval is found in the Kanerva Machine (Wu et al., 2018), which combines a variational auto-encoder with a latent linear model to serve as an addressable memory. The Kanerva machine benefits from a high-level representation extracted by the auto-encoder. However, its linear model can only represent convex combinations of memory slots and is thus less expressive than distributed storage realizable in weights of a deep network.

We described literature on associative and energy-based memory in section 1, but other types of memory should be mentioned in connection with our work. Many recurrent architectures aim at maintaining efficient compressive memory (Graves et al., 2016; Rae et al., 2018). Models developed by Ba et al. (2016) and Miconi et al. (2018) enable associative recall by combining standard RNNs with structures similar to Hopfield network. And, recently Munkhdalai et al. (2019) explored the idea of using arbitrary feed-forward networks as a key-value storage.

Finally, the idea of learning a surrogate model to define a gradient field useful for a problem of interest has a number of incarnations.  Putzky and Welling (2017) jointly learn an energy model and an optimizer to perform denoising or impainting of images.  Marino et al. (2018) use gradient descent on an energy defined by variational lower bound for improving variational approximations. And, Belanger et al. (2017) formulate a generic framework for energy-based prediction driven by gradient descent dynamics.

6 Conclusion

We introduced a novel learning method for deep associative memory systems. Our method benefits from the recent progress in deep learning so that we can use a very large class of neural networks both for learning representations and for storing patterns in network weights. At the same time, we are not bound by slow gradient learning thanks to meta-learning of fast writing rules. We showed that our method is applicable in a variety of domains from non-compressible (binary strings; see Appendix) to highly compressible (natural images) and that the resulting memory system uses available capacity efficiently. We believe that more elaborate architecture search could lead to stronger results on par with state-of-the-art generative models.

The existing limitation of EBMM is the batch writing assumption, which is in principle possible to relax. This would enable embedding of the model in reinforcement learning agents or into other tasks requiring online-updating memory. It would be also interesting to explore a stochastic variant of EBMM that could return different associations in the presence of uncertainty caused by compression. Finally, many general principles of learning attractor models with desired properties are yet to be discovered and we believe that our results provide a good motivation for this line of research.

Acknowledgments

We would like to thank Yan Wu and Daan Wierstra for many insightful discussions that improved the paper. Yan Wu also helped us with setting up Dynamic Kanerva Machine and reviewing the manuscript.

References

  • E. Aarts and J. Korst (1988) Simulated annealing and boltzmann machines. Cited by: §2, §2.
  • D. J. Amit and D. J. Amit (1992) Modeling brain function: the world of attractor neural networks. Cambridge university press. Cited by: §1.
  • M. Andrychowicz, M. Denil, S. Gomez, M. W. Hoffman, D. Pfau, T. Schaul, B. Shillingford, and N. De Freitas (2016) Learning to learn by gradient descent by gradient descent. In Advances in Neural Information Processing Systems, pp. 3981–3989. Cited by: §5.
  • A. Antoniou, H. Edwards, and A. Storkey (2018) How to train your maml. arXiv preprint arXiv:1810.09502. Cited by: §B.4.
  • J. Ba, G. E. Hinton, V. Mnih, J. Z. Leibo, and C. Ionescu (2016) Using fast weights to attend to the recent past. In Advances in Neural Information Processing Systems, pp. 4331–4339. Cited by: §A.2, §4, §5, §5.
  • S. Bartunov and D. P. Vetrov (2016) Fast adaptation in generative models with generative matching networks. arXiv preprint arXiv:1612.02192. Cited by: §5.
  • D. Belanger, B. Yang, and A. McCallum (2017) End-to-end learning for structured prediction energy networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 429–439. Cited by: §B.4, §5.
  • E. Candes and T. Tao (2004) Near optimal signal recovery from random projections: universal encoding strategies?. arXiv preprint math/0410542. Cited by: §1.
  • J. Chung, C. Gulcehre, K. Cho, and Y. Bengio (2014)

    Empirical evaluation of gated recurrent neural networks on sequence modeling

    .
    arXiv preprint arXiv:1412.3555. Cited by: §C.1.
  • Y. Du and I. Mordatch (2019) Implicit generation and generalization in energy-based models. arXiv preprint arXiv:1903.08689. Cited by: §1, §5.
  • C. Finn, P. Abbeel, and S. Levine (2017) Model-agnostic meta-learning for fast adaptation of deep networks. arXiv preprint arXiv:1703.03400. Cited by: Appendix D, §3, §5.
  • A. Graves, G. Wayne, M. Reynolds, T. Harley, I. Danihelka, A. Grabska-Barwińska, S. G. Colmenarejo, E. Grefenstette, T. Ramalho, J. Agapiou, et al. (2016) Hybrid computing using a neural network with dynamic external memory. Nature 538 (7626), pp. 471. Cited by: §A.2, §4, §5.
  • K. Gregor, F. Besse, D. J. Rezende, I. Danihelka, and D. Wierstra (2016) Towards conceptual compression. In Advances In Neural Information Processing Systems, pp. 3549–3557. Cited by: §1.
  • K. He, X. Zhang, S. Ren, and J. Sun (2015) Delving deep into rectifiers: surpassing human-level performance on imagenet classification. In

    Proceedings of the IEEE international conference on computer vision

    ,
    pp. 1026–1034. Cited by: Appendix A.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    ,
    pp. 770–778. Cited by: §4.1.
  • G. E. Hinton, S. Osindero, and Y. Teh (2006a) A fast learning algorithm for deep belief nets. Neural computation 18 (7), pp. 1527–1554. Cited by: §1.
  • G. E. Hinton (2012) A practical guide to training restricted boltzmann machines. In Neural networks: Tricks of the trade, pp. 599–619. Cited by: §1.
  • G. Hinton, S. Osindero, M. Welling, and Y. Teh (2006b) Unsupervised discovery of nonlinear structure using contrastive backpropagation. Cognitive science 30 (4), pp. 725–731. Cited by: §5.
  • S. Hochreiter and J. Schmidhuber (1997) Long short-term memory. Neural computation 9 (8), pp. 1735–1780. Cited by: §A.2, Table 1, §4.
  • S. Hochreiter, A. S. Younger, and P. R. Conwell (2001) Learning to learn using gradient descent. In International Conference on Artificial Neural Networks, pp. 87–94. Cited by: §5.
  • J. J. Hopfield (1982) Neural networks and physical systems with emergent collective computational abilities. Proceedings of the national academy of sciences 79 (8), pp. 2554–2558. Cited by: §A.2, §1, §1, §2, §4.
  • J. Johnson, A. Alahi, and L. Fei-Fei (2016)

    Perceptual losses for real-time style transfer and super-resolution

    .
    In European conference on computer vision, pp. 694–711. Cited by: §4.3.
  • P. Kanerva (1988) Sparse distributed memory. MIT press. Cited by: §1.
  • D. P. Kingma and M. Welling (2013) Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114. Cited by: §1.
  • T. Kraska, A. Beutel, E. H. Chi, J. Dean, and N. Polyzotis (2018) The case for learned index structures. In Proceedings of the 2018 International Conference on Management of Data, pp. 489–504. Cited by: §5.
  • D. Krotov and J. J. Hopfield (2016) Dense associative memory for pattern recognition. In Advances in Neural Information Processing Systems, pp. 1172–1180. Cited by: §5.
  • B. M. Lake, R. Salakhutdinov, and J. B. Tenenbaum (2015) Human-level concept learning through probabilistic program induction. Science 350 (6266), pp. 1332–1338. Cited by: §4.1.
  • Y. LeCun, Y. Bengio, et al. (1995) Convolutional networks for images, speech, and time series. The handbook of brain theory and neural networks 3361 (10), pp. 1995. Cited by: §1.
  • I. Loshchilov and F. Hutter (2017) Fixing weight decay regularization in adam. arXiv preprint arXiv:1711.05101. Cited by: Appendix A.
  • J. Marino, Y. Yue, and S. Mandt (2018) Iterative amortized inference. arXiv preprint arXiv:1807.09356. Cited by: §5.
  • T. Miconi, J. Clune, and K. O. Stanley (2018) Differentiable plasticity: training plastic neural networks with backpropagation. arXiv preprint arXiv:1804.02464. Cited by: §A.2, Table 1, §4, §5.
  • T. Munkhdalai, A. Sordoni, T. Wang, and A. Trischler (2019) Metalearned neural memory. ArXiv abs/1907.09720. Cited by: §5.
  • Y. E. Nesterov (1983) A method for solving the convex programming problem with convergence rate o (1/k^ 2). In Dokl. akad. nauk Sssr, Vol. 269, pp. 543–547. Cited by: §B.2.
  • Y. Nesterov (2013) Introductory lectures on convex optimization: a basic course. Vol. 87, Springer Science & Business Media. Cited by: §2.
  • O. M. Parkhi, A. Vedaldi, A. Zisserman, et al. (2015)

    Deep face recognition.

    .
    In BMVC, Vol. 1, pp. 6. Cited by: §5.
  • P. Putzky and M. Welling (2017) Recurrent inference machines for solving inverse problems. arXiv preprint arXiv:1706.04008. Cited by: §5.
  • J. W. Rae, S. Bartunov, and T. P. Lillicrap (2018) Meta-learning neural bloom filters. Cited by: §5.
  • S. Reed, Y. Chen, T. Paine, A. v. d. Oord, S. Eslami, D. Rezende, O. Vinyals, and N. de Freitas (2017)

    Few-shot autoregressive density estimation: towards learning to learn distributions

    .
    arXiv preprint arXiv:1710.10304. Cited by: §5.
  • D. J. Rezende, S. Mohamed, and D. Wierstra (2014) Stochastic backpropagation and approximate inference in deep generative models. arXiv preprint arXiv:1401.4082. Cited by: §1.
  • R. Salakhutdinov and H. Larochelle (2010) Efficient learning of deep boltzmann machines. In

    Proceedings of the thirteenth international conference on artificial intelligence and statistics

    ,
    pp. 693–700. Cited by: §1.
  • A. Santoro, S. Bartunov, M. Botvinick, D. Wierstra, and T. Lillicrap (2016) Meta-learning with memory-augmented neural networks. In International conference on machine learning, pp. 1842–1850. Cited by: §A.2, Table 1, §3, §3, §4, §5.
  • A. J. Storkey and R. Valabregue (1999) The basins of attraction of a new hopfield learning rule. Neural Networks 12 (6), pp. 869–876. Cited by: §A.2, §1.
  • S. Thrun and L. Pratt (2012) Learning to learn. Springer Science & Business Media. Cited by: §5.
  • G. Toderici, D. Vincent, N. Johnston, S. Jin Hwang, D. Minnen, J. Shor, and M. Covell (2017) Full resolution image compression with recurrent neural networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5306–5314. Cited by: §1.
  • O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra, et al. (2016) Matching networks for one shot learning. In Advances in neural information processing systems, pp. 3630–3638. Cited by: §5.
  • G. K. Wallace (1992) The jpeg still picture compression standard. IEEE transactions on consumer electronics 38 (1), pp. xviii–xxxiv. Cited by: §1.
  • J. Weston, S. Chopra, and A. Bordes (2014) Memory networks. arXiv preprint arXiv:1410.3916. Cited by: §A.2, Table 1, §4.
  • D. J. Willshaw, O. P. Buneman, and H. C. Longuet-Higgins (1969) Non-holographic associative memory. Nature 222 (5197), pp. 960. Cited by: §1.
  • Y. Wu, G. Wayne, K. Gregor, and T. Lillicrap (2018) Learning attractor dynamics for generative memory. In Advances in Neural Information Processing Systems, pp. 9401–9410. Cited by: §4, §5.
  • K. Yoon, M. A. Buice, C. Barry, R. Hayman, N. Burgess, and I. R. Fiete (2013) Specific evidence of low-dimensional continuous attractor dynamics in grid cells. Nature neuroscience 16 (8), pp. 1077. Cited by: §2.
  • C. Zhang, S. Bengio, M. Hardt, B. Recht, and O. Vinyals (2016) Understanding deep learning requires rethinking generalization. arXiv preprint arXiv:1611.03530. Cited by: §5.

Appendix A Additional experimental details

We train all models using AdamW optimizer (Loshchilov and Hutter, 2017) with learning rate and weight decay

, all other parameters set to AdamW defaults. We also apply gradient clipping by global norm at

. All models were allowed to train for gradient updates or 1 week whichever ended first. All baseline models always made more updates than EBMM.

One instance of each model has been trained. Error bars showed on the figures correspond to 5- and 95-percentiles computed on a 1000 of random batches.

In all experiments we used initialization scheme proposed by He et al. (2015).

a.1 Failure modes of baseline models

Image retrieval appeared to be difficult for a number of baselines.

LSTM failed to train due to quadratic growth of the hidden-to-hidden weight matrix with increase of the hidden state size. Even moderately large hidden states were prohibitive for training on a modern GPU.

Differential plasticity additionally struggled to train when using a deep representation instead of the raw image data. We hypothesize that it was challenging for the encoder-decoder pair to train simultaneously with the recurrent memory, because in the binary experiment, while not performing the best, the model managed to learn a memorization strategy.

Finally, the Kanerva machine could not handle the relatively strong noise we used in this task. By design, Kanerva machine is agnostic to the noise model and is trained simply to maximize the data likelihood, without meta-learning a particular de-noising scheme. In the presence of the strong noise it failed to train on sequences longer than 4 images.

a.2 Experiments with random binary patterns


[width=10em]Method# patterns 16 32 48 64 96
Hopfield network, Hebb rule 0.4 5.0 9.8 13.0 16.5
Hopfield network, Storkey rule 0.0 0.9 6.3 11.3 17.1
Hopfield network, pseudo-inverse rule 0.0 0.0 0.3 4.3 22.5
Differentiable plasticity (Miconi et al., 2018) 3.0 13.2 20.8 26.3 34.9
MANN (Santoro et al., 2016) 0.1 0.2 1.8 4.25 9.6
LSTM (Hochreiter and Schmidhuber, 1997) 30 58 63 64 64
Memory Networks (Weston et al., 2014) 0.0 0.0 0.0 0.0 10.5
EBMM RNN 0.0 0.0 0.1 0.5 4.2
Table 1: Number of error bits in retrieved binary patterns.

Besides highly-structured patterns such as Omniglot or ImageNet images we also conducted experiments on random binary patterns – the classical setting in which associative memory models have been evaluated. While such random patterns are not compressible in expectation due to lack of any internal structure, by this experiment we examine the efficiency of a learned coding scheme, i.e. how well can each of the models store binary information in the floating point format.

We generate random -dimensional patterns, each dimension of which takes values of or with equal probability, corrupt half of the bits and use this as a query for associative retrieval. We compare EBMM employing a simple fully recurrent network (an RNN using the same input at each iteration, see Appendix C.1) as an energy model, against a classical Hopfield network (Hopfield, 1982) using different writing rules (Storkey and Valabregue, 1999) and a recently proposed differential plasticity model (Miconi et al., 2018). It is worth noting the differentiable plasticity model is a generalized variant of Fast Weights (Ba et al., 2016), where the plasticity of each activation is modulated separately. We also consider an LSTM (Hochreiter and Schmidhuber, 1997), Memory network (Weston et al., 2014) and a Memory-Augmented Neural Network (MANN) used by Santoro et al. (2016) which is a variant of the DNC (Graves et al., 2016).

Since the Hopfield network has limited capacity that is strongly tied to input dimensionality and that cannot be increased without adding more inputs, we use its memory size as a reference and constrain all other baseline models to use the same amount of memory. For this task it equals to

to parametrize a symmetric matrix and a frequency vector. We measure Hamming distance between the original and the retrieved pattern for each system, varying the number of stored patterns. We found it difficult to train the recurrent baselines on this task, so we let all models clamp non-distorted bits to their true values at retrieval which significantly stabilized training.

As we can see from the results shown in Table 1, EBMM learned a highly efficient associative memory. Only the EBMM and the memory network could achieve near-zero error when storing 64 vectors and even though EBMM could not handle 96 vectors with this number of parameters, it was the most accurate memory model.

Appendix B Reading in EBMM

b.1 Projected gradient descent

We described the basic reading procedure in section 2, however, there is a number of extensions we found useful in practice.

Since in all experiments we work with data constrained to the

interval, one has to ensure that the read data also satisfies this constraint. One strategy that is often used in the literature is to model the output as an argument to a sigmoid function (logits). This may not work well for values close to the interval boundaries due to vanishing gradient, so instead we adopted a projected gradient descent, i.e.

where the proj function clips data to the interval.

Quite interestingly, this formulation allows more flexible behavior of the energy function. If a stored pattern has one of the dimensions exactly on the feasible interval boundary, e.g. , then does not necessarily have to be zero, since will not be able to go beyond zero. We provide more information on the properties of storied patterns in further appendices.

b.2 Nesterov momentum

Another extension we found useful is to employ Nesterov momentum (Nesterov, 1983) into the optimization scheme and we use it in all our experiments.

b.3 Step sizes

To encourage learning converging attractor dynamics we constrained step sizes to be a non-increasing sequence:

Then the actual parameters to meta-learn is the initial step size and the logits . We apply a similar parametrization to the momentum learning rates .

b.4 Step-wise reconstruction loss

As has often been found helpful in the literature (Belanger et al., 2017; Antoniou et al., 2018) we apply the reconstruction loss (2) not just to the final iterate of the gradient descent, but to all iterates simultaneously:

Appendix C Architecture details

Below we provide pseudocode for computational graphs of models used in the experiments. All modules containing memory parameters are specifically named as memory.

c.1 Gated RNN

We used a fairly standard recurrent architecture only equipped with an update gate as in (Chung et al., 2014). We unroll the RNN for 5 steps and compute the energy value from the last hidden state.

hidden_size = 1024
input_size = 128
# 128 * (128 - 1) / 2 + 128 parameters in total
dynamic_size = (input_size - 1) // 2

state = repeat_batch(zeros(hidden_size))
memory = Linear(input_size, dynamic_size)

gate = Sequential([
    Linear(input_size + hidden_size, hidden_size),
    sigmoid
])

static = Linear(input_size + hidden_size, hidden_size - dynamic_size)

for hop in xrange(5):
    z = concat(x, state)

    dynamic_part = memory(x)
    static_part = static(z)
    c = tanh(concat(dynamic_part, static_part))
    u = gate(z)
    state = u * c + (1 - u) * state

energy = Linear(1)(state)

c.2 ResNet, fully-connected memory

channels = 32 hidden_size = 512 representation_size = 512 static_size = representation_size - dynamic_size state = repeat_batch(zeros(hidden_size)) encoder = Sequential([ ResBlock(channels * 1, kernel=[3, 3], stride=2, downscale=False), ResBlock(channels * 2, kernel=[3, 3], stride=2, downscale=False), ResBlock(channels * 3, kernel=[3, 3], stride=2, downscale=False), flatten, Linear(256), LayerNorm() ]) gate = Sequential([ Linear(hidden_size), sigmoid ]) hidden = Sequential([ Linear(hidden_size), tanh ]) x = encoder(x) memory = Linear(input_size, dynamic_size) dynamic_part = memory(x) static_part = Linear(static_size)(x) x = tanh(concat(dynamic_part, static_part)) for hop in xrange(3): z = concat(x, state) c = hidden(z) c = LayerNorm()(c) u = gate(z) state = u * c + (1 - u) * c h = tanh(Linear(1024)()(state)) energy = Linear(1)(h)

The encoder module is also shared with all baseline models together with its transposed version as a decoder.

c.3 ResNet, convolutional memory

channels = 32
x = ResBlock(channels * 1, kernel=[3, 3], stride=2, downscale=True)(x)
x = ResBlock(channels * 2, kernel=[3, 3], stride=2, downscale=True)(x)

def resblock_bottleneck(x, channels, bottleneck_channels, downscale=False):
  static_size = channels - dynamic_size

  z = x

  x = Conv2D(bottleneck_channels, [1, 1])(x)
  x = LayerNorm()(x)
  x = tanh(x)

  if downscale:
    memory_part = Conv2D(dynamic_size, kernel=[3, 3], stride=2, downscale=True)(x)
    static_part = Conv2D(static_size, kernel=[3, 3], stride=2, downscale=True)(x)
  else:
    memory_part = Conv2D(dynamic_size, kernel=[3, 3], stride=1, downscale=False)(x)
    static_part = Conv2D(static_size, kernel=[3, 3], stride=1, downscale=False)(x)
    x = concat([static_part, memory_part], -1)
    x = LayerNorm)(x)
    x = tanh(x)

  z = Conv2D(channels, kernel=[1, 1])(z)
  if downscale:
    z = avg_pool(z, [3, 3] + [1], stride=2)
    x += z
  return x

x = resblock_bottleneck(x, channels * 4, channels * 2, False)
x = resblock_bottleneck(x, channels * 4, channels * 2, True)

recurrent = Sequential([
  Conv2D(hidden_size, kernel=[3, 3], stride=1),
  LayerNorm(),
  tanh
])

update_gate = Sequential([
  Conv2D(hidden_size, kernel=[1, 1], stride=1),
  LayerNorm(),
  sigmoid
])

hidden_size = 128
hidden_state = repeat_batch(zeros(4, 4, hidden_size))

for hop in xrange(3):
  z = concat([x, hidden_state], -1)
  candidate = recurrent(z)
  u = update_gate(z)
  hidden_state = u * candidate + (1. - u) * hidden_state

  x = Linear(1024)(x)
  x = tanh(x)
  energy = Linear(1)

c.4 ResNet, ImageNet

This network is effectively a slightly larger version of the ResNet with convolutional memory described above.

channels = 64
dynamic_size = 8

x = ResBlock(channels * 1, kernel=[3, 3], stride=2, downscale=True)(x)
x = ResBlock(channels * 2, kernel=[3, 3], stride=2, downscale=True)(x)

x = resblock_bottleneck(x, channels * 4, channels * 2, True)
x = resblock_bottleneck(x, channels * 4, channels * 2, True)

recurrent = Sequential([
  Conv2D(hidden_size, kernel=[3, 3], stride=1),
  LayerNorm(),
  tanh
])

update_gate = Sequential([
  Conv2D(hidden_size, kernel=[1, 1], stride=1),
  LayerNorm(),
  sigmoid
])

hidden_size = 256
hidden_state = repeat_batch(zeros(4, 4, hidden_size))

for hop in xrange(3):
  z = concat([x, hidden_state], -1)
  candidate = recurrent(z)
  u = update_gate(z)
  hidden_state = u * candidate + (1. - u) * hidden_state

  x = Linear(1024)(x)
  x = tanh(x)
  energy = Linear(1)

c.5 The role of skip-connections in energy models

Gradient-based meta-learning and EBMM in particular rely on the expressiveness of not just the forward pass of a network, but also the backward pass that is used to compute a gradient. This may require special considerations about the network architecture.

One may notice that all energy models considered above have an element of recurrency of some sort. While the recurrency itself is not crucial for good performance, skip-connections, of which recurrency is a special case, are.

We can illustrate this by considering an energy function of the following form:

Here we can think of as a representation from which the energy is computed. We allow the representation to be first computed as and then to be refined by adding .

During retrieval, we use gradient of the energy with respect to which can be computed as

One can see, that with a skip-connection the model is able to refine the gradient together with the energy value.

A simple way of incorporating such skip-connections is via recurrent computation. We allow the model to use a gating mechanism that can modulate the refinement and prevent from unnecessary updates. We found that usually a small number of recurrent steps (3-5) is enough for good performance.

Appendix D Explanations on the writing loss

Our setting deviates from the standard gradient-based meta-learning as described in (Finn et al., 2017)

. In particular, we are not using the same loss function (naturally defined by the energy function) in adaptation and inference phases. As we explain in section 

3, writing loss (3) besides just the energy term also contains the gradient term and the prior term.

Even though we found it sufficient to use just the energy value as the writing loss, perhaps not surprisingly, minimizing the gradient norm appeared to help optimization especially in the early training (see figure 7) and lead to better final results.

Figure 7: Effect of including the term in the writing loss (3) on Omniglot.

We use an individual learning rate per each writable layer and each of the three loss terms, initialized at and learned together with other parameters. We used softplus function to ensure that all learning rates remain non-negative.