Numerous studies (xu1996infants; kellman1983perception; spelke1995spatiotemporal; baillargeon1985object; saxe2006perception) show that infants quickly develop an understanding of intuitive physics, objects and relations in an unsupervised manner. To facilitate the solution of real-world problems, intelligent agents should be able to acquire such knowledge van2019perspective
. However, artificial neural networks are still far from human-level understanding of intuitive physics.
Existing approaches to unsupervised learning about objects and relations from visual data can be categorized into either parallel greff2016tagger; greff2017neural; greff2019multi or sequential schmidhuber91artificial; schmidhuber93ratioicann; eslami2016attend; kosiorek2018sequential; crawford2019spatially; burgess2019monet; yuan2019generative, depending on the core mechanism responsible for inferring object representations from a single image. One model from the former group is Tagger greff2016tagger which applies the Ladder Network rasmus2015semi to perform perceptual grouping. RTagger premont2017recurrent replaces the Ladder Network by a Recurrent Ladder Network, thus extending Tagger to sequential settings. NEM greff2017neural learns object representations using a spatial mixture model and its relational version R-NEM van2018relational endows it with a parallel relational mechanism. The recently proposed IODINE greff2019multi iteratively refines inferred objects and handles multi-modal inputs.
On the other hand, the sequential attention model AIR eslami2016attend learns to infer one object per iteration over a given image. Contrary to NEM, it extracts object glimpses through a hard attention mechanism schmidhuber91artificial and processes only the corresponding glimpse. Furthermore, it builds a probabilistic representation of the scene to model uncertainty. Many recent models have AIR as the core mechanism: SQAIR kosiorek2018sequential extends AIR to sequential settings, similarly does DDPAE hsieh2018learning. SPAIR crawford2019spatially scales AIR to scenarios with many objects and SuPAIRstelzner2019faster improves speed and robustness of learning in AIR. The recent MoNET burgess2019monet
also uses a VAE and a recurrent neural network (RNN) to decompose scenes into multiple objects. These methods usually model relations by a sequential relational mechanism such as an RNN which limits their relational reasoning capabilitiesbattaglia2018relational.
Here we present Relational Sequential Attend, Infer, Repeat (R-SQAIR) to learn a generative model of intuitive physics from video data. R-SQAIR builds on SQAIR which we augment by a mechanism that has a strong relational inductive bias (battaglia2016interaction; van2018relational; santoro2018relational). Our explicit parallel model of pairwise relations between objects is conceptually simpler than a sequential RNN-based model that keeps previous interactions in its memory and cannot directly model the effects of interactions of previously considered objects. Our experiments demonstrate improved generalization performance of trained models in new environments.
2 Relational Sequential Attend Infer Repeat
Attend, Infer, Repeat (AIR) (eslami2016attend) is a generative model that explicitly reasons about objects in a scene. It frames the problem of representing the scene as probabilistic inference in a structured VAE. At the core of the model is an RNN that processes objects one at a time and infers latent variables , where is the number of objects. The continuous latent variable encodes the appearance of the object in the scene and encodes the coordinates according to which the object glimpse is scaled and shifted by a Spatial Transformer (jaderberg2015spatial). Given an image , the generative model of AIR is defined as follows: p_θ() = ∑_n=1^N p_θ(n) ∫p_θ^z (| n) p_θ^x(x|) d, where represents the number of objects present in the scene, captures the prior assumptions about the underlying object and defines how it is rendered in the image. In general, the inference for section 2 is intractable, so (eslami2016attend) employs amortized variational inference using a sequential algorithm, where an RNN is run for steps to infer latent representation of one object at a time. The variational posterior is then: ϕ = z_pres^n+1 = 0^1:n, ϕ ∏_i=1^n ^i, z_pres^i = 1^1:i-1, ϕ, whereand .
Relational Sequential Attend, Infer, Repeat (R-SQAIR) augments SQAIR through a parallel relational mechanism. SQAIR extends AIR to the sequential setting by leveraging the temporal consistency of objects using a state-space model. It has two phases: Discovery (DISC) and Propagation (PROP). PROP is active from the second frame in the sequence, propagating or forgetting objects from the previous frame. It does so by combining an RNN, which learns temporal dynamics of each object, with the AIR core which iterates over previously propagated objects (explaining away phenomena). DISC phase uses the AIR core, conditioned on propagated objects, to discover new appearances of objects. For a full description of AIR and SQAIR we refer to previous work (eslami2016attend; kosiorek2018sequential).
R-SQAIR retains the strengths of its predecessors and improves their relational capabilities. More specifically, SQAIR relies on AIR’s core RNN to model the relations. However, an RNN has only a weak relational inductive bias battaglia2018relational, as it needs to compute pairwise interactions between objects sequentially, iterating over them in a specific order. R-SQAIR, on the other hand, employs networks with strong relational inductive bias which can model arbitrary relations between objects in parallel. To construct conceptually simple yet powerful architectures that support combinatorial generalization, we use the following two methods: Interaction Network (IN) (van2018relational) and Relational Memory Core (RMC) (santoro2018relational).
The R-SQAIR generative model is built by extending the PROP module of SQAIR to include relations , where is the relational module and are object representations from the previous timestep, defined as follows:
The discovery prior samples latent variables for new objects that enter the frame, by conditioning on propagated variables . The propagation prior samples latent variables for objects that are propagated from the previous frame and removes those that disappear. Both priors are learned during training. We recover the original SQAIR model for . The inference model is therefore:
where are hidden states of the temporal and AIR core RNNs. Discovery is essentially the posterior of AIR. Again, the difference to SQAIR lies in the propagation module , which receives relations as the input.
Our first relational module is the Interaction Network (IN) of R-NEM van2018relational, depicted in Figure 2, which is closely related to Interaction Networks battaglia2016interaction; watters2017visual. Here, the effect on object of all other objects is computed by the relational module , which in the case of IN is defined as follows (for simplicity we drop time indices):
where from the previous time step. First, each object is transformed using an MLP to obtain , which is equivalent to a node embedding operation in a graph neural network. Then each pair is processed by another MLP , which corresponds to a node-to-edge operation by encoding the interaction between object and object in the embedding . Note that the computed embedding is directional. Finally, an edge-to-node operation is performed, where the effect on object is computed by summing the individual effects of all other objects on the particular object . Note that the sum is weighted by an attention coefficient , which allows each individual object to consider only particular interactions. This technique also yields better combinatorial generalization to a higher number of objects, as it controls the magnitude of the sum.
Relational Memory Core
We compare the effects modeled by IN to the effects learned by a Relational Memory Core (RMC), . RMC (Figure 2) learns to compartmentalize objects into memory slots, and can keep the state of an object and combine this information with the current object’s representation . This is achieved by borrowing ideas from memory-augmented networks (sukhbaatar2015end; graves2014neural; graves2016hybrid) and interpreting memory slots as object representations. The interactions between objects are then computed by a multi-head self-attention mechanism (vaswani2017attention) Finally, recurrence for the sequential interactions is introduced, resulting in an architecture that is akin to a 2-dimensional LSTM(hochreiter1997long), where rows of the memory matrix represent objects. The model parameters are shared for each object, so the number of memory slots can be changed without affecting the total number of model parameters. For a full description, we refer to previous work (santoro2018relational).
We analyze the physical reasoning capabilities of R-SQAIR on the bouncing balls dataset, which consists of video sequences of 64x64 images. As done in SQAIR experiments, we crop the central 50x50 pixels from the image, such that a ball can disappear and later re-appear. Although visually simple, this dataset contains highly complex physical dynamics and has been previously used for similar studies (R-NEM (van2018relational)). The method is trained in SQAIR-like fashion by maximizing the importance-weighted evidence lower-bound IWAE (burda2015importance), with particles and the batch size of
. Curriculum learning starts at sequence length 3 which is increased by one every 10000 iterations, up to a maximum length of 10. Early stopping is performed when the validation score has not improved for 10 epochs.
Qualitative evaluation of R-SQAIR is present in Figure 3. Each column represents one time step in the video. The first row is about the R-SQAIR model trained and evaluated on videos with 4 balls, with object representations highlighted by different color bounding boxes. In the second row the same model is evaluated on datasets with 6-8 balls. Note that R-SQAIR disentangles objects already in the first few frames and later only refines the learned representations. At each time step, it computes up to object representations, by considering objects from the previous frame and the learned dynamics.
For all SQAIR hyperparameters we use default values, except for the dimensionality of latent variable, which is set to 5 instead of 50. This reflects the low visual complexity of individual objects in the scene. For similar reasons, the embedding dimensionality of IN we use is also set to 5. We use a version of the IN module with attention coefficients to compute the weighted sum of the effects. In total, this adds 9 389 parameters to the 2 726 166 of the default SQAIR implementation. It also suggests that improved performance is a result of learning a better propagation prior instead of just increasing the number of model parameters.
RMC has more hyperparameters to choose from. We use self-attention with 4 heads, each of dimensionality 10. The number of memory slots is 4 and coincides with the total number of sequential attention steps we perform. Finally, RMC can perform several computations of attention per time step, where each corresponds to one message passing phase. As we are interested only in collisions, we compute attention only once per time step. This results in 98 880 parameters. Comparing the size of the SQAIR model, we obtain a conclusion similar to the one for the case of IN.
Note that the last frames in Figure 3 are sampled from the learned propagation prior. This enables us to evaluate the role of the relational module, as it is responsible for learning the object dynamics. Moreover, as the models are stochastic, we train 5 models for each architecture and sample 5 different last frames. We compare models in terms of data log-likelihood and relational log-likelihood, which takes into account only the objects which are currently colliding (ground truth available in the dataset). The evaluation on the test set with 4 balls shows an increase in average data log-likelihood from 399.5 achieved by SQAIR (0.21 relational) to 429.2 by R-SQAIR(IN) (relational 1.95) and 457.32 by R-SQAIR(RMC) (relational 3.62). Error bars in Figure 4 represent the standard deviation of the stochastic samples from the trained models.
We test generalization of R-SQAIR by evaluating the models trained on sequences with 4 balls on a test set with videos of 6-8 balls. Both qualitative (Figure 3 bottom row) and quantitative results show that R-SQAIR is capable of generalizing, with increase in relational log-likelihood from -164.1 achieved by SQAIR to -96.7 achieved by R-SQAIR(IN) and -97 achieved by R-SQAIR(RMC). Larger margins between relational losses of R-SQAIR and SQAIR on the test set with 6-8 balls suggest higher generalization capabilities of R-SQAIR.
Graph neural networks are promising candidates for combinatorial generalization, a central theme of AI research (battaglia2018relational; van2019perspective). We show that a sequential attention model can benefit from incorporating an explicit relational module which infers pairwise object interactions in parallel. Without retraining, the model generalizes to scenarios with more objects. Its learned generative model is potentially useful as part of a world simulator (schmidhuber90sandiego; schmidhuber90diffgenau; ha2018recurrent; watters2019cobra).
We would like to thank Adam R. Kosiorek, Hyunjik Kim, Ingmar Posner and Yee Whye Teh for making the codebase for the SQAIR model kosiorek2018sequential publicly available. This work was made possible by their commitment to open research practices. We thank Sjoerd van Steenkiste for helpful comments and fruitful discussions. This research was supported by the Swiss National Science Foundation grant 407540_167278 EVAC - Employing Video Analytics for Crisis Management. We are grateful to NVIDIA Corporation for a DGX-1 as part of the Pioneers of AI Research award, and to IBM for donating a “Minsky” machine.