SWARM
Set-Equivariant Deep Learning Models
view repo
In this work we propose a new neural network architecture that efficiently implements and learns general purpose set-equivariant functions. Such a function f maps a set of entities x={ x_1,...,x_n} from one domain to a set of same cardinality y=f(x)={ y_1,...,y_n} in another domain regardless of the ordering of the entities. The architecture is based on a gated recurrent network which is iteratively applied to all entities individually and at the same time syncs with the progression of the whole population. In reminiscence to this pattern, which can be frequently observable in nature, we call our approach SWARM mapping. Set-equivariant and generally permutation invariant functions are important building blocks for many state of the art machine learning approaches. Even in application where the permutation invariance is not of primary interest, as to be seen in the recent success of attention based transformer models (Vaswani et. al. 2017). Accordingly, we demonstrate the power and usefulness of SWARM mappings in different applications. We compare the performance of our approach with another recently proposed set-equivariant function, the SetTransformer (Lee et.al. 2018) and we demonstrate that transformer solely based on SWARM layers gives state of the art results.
READ FULL TEXT VIEW PDFSet-Equivariant Deep Learning Models
Permutation invariant transformations have recently attracted growing attention in the research community. Today, there are numerous deep learning tasks where data comes in an unordered on non-meaningfully order. Think of, for example, an image based classification task, where the decision has to be made based on a collection of images. The order in the data batch often is arbitrary and non-informative, though the classifier may be sensitive to it. When empirically marginalizing over the ordering, the sensitivity reflects in the variance of the classifier. We will demonstrate this effect in a little example below. But interesting applications are not limited to information pooling from collections. In principle everywhere where information on a population of entities is processed - be it to take a decision upon the whole population or a decision on the individuals that is influenced by the population - permutation invariant or equivariant functions emerge. The population can be, as already mentioned, a collection of object to classify jointly, the data points in Bayesian experiment, examples in a few-shot learning setup and many more. The theory of invariant functions is well understood.
Zaheer et al. (2017) have introduced the notion of deep sets as learnable set functions. Yarotsky (2018) and Sannai et al. (2019)study generalizations of the universal approximation theorem for neural networks for invariant or equivariant mappings. How ever it is not clear if the theoretical results always provide useful foundation for designing practically applicable set functions Wagstaff et al. (2019).In this work, we propose a new approach to set-equivariant function that practically works well and efficiently also under circumstances where approaches inspired by universal approximation theorem do not. First we will introduce our model, which we call SWARM mappings. We will then introduce an amortized clustering task as a challenging performance benchmark. We compare SWARM mappings with other approaches to set-equivariant functions. Further we demonstrate that SWARM mappings can also be used in a not equivariant setting by allowing a setup of a 1-layer transformer architecture for the generation of images
We study problems in which an unordered set or population of entities is processed simmultaneously by a deep neural network. We use bold face symbols or a notation in parentheses to indicate the whole population of entities as a matrix or as a set. Whenever there is no ambiguity, we may omit the subscripts
for simplicity. Although the population of entities is a set of vectors, it makes sense to consider them in arbitrary but fixed order as a matrix. For a set-equivariant mapping we have to ensure that in can be carried out an arbitrary number of entities and their ordering doesn’t matter.
A function is set-equivariant if it is defined for all and for all the following holds
(1) |
for arbitrary permutations of the columns of , .
From eq. (1) it follows directly that for a repeated application of functions to be a set-equivariant mapping it is sufficient that every
is set-equivariant. Thus, we can model arbitrarily complex functions in a hierarchical structure, just like in any other feed forward neural network architecture, as long as we provide that all components fulfill (
1). Apparently, any function that maps entities individually is trivially set-equivariant. Standard non-linearities or entity-wise linear or non-linear operations (sometimes referred to as -convolutions) fall into that category.However, the family of function that fulfill the definition is much richer than this. The simplest non-trivial one is the linear mapping
(2) |
In fact, this is equivalent to two linear functions, one operating on all entities individually () and one working on all entities summed up (), the output of which as well as the bias are shared by all entities. In a feed forward architecture with several such layers combined with appropriate non-linearities, significantly non-trivial set functions can be learned. It has been proven theoretically that such set pooling functions can approximate arbitrary complex set-equivariant (and -invariant) functions Zaheer et al. (2017). However, in practical applications, this structure seems to be too limiting Wagstaff et al. (2019). Also in our experiments it turned out that a model solely built with such set linear layers failed to learn appropriately. We will call such layers ’set-linear’ layers in the following.
Our goal was to improve on the limited practical approximation capabilities of set-linear layers. When we look at their working principle (2) then we see that all entities are processed individually with the same affine transformation and all entities receive the same additive population update . Our idea was to increase expressiveness by letting every single entity maintain their own memory about how they develop compared to the development of the whole population during adjacent transformation steps. The core idea of SWARM mappings is to implement exactly this entity individual memory. In processing long sequences it is well known that gated network architectures like LSTM Hochreiter and Schmidhuber (1997); Gers et al. (1999) and GRU Cho et al. (2014) can carry on information over long (temporal) distances. In spite on non-temporal but layered architectures, Highway Networks Srivastava et al. (2015)have shown to have the same positive effect on carrying on information through many adjacent processing steps.
For SWARM we use a modified LSTM cell that receives as input to its gating networks the entity input , the last output , and additionally a population input , where is a set-invariant population function of all entities. Thus, for the activation of gate of entity , i.e. input, output, and forget gates (and similarly for the cell update), we have an additional population term in the activation equation involving ,
For set-equivariance all parameters have to be shared among all entities, thus allowing for variable number of entities and permutation invariance. The update of the memory cell then works as usual in any LSTM. Figure 1a) shows an illustration of the SWARM LSTM cell with population pooling. For processing in a SWARM layer, the cell is executed in parallel for all entities and repeatedly over several iterations. During the iterations the input to the cell, , will remain the same, but the cell’s memory state is constantly updated with the feedback provided by the population. In the last iteration, the memory will be sufficient to produce together with the input the right output for entity . Figure 1b) depicts a SWARM layer as a recurrent processing unit. Initial values for and are set to zero^{1}^{1}1It may be useful to initialize them with small random values in order to break symmetry of some entities that happen to be identical.. Taking the SWARM layer as a set-equivariant building block, nothing speaks against stacking several of them together or combining them with other set-equivariant blocks. In our experiments we used one or two with a non-linearity layer between them.
We compare the performance SWARM layers and other architectures in an amortized clustering experiment. In this task, the model is presented a number of entities at a time and its task is to simultaneously assign every entity to one out of cluster indices. This is inspired by the experiment conducted in Lee et al. (2018), however, as we are primarily interested in set-equivariant rather than set-invariant mappings, we do not attempt to learn the parameters of the data generating multi-modal distribution, but the cluster assignments directly. It turned out that this is a rather challenging task that is difficult for many models to solve.
Training amortized clustering is a supervised learning task similar to classification. We want to assign every entity to a class (or cluster). However, we are confronted with the permutation ambiguity of clustering. The model had no chance to figure out the original assignment but guessing. Therefore we cannot simply use categorical cross entropy loss for the task. Fortunately, it had turned out that we don’t need to explore all
possible assignment permutations but can do a greedy matching of logits to target cluster indices. The procedure of
greedy cross entropy loss is described in Algorithm 1The dataset comprises 10.000 tasks each having random number of entities, . Entities are points in
that are iid. drawn from a Mixture of Gaussians. For every task, the number of Gaussian clusters is drawn uniformly random between 3 and 10, the cluster centers were drawn iid. from the standard normal distribution. The cluster covariance were drawn iid. from the inverse Wishard distribution with 4 degrees of freedom and a scale factor of
. Cluster assignments were uniform over the number of clusters. The dataset was sampled once with a fixed random seed and then used for all experiments. It was split into 9.000 tasks for training and 1.000 tasks for validation/testing.We compare our approach to different architectures:
A recurrent architecture based on (potentially multi-layered) bidirectional LSTM Graves and Schmidhuber (2005)which treats the entities as an ordered sequence. This is of course no set-equivariant operation. We have included this model more as a baseline. Also we will use it to study the effect of an implicit ordering of an otherwise unordered set. For training of this model we had the entities explicitly shuffled randomly to avoid that the model learns from any spurious ordering of the entities.
A model based on SetTransformers. These were recently proposed in Lee et al. (2018) and build upon multi-head attention and self attention layers Vaswani et al. (2017). To be compatible with our setup we only use the encoder part of their model, which is a set-equivariant function before it gets pooled down to an output of fixed size in the decoder. From the different architectural building blocks described in the paper we have chosen the Induced Set Attention Blocks (ISAB) to compare with, as they were reported best performing.
A SWARM layer network. SWARM layers can be used individually or stacked on top of each other to form a set-equivariant feed forward network. When more than one layer was used, a ReLU non-linearity was used between them.
In our experiment, we explored different hyper parameters to find the best performing model. These were number of hidden units or memory cells, the number of layers, the number of inducing points in the ISAB blocks, and the number of iterations in the swarm layer. Table 1 lists the range of explored hyper parameters and which were applicable for which model. All performance results are subject to fixed compute resources. All models were dedicated 60 minutes net training time^{4}^{4}4not including setup, intermediate validation, logging, and checkpoint generation on one P100 GPU with two 3.2GHz Xeon CPU cores and 8GB of RAM. Training was made with batch size 50, no dropout and Adam optimizer.
To automatize and stabilize bulk learning, we have developed a robust back tracking heuristics that prevents the model from divergence due to outliers or too large large learning rates, for example. In a nutshell, whenever the validation loss after an epoch is too large compared to the previous epochs, the model is set back to the checkpoint with the best validation loss and the learning rate is lowered. The procedure is detailed in Algorithm
2. In our experiments we used the parameters .SetLinear | LSTM | SetTransformer | SWARM | ||
---|---|---|---|---|---|
code | (a-d) | (a-d) | (a-b-d) | (a-c-d) | |
explored | a) no. of hidden units / cell states | 32, 64 | 16,32,64,128 | 16,32,64,128 | 16,32,64,128 |
b) no. of inducing points | 10, 20,40 | ||||
c) no. of iterations | 2, 5, 10 | ||||
d) no. of layers | 2, 4, 8 | 1, 2, 3 | 1, 2, 3 | 1, 2 | |
best | architecture | (64-8) | (128-3) | (32-40-3) | (128-10-1) |
no. of parameters | 51,604 | 928.266 | 53.802 | 135.178 | |
validation loss | 1,498 ±0,008 | 0,676 ±0,023 | 0,565 ±0,036 | 0,455 ±0,009 |
Figure 3 gives an overview of the performance of the different architectures. It is plotted there the number of parameters as a measure of model complexity versus the negative log-likelihood of the model on the validation set. Dashed lines in the respective color indicate the empirical frontier for the four model classes. For the models at the frontier, their architecture code is plotted next to the point. See Table 1 for a definition of the architecture codes. For the best performing models of every class, their negative log-likelihoods are listed at the bottom of Table 1. One can clearly see that SWARM layer models outperform all other model classes and constitute the overall frontier. Interestingly, top performing SWARM models have just one layer. This approves empirical findings of ours from before this study. In an early approach, we had tried larger stacks of layers or even different SWARM cells in every iteration. However, none of these higher parametrized architectures worked particularly well. The LSTM models are significantly worse in the validation loss. Note that even the worst SWARM models are below the LSTM frontier. Interestingly, the LSTM models seem to be quite robust in their setups. None of the models is significantly away from their frontier. This is in contrast to the SetTransformers, where some of them (mainly the heavier ones) failed completely. It could be that they are more sensitive to learning rate and batch size setup. Last but not least, the SetLinear models completely failed to learn the task. That is surprising and we still do not fully understand what is happening here. As the models initially learn and always robustly converge to the same value it could be that they are just not flexible enough to solve that challenging task.
Results of the amortized clustering generated by a single SWARM layer with 128 units and 10 iteration can be seen in Figure 4. The left panel shows the ground truth data generated from a test set. This task has 8 clusters shown in different colors together with the covariance ellipses of the generating covariance matrices. The right panel shows the clusters assigned by the SWARM layer. Note that the colors don’t map one by one because of the above mention permutation ambiguity in clustering. Apart from the dense center region with overlapping clusters the model’s cluster assignment is quite consistent. Note that this example is not picked but chosen randomly and is quite representative. The gray shaded are show the assignment confidence for a single entity that was augmented to the population resulting in . Being the logits after the transformation with the SWARM layer, , the gray level corresponds to the entropy
at the respective position . Darker regions are regions of higher entropy, thus lower assignment confidence. Note that for the center region where the model makes actually a mistake, the entropy is comparably large.
To conclude this experiment, we have done an ablation study where we varied the number of iterations in the SWARM layer which was trained with 10 iterations. We see that the model starts with coarse structures and low confidence, which is iteratively refined more and more. Doing more iterations than used during training, the model slowly starts to get over confident and the performance degrades.
The rightmost panel shows the confidence map generated by a (64-8) SetLinear model. It could be an explanation for their weak performance, should it be the case that they are not capable of generating more complex structures that that. Compare this with the early iterations of the SWARM model.
Finally, we had a closer look at the performance of the LSTM models. Despite their weak performance in this task, we wanted to know if the fact that they are no set-equivariant functions does matter a lot. For populations from 100 to 1.000 entities, clustering losses of the best LSTM model (128-3) were recorded for 1.000 different random shufflings of the entities per task. Figure 3shows a violin distribution plot of the standard deviations per task, scaled by the average loss of that task. We see that it is in the range of 20% (slightly decreasing for larger populations), which is quit significant. Not that for the set-equivariant models this value is zero by construction.
A set-equivaraint layer is the main building block for powerful neural network architectures, which recently enjoy increasing popularity - Transformers Vaswani et al. (2017). Scaled Dot Product Attention, Self Attention, Multi-Head Self Attentions are the ingredients for several models that are state of the art in many challenges currently. Surprisingly, the set-equivariance is not actually needed there. To be precise, it is even explicitly eliminated by the introduction of positional encodings. Still, reportedly transformers frequently outperform recurrent or convolutional architectures. The question was, could a SWARM layer also be used in a transformer-like setting. We investigate this with the task of image generation, as there have been reported great success with transformers recently Parmar et al. (2018); Child et al. (2019). We have adopted the setup widely from the Image Transformer. To build a SWARM Transformer we had to replace the pooling operation in the SWARM layer with a causal mean pooling, that is where the entities are explicitly ordered along the scanlines of the image. We further used 256 dimensional fixed positional encodings, similar to those in the Image Transformer, and 256 dimensional trainable input and channel embeddings. As they are adjustable and are immediately followed by a linear layer operation in the SWARM cell, we have added them up instead of concatenating them, as suggested in Vaswani et al. (2017). SWARM Transformer generated MNISTLeCun et al. (1998) and FashionMNIST Xiao et al. (2017)look very convincing and also their likelihoods are state of the art (cf. Nalisnick et al. (2018)). The CIFAR10 Krizhevsky and Hinton (2009) results are more off (cf. the survey in van den Oord et al. (2016) and Parmar et al. (2018)) and also the samples are less visually appealing. We hope that with refined architectures can improve on that.
We have presented a powerful yet simple architecture for set-equivariant functions and could demonstrate that it outperforms other state of the art models. Notably, the SWARM layers can be used as an immediate replacement for attention and self-attention blocks if the pooling function is designed appropriately. We could demonstrate that this can yield state of the art performance in image-transformer-like tasks (MNIST and FashionMNIST) with much simpler architectures than attention based image transformers. For our future work it remains to systematically analyze in which areas SWARM mappings are beneficial over attention based models. In particular we want to better understand why the SWARM transformer performed in our experiment so much better on the 1-channel tasks MNIST and Fashion MNIST compared to CIFAR10.