perceiver-pytorch
None
view repo
Many machine learning tasks such as multiple instance learning, 3D shape recognition and few-shot image classification are defined on sets of instances. Since solutions to such problems do not depend on the permutation of elements of the set, models used to address them should be permutation invariant. We present an attention-based neural network module, the Set Transformer, specifically designed to model interactions among elements in the input set. The model consists of an encoder and a decoder, both of which rely on attention mechanisms. In an effort to reduce computational complexity, we introduce an attention scheme inspired by inducing point methods from sparse Gaussian process literature. It reduces computation time of self-attention from quadratic to linear in the number of elements in the set. We show that our model is theoretically attractive and we evaluate it on a range of tasks, demonstrating increased performance compared to recent methods for set-structured data.
READ FULL TEXT VIEW PDF
End-to-end speech recognition has become popular in recent years, since ...
read it
An increasing number of machine learning tasks deal with learning
repres...
read it
Recently, conformer-based end-to-end automatic speech recognition, which...
read it
Multiple Instance learning (MIL) algorithms are tasked with learning how...
read it
A set is an unordered collection of unique elements–and yet many machine...
read it
Set-input deep networks have recently drawn much interest in computer vi...
read it
In this work we propose a new neural network architecture that efficient...
read it
None
meta learning from the initializaion induced by word embedding
Tensorflow implementation of a modified Set Transformer
Architecture based on the transformer for prediction of molecules represented as sequence of strings
Learning representations has proven to be an essential problem for deep learning and its many success stories. The majority of problems tackled by deep learning are
instance-basedand take the form of mapping a fixed-dimensional input tensor to its corresponding target value
(Krizhevsky et al., 2012; Graves et al., 2013). For some applications, we are required to process set-structured data. Multiple instance learning (Dietterich et al., 1997; Maron & Lozano-Pérez, 1998) is an example of such a set-input problem, where a set of instances is given as an input and the corresponding target is a label for the entire set. Other problems such as 3D shape recognition (Wu et al., 2015; Shi et al., 2015; Su et al., 2015; Charles et al., 2017), sequence ordering (Vinyals et al., 2016), and various set operations (Muandet et al., 2012; Edwards & Storkey, 2017; Zaheer et al., 2017) can also be viewed as such set-input problems. Moreover, many meta-learning (Thrun & Pratt, 1998; Schmidhuber, 1987) problems which learn using a set of different but related tasks may also be treated as set-input tasks where an input set corresponds to the training dataset of a single task. For example, few-shot image classification (Finn et al., 2017; Snell et al., 2017; Lee & Choi, 2018)operates by building a classifier using a
support set of images, which is evaluated with query images.A model for set-input problems should satisfy two critical requirements. First, it should be permutation invariant
— the output of the model should not change under any permutation of the elements in the input set. Second, such a model should be able to process input sets of any size. While these requirements stem from the definition of a set, they are not easily satisfied in neursal-network-based models: classical feed-forward neural networks violate both requirements, and RNNs are sensitive to input order.
Recently, Edwards & Storkey (2017) and Zaheer et al. (2017) propose neural network architectures which meet both criteria, which we call set pooling methods. In this model, each element in a set is first independently fed into a feed-forward neural network that takes fixed-size inputs. Resulting feature-space embeddings are then aggregated using a pooling operation (, , or similar). The final output is obtained by further non-linear processing of the aggregated embedding. This remarkably simple architecture satisfies both aforementioned requirements, and more importantly, is proven to be a universal approximator for any set function (Zaheer et al., 2017)
. Thanks to this property, it is possible to learn complex mapping between input sets and their target outputs in a black-box fashion, much like with feed-forward or recurrent neural networks.
Even though this set pooling approach is theoretically attractive, it remains unclear whether we can approximate complex mappings well using only instance-based feature extractors and simple pooling operations. Since every element in a set is processed independently in a set pooling operation, some information regarding interactions between elements has to be necessarily discarded. This can make some classes of problems unnecessarily difficult to solve.
Consider the problem of meta-clustering: we would like to learn a parametric mapping from an input set of points to centers of any clusters in the set, for many such sets. Even though a neural network with a set pooling operation can approximate such a mapping by learning to quantize space, this quantization cannot depend on the contents of the set. It limits the quality of the solution on one hand, and may make optimization of such a model more difficult; we show empirically in Section 4 that it leads to under-fitting.
In this paper, we propose a novel set-input deep neural network architecture called the Set Transformer, (cf. Transformer, Vaswani et al. (2017)). The novelty of the Set Transformer comes from three important design choices: 1) We use a self-attention mechanism based on the Transformer to process every element in an input set, which allows our approach to naturally encode pairwise- or higher-order interactions between elements in the set. 2) We propose a method to reduce the computation time of Transformers to where
is a fixed hyperparameter. 3) We use a self-attention mechanism to aggregate features, which is especially beneficial when the problem of interest requires multiple dependent outputs, such as the problem of meta-clustering, where the meaning of each cluster center heavily depends its location relative to the other clusters. We apply the Set Transformer to several set-input problems and empirically demonstrate the importance and effectiveness of these design choices.
This paper is organized as follows. In Section 2, we briefly review the concept of set functions, existing architectures, and the self-attention mechanism. In Section 3, we introduce Set Transformers, our novel neural network architecture for set functions. In Section 4, we present various experiments that demonstrate the benefits of the Set Transformer. We discuss related works in Section 5 and conclude the paper in Section 6.
Problems involving a set of objects have the permutation invariance property: the target value for a given set is the same regardless of the order of objects in the set. A simple example of a permutation invariant model is a network that performs pooling over embeddings extracted from the elements of a set. More formally,
Zaheer et al. (2017) has proven that all permutation invariant functions can be represented as (2.1) when is the operator and any continuous functions, thus justifying the use of this architecture for set-input problems.
Note that we can deconstruct (2.1) into two parts: an encoder () which independently acts on each element of a set of items, and a decoder () which aggregates these encoded features and produces our desired output. Most network architectures for set-structured data follow this encoder-decoder structure. Our proposed method is also composed of an encoder and a decoder, but our embedding function does not act independently on each item but considers the whole set to obtain the embedding. Additionally, instead of a fixed function such as , our aggregating function is parameterized and can thus adapt to the problem at hand.
Assume we have
query vectors (corresponding to
points in an input set) each with dimension : . An attention function is a function that maps queries to outputs using key-value pairs .The pairwise dot product
measures how similar each pair of query and key vectors is, with weights computed with an activation function
. The output is a weighted sum of where a value gets more weight if its corresponding key has larger dot product with the query.Multi-head attention, originally introduced in Vaswani et al. (2017), is an extension of the previous attention scheme. Instead of computing a single attention function, this method first projects onto different -dimensional vectors, respectively. An attention function () is applied to each of these
projections. The output is a linear transformation of the concatenation of all attention outputs:
Note that has learnable parameters , where , , . A typical choice for the dimension hyperparameters is , , . For brevity, we set and throughout the rest of the paper. Unless specified otherwise, we use the scaled softmax , which our experiments showed worked robustly in most settings.
Set operations | Time complexity | High-order | Permutation |
---|---|---|---|
interactions | invariant | ||
Recurrent | Yes | No | |
Pooling (Zaheer et al., 2017) | No | Yes | |
Relational Networks (Santoro et al., 2017) | Yes | Yes | |
Set Transformer (SAB + PMA, ours) | Yes | Yes | |
Set Transformer (ISAB + PMA, ours) | Yes | Yes |
In this section, we motivate and describe the Set Transformer: an attention-based neural network architecture that is designed to process sets of data. A Set Transformer consists of an encoder followed by a decoder (cf. Section 2.1). The encoder transforms a set of instances into a set of features, which the decoder transforms into the desired fixed-dimensional output.
We begin by defining our attention-based set operations. While existing pooling methods for sets obtain instance features independently of other instances, we use self-attention to concurrently encode the whole set. This gives the Set Transformer the ability to preserve pairwise as well as higher-order interactions among instances during the encoding process. For this purpose, we adapt the multihead attention mechanism used in Transformer. We emphasize that all blocks introduced here are neural network blocks with their own parameters, and not fixed functions.
Given matrices which represent two sets of -dimensional vectors, we define the Multihead Attention Block (MAB) with parameters as follows:
where is any row-wise feedforward layer (i.e. it processes each instance independently and identically), and is layer normalization (Ba et al., 2016). The MAB is an adaptation of the encoder block of the Transformer (Vaswani et al., 2017) without positional encoding and dropout. Using the MAB, we define the Set Attention Block (SAB) as
In other words, an SAB takes a set and performs self-attention between the elements in the set, resulting in a set of equal size. Since the output of SAB contains information about pairwise interactions between the elements in the input set , we can stack multiple SABs to encode higher order interactions. Note that while the SAB (3.1) involves a multihead attention operation (3.1), where , it could reduce to applying a residual block on . In practice, it learns more complicated functions due to linear projections of inside attention heads, (2.2) and (2.2).
A potential problem with using SABs for set-structured data is the quadratic time complexity , which may be too expensive for large sets (). We thus introduce the Induced Set Attention Block (ISAB), which bypasses this problem. Along with the set , additionally define -dimensional vectors , which we call inducing points. Inducing points are part of the ISAB itself, and they are trainable parameters which we train along with other parameters of the network. An ISAB with inducing points is defined as:
The ISAB first transforms into by attending to the input set. The set of transformed inducing points , which contains information about the input set , is again attended to by the input set to finally produce a set of
elements. This is analogous to low-rank projection or autoencoder models, where inputs (
) are first projected onto a low-dimensional object () and then reconstructed to produce outputs. The difference is that the goal of these methods is reconstruction whereas ISAB aims to obtain good features for the final task. We expect the learned inducing points to encode some global structure which helps explain the inputs . As an example, think of a clustering problem on a 2D plane. The inducing points could be appropriately distributed points on the 2D plane so that the encoder can compare elements in the query dataset indirectly through their proximity to these grid points.Note that in (3.1) and (3.1), attention was computed between a set of size and a set of size . Therefore, the time complexity of is where is a hyperparameter — an improvement over the quadratic complexity of the SAB. We compare characteristics of various set operations in Table 1. We also emphasize that both of our set operations are permutation equivariant:
We say a function is permutation equivariant iff for any permutation , . Here is the set of all permutations of indices .
Both and are permutation equivariant.
Using the SAB and ISAB defined above, we construct the encoder of the Set Transformer by stacking multiple SABs or multiple ISABs, for example:
We point out again that the time complexity for stacks of SABs and ISABs are and , respectively. This can result in much lower processing times when using ISAB (as compared to SAB), while still maintaining high representational power.
After the encoder transforms data into features , the decoder aggregates them into a single vector which is fed into a feed-forward network to get final outputs. A common aggregation scheme is to simply take the average or dimension-wise maximum of the feature vectors (cf. Section 1). We instead aggregate features by applying multihead attention on a learnable set of seed vectors . We call this scheme Pooling by Multihead Attention (PMA):
Note that the output of is a set of items. In most cases, using one seed vector () and no SAB sufficed. However, when the problem of interest requires correlated outputs, the natural thing to do is to use inducing points. An example of such a problem is clustering where the desired output is centers. In this case, the additional SAB was crucial because it allowed the network to directly take the correlation between the pooled features into account. Intuitively, feature aggregation using attention should be beneficial because the influence of each instance on the target is not necessarily equal. For example, consider a problem where the target value is the maximum value of a set of real numbers. Since the target can be recovered using only a single instance (the largest), finding and attending to that instance during aggregation will be advantageous. In the next subsection, we further analyze both the encoder and decoder structures more rigorously.
Since the blocks used to construct the encoder (i.e., SAB, ISAB) are permutation equivariant, the mapping of the encoder is permutation equivariant as well. Combined with the fact that the PMA in the decoder is a permutation invariant transformation, we have the following:
The Set Transformer is permutation invariant.
Being able to approximate any function is a desirable property, especially for black-box models such as deep neural networks. Building on previous results about the universal approximation of permutation invariant functions, we prove the universality of Set Transformers:
The Set Transformer is a universal approximator of permutation invariant functions.
See Appendix A. ∎
To evaluate the Set Transformer, we apply it to a suite of tasks involving sets of data points. We repeat all experiments five times and report performance metrics evaluated on corresponding test datasets. We compared various architectures arising from the combination of the choices of having attention in encoders and decoders, each of which roughly represents existing works as its special cases. Unless specified otherwise, ”simple pooling” means average pooling.
rFF + Pooling ( Zaheer et al. (2017)): rFF layers in encoder and simple pooling + rFF layers in decoder.
rFF + PMA (includes Ilse et al. (2018) as special cases): rFF layers in encoder and PMA (followed by stack of SABs) in decoder.
Set Transformer: Stack of SABs (ISABs) in encoder and PMA (followed by stack of SABs) in decoder.
To demonstrate the advantage of attention-based set aggregation over simple pooling operations, we consider a toy problem: regression to the maximum value of a given set. Given a set of real numbers , the goal is to return . Given prediction , we use the mean absolute error
as the loss function. We constructed simple pooling architectures with three different pooling operations:
, n, and . We report loss values after training in Table LABEL:table:max. Mean- and sum-pooling architectures result in a high mean absolute error (MAE). The model with max-pooling can predict the output perfectly by learning its encoder to be an identity function, and thus achieves the highest performance. Notably, the Set Transformer achieves performance comparable to the max-pooling model, which underlines the importance of additional flexibility granted by attention mechanisms — it can learn to find and attend to the maximum element.
Architecture | MAE |
---|---|
rFF + Pooling (mean) | 2.133 0.190 |
rFF + Pooling (sum) | 1.902 0.137 |
rFF + Pooling (max) | 0.1355 0.0074 |
Set Transformer | 0.2085 0.0127 |
Architecture | Error |
---|---|
rFF + Pooling | 0.5618 0.0072 |
rFF + PMA | 0.5428 0.0076 |
SAB + Pooling | 0.4477 0.0077 |
Set Transformer | 0.4178 0.0075 |
In order to test the ability of modelling interactions between objects in a set, we introduce a new task of counting unique elements in an input set. We use the Omniglot (Lake et al., 2015) dataset, which consists of 1,623 different handwritten characters from various alphabets, where each character is represented by 20 different images.
We split all characters (and corresponding images) into train, validation, and test sets and only train using images from the train character classes. We generate input sets by sampling between 6 and 10 images and we train the model to predict the number of different characters inside the set. We used a Poisson regression model to predict this number, with the rate given as the output of a neural network. We maximized the log likelihood of this model using stochastic gradient ascent.
We evaluated model performance using sets of images sampled from the test set of characters. LABEL:table:unique
reports accuracy, measured as the frequency at which the mode of the Poisson distribution chosen by the network is equal to the number of characters inside the input set.
We applied the set-input networks to the task of maximum likelihood of mixture of Gaussians (MoGs). The log-likelihood of a dataset generated from an MoG with components is
The goal is to learn the optimal parameters . The typical approach to this problem is to run an iterative algorithm such as Expectation-Maximisation (EM) until convergence. Instead, we aim to learn a generic meta-algorithm that directly maps the input set to . One can also view this as amortized maximum likelihood learning. Specifically, given a dataset , we train a neural network to output parameters which maximize
We structured as a set-input neural network and learned its parameters using stochastic gradient ascent, where we approximate gradients using minibatches of datasets.
We tested Set Transformers along with other set-input networks on two types of datasets. We used four seed vectors for the PMA (), the same as the number of clusters.
Synthetic 2D mixtures of Gaussians: Each dataset contains points on a 2D plane, each sampled from one of four Gaussians.
CIFAR-100 meta-clustering: Each dataset contains images sampled from four random classes in the CIFAR-100 dataset. Each image is represented by a 512-dim vector obtained from a pretrained VGG net (Simonyan & Zisserman, 2014).
Synthetic | CIFAR-100 | |||
---|---|---|---|---|
Architecture | LL0/data | LL1/data | ARI0 | ARI1 |
Oracle | -1.4726 | 0.9150 | ||
rFF + Pooling | -2.0006 0.0123 | -1.6186 0.0042 | 0.5593 0.0149 | 0.5693 0.0171 |
SAB + Pooling | -1.6772 0.0066 | -1.5070 0.0115 | 0.5831 0.0341 | 0.5943 0.0337 |
ISAB (16) + Pooling | -1.6955 0.0730 | -1.4742 0.0158 | 0.5672 0.0124 | 0.5805 0.0122 |
rFF + PMA | -1.6680 0.0040 | -1.5409 0.0037 | 0.7612 0.0237 | 0.7670 0.0231 |
Set Transformer | -1.5145 0.0046 | -1.4619 0.0048 | 0.9015 0.0097 | 0.9024 0.0097 |
Set Transformer (16) | -1.5009 0.0068 | -1.4530 0.0037 | 0.9210 0.0055 | 0.9223 0.0056 |
Meta clustering results. The number inside parenthesis indicates the number of inducing points used in ISABs of encoders. We show average likelihood per data for the synthetic dataset and the adjusted rand index (ARI) for the CIFAR-100 experiment. LL1/data, ARI1 are the evaluation metrics after a single EM update step. The oracle for the synthetic dataset is the log likelihood of the actual parameters used to generate the set, and the CIFAR oracle was computed by running EM until convergence.
We report the performance of the oracle and of different models in Table 4. Additionally, it contains scores attained by all models after a single EM update. Overall, the Set Transformer found accurate parameters and even outperformed the oracles after a single EM update. This can be explained by relatively small size of the input sets, which leads to some clusters having fewer than 10 points. In this regime, sample statistics can differ from population statistics, which limits the performance of the oracle, but the Set Transformer can adapt accordingly. Notably, the Set Transformer with only 16 inducing points showed the best performance, even outperforming the full Set Transformer. We believe this is due to the knowledge transfer and regularization via inducing points, helping the network to learn global structures. Our results also imply that the improvements from using the PMA is more significant than that of using SAB, supporting our claim of the importance of attention-based decoders. We provide detailed generative processes, network architectures, and training schemes along with additional experiments with various numbers of inducing points in Section B.3.
Architecture | Test AUROC | Test AUPR |
---|---|---|
Random guess | 0.5 | 0.125 |
rFF + Pooling | 0.5643 0.0139 | 0.4126 0.0108 |
SAB + Pooling | 0.5757 0.0143 | 0.4189 0.0167 |
rFF + PMA | 0.5756 0.0130 | 0.4227 0.0127 |
Set Transformer | 0.5941 0.0170 | 0.4386 0.0089 |
Meta set anomaly results. Each architecture is evaluated using average of test area under receiver operating characteristic curve (AUROC) and test area under precision-recall curve (AUPR).
We evaluate our methods on the task of meta-anomaly detection within a set using the CelebA dataset. The dataset consists of 202,599 images with the total of 40 attributes. We randomly sample 1,000 sets of images. For every set, we select two attributes at random and construct the set by selecting seven images containing both attributes and one image with neither. The goal of this task is to find the image that does not belong to the set. We give a detailed description of the experimental setup in
Section B.4. Table 5 contains empirical results, which show that Set Transformers outperformed all other methods by a significant margin.We evaluated Set Transformers on a classification task using the ModelNet40 (Chang et al., 2015) dataset, containing 40 categories of three-dimensional objects. Each object is represented as a point cloud, which we treat as a set of elements in . Table 6 contains experimental results on point clouds^{1}^{1}1The point-cloud dataset used in this experiment was obtained directly from the authors of Zaheer et al. (2017). with points each. In this setting, MABs turned out to be prohibitively expensive due to their time complexity. Additional results with points and experiment details are available in Section B.5. Note that ISAB (16) + Pooling outperformed Set Transformers (ISAB (16) + PMA (1)) by a large margin. Our interpretation is that the class of a point cloud object could be efficiently represented by simple aggregation of point features, and the PMA suffered from an optimization issue in this setting. We would like to point out that PMA outperformed simple pooling in all other experiments.
Architecture | Accuracy |
---|---|
rFF + Pooling | 0.8551 0.0142 |
rFF + PMA (1) | 0.8534 0.0152 |
ISAB (16) + Pooling | 0.8915 0.0144 |
Set Transformer (16) | 0.8662 0.0149 |
rFF + Pooling (Zaheer et al., 2017) | 0.83 0.01 |
rFF + Pooling + tricks (Zaheer et al., 2017) | 0.87 0.01 |
Pooling architectures for permutation invariant mappings Pooling architectures for sets have been used in various problems such as 3D shape recognition (Shi et al., 2015; Su et al., 2015), discovering causality (Lopez-Paz et al., 2016), learning the statistics of a set (Edwards & Storkey, 2017), few-shot image classification (Snell et al., 2017), and conditional regression and classification (Garnelo et al., 2018). Zaheer et al. (2017) discusses the structure in general and provides a partial proof of the universality of the pooling architecture.
Attention-based approaches for sets Vinyals et al. (2016) proposes an architecture to map sets into sequences, where elements in a set are pooled by weighted average with weights computed from attention mechanism. Several recent works have highlighted the competency of attention mechanisms in modeling sets. (Yang et al., 2018) proposes AttSets for multi-view 3D reconstruction, where attention is applied to the encoded features of elements in sets before pooling. Similarly, (Ilse et al., 2018) uses an attention in pooling for multiple instance learning. Although not permutation invariant, (Mishra et al., 2018) has an attention as one of its core components to meta-learn to solve various tasks using sequences of inputs.
Modeling interactions between elements in sets An important reason to use the Transformer is to explicitly model higher-order interactions among the elements in a set. Santoro et al. (2017) proposes the relational network, a simple architecture that sum-pools all pairwise interactions of elements in a given set, but not higher-order interactions. Similarly to our work, Ma et al. (2018) uses the Transformer to model interactions between the objects in a video. They use mean-pooling to obtain aggregated features which they fed into an LSTM.
Inducing point methods The idea of letting trainable vectors directly interact with datapoints is loosely based on the inducing point methods used in sparse Gaussian processes (Quiñonero-Candela & Rasmussen, 2005) and the Nyström method for matrix decomposition (Fowlkes et al., 2004). trainable inducing points can also be seen as independent memory cells accessed with an attention mechanism. The Differential Neural Dictionary (Pritzel et al., 2017) stores previous experience as key-value pairs and uses this to process queries. One can view the ISAB is the inversion of this idea, where queries are stored and the input features are used as key-value pairs.
In this paper, we introduced the Set Transformer, an attention-based set-input neural network architecture. Our proposed method uses attention mechanisms for both encoding and aggregating features, and we have empirically validated that both of them are necessary for modelling complicated interactions among elements of a set. We also proposed an inducing point method for self-attention, which makes our approach scalable to large sets. We also showed useful theoretical properties of our model, including the fact that it is a universal approximator for permutation invariant functions. To the best of our knowledge, no previous work has successfully trained a neural network to perform amortized clustering in a single forward pass. An interesting topic for future work would be to apply Set Transformers to meta-learning problems other than meta-clustering. In particular, using Set Transformers to meta-learn posterior inference in Bayesian models seems like a promising line of research. Another exciting extension of our work would be to model the uncertainty in set functions by injecting noise variables into Set Transformers in a principled way.
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)
, 2017.The mean operator is a special case of dot-product attention with softmax.
Let and .
∎
The decoder of a Set Transformer, given enough nodes, can express any element-wise function of the form .
We first note that we can view the decoder as the composition of functions
We focus on in equation (A). Since feed-forward networks are universal function approximators at the limit of infinite nodes, let the feed-forward layers in front and back of the MAB encode the element-wise functions and , respectively. We let , so the number of heads is the same as the dimensionality of the inputs, and each head is one-dimensional. Let the projection matrices in multi-head attention () represent projections onto the jth dimension and the output matrix (
) the identity matrix. Since the mean operator is a special case of dot-product attention, by simple composition, we see that an MAB can express any dimension-wise function of the form
∎
A PMA, given enough nodes, can express sum pooling .
We prove this by construction.
Set the seed to a zero vector and let , where is any activation function such that
. The identiy, sigmoid, or relu functions are suitable choices for
. The output of the multihead attention is then simply a sum of the values, which is in this case. ∎We additionally have the following universality theorem for pooling architectures:
Models of the form are universal function approximators in the space of permutation invariant functions.
See Appendix A of Zaheer et al. (2017). ∎
By Lemma 3, we know that can express any function of the form . Using this fact along with Theorem 1, we can prove the universality of Set Transformers:
The Set Transformer is a universal function approximator in the space of permutation invariant functions.
By setting the matrix
to a zero matrix in every SAB and ISAB, we can ignore all pairwise interaction terms in the encoder. Therefore, the
can express any instance-wise feed-forward network (). Directly invoking Theorem 1 concludes this proof. ∎While this proof required us to ignore the pairwise interaction terms inside the SABs and ISABs to prove that Set Transformers are universal function approximators, our experiments indicated that self-attention in the encoder was crucial for good performance.
In all implementations, we omit the feed-forward layer in the beginning of the decoder () because the end of the previous block contains a feed-forward layer. All MABs (inside SAB, ISAB and PMA) use fully-connected layers with ReLU activations for rFF layers.
In the architecture descriptions, denotes the fully-connected layer with units and activation function . denotes the SAB with units and heads. denotes the ISAB with units, heads and inducing points. denotes the PMA with units, heads and vectors. All MABs used in SAB and PMA uses FC layers with ReLU activations for FF layers.
Given a set of real numbers , the goal of this task is to return the maximum value in the set . We construct training data as follows. We first sample a dataset size uniformly from the set of integers . We then sample real numbers independently from the interval . Given the network’s prediction , we use the actual maximum value to compute the mean absolute error . We don’t explicitly consider splits of train and test data, since we sample a new set at each time step.
Encoder | Decoder | ||
---|---|---|---|
FF | SAB | Pooling | PMA |
Architecture | Accuracy |
---|---|
rFF + Pooling | 0.4366 0.0071 |
rFF + PMA | 0.4617 0.0073 |
SAB + Pooling | 0.5659 0.0067 |
Set Transformers (SAB + PMA (1)) | 0.6037 0.0072 |
Set Transformers (SAB + PMA (2)) | 0.5806 0.0075 |
Set Transformers (SAB + PMA (4)) | 0.5945 0.0072 |
Set Transformers (SAB + PMA (8)) | 0.6001 0.0078 |
Encoder | Decoder | ||
---|---|---|---|
rFF | SAB | Pooling | PMA |
The task generation procedure is as follows. We first sample a set size uniformly from the set of integers . We then sample the number of characters uniformly from . We sample characters from the training set of characters, and randomly sample instances of each character so that the total number of instances sums to and each set of characters has at least one instance in the resulting set.
We show the detailed architectures used for the experiments in Table 9. For both architectures, the resulting -dimensional output is passed through a activation to produce the Poisson parameter . The role of is to ensure that is always positive.
The loss function we optimize, as previously mentioned, is the log likelihood . We chose this loss function over mean squared error or mean absolute error because it seemed like the more logical choice when trying to make a real number match a target integer. Early experiments showed that directly optimizing for mean absolute error had roughly the same result as optimizing in this way and measuring . We train using the Adam optimizer with a constant learning rate of for batches each with batch size .
We generated the datasets according to the following generative process.
Generate the number of data points, .
Generate centers.
Generate cluster labels.
Generate data from spherical Gaussian.
Table 10 summarizes the architectures used for the experiments. For all architectures, at each training step, we generate 10 random datasets according to the above generative process, and updated the parameters via Adam optimizer with initial learning rate . We trained all the algorithms for steps, and decayed the learning rate to after steps. Table 11 summarizes the detailed results with various number of inducing points in the ISAB. Figure 3 shows the actual clustering results based on the predicted parameters.
Encoder | Decoder | |||
---|---|---|---|---|
rFF | SAB | ISAB | Pooling | PMA |
Architecture | LL0/data | LL1/data |
---|---|---|
Oracle | -1.4726 | |
rFF + Pooling | -2.0006 0.0123 | -1.6186 0.0042 |
SAB + Pooling | -1.6772 0.0066 | -1.5070 0.0115 |
ISAB (16) + Pooling | -1.6955 0.0730 | -1.4742 0.0158 |
ISAB (32) + Pooling | -1.6353 0.0182 | -1.4681 0.0038 |
ISAB (64) + Pooling | -1.6349 0.0429 | -1.4664 0.0080 |
rFF + PMA | -1.6680 0.0040 | -1.5409 0.0037 |
Set Transformer | -1.5145 0.0046 | -1.4619 0.0048 |
Set Transformer (16) | -1.5009 0.0068 | -1.4530 0.0037 |
Set Transformer (32) | -1.4963 0.0064 | -1.4524 0.0044 |
Set Transformer (64) | -1.5042 0.0158 | -1.4535 0.0053 |
We pretrained VGG net (Simonyan & Zisserman, 2014) with CIFAR-100, and obtained the test accuracy 68.54%. Then, we extracted feature vectors of 50k training images of CIFAR-100 from the 512-dimensional hidden layers of the VGG net (the layer just before the last layer). Given these feature vectors, the generative process of datasets is as follows.
Generate the number of data points, .
Uniformly sample four classes among 100 classes.
Uniformly sample data points among four sampled classes.
Table 12 summarizes the architectures used for the experiments. For all architectures, at each training step, we generate 10 random datasets according to the above generative process, and updated the parameters via Adam optimizer with initial learning rate . We trained all the algorithms for steps, and decayed the learning rate to after steps. Table 13 summarizes the detailed results with various number of inducing points in the ISAB.
Encoder | Decoder | |||
---|---|---|---|---|
rFF | SAB | ISAB | rFF | PMA |
) | ||||
Architecture | ARI0 | ARI1 |
---|---|---|
Oracle | 0.9151 | |
rFF + Pooling | 0.5593 0.0149 | 0.5693 0.0171 |
SAB + Pooling | 0.5831 0.0341 | 0.5943 0.0337 |
ISAB (16) + Pooling | 0.5672 0.0124 | 0.5805 0.0122 |
ISAB (32) + Pooling | 0.5587 0.0104 | 0.5700 0.0134 |
ISAB (64) + Pooling | 0.5586 0.0205 | 0.5708 0.0183 |
rFF + PMA | 0.7612 0.0237 | 0.7670 0.0231 |
Set Transformer | 0.9015 0.0097 | 0.9024 0.0097 |
Set Transformer (16) | 0.9210 0.0055 | 0.9223 0.0056 |
Set Transformer (32) | 0.9103 0.0061 | 0.9119 0.0052 |
Set Transformer (64) | 0.9141 0.0040 | 0.9153 0.0041 |
Encoder | Decoder | ||
---|---|---|---|
rFF | SAB | Pooling | PMA |
Table 14 describes the architecture for meta set anomaly experiments. We trained all models via Adam optimizer with learning rate and exponential decay of learning rate for 1,000 iterations. 1,000 datasets subsampled from CelebA dataset (see Figure 4) are used to train and test all the methods. We split 800 training datasets and 200 test datasets for the subsampled datasets.
We used the ModelNet40 dataset for our point cloud classification experiments. This dataset consists of a 3-dimensional representation of 9,843 training and 2,468 test data which each belong to one of object classes. As input to our architectures, we produce point clouds with points each (each point is represented by coordinates). For generalization, we randomly rotate and scale each set during training.
We show results our architectures in Table 15 and additional experiments which used points in Table 16. We trained using the Adam optimizer with an initial learning rate of which we decayed by a factor of every steps.
Encoder | Decoder | ||
---|---|---|---|
rFF | ISAB | Pooling | PMA |
Architecture | Accuracy |
---|---|
rFF + Pooling | 0.7951 0.0166 |
rFF + PMA (1) | 0.8076 0.0160 |
ISAB (16) + Pooling | 0.8273 0.0159 |
Set Transformer (16) | 0.8454 0.0144 |
rFF + Pooling + tricks (Zaheer et al., 2017) | 0.82 0.02 |
Comments
There are no comments yet.