Symmetry has proved to be a powerful inductive bias for improving generalization in supervised and unsupervised learning. A symmetry group defines equivalence classes of inputs and outputs in terms of transformations that can be performed on the input along with corresponding transformations for the output. Recent years have seen many proposed equivariant models that incorporate symmetries into deep neural networks(cohen2016group; cohen2016steerable; cohen2019gauge; weiler2019e2cnn; weiler2018learning; kondor2018generalization; bao2019equivariant; worrall2017harmonic). This results in models that are often more parameter efficient, more sample efficient, and safer to use by behaving more consistently in new environments.
However, the applicability of equivariant models is impeded in that it is not always obvious how a symmetry group acts on input data. For example, consider the pairs of images in Figure 1
. On the left, we have MNIST digits where a 2D rotation in pixel space induces a corresponding rotation in feature space. Here an-equivariant network achieves state of the art accuracy (weiler2019e2cnn). In contrast, exploiting the underlying symmetry is challenging for the images on the right, which are of the same object in two orientations. While there is also an underlying symmetry group of rotations, it is not easy to characterize the transformation in pixel space associated with a particular rotation.
In this paper, we consider the task of learning symmetric representations of data in domains where transformations cannot be hard-coded, i.e. the group action is unknown. We train a network that maps an input space, for which the group action is difficult to characterize, onto a latent space, where the action is known. We refer to this network as a symmetric embedding network (SEN). Our goal is to learn a SEN that is equivariant: for any pair of inputs related by a transformation in input space, the outputs should be related by a corresponding transformation in feature space.
Learning the group action from data requires supervision or inductive biases. In certain domains we can learn an SEN in a supervised manner by pairing it with an equivariant classifier. We demonstrate the feasibility of this approach in Section3
. However, our main interest is learning SENs in domains where direct supervision is not available. As a concrete instantiation of this setting, we focus on world models, i.e. models that encode the effects of actions in the state space of a Markov decision process (MDP). We propose a meta-architecture that pairs an SEN with an equivariant transition network, which are trained jointly by minimizing a contrastive loss. The intuition in this approach is that the symmetry group of the transition model can help induce an approximately equivariant embedding network.
The idea of training an SEN in the form of a standard network with an equivariant task network has, to our knowledge, not previously been proposed or demonstrated. To test this idea, we consider 5 domains with 3 different symmetry groups, and 3 different equivariant architectures (see Table 1 for more details). While not the main contribution of this paper, these domains do require innovations in architecture design. Most notably, we combine message passing neural networks (MPNNs) with -convolutions in domains with multiple objects, resulting in a novel architecture which has been concurrently proposed in (brandstetter2021geometric). However, our primary contribution is to demonstrate that SENs can extend the applicability of equivariant networks to new domains with unknown group actions.
We summarize our contributions as follows:
We propose SENs that map from input data, for which symmetries are difficult to characterize, to a feature space with a known symmetry. This network implicitly learns the group action in the input space.
We show proof-of-concept results for supervised learning of SENs on sequence labeling.
Using world models as a test case, we show SENs can be trained end-to-end by minimizing a contrastive loss. We develop 5 domains with 3 different symmetry groups using architectures that are representative of (and improve upon) the state of the art in equivariant deep learning.
Our experiments show that world models with SENs can make equivariant architectures applicable to previously inaccessible domains and can yield improvements in accuracy and generalization performance.
2 Related Work
Equivariant Neural Networks
A multitude of equivariant neural networks have been devised to impose symmetry with respect to various groups across a variety of data types. These require that the group is known and the the group action on input, output, and hidden spaces is explicitly constructed. Examples include -convolution (cohen2016group), -steerable convolution (cohen2016steerable; weiler2019e2cnn)
, tensor product and Clebsch-Gordon decomposition(thomas2018tensor), or convolution in the Fourier domain (esteves2017polar). These models have been applied to gridded data (weiler2019e2cnn), spherical data (cohen2018spherical), point clouds (dym2020universality), and sets (maron2020learning). They have found applications in many domains including molecular dynamics (anderson2019cormorant), particle physics (bogatskiy2020lorentz), and trajectory prediction (walters2020trajectory). In particular, ravindran2004algebraic consider symmetry in the context of Markov Decision Processes (MDPs) and (van2020mdp) construct equivariant policy networks for policy learning. Our work also considers MDPs with symmetry but focuses on learning equivariant world models (see Appendix B).
Our work occupies a middle ground between equivariant neural networks with known group actions and symmetry discovery models. Symmetry discovery methods attempt to learn both the group and actions from data. For example, zhou2020meta learn equivariance by learning a parameter sharing scheme using meta-learning. dehmamy2021automatic similarly learn a basis for a Lie algebra generating a symmetry group while simultaneously learning parameters for the equivariant convolution over this symmetry group. benton2020learning propose an adaptive data augmentation scheme, where they learn which group of spatial transformations best supports data augmentation.
higgins2018definition define disentangled representations based on symmetry, with latent factors considered disentangled if they are independently transformed by commuting subgroups. Within this definition, quessard learn the underlying symmetry group by interacting with the environment, where the action space is a group of symmetry transformations. Except for the 3D Teapot domain, we handle the more general case where the action space may be different from the symmetry group. Their latent transition is given by multiplication with a group element, whereas our transition model may be an equivariant neural network.
Structured Latent World Models
World models learn state representations by ignoring unnecessary information unrelated to predicting environment dynamics. Such models are frequently used for high-dimensional image inputs, and usually employ (1) reconstruction loss (ha2018world; NIPS2015_a1afc58c; hafner2019learning; hafner2020mastering) or (2) constrastive loss. Minimizing the contrastive loss is known to be less computationally costly and can produce good representations from high-dimensional inputs (oord2018representation; anand2019unsupervised; chen2020simple; srinivas2020curl; van2020plannable), thus we use it here. We take inspiration from kipf20, who learn object-factored representations for structured world modeling with GNNs, which respect permutation symmetry.
3 Illustrative Example: Sequence Labeling
Our goal is to use an equivariant task network as an inductive bias for learning an SEN. While the SEN is itself not equivariant by construction, we may be able to learn an equivariant SEN by training both networks end-to-end. To validate this approach, we first consider a simple supervised learning problem in the form of a sequence labeling task.
In this simulated task, we detect local maxima in time series. Our training data (Figure 11) comprises sine waves with points, where each time series is shifted using a random offset ,
For each time point , there is a label indicating whether the point is a local maximum.
This domain has clear translational equivariance with known action: shifting inputs by time points result in shifting predictions by time points as well. For this reason, 1D convolutions are commonly used in sequence labeling (pmlr-v32-santos14; ma-hovy-2016-end).
To test whether we can learn the group action, we compose a fully-connected (FC) layer, which acts as our non-equivariant SEN, with two translation-equivariant 1D convolutional layers. We compare this network against a non-equivariant network with three FC layers. Both networks use ReLU activations and one kernel for both convolutional layers for more interpretable visualizations.
Figure 2 shows the weights for the first FC layer in both networks after end-to-end supervised training. We see that the learned weights in the FC+Conv model exhibit an approximate circulant structure (Fig. 2a), i.e. each column is shifted with respect to the preceding column. This is in excellent agreement with the idealized form that we would expect for a perfectly-equivariant layer (Fig. 2c). By contrast, the weights in the non-equivariant model (Fig. 2b) do not exhibit the same structure.
4 Symmetric Embeddings for World Models
The supervised learning results on toy data are encouraging: an equivariant task network can indeed induce an approximately equivariant SEN, which implicitly learns the group action in the input space. To demonstrate the feasibility of learning SENs in more challenging domains we consider world models. These models are an excellent use case for equivariant representation learning. Interactions with the physical world often exhibit symmetries, such as permutation equivariance (when interacting with multiple objects), or rotational and translational symmetries (when interacting with individual objects). Incorporating these symmetries can aid generalization across the combinatorial explosion of possible object arrangements, which grows exponentially with the number of objects in a scene.
World models are also a good test bed for learning SENs in that they allow us to control the difficulty of the learning problem. In an equivariant world model there are three interrelated notions of “action”: (1) the MDP action in the world model, (2) the learned action of the symmetry group on the input space, and (3) the known action of the symmetry group in the latent space. In certain domains there is a direct correspondence between these notions, such as when MDP actions perform rotations on a single object. In other domains the correspondence will be more indirect, such as when MDP actions apply forces to joints in a robot arm. The MDP action in a world model hereby provides a form of “distant” supervision that either directly or indirectly relates to the underlying symmetry.
To establish notation, we first define the difference between an abstract symmetry group with known action and one with unknown action. We then define a meta-architecture for contrastive training of SENs and equivariant world models, and discuss implementations of this architecture for specific domains with different underlying symmetries.
A group is a set with a binary operation that satisfies associativity, , existence of an identity, , and existence of an inverse, . An action of a symmetry group associates a transformation with each . We define an action on a set as a map that is compatible with composition of group elements, which is to say that .
is a vector spaceand the map on is linear, then we say that is a group representation. This representation associates an matrix with each , which we denote . The same group may have different actions on different sets. For example, the cyclic group has a simple representation by -rotation matrices on vectors in but a more complicated action by matrices on images in MNIST.
Equivariant Networks and Equivariance Learning
Given a group with representations and acting on and , we say that a function is equivariant if, for all ,
This means that group transformations commute with function application; transforming the input before application of is equivalent to transforming the output after application. The mapping thus preserves the symmetry group but alters the way in which the group acts.
Equivariant neural networks define parametric families of equivariant functions by composing layers that are individually equivariant with respect to the same group. To ensure that equivariance can be satisfied by construction for any choice of network weights, these networks require that we explicitly know both and . An example would be classification of rotated MNIST digits, as in Figure 1.
In this paper, we are interested in cases where we have a known output action , but the input action is not known, as with the images of rotated cars in Figure 1. In this setting, we are interested in learning equivariance using an unconstrained network , which we refer to as a symmetric embedding network (SEN). Given a triple such that , this network should satisfy
In other words, our goal is to learn a network that is not equivariant by construction, but is as equivariant as possible.
Equivariant World Models
World models define a transition function on a state space and action space , which outputs the next state given the current state and action . In an equivariant world model, we assume a symmetry group which jointly transforms states and actions by representations and respectively. For example, in the 2D shapes domain shown in Table 1, rotation by moves the blocks and permutes the actions . The transition function is equivariant with respect to in the sense that
As with other equivariant approaches, recent work on equivariant world models has required that both and are known (van2020mdp).
4.2 Meta-Architecture and Contrastive Loss
In this paper, we use SENs to define approximately-equivariant world models that can be trained without access to . To do so, we define a meta-architecture that combines a symmetric embedding network with an equivariant world model, which is illustrated in Figure 3. This architecture comprises three domain-dependent components:
1. A symmetric embedding network , which maps states in a pixel-space to an intermediate space for which an explicit symmetry group action is known.
2. An equivariant encoder , which extracts the subset of features that are necessary to predict dynamics in a latent space with a known group action .
3. An equivariant transition model which serves as an inductive bias by defining dynamics that satisfy the relation in Equation 4 with respect to the known group representations and .
We employ the self-supervised contrastive loss introduced by kipf20 for training. We assume access to a dataset collected offline of triplets consisting of the current state , the action , and the next state . We combine this ground truth transition triplet with a negative state , which is randomly sampled from triplets within the minibatch. The contrastive loss is
where , , and . Minimizing this loss pushes towards and away from the negative sample .
4.3 Environments and Architectures
|Environment||2D Shapes & 3D Blocks||Rush Hour||Reacher||3D Teapot|
|2-layer CNN||7-layer CNN||4 conv, 3 FC layers|
|Equ. Encoder||MLP + -conv||MLP + -conv||
|MPNN + -conv||
We consider 5 environments with varying symmetries. Table 1 shows an overview of symmetries, representation types, and model architectures. The first two environments, 2D Shapes and 3D Blocks, are grid-worlds with 5 moving objects (kipf20). Rush Hour is a variant of 2D Shapes where objects move relative to their orientation. In these 3 domains, we consider symmetry to rotations () and object permutations (). The fourth domain is a continuous control domain, the Reacher-v2 MuJoCo environment (todorov2012mujoco), which is symmetric under rotations, flips (), and translations. The last domain is a 3D teapot, where actions are 3D rotations in the group . All environments use images as observed states.
Transition Model ()
The transition model defines the main inductive bias. In 2D shapes, 3D blocks, and Rush Hour, we use a message-passing neural network. This defines an object-factored representation that is equivariant to permutations and models pairwise interactions (i.e. movement of one object can be blocked by another object). We extend the architecture proposed by kipf20 to incorporate rotational symmetry using convolutions, resulting in a network similar to the one that has concurrently been proposed by brandstetter2021geometric.
The Reacher and 3D Teapot environments do not use model components that consider permutations. In Reacher, is an Equivariant MLP (EMLP) made with 1x1-convolutions in the -CNN framework (weiler2019e2cnn).
For 3D Teapot, the action space , symmetry group , and latent space are all . Since is not a vector space, is a non-linear group action. Semantically, the MDP action is a rotation matrix and the latent state is a positively-oriented orthogonal coordinate frame. Though , these interpretations lead to differing actions with but (see Figure 4 for an illustration). If is correctly learned, then can be implemented as a matrix multiplication which is equivariant,
This method, which we label MatMul, is similar to the one in quessard, except in our framework the ground truth is provided to aid learning .
Equivariant encoder ()
The encoder in object-centric environments is shared over all 5 objects and uses group convolution over (cohen2016group), thus achieving -equivariance. In the Reacher environment, we combine 3 -equivariant layers with a 3-layer -EMLP. For 3D Teapot, no encoder is needed.
Symmetric Embedding Network ()
The SEN in each environment is based on a convolutional network. It maps the image to an image in which the -action is easy to describe in terms of pixel manipulation. In the object-centric environments, the output is a down-sampled image with 5 channels. The action of permutes the channels while rotates the image. In the Reacher environment, is a down-sampled image which is rotated, flipped, and translated by via .
In the case of the 3D teapot environment, we expect the SEN to detect the pose of the object in 3D. We use a two-part network that directly encodes using a down-sampling CNN whose output is passed to an MLP, and converted to an element of . To force the output of the symmetric embedding network to be an element of , we have the last layer output 2 vectors and perform Gram-Schmidt orthogonalization to construct a positively oriented orthonormal frame (see Appendix E). This method is also used by falorsi2018explorations, who conclude it produces less topological distortion than alternatives such as quaternions.
4.4 Generalizing over the MDP Action Space
In settings where data collection is costly, equivariance can improve sample efficiency and generalization. While it is difficult to generalize over high-dimensional states without explicit symmetry, the MDP actions are low dimensional and have clear symmetry. Furthermore, the MDP action bypasses the non-equivariant and is passed directly to , which is explicitly equivariant. This means we can train using only a proper subset of the action space and then test on the entire . In other words, our model has the added benefit of generalizing to unseen actions when trained on only a fraction of data, which we demonstrate in Section 5.4. We state a proposition that bounds the model error over the entire action space when the model is trained on the subset . (proof in Appendix H).
Let . Assume , i.e. every MDP action is a -transformed version of one in . Consider sampled from . Denote the set of all -transforms of all of samples in and . Assume a -invariant norm and model error is bounded where and equivariance errors are bounded and for all and all . Then model error over is also bounded for all .
For all experiments, we consider three types of models: (a) a non-equivariant model with no enforced symmetry, (b) a fully-equivariant model with chosen to be the closest explicit pixel-level transformation to the actual symmetry, and (c) our method. For 3D Teapot, we forgo the fully equivariant baseline as it is hard to define a acting on the 2D image space which approximates the true group action. We instead include a comparison to Homeomorphic VAE (falorsi2018explorations) which is trained to on images of teapots without any actions. As its latent space is the same as our model, we can use the MatMul transition model in order to predict the effect of actions. We keep the total number of parameters comparable across all models by reducing the hidden dimensions for the equivariant networks. Other details are provided in Appendix E and F.
|3D Teapot||Homemorphic VAE|
To evaluate the model without state reconstructions, we use two types of metrics. The first are accuracy metrics adapted from kipf20 and the second are equivariance metrics to measure the degree of equivariance.
Hits, Hard Hits, Traversal Hits, and MRR
Given a dataset of triplets , ranking metrics compute the distance between each predicted state and all next states . The Hits at Rank (H@k) computes the proportion of triplets for which is among the -nearest neighbors of the corresponding next state . The mean reciprocal rank (MRR) is the average inverse rank. We also compute Hard Hits at Rank (HH@k), where we generate negative samples of states close to the true next state and compute the proportion of samples where the distance to is lower than the distance to . This is a harder version of H@k, as the model must distinguish between similar negative samples and the true positive sample. For Traversal Hits (TH@k), which we use for Teapot experiments, we use increments along three axes of rotation (yaw, pitch, roll) to be the negative states . We measure whether at each increment can be distinguished from the of the other points along the traversal. For example, the traversal shown in Figure 6 reaches 100%, whereas a model mapping every increment into the same latent state has TH@1 = 0%.
Equivariance Error ()
To evaluate the degree to which the learned SEN is equivariant, we generate triplets for which a known element acts on the state . This yields images and during generation, which allows us to compute the equivariance error,
Distance Invariance Error ()
The equivariance error can be computed when the output space is spatial and we can manually perform group actions on the outputs. However it cannot be applied to the latent space in the case of non-equivariant models, since the group action on the latent space cannot be meaningfully defined.
We therefore propose a proxy for the equivariance error using invariant distances. For a pair of input states , an equivariant model will have the same distances and assuming the action of is norm preserving, as it is for all transformations considered in the paper. (The action is assumed.) Due to the linearity of the action, . The distance invariance error is computed as
5.2 Model performance comparison
Tables 2, 5, and 6 summarize performance of models and baselines, with additional results in Appendix C. In general, the ranking metrics show that all models are accurate on 3D blocks, Rush Hour, and Reacher.
Surprisingly, the fully equivariant model performs very well even when the group action on the input space
is not accurate. Due to the skewed perspective, we can see that the simple pixel-level transformation maps training data to out-of-distribution images which are never seen by the model. We hypothesize that equivariance does not hamper its performance on training data, but only constrains its extrapolation capabilities to out-of-distribution samples.
In Table 2, we observe that both “None” and our model are accurate on hard hits (HH@1), but “None” performs poorly on traversal hits (TH@1). This baseline is only sensitive to pitch and roll, while completely ignoring the yaw of the teapot. The Homeomorphic VAE results indicate that the model makes only coarse distinctions between different orientations of the teapot.
For the equivariance metrics, we can see that SEN-based models outperform baselines on for 3D Blocks and on for Teapot. For the other environments, the equivariance metrics are relatively similar for all models.
5.3 Latent visualizations
We visualize the latent embeddings of our model to qualitatively analyze the learned representations. Figure 6 shows a sample from the evaluation dataset in both pixel and latent space. We factor the encoded state and next state into irreducible -representations (irreps) before visualizing (see (hall2003lie)
). Some irreps are 1-dimensional and are plotted as a line. The 2-dimensional irreps show a clear circular pattern, matching the joint rotations of the environment.
We also transform the embedding with the group action and visualize the corresponding pixel-level outputs. Figure 6 shows the traversal of rotations in pixel and latent space for 3D Teapot. The latent space can choose its own base coordinate frame and is thus oriented downwards. We can clearly see that the effective rotations relative to the objects’ orientation perfectly align, demonstrating that the learned embeddings correctly encode 3D poses and rotations. For the 3D Blocks and Reacher, we train a separate decoder for epochs after freezing our model in order to decode into pixel space. Figures 9 and 10 show our model implicitly learns a reasonable group action in input space which corresponds the group action in latent space .
5.4 Generalization from limited actions
In these experiments, we train on a subset of actions and evaluate on datasets generated with the full action space. These experiments aim to verify that our model, even with components not constrained to be equivariant, can learn a good equivariant representation which can generalize to actions that were not seen during training. We perform experiments on the 2D Shapes, 3D Blocks, and Reacher domains. For 2D Shapes, the training data contains only ‘up’ actions and for 3D Blocks, we omit the left action in training. For Reacher, we restrict the actuation force of the second joint to be positive, meaning that the second arm rotates in only one direction.
Tables 3 and 4 show results for 2D Shapes, 3D Blocks, and Reacher. We see that our method can successfully generalize over unseen actions compared to both the non-equivariant and fully equivariant baselines. The non-equivariant baseline in particular performs poorly on all domains, achieving only 2.8% on Hits@1 and 5.3% on MRR for 2D Shapes. The fully equivariant model performs worse than our method for 3D Blocks and achieves a similar performance on Reacher. As the fully equivariant model performs well when trained on all actions but does not perform as well in these generalization experiments, these results lend support to our hypothesis that the inaccurate pixel-level equivariance bias limits its extrapolation abilities to out-of-distribution samples. In these limited actions experiments, the fully equivariant model cannot extrapolate correctly .
Figure 8 in the Appendix shows embeddings for all states in the evaluation dataset for our model and the non-equivariant model trained on only the up action. Our model shows a clear grid, while the non-equivariant model learns a degenerate solution (possibly encoding only the row index).
6 Conclusion and Future Work
This work demonstrates that an equivariant world model can be paired with a symmetric embedding network, which itself is not equivariant by construction, to learn a model that is approximately equivariant. This makes it possible to use equivariant neural networks in domains where the symmetry is known, but transformation properties of the input data cannot be described explicitly. We consider a variety of domains and equivariant neural network architectures, for which we demonstrate generalization to actions outside the training distribution. Future work will include tasks besides world models and using symmetric embeddings to develop disentangled and more interpretable features in domains with known but difficult to isolate symmetry.
Appendix A Outline
Our appendix is organized as follows. First, in Section B, we provide an additional details on the problem setup. Additional experiment results are presented in Section LABEL:app:epxeriments, followed by the details of environments and network architectures in Sections D and E. We further explain the notation and definition in Section G. The proof of Proposition 4.1 is in Section H.
Appendix B Setup: Equivariant World Models
In this section, we provide a technical background for building equivariant world models, which we use in learning symmetric representations.
We model our interactive environments as Markov decision processes. A (deterministic) Markov decision process (MDP) is a -tuple , with state space , action space , (deterministic) transition function , reward function , and discount factor .
Symmetry can appear in MDPs naturally (zinkevich2001symmetry; narayanamurthy2008hardness; ravindran2004algebraic; van2020mdp), which we can exploit using equivariant networks. For example, van2020mdp study geometric transformations, such as reflections and rotations. ravindran2004algebraic study group symmetry in MDPs as a special case of MDP homomorphisms.
Symmetry in MDPs.
Symmetry in MDPs is defined by the automorphism group of an MDP, where an automorphism is an MDP homomorphism that maps to itself and thus preserves its structure. zinkevich2001symmetry show the invariance of value function for an MDP with symmetry. narayanamurthy2008hardness prove that finding exact symmetry in MDPs is graph isomorphism complete.
ravindran2004algebraic provide a comprehensive overview of using MDP homomorphisms for state abstraction and study symmetry in MDPs as a special case. A more recent work by van2020mdp builds upon the notion of MDP homomorhpism induced by group symmetry and uses it in an inverse way. They assume knowledge of MDP homomorphism induced by symmetry group is known and exploit it. Different from us, their focus is on policy learning, which needs to preserve both transition and reward structure and thus has optimal value equivalence (ravindran2004algebraic).
More formally, an MDP homomorphism is a mapping from one MDP to another which needs to preserve the transition and reward structure (ravindran2004algebraic). The mapping consists of a tuple of surjective maps , where is the state mapping and is the state-dependent action mapping. The mappings are constructed to satisfy the following conditions: (1) the transition function is preserved , (2) and the reward function is also preserved for all and for all .
An MDP isomorphism from an MDP to itself is call an automorphism of . The collection of all automorphisms of along with the composition of homomorphisms is the automorphism group of , denoted .
We specifically care about a subgroup of