Relational Neural Expectation Maximization: Unsupervised Discovery of Objects and their Interactions

02/28/2018 ∙ by Sjoerd van Steenkiste, et al. ∙ IDSIA berkeley college 0

Common-sense physical reasoning is an essential ingredient for any intelligent agent operating in the real-world. For example, it can be used to simulate the environment, or to infer the state of parts of the world that are currently unobserved. In order to match real-world conditions this causal knowledge must be learned without access to supervised data. To address this problem we present a novel method that learns to discover objects and model their physical interactions from raw visual images in a purely unsupervised fashion. It incorporates prior knowledge about the compositional nature of human perception to factor interactions between object-pairs and learn efficiently. On videos of bouncing balls we show the superior modelling capabilities of our method compared to other unsupervised neural approaches that do not incorporate such prior knowledge. We demonstrate its ability to handle occlusion and show that it can extrapolate learned knowledge to scenes with different numbers of objects.



There are no comments yet.


page 3

page 6

page 7

page 9

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Humans rely on common-sense physical reasoning to solve many everyday physics-related tasks (Lake et al., 2016). For example, it enables them to foresee the consequences of their actions (simulation), or to infer the state of parts of the world that are currently unobserved. This causal understanding is an essential ingredient for any intelligent agent that is to operate within the world.

Common-sense physical reasoningperhaps there is a better insert is facilitated by the discovery and representation of objects (a core domain of human cognition (Spelke & Kinzler, 2007)

) According to Tom Griffiths, "core knowledge" is controversial: in the rest of the paper, we probably would want to be more careful with the phrase "core knowledge" because "core knowledge" is merely a hypothesis. Munakata1997rethinking gives opposing evidence that babies are inborn with an innate representation of objects. Nevertheless, object-based representations are still a good prior to have anyways, and we should emphasize that. One way Tom suggested we replace “core knowledge” is "young infants have surprisingly detailed knowledge of the world around them. Among this knowledge is a notion of intuitive physics." that serve as primitives of a compositional system. They allow humans to decompose a complex visual scene into distinct parts, describe relations between them and reason about their dynamics as well as the consequences of their interactions 

(Ullman et al., 2017; Battaglia et al., 2013; Lake et al., 2016).

The most successful machine learning approaches to common-sense physical reasoning incorporate such prior knowledge in their design. They maintain explicit object representations, which allow for general physical dynamics to be learned between object pairs in a compositional manner 

(Chang et al., 2016; Battaglia et al., 2016; Watters et al., 2017). However, in these approaches learning is supervised, as it relies on object-representations from external sources (e.g. a physics simulator) that are typically unavailable in real-world scenarios.

Neural approaches that learn to directly model motion or physical interactions in pixel space offer an alternative solution (Sutskever et al., 2009; Srivastava et al., 2015). However, while unsupervised, these methods suffer from a lack compositionality at the representational level of objects. This prevents such end-to-end neural approaches from efficiently learning functions that operate on multiple entities and generalize in a human-like way (c.f. Lake et al. (2016); Santoro et al. (2017); Battaglia et al. (2013), but see Perez et al. (2017)).

In this work we propose RNEM, a novel approach to common-sense physical reasoning that learns physical interactions between objects from raw visual images in a purely unsupervised fashion. At its core is NEM, a method that allows for the discovery of compositional object-representations, yet is unable to model interactions between objects. Therefore, we endow NEM with a relational mechanism inspired by previous work (Santoro et al., 2017; Chang et al., 2016; Battaglia et al., 2016), enabling it to factor interactions between object-pairs, learn efficiently, and generalize to visual scenes with a varying number of objects without re-training.

2 Method

Our goal is to learn common-sense physical reasoning in a purely unsupervised fashion directly from visual observations. We have argued that in order to solve this problem we need to exploit the compositional structure of a visual scene. Conventional unsupervised representation learning approaches (eg. VAEs Kingma & Welling (2013); GANs Goodfellow et al. (2014)

) learn a single distributed representation that

superimposes information about the input, without imposing any structure regarding objects or other low-level primitives. These monolithic representations can not factorize physical interactions between pairs of objects and therefore lack an essential inductive bias to learn these efficiently. Hence, we require an alternative approach that can discover objects representations as primitives of a visual scene in an unsupervised fashion.

One such approach is Neural Expectation Maximization (N-EM; Greff et al. (2017)), which learns a separate distributed representation for each object described in terms of the same features through an iterative process of perceptual grouping and representation learning. The compositional nature of these representations enable us to formulate Relational N-EM (R-NEM): a novel unsupervised approach to common-sense physical reasoning that combines NEM (Section 2.1) with an interaction function that models relations between objects efficiently (Section 2.2).

2.1 Neural Expectation Maximization

Neural Expectation Maximization (N-EM; Greff et al. (2017)) is a differentiable clustering method that learns a representation of a visual scene composed of primitive object representations. These representations adhere to many useful properties of a symbolic representation of objects, and can therefore be used as primitives of a compositional system (Hummel et al., 2004). They are described in the same format and each contain only information about the object in the visual scene that they correspond to. Together, they form a representation of a visual scene composed of objects that is learned in an unsupervised way, which therefore serves as a starting point for our approach.

The goal of NEM is to group pixels in the input that belong to the same object (perceptual grouping) and capture this information efficiently in a distributed representation for each object. At a high-level, the idea is that if we were to have access to the family of distributions (a statistical model of images given object representations ) then we can formalize our objective as inference in a mixture of these distributions. By using EM to compute a Maximum Likelihood Estimate (MLE) of the parameters of this mixture (), we obtain a grouping (clustering) of the pixels to each object (component) and their corresponding representation. In reality we do not have access to , which NEM learns instead by parameterizing the mixture with a neural network and back-propagating through the iterations of the unrolled generalized EM procedure.

Following Greff et al. (2017), we model each image as a spatial mixture of

components parameterized by vectors

. A neural network is used to transform these representations into parameters for separate pixel-wise distributions. A set of binary latent variables encodes the unknown true pixel assignments, such that iff pixel was generated by component . The full likelihood for given is given by:

Figure 1: Illustration of the different computational aspects of R-NEM when applied to a sequence of images of bouncing balls. Note that at the Representations level correspond to the (E-step), (Group Reconstructions) from the previous time-step. Different colors correspond to different cluster components (object representations).The right side shows a computational overview of , a function that computes the pair-wise interactions between the object representations.

If has learned a statistical model of images given object representations , then we can compute the object representations for a given image by maximizing . Marginalization over complicates this process, thus we use generalized EM to maximize the following lowerbound instead:


Each iteration of generalized EM consists of two steps: the E-step

computes a new estimate of the posterior probability distribution over the latent variables

given from the previous iteration. It yields a new soft-assignment of the pixels to the components (clusters), based on how accurately they model . The generalized M-step updates by taking a gradient ascent step on (2), using the previously computed soft-assignments: .111We can not compute analytically, due to non-linearity of .

The unrolled computational graph of the generalized EM steps is differentiable, which provides a means to train to implement a statistical model of images given object representations. Using back-propagation through time (eg. Werbos (1988); Williams (1989)) we train to minimize the following loss:


The intra-cluster term is identical to (2), which credits each component for accurately representing pixels that have been assigned to it. The inter-cluster term ensures that each representation only captures the information about the pixels that have been assigned to it.

A more powerful variant of NEM can be obtained (RNN-EM) by substituting the generalized M-step with a recurrent neural network having hidden state

. In this case, the entirety of consists of a recurrent encoder-decoder architecture that receives as input at each step.

The learning objective in (3) is prone to trivial solutions in case of overcapacity, which could prevent the network from modelling the statistical regularities in the data that correspond to objects. By adding noise to the input image or reducing in dimensionality we can guide learning to avert this. Moreover, in the case of RNN-EM one can evaluate (3) at the following time-step (predictive coding) to encourage learning of object representations and their corresponding dynamics. One intuitive interpretation of using denoising or next-step prediction as part of the training objective is to guide the network to learn about essential properties of objects, in this case those that correspond to the Gestalt Principles of prägnanz and common fate (Hatfield & Epstein, 1985).

2.2 Relational Neural Expectation Maximization

RNN-EM (unlike N-EM) is able to capture the dynamics of individual objects through a parametrized recurrent connection that operates on the object representation across consecutive time-steps. However, the relations and interactions that take place between objects can not be captured in this way. In order to overcome this shortcoming we propose Relational N-EM (R-NEM), which adds relational structure to the recurrence to model interactions between objects without violating key properties of the learned object representations.

Consider a generalized form of the standard RNN-EM dynamics equation, which computes the object representation at time as a function of all object representations at the previous time-step through an interaction function :


Here are weight matrices,

is the sigmoid activation function, and

is the input to the recurrent model at time (possibly transformed by an encoder). When , this dynamics model coincides with a standard RNN update rule, thereby recovering the original RNN-EM formulation.

The inductive bias incorporated in reflects the modeling assumptions about the interactions between objects in the environment, and therefore the nature of ’s interdependence.klaus: structure as opposed to inductive bias? If incorporates the assumption that no interaction takes place between objects, then the ’s are fully independent and we recover . On the other hand, if we do assume that interactions among objects take place, but assume very little about the structure of the interdependence between the ’s, then we forfeit useful properties of such as compositionality. For example, if we can no longer extrapolate learned knowledge to environments with more or fewer than objects and lose overall data efficiency (Santoro et al., 2017). Instead, we can make efficient use of compositionality among the learned object representations to incorporate general but guiding constraints on how these may influence one another (Battaglia et al., 2016; Chang et al., 2016). In doing so we constrain to capture interdependence between ’s in a compositional manner that enables physical dynamics to be learned efficiently, and allow for learned dynamics to be extrapolated to a variable number of objects.

We propose a parametrized interaction function that incorporates these modeling assumptions and updates based on the pairwise effects of the objects on :


where is the concatenation operator and

corresponds to a multi-layer perceptron. First, each

is transformed using to obtain , which enables information that is relevant for the object dynamics to be made more explicit in the representation. Next, each pair is concatenated and processed by , which computes a shared embedding that encodes the interaction between object and object . Notice that we opt for a clear separation between the focus object and the context object i as in previous work (Chang et al., 2016). From we compute : the effect of object on object ; and an attention coefficient that encodes whether interaction between object and object takes place. These attention coefficients (Bahdanau et al., 2014; Xu et al., 2015)

help to select relevant context objects, and can be seen as a more flexible unsupervised replacement of the distance based heuristic that was used in previous work 

(Chang et al., 2016). Finally, we compute the total effect of on as a weighted sum of the effects multiplied by their attention coefficient. A visual overview of can be seen on the right side of Figure 1.

3 Related Work

Machine learning approaches to common-sense physical reasoning can roughly be divided in two groups: symbolic approaches and approaches that perform state-to-state prediction. The former group performs inference over the parameters of a symbolic physics engine (Wu et al., 2015; Ullman et al., 2017; Battaglia et al., 2013), which restricts them to synthetic environments. The latter group employs machine learning methods to make state-to-state predictions, often describing the state of a system as a set of compact object-descriptions that are either used as an input to the system (Battaglia et al., 2016; Fragkiadaki et al., 2015; Chang et al., 2016; Grzeszczuk et al., 1998) or for training purposes (Watters et al., 2017). By incorporating information (eg. position, velocity) about objects these methods have achieved excellent generalization and simulation capabilities. Purely unsupervised approaches for state-to-state prediction (Sutskever et al., 2009; Michalski et al., 2014; Agrawal et al., 2016; Lerer et al., 2016)

that use raw visual inputs as state-descriptions have yet to rival these capabilities. Our method is a purely unsupervised state-to-state prediction method that operates in pixel space, taking a first step towards unsupervised learning of common-sense reasoning in real-world environments.

The proposed interaction function can be seen as a type of Message Passing Neural Network (MPNN; Gilmer et al. (2017)) that incorporates a variant of neighborhood attention (Duan et al., 2017). In light of other recent work (Zaheer et al., 2017) it can be seen as a permutation equivariant set function.

R-NEM relies on N-EM (Greff et al., 2017) to discover a compositional object representation from raw visual inputs. A closely related approach to N-EM is the TAG framework (Greff et al., 2016), which utilizes a similar mechanism to perform inference over group representations, but in addition performs inference over the group assignments. In recent work TAG was combined with a recurrent ladder network (Ilin et al., 2017) to obtain a powerful model (RTagger) that can be applied to sequential data. However, the lack of a single compact representation that captures all information about a group (object) makes a compositional treatment of physical interactions more difficult. Other unsupervised approaches rely on attention to group together parts of the visual scene corresponding to objects (Eslami et al., 2016; Gregor et al., 2015). These approaches suffer from a similar problem in that their sequential nature prevents a coherent object representation to take shape.

Other related work have also taken steps towards combining the learnability of neural networks with the compositionality of symbolic programs in modeling physics (Chang et al., 2016; Battaglia et al., 2016), playing games (Kansky et al., 2017; Denil et al., 2017), learning algorithms (Cai et al., 2017; Bošnjak et al., 2017; Reed & De Freitas, 2015; Li et al., 2016), visual understanding (Johnson et al., 2017; Ellis et al., 2017)

, and natural language processing 

(Andreas et al., 2016; Hu et al., 2017).

Figure 2: R-NEM applied to a sequence of bouncing balls. Each column corresponds to a time-step, which coincides with an EM step. At each time-step, R-NEM computes new representations according to (4) (see also Representations in Figure 1) from the input with added noise (bottom row). From each new a group reconstruction is produced (rows 2-6 from bottom) that predicts the state of the environment at the next time-step. Attention coefficients are visualized by overlaying a colored reconstruction of a context object on the white reconstruction of the focus object (see Attention in Section 4). Based on the prediction accuracy of , the E-step (see Figure 1) computes new soft-assignments (row 7 from bottom), visualized by coloring each pixel according to their distribution over components . Row 8 visualizes the total prediction by the network () and row 9 the ground-truth sequence at the next time-step.

4 Experiments

Figure 3: Performance of each method on the bouncing balls task. Each method was trained on a dataset with 4 balls, evaluated on a test set with balls (left), and on a test-set with 6-8 balls (middle). The losses are reported relative to the loss of a baseline for each dataset that always predicts the current frame. The ARI score (right) is used to evaluate the degree of compositionality that is achieved.

In this section we evaluate R-NEM on three different physical reasoning tasks that each vary in their dynamical and visual complexity: bouncing balls with variable mass, bouncing balls with an invisible curtain and the Arcade Learning Environment (Bellemare et al., 2013). We compare R-NEM to other unsupervised neural methods that do not incorporate any inductive biases reflecting real-world dynamics and show that these are indeed beneficial.222Code is available at

All experiments use ADAM (Kingma & Ba, 2014)

with default parameters, on 50K train + 10K validation + 10K test sequences and early stopping with a patience of 10 epochs. For each of

we used a unique single layer neural network with 250 rectified linear units. For we used a two-layer neural network: 100 tanh units followed by a single sigmoid unit. A detailed overview of the experimental setup can be found in Appendix A.

Bouncing Balls

We study the physical reasoning capabilities of R-NEM on the bouncing balls task, a standard environment to evaluate physical reasoning capabilities that exhibits low visual complexity and complex non-linear physical dynamics.333Videos are available at We train R-NEM on sequences of binary images over 30 time-steps that contain four bouncing balls with different masses corresponding to their radii. The balls are initialized with random initial positions, masses and velocities. Balls bounce elastically against each other and the image window.

Qualitative Evaluation

Figure 1 presents a qualitative evaluation of R-NEM on the bouncing balls task. After 10 time-steps it can be observed that the pixels that belong to each of the balls are grouped together and assigned to a unique component (with a saturated color); and that the background (colored grey) has been divided among all components (resulting in a grey coloring). This indicates that the representation from which each component produces the group reconstruction does indeed only contain information about a unique object, such that together the ’s yield a compositional object representation of the scene. The total reconstruction (that combines the group reconstructions and the soft-assignments) displays an accurate reconstruction of the input sequence at the next time-step, indicating that R-NEM has learned to model the dynamics of bouncing balls.


We compare the modelling capabilities of R-NEM to an RNN, LSTM (Hochreiter & Schmidhuber, 1997; Gers et al., 1999) and RNN-EM in terms of the Binomial Cross-Entropy (BCE) loss between the predicted image and the ground-truth image of the last frame,444Since the E-step in R-NEM and RNN-EM utilizes the ground-truth for reconstruction, we substitute it with a simple operator. The resulting loss serves as an upperbound to the true BCE loss. as well as the relational BCE that only takes into account objects that currently take part in collision. Unless specified we use .

On a test-set with sequences containing four balls we observe that R-NEM produces markedly lower losses when compared to all other methods (left plot in Figure 3). Moreover, in order to validate that each component captures only a single ball (and thus compositionality is achieved), we report the Adjusted Rand Index (ARI; Hubert & Arabie (1985)) score between the soft-assignments and the ground-truth assignment of pixels to objects. In the left column of the ARI plot (right side in Figure 3) we find that R-NEM achieves an ARI score of 0.8, meaning that in roughly of the cases each ball is modeled by a single component. This suggests that a compositional object representation is achieved for most of the sequences. Together these observations are in line with our qualitative evaluation and validate that incorporating real world priors is greatly beneficial (comparing to RNN, LSTM) and that enables interactions to be modelled more accurately compared to RNN-EM in terms of the relational BCE.

Similar to Greff et al. (2017) we find that further increasing the number of components during training (leaving additional groups empty) increases the quality of the grouping, see R-NEM in Figure 3. In addition we observe that the loss (in particular the relational BCE) is reduced further, which matches our hypothesis that compositional object representations are greatly beneficial for modelling physical interactions.

Figure 4: Left: Three sequences of 15 time-steps ground-truth (top), R-NEM (middle), RNN (bottom). The last ten time-steps of the sequences produced by R-NEM and RNN are simulated. Right: The BCE loss on the entire test-set for these same time-steps.

Extrapolating learned knowledge

We use a test-set with sequences containing 6-8 balls to evaluate the ability of each method to extrapolate their learned knowledge about physical interactions between four balls to environments with more balls. We use when evaluating R-NEM and RNN-EM on this test-set in order to accommodate the increased number of objects. As can be seen from the middle plot in Figure 3, R-NEM again greatly outperforms all other methods. Notice that, since we report the loss relative to a baseline, we roughly factor out the increased complexity of the task. Perfect extrapolation of the learned knowledge would therefore amount to no change in relative performance. In contrast, we observe far worse performance for the LSTM (relative to the baseline) when evaluated on this dataset with extra balls. It suggests that the gating mechanism of the LSTM has allowed it to learn a sophisticated and overly specialized solution for sequences with four balls that does not generalize to a dataset with 6-8 balls.

R-NEM and RNN-EM scale markedly better to this dataset than LSTM. Although the RNN similarly suffers to a lesser extend from this type of “overfitting”, this is most likely due its inability to learn a reasonable solution on sequences of four balls to begin with. Hence, we conclude that the superior extrapolation capabilities of RNN-EM and R-NEM are inherent to their ability to factor a scene in terms of permutation invariant object representations (see right side of the right plot in Figure 3).


Further insight in the role of the attention mechanism can be gained by visualizing the attention coefficients, as is done in Figure 2. For each component we draw on top of the reconstruction , colored according to the color of component . These correspond to the colored balls (that are for example seen in time-steps 13, 14), which indicate whether component took information about component into account when computing the new state (recall (5)). It can be observed that the attention coefficient becomes non-zero whenever collision takes place, such that a colored ball lights up in the following time-steps. The attention mechanism learned by R-NEM thus assumes the role of the distance-based heuristic in previous work (Chang et al., 2016), matching our own intuitions of how this mechanism would best be utilized.

A quantitative evaluation of the attention mechanism is obtained by comparing R-NEM to a variant of itself that does not incorporate attention (R-NEM no att). Figure 3 shows that both methods perform equally well on the regular test set (4 balls), but that R-NEM no att performs worse at extrapolating from its learned knowledge (6-8 balls). A likely reason for this behavior is that the range of the sum in (5) changes with . Thus, when extrapolating to an environment with more balls the total sum may exceed previous boundaries and impede learned dynamics.


Once a scene has been accurately modelled, R-NEM can approximately simulate its dynamics through recursive application of (4) for each .555Note that in this case the input to the neural network encoder in component corresponds to , such that the output of the encoder when . In Figure 4 we compare the simulation capabilities of R-NEM to RNN-EM and an RNN on the bouncing balls environment.footnote 3 On the left it shows for R-NEM and an RNN a sequence with five normal steps followed by 10 simulation steps, as well as the ground-truth sequence. From the last frame in the sequence it can clearly be observed that RNEM has managed to accurately simulate the environment. Each ball is approximately in the correct place, and the shape of each ball is preserved. The balls simulated by the RNN, on the other hand, deviate substantially from their ground-truth position and their size has increased. In general we find that RNEM produces mostly very accurate simulations, whereas the RNN consistently fails. Interestingly we found that the cases in which RNEM frequently fails are those for which a single component models more than one ball. The right side of Figure 4 summarizes the BCE loss for these same time-steps across the entire test-set. Although this is a crude measure of simulation performance (since it does not take into account the identity of the balls), we still observe that R-NEM consistently outperforms RNN-EM and an RNN.

Figure 5: R-NEM applied to a sequence of bouncing balls with an invisible curtain. The ground truth sequence is displayed in the top row, followed by the prediction of R-NEM (middle) and the soft-assignments of pixels to components (bottom). R-NEM models objects, as well as its interactions, even when the object is completely occluded (step 36). Only a subset of the steps is shown.

Hidden Factors

Occlusion is abundant in the real world, and the ability to handle hidden factors is crucial for any physical reasoning system. We therefore evaluate the capability of RNEM to handle occlusion using a variant of bouncing balls that contain an invisible “curtain.” Figure 5 shows that RNEM accurately models the sequence and can maintain object states, even when confronted with occlusion.footnote 3 For example, note that in step 36 the “blue” ball, is completely occluded and is about to collide with the “orange” ball. In step 38 the ball is accurately predicted to re-appear at the bottom of the curtain (since collision took place) as opposed to the left side of the curtain. This demonstrates that RNEM has a notion of object permanence and implies that it understands a scene on a level beyond pixels: it assigns persistence and identity to the objects.

In terms of test-set performance we find that R-NEM (BCE: , relational BCE: ) outperforms an RNN (BCE: , relational BCE: ) and an LSTM (BCE: , relational BCE: ).

Space Invaders

To test the performance of R-NEM in a visually more challenging environment, we train it on sequences of binarized images over 25 time-steps of game-play on Space Invaders from the Arcade Learning Environment (Bellemare et al., 2013).666Binarization ensures that the color group of the entities on the screen does not give away their grouping. We use and also feed the action of the agent to the interaction function. Figure 6 confirms that R-NEM is able to accurately model the environment, even though the visual complexity has increased. Notice that these visual scenes comprise a large numbers of (small) primitive objects that behave similarly. Since we trained R-NEM with four components it is unable to group pixels according to individual objects and is forced to consider a different grouping. We find that R-NEM assigns different groups to every other column of aliens together with the spaceship, and to the three large “shields.” These groupings seem to be based on movement, which to some degree coincides with their semantic roles of the environment. In other examples (not shown) we also found that R-NEM frequently assigns different groups to every other column of the aliens, and to the three large “shields.” Individual bullets and the space ship are less frequently grouped separately, which may have to do with the action-noise of the environment (that controls the movement of the space-ship) and the small size of the bullets at the current resolution that makes them less predictable.

Figure 6: R-NEM accurately models a sequence of frames obtained by an agent playing Space Invaders. A group no longer corresponds to an object, but instead assumes the role of high-level entities that engage in similar movement patterns.

5 Discussion and Conclusion

We have argued that the ability to discover and describe a scene in terms of objects provides an essential ingredient for common-sense physical reasoning. This is supported by converging evidence from cognitive science and developmental psychology that intuitive physics and reasoning capabilities are built upon the ability to perceive objects and their interactions (Ullman et al., 2017; Spelke, 1988). The fact that young infants already exhibit this ability, may even suggest an innate bias towards compositionality (Spelke & Kinzler, 2007; Lake et al., 2016; Munakata et al., 1997). Inspired by these observations we have proposed R-NEM, a method that incorporates inductive biases about the existence of objects and interactions, implemented by its clustering objective and interaction function respectively. The specific nature of the objects, and their dynamics and interactions can then be learned efficiently purely from visual observations.

In our experiments we find that R-NEM indeed captures the (physical) dynamics of various environments more accurately than other methods, and that it exhibits improved generalization to environments with different numbers of objects. It can be used as an approximate simulator of the environment, and to predict movement and collisions of objects, even when they are completely occluded. This demonstrates a notion of object permanence and aligns with evidence that young infants seem to infer that occluded objects move in connected paths and continue to maintain object-specific properties (Spelke, 1990). Moreover, young infants also appear to expect that objects only interact when they come into contact (Spelke, 1990), which is analogous to the behaviour of R-NEM to only attend to other objects when a collision is imminent. In summary, we believe that our method presents an important step towards learning a more human-like model of the world in a completely unsupervised fashion.

Current limitations of our approach revolve around grouping and prediction. What aspects of a scene humans group together typically varies as a function of the task in mind. One may perceive a stack of chairs as a whole if the goal is to move them to another room, or as individual chairs if the goal is to count the number of chairs in the stack. In order to facilitate this dynamic grouping one would need to incorporate top-down feedback from an agent into the grouping procedure to deviate from the built-in inductive biases. Another limitation of our approach is the need to incentivize R-NEM to produce useful groupings by injecting noise, or reducing capacity. The former may prevent very small regularities in the input from being detected. Finally the interaction in the E-step among the groups makes it difficult to increase the number of components above ten without causing harmful training instabilities. Due to the multitude of interactions and objectives in R-NEM (and RNN-EM) we find that they are sometimes challenging to train.

In terms of prediction we have implicitly assumed that objects in the environment behave according to rules that can be inferred. This poses a challenge when objects deform in a manner that is difficult to predict (as is the case for objects in Space Invaders due to downsampling). However in practice we find that (once pixels have been grouped together) the masking of the input helps each component in quickly adapting its representation to any unforeseen behaviour across consecutive time steps. Perhaps a more severe limitation of R-NEM (and of RNN-EM in general) is that the second loss term of the outer training objective hinders in modelling more complex varying backgrounds, as the background group would have to predict the “pixel prior” for every other group.

We argue that the ability to engage in common-sense physical reasoning benefits any intelligent agent that needs to operate in a physical environment, which provides exciting future research opportunities. In future work we intend to investigate how top-down feedback from an agent could be incorporated in R-NEM to facilitate dynamic groupings, but also how the compositional representations produced by R-NEM can benefit a reinforcement learner, for example to learn a modular policy that easily generalizes to novel combinations of known objects. Other interactions between a controller C and a model of the world M (implemented by R-NEM) as posed in Schmidhuber (2015) constitute further research directions.


The authors wish to thank Tom Griffiths and the anonymous reviewers for helpful comments and constructive feedback. This research was supported by the Swiss National Science Foundation grant 200021_165675/1, the EU project “INPUT” (H2020-ICT-2015 grant no. 687795), and the Zeno Karl Schindler Foundation Summerschool Grant. Chang would like to thank Christiane Born, Sarah Craver, Cinzia Daldini, and the MIT MISTI Program for supporting his stay in Switzerland. We are grateful to NVIDIA Corporation for donating us a DGX-1 as part of the Pioneers of AI Research award, and to IBM for donating a “Minsky” machine.


Appendix A Experiment Details

In all experiments we train the networks using ADAM (Kingma & Ba, 2014) with default parameters, a batch size of 64 and train + validation + test inputs. The quality of the learned groupings is evaluated by computing the Adjusted Rand Index (ARI; Hubert & Arabie (1985)) with respect to the ground truth, while ignoring the background and overlap regions (as is consistent with earlier work (Greff et al., 2017)). We use early stopping when the validation loss has not improved for 10 epochs.

a.1 Bouncing Balls

The bouncing balls data is similar to previous work (Sutskever et al., 2009) with a few modifications. The data consists of sequences of binary images over 30 time-steps and balls are randomly sampled from two types: one ball is six times heavier and 1.25 times larger in radius than the other. The balls are initialized with random initial positions and velocities. Balls bounce elastically against each other and the image window.

As in previous work (Greff et al., 2017) we use a convolutional encoder-decoder architecture with a recurrent neural network as bottleneck, that is updated according to (4):

  1. conv. 16 ELU. stride 2. layer norm

  2. conv. 32 ELU. stride 2. layer norm

  3. conv. 64 ELU. stride 2. layer norm

  4. fully connected. 512 ELU. layer norm

  5. recurrent. 250 Sigmoid. layer norm on the output

  6. fully connected. 512 RELU. layer norm

  7. fully connected. RELU. layer norm

  8. reshape 2 nearest-neighbour, conv. 32 RELU. layer norm

  9. reshape 2 nearest-neighbour, conv. 16 RELU. layer norm

  10. reshape 2 nearest-neighbour, conv. 1 Sigmoid

Instead of using transposed convolutions (to implement the "de-convolution") we first reshape the image using the default nearest-neighbour interpolation followed by a normal convolution in order to avoid frequency artifacts 

(Odena et al., 2016). Note that we do not add layer norm on the recurrent connection.

At each timestep we feed as input to the network, where is the input with added bitflip noise (). Consistent with earlier work (Greff et al., 2017)

R-NEM is trained with a next-step prediction objective, the prior for each pixel in the data is set to a Bernoulli distribution with

, and we prevent conflicting gradient updates by not back-propagating any gradients through .

The Interaction Function network is structured as follows:

  • : fully connected. 250 RELU. layer norm

  • : fully connected. 250 RELU. layer norm

  • : fully connected. 250 RELU. layer norm

  • : fully connected. 100 Tanh. layer norm - fully connected. 1 Sigmoid.

We experimented with deeper architectures, but were unable to observe significant improvement.

Comparison and Extrapolation

In the comparison experiment both R-NEM and RNN-EM are trained with (unless otherwise mentioned), following insights from Greff et al. (2017). On the extrapolation task we adjusted the number of components at test time to .

When comparing to RNN-EM we used . For comparing to RNN we set , and used , yielding a standard recurrent autoencoder that receives at each time-step the difference between the prediction and the noisy ground-truth as input. In case of LSTM, we additionally replace the recurrent layer with an LSTM update. The R-NEM no att model is the same as R-NEM, without , such that


Since the E-step relies on the ground-truth, which was not available for simulation, we used a thresholded version of at 0.1 (such that everything below becomes 0 and everything above becomes 1) as a replacement in stead.


On the occlusion dataset we used three balls with equal mass. The curtain was spawned at a random location for each sequence. We trained R-NEM with .

a.2 Space Invaders

We used a pre-trained DQN to produce a dataset with sequences of 25 time-steps. The DQN receives a stack of four frames as input and we recorded every first frame of this stack. These frames were first pre-processed as in Mnih et al. (2013) and then thresholded at to obtain binary images.

Since the images are we used a different encoder and decoder, given by:

  1. conv. 16 ELU. stride 2. layer norm

  2. conv. 32 ELU. stride 2. layer norm

  3. conv. 32 ELU. stride 2. layer norm

  4. conv. 32 ELU. stride 2. layer norm

  5. fully connected. 512 ELU. layer norm

  6. recurrent. 250 Sigmoid. layer norm on the output

  7. fully connected. 512 RELU. layer norm

  8. fully connected. RELU. layer norm

  9. reshape 2 nearest-neighbour, conv. 32 RELU. layer norm

  10. reshape 2 nearest-neighbour, conv. 32 RELU. layer norm

  11. reshape 2 nearest-neighbour, conv. 16 RELU. layer norm

  12. reshape 2 nearest-neighbour, conv. 1 Sigmoid

We used the same architecture for , with the only difference that at each time-step we concatenated an embedding of the action produced by the agent to the hidden state. Here we used a single layer MLP with 10 units and a ReLU activation function to compute this embedding.

In the Atari experiment we trained with and reduced the input noise to 0.02, in order to preserve tiny elements such as bullets (that only occupy 1-2 pixels).