As Yann LeCun said, “if intelligence is a cake, the bulk of the cake is unsupervised learning”. This sentence reflects thatUn-/Self-supervised Learning
played a central role in the resurgence of deep learning. Common approaches focuses on designing different pretext tasksdevlin2018bert; wu2018unsupervised; he2019momentum; chen2020generative; chen2020simple; chen2020improved; chen2020big; grill2020bootstrap; caron2020unsupervised and aim to learn useful representations of the input data without relying on human annotations. It then uses those representations in downstream tasks, such as image classification, objection detection, and semantic segmentation.
In computer vision, previous methods focus on designing different pretext tasks. One of the most promising directions among them is contrastive learning/instance discriminationhjelm2018learning; oord2018representation, which regards each instance in the training dataset as a single category. Based on instance discrimination he2019momentum; chen2020simple; chen2020improved; chen2020big; grill2020bootstrap; caron2020unsupervised, some methods show the effectiveness in the image classification task. They successfully bridge the performance gap between self-supervised and full-supervised methods. However, almost all of self-supervised learning methods, which formulate the learning as image-level prediction using global features, are suboptimal in the pixel-level predictions he2019momentum; caron2020unsupervised; grill2020bootstrap, such as object detection and semantic segmentation. Also, InfoMin zhao2021what finds that high-level features do not truly matter in transferring to dense prediction tasks. Here, current self-supervised learning may overfit to image classification while not being well tamed for downstream tasks requiring dense prediction.
Meanwhile, large-scale pre-trained models have become the prevailing formula for a wide variety of Natural Language Processing (NLP) tasks due to its impressive empirical performance. These models typically abstract semantic information from massive unlabeled corpora in a self-supervised manner. The Masked Language Modeling (MLM) has been widely utilized as the objective for pre-training language models. In the MLM setup, a certain percentage of tokens within the input sentence are randomly masked, and the objective is to predict the original information of the masked tokens based only on its context. In NLP tasks, we found that the different mask strategies used in the MLM framework had a great impact on the performance of the model. However, in the field of vision, images have higher-dimensional, noisy, and redundant format compared to text. The main information of input images is randomly distributed in tokens. If tokens are randomly masked, it will lead to poor performance. Some of previous methods use random tokens, such as iGPTchen2020generative and ViT dosovitskiy2021an
. iGPT trains self-supervised Transformers using an amount of 6801M parameters and achieves 72.0% Top-1 accuracy on ImageNet by masking and reconstructing pixels, while ViT trains ViT-B model on the JFT-300M dataset, and the result is significantly lower than the supervised model.
The random MLM is prone to mask the tokens of crucial region for images, resulting in misunderstanding, and is not suitable for directly applying to self-supervised vision Transformers. In order to avoid masking the tokens of crucial region, we propose a masked token strategy based on the multi-head self-attention map, which dynamically mask some tokens of patches without damaging the crucial structure for self-supervised learning. Notably, the strategy would not increase the training time. Also, predicting original tokens alone may cause the model to over-emphasize local region, and therefore suppress the ability to recognize objects. Hence, in this paper, we present a novel Masked Self-supervised Transformer approach named MST, which can explicitly capture the local context of an image while preserving the global semantic information. In addition, a global image decoder is further exploited to recover the spatial information of the image and is thus more friendly to the downstream dense prediction tasks.
We validate our method on multiple visual tasks. In particular, on the ImageNet linear evaluation protocol, we reach 76.9% top-accuracy with DeiT-S and achieve the state-of-the-art performance. Overall, we make the following contributions:
We propose a new masked self-supervised transformer approach called MST. It makes full use of self-attention map to guide the masking of local patches, thus enhancing the understanding of local context semantics in pre-training without damaging the crucial structure.
Our method can effectively recover the spatial information of the image by a global image decoder, which is vital for the downstream dense prediction task and greatly improves the versatility and scalability of the pre-training model.
Extensive experiments demonstrate the effectiveness and transfer ability of our method. Specifically, the results on ImageNet deng2009imagenet, MS COCO lin2014microsoft and Cityscapes cordts2016the show that our method outperforms previous state-of-the-art methods.
2 Related Works
2.1 Self-supervised visual representation learning
Following MLM paradigm in NLP devlin2018bert; radford2018improving, iGPT chen2020generative trains self-supervised Transformers by masking and reconstructing pixels, while ViT dosovitskiy2021an masks and reconstructs patches. Recently, the most competitive pretext task for self-supervised visual representation learning is instance discrimination he2019momentum; chen2020simple; chen2020improved; chen2020big; grill2020bootstrap; caron2020unsupervised. The learning objective is simply to learn representations by distinguishing each image from others, and this approach is quite intractable for large-scale datasets. MoCo he2019momentum improves the training of instance discrimination methods by storing representations from a momentum encoder instead of the trained network. SimCLR chen2020simple shows that the memory bank can be entirely replaced with the elements from the same batch if the batch is large enough. In order to avoid comparing every pair of images and incur overfitting, BYOL grill2020bootstrap directly bootstraps the representations by attracting the different features from the same instance. SwAV caron2020unsupervised
maps the image features to a set of trainable prototype vectors and proposes multi-crop data augmentation for self-supervised learning to increase the number of views of an image. MoCov3chen2021an and DINO caron2021emerging apply the self-supervised learning methods of computer vision to Transformers and achieve superior performance in image classification task. These works achieve comparable results compared to supervised ImageNet deng2009imagenet pre-training. The success of these methods suggest that it is of central importance to learn invariant features by matching positive samples. However, almost all of these self-supervised learning methods formulate the learning process as image-level prediction using global features, so they lack the ability to pay attention to local features.
2.2 Self-supervised dense prediction learning
Based on the existing instance discrimination, some researchers propose self-supervised dense prediction methods. Self-EMD liu2020self adopts Earth Mover’s Distance (EMD) to compute the similarity between two embedding. Insloc yang2021instance pastes image instances at various locations and scales onto background images. The pretext task is to predict the instance category given the composited images as well as the foreground bounding boxes. PixPro xie2020propagate directly applies contrastive learning at the pixel level. DenseCL wang2020dense presents dense contrastive learning by optimizing a pairwise contrastive loss at the pixel level between two views of input images. These methods also show the effectiveness in detection and segmentation tasks but get poor performance on image classification tasks. In a word, these methods overfit a single task and cannot train a general pre-training model.
The pipeline of our proposed MST is shown in Figure 1. We propose a Masked Self-supervised Transformer (MST) approach, which creatively introduces attention-guided mask strategy and uses it to complete image restoration task. Our method is combined with some classical components of instance discrimination, such as the momentum design, asymmetric data augmentations, and multi-crop strategies. Here, we first review the basic instance discrimination method in 3.1. Then, the mechanism and effect of our attention-guided mask strategy are explained in 3.2. Finally, we overlook the reconstruction branch and the training target of our method in 3.3.
3.1 The basic instance discrimination method
As noted in prior workschen2020simple; he2019momentum; grill2020bootstrap; wu2018unsupervised; caron2020unsupervised, many existing augmentation policies adopt random resized cropping, horizontal flipping, color jittering and so on. We generate multiple views for each image under random data augmentation according to multi-crop caron2020unsupervised. This operation can acquire two standard resolution crops and representing the global view and sample low-resolution crops indicating partial view. They are encoded by two encoders, teacher network and student network , parameterized by and respectively, and outputting vectors and . Both encoder and consist of a Transformer backbone and a projection head chen2020big, which share the same architecture with different parameters. The parameters of fixed encoder is updated by the moving-average of according to Eq (1).
Given a fixed teacher network , the student network learns the parameters by minimizing cross entropy loos as Eq (2).
3.2 Masked token strategy
Random mask strategy. Inspired of the MLM strategy for natural language pre-training, we apply the random mask strategy to self-supervised learning. Given a dataset without manual annotations, and = denote a image of tokens, where = . Let = denote a binary vector of length , where , representing the mask over image. According to BERT devlin2018bert, the
can be obtained with probabilityby Eq (3), and the is 0.15 by default.
According to Eq (3), the tokens of crucial and nonessential regions have the same probability of being masked. As shown in Figure 2 (c), we observe that the random mask strategy may eliminate tokens of crucial regions that are responsible for recognizing objects, resulting in indistinguishable semantic features for input images. The random mask strategy is prone to mask crucial regions for images, and suppress the ability of network to recognize objects. It is not suitable to directly apply this strategy to self-supervised vision Transformers and the overall performance would deteriorate if the mask strategy is not properly modulated.
Attention-guided mask strategy. In this section, we propose our attention-guided mask strategy for dynamically controlling the fidelity of masked tokens and thereby decreasing the probability of masking crucial regions in self-supervised Transformer. Meanwhile, our strategy does not increase additional time consumption. Our algorithm is shown as Alg. 1.
Our framework consists of two networks, teacher network and student network , with the same transformer architecture. Let denote the input image. It is firstly projected to a sequence of 1-d tokens = , and then processed by several self-attention layers. Each self-attention layer owns three groups of embeddings for one token, denoted as (query), (key), (value). The attention map is calculated as the correlation between the query embedding of class token and key embeddings of all other patches . It is averaged for all heads as Eq (4). We output the attention map from the last layer in the teacher network to guide our strategy.
We sort the attention of different patches for each image in ascending order, and take the sorted attention value of of total tokens as the threshold . This means that the lowest of total tokens are selected as the masked candidates. The student model receives the importance of different patches and generates the mask with probability
, according to the Bernoulli distribution as Eq (5).
We use to denote the final masked tokens as Eq (6). Follow the BERT devlin2018bert, the masked regions are filled with a learnable mask embedding . Our strategy can ensure the patches with the highest scores are always presented (in Figure 2).
The attention-guided mask strategy can benefit pre-training models in two ways:
The models utilize contextual information to understand the relationship of different patches, thus preserving the global semantic information of the image while paying more attention to the local details of the image.
Our strategy can avoid masking crucial regions while replacing nonessential regions with the learnable mask embedding, making the models focus on the crucial regions.
3.3 Masked self-supervised transformer
In MLM, denote the complementary set of , that is, . The loss function of MLM pre-training strategy over one data is shown as Eq (7), where is the probability of the network correctly predicting given the masked token. That is, the network only restores the masked tokens.
There are a sub-sequence such that each index independently has probability of appearing in , and the overall loss function for training the network is shown as Eq (8). In pre-training, the MLM strategy minimizes the overall loss over pre-training dataset.
However, MLM only predicts the masked tokens according to Eq (8
). Different from original MLM, our method encourage the network reconstruct the original input images. We argue that a pixel-level restoration task can make the network avoid overfitting patch prediction, therefore enhancing the ability to capture the pixel-level information and recovering spatial structure from a finer grain. Since convolution neural networks (CNNs) have the ability of inductive biases, the restoration task adopts CNN as the decoder module, with convolution layers and up-sampling operations alternately stacked. To maximally mitigate the adversarial effect, the up-sampling operations are restricted to 2×. Hence, a total of 4 operations are needed for reaching the full resolution from
. And the running mean and running variance of BN are only updated from the global crops. The global image decoder consists of the Transformer and decoder. The restoration task is only performed on the student network. For a decoder with parameters , its loss function over a image and a mask as Eq (9).
The overall loss function for training the network is shown as Eq (11), and we only need the parameters of student network .
Therefore, the total loss is shown as Eq (11), and the MST minimizes the loss over ImageNet deng2009imagenet dataset in pre-training.
Several experiments with MST are conducted in this section. We first train self-supervised models with different transformer architectures on ImageNet benchmark, and then examine their transfer capacity with downstream tasks like object detection and semantic segmentation. After that, ablation studies are introduced to elaborate on how our method could achieve state-of-the-art performance.
4.1 Pre-training settings
Dataset and Models Our method is validated on the popular ImageNet 1k dataset deng2009imagenet. This dataset contains 1.28M images in the training set and 5K images in the validation set from 1000 classes. We only use the training set during the process of self-supervised learning. As to models, we choose the classical DeiT-S deit2020 and popular Swin-T liu2021swin as representatives of all transformer-based architectures. After the backbone, a 3-layer MLP with hidden dimension 2048 is added as the projection head. When evaluating our pretrained model, we both use the k-NN algorithm and train a linear classification for 100 epochs as former works. Top-1 accuracy is reported.
Training Configurations Our model is optimized by AdamW loshchilov2018decoupled with learning rate and batch size 1024. Weight decay is set to be 0.04. We adopt learning rate warmup goyal2017accurate in the first 10 epochs, and after warmup the learning rate follows a cosine decay schedule loshchilov2016sgdr. The model uses multi-crop similar to caron2020unsupervised and data augmentations similar to grill2020bootstrap. The setting of momentum, temperature coefficient, and weight decay follows caron2021emerging. The coefficient of basic instance discrimination task is set as 1.0 while the restoration task is set as 0.6.
4.2 Compared with other methods on ImageNet
We compare our method with other prevailing algorithms in Table 1. All these methods share the same backbone for fair comparison. Our 300-epoch model achieves 76.9% top-1 accuracy with linear probing. It outperforms previous best algorithm DINO by 1.7% at the same training epochs, and even approaches the performance of DINO with a much longer training schedule (77.0% with 800 epochs). It should be emphasized that our algorithm relieves the need of extreme long training time for self-supervised learning, and is able to obtain a decent result (75.0%) with only 100 epochs.
MST is general to be applied with any other transformer-based architectures. Here we use the popular Swin-T for an example. It has similar amount of parameters with DeiT-S. Using the same training epochs, MST outperforms MoBY by 1.8%, which is a self-supervised learning method designed delicately for Swin-T. Swin-T shares the same hyperparameters with DeiT-S, there it can still be improved by further tuning.
4.3 Object detection and instance segmentation
Since Swin-Transformer achieves state-of-the-art under supervised training, it is adopted as the backbone to validate the transfer ability of our method in the task of object detection and instance segmentation. We perform object detection experiments with MS COCO lin2014microsoft dataset and Mask R-CNN detector he2017mask framework. MS COCO is a popular benchmark for object detection, with 118K images in training set and 5K images for validation. This dataset contains annotations for 81 classes. Box AP and mask AP are reported on the validation set. As to training settings, we follow the default 1x schedule with 12 epochs. The shorter edges of the input images are resized to be 800 and the longer edges are limited by 1333 pixels. AdamW optimizer is used, and all hyper-parameters follow the original paper.
In Table 2, we show the performance of the learned representation by different self-supervised methods and supervised training. For fair comparison, all these methods are pre-trained with 100 epochs. We observe that our method achieves the best results with bbox mAP and mask mAP. It outperforms the ImageNet supervised model by 1.2% and 0.5%, and MoBY results by 1.2% and 0.5% with the same epoch. The results indicate that MST not only performs well on image classification task, but also performs well on downstream dense prediction task. Therefore it has a strong transfer ability.
|box AP||mask AP|
4.4 Semantic segmentation
SETR zheng2020rethinking provide a semantic segmentation framework for standard Vision Transformer. Hence, we adopt the SETR as the semantic segmentation strategy on Cityscapes cordts2016the. Cityscapes contains 5000 images, with 19 object categories annotated in pixel level. There are 2975, 500, and 1525 images in training, validation, and testing set respectively. We follow the training config as original SETR. For fair comparison, we both use the 300-epoch pretrained model for DINO and our method.
As shown in Table 3, it illustrates the comparison of supervised method, DINO, and our method on this evaluation. Our method achieves the highest mIoU 74.7% and mAcc 82.35%. It outperforms both supervised results (+2.71% mIoU and +2.05% mAcc) and DINO pretrained results (+1.08% mIoU and +1.03% mAcc). Our model is also suitable to transfer for the semantic segmentation task.
4.5 Ablation studies
In this section, we conduct some ablation studies to elaborate on the effectiveness of our method. All ablation experiments are conducted under 100-epoch setting. By default, only the cls
token from the last layer is used to train the linear classifier.
4.5.1 Impact of different mask strategy
Table 5 shows the impact of different mask strategies. We train DeiT-S with random mask strategydevlin2018bert, attention-guided mask strategy and no mask. For fair comparison, all methods mask with the same probability . It can be observed that the performance of random mask strategy degrades. This strategy would probably suppress the ability to recognize the object in the images (from 73.1 to 63.2). Random mask strategy may destroy the tokens of crucial regions of original image which may be indispensable for recognizing object. The masked input may have incomplete or even misleading information. On the contrary, the performance of our attention-guided mask strategy has a steady improvement (from 73.1 to 73.7). Essential regions are mostly preserved, which could be a strong proof of our hypothesis.
4.5.2 Impact of different mask hyper-parameters
Table 5 validates the performance of different mask hyper-parameters under attention-guided mask strategy. We sort the attention map of different patches for each image in ascending order, and split the first patches as the masked candidates. Removing these candidates can force the network to learn local features from adjacent patches, therefore strengthening the capacity of modeling local context without destroying the semantics. These candidates are masked according to the probability . Top-1 accuracy of linear evaluation on ImageNet is shown in Table 5. When is set to 8, any choice of can get a robust result, which suggests that the last patches are relatively safe to be mask candidates.
|Mask Strategy||Top-1 acc (%)|
4.6 Impact of w/o BN
Former work caron2021emerging found that the performance will be better if dropping BN in the projection head. We argue that the degradation is not caused by BN. As shown in Table 6, normal BN downgrades the performance of baseline model, while the update rule introduced in Section 3.3
helps improve top-1 accuracy slightly. This may be due to the need to keep consistent structure with the global image Decoder since the image Decoder consists of Conv-BN-ReLu.
|w/o BN||w/ BN|
Impact of Batch Normalization.
In this paper, we investigate the two problems of current visual self-supervised learning, namely lack of local information extraction and loss of spatial information. To overcome the above problems, we propose a new self-supervised learning method based on transformer called MST. The proposed MST exploits an attention-guided mask strategy to capture the local relationships between patches while also preserving the global semantic information. It is noted that the attention-guided mask strategy is based on the multi-head self-attention map extracted from the teacher model and does not cause extra computation cost. In addition, a global image decoder is further used to assist the attention-guided mask strategy to recover the spatial information of the image, which is vital for dense prediction tasks. The proposed method shows good versatility and scalability in multiple downstream visual tasks.
Appendix A The setting of computation resources
In ablation studies, the MST with 1024 images is trained in 128 AMD DCUs that are publicly available in Sugon Cloud. For verifying the generality of the results, the pre-trained model is used to validate downstream experiments for 32 Nvidia Tesla V100 GPUs. Meanwhile, the same random seed is set for fair comparison. Also, we report the average result after running multiple experiments.
Appendix B Data augmentation
The image augmentation pipeline consists of the following transformations: random resized cropping, horizontal flipping, color jittering, grayscale conversion, Gaussian blurring, solarization, and multi-crop. The random resized cropping and multi-crop transformations are always applied, while the rest of transformations are applied randomly, with some probability. This probability is different for the two distorted views in the blurring and solarization transformations. We use the same augmentation parameters as BYOL besides multi-crop. The multi-crop follows SwAV caron2020unsupervised and DINO caron2021emerging. Each input image with is transformed twice to produce the two distorted views.
Appendix C BatchNorm
Following chen2020big; grill2020bootstrap; chen2021an, we adopt SyncBN as our default BatchNorm. The running mean and running variance of BN of MST only are updated from different images in the same batch while SimCLR chen2020simple is updated from total images in teacher and student batches. The two kinds of BN influence the gradient variance. Hence, the two implementations should lead to different results. Meanwhile, the running mean and running variance of BN are only updated from the global crops when our method adopts masked self-supervised Transformer.
Appendix D k-NN classification
According to Wu et al. wu2018unsupervised, we evaluate the quality of features with a simple weighted Nearest Neighbor classifier. We freeze the parameters of pre-trained model and extract the features of class embedding for the train and validation dataset. As shown in Table 7, we evaluate different values for and find that the setting of 10 is consistently leading to the best accuracy across our runs. More importantly, we evaluate Top-1 accuracy in the validation dataset.
|Method||Architecture||epoch||k||k-NN Top-1 (%)|
Appendix E Linear probing
Following the popular setting of self-supervised learning, we evaluate the representation quality by linear probing. After self-supervised pre-training, we remove the MLP heads and train a supervised linear classifier on frozen features. We use SGD optimizer, with a batch size of 1024, weight decay of 0 and learning rate of 0.00024 during 100 epochs on ImageNet training dataset, using only random resized cropping and flipping augmentation. Meanwhile, we evaluate single-crop Top-1 accuracy in the validation dataset. For the linear probing of DeiT-S, we adopt the class tokens of last layer as the input, following the common practice. However, DINO caron2021emerging concatenates the late few blocks as the input to the linear classifier. For fair comparison, we adopt the linear probing of DINO as the final result while reporting common linear probing on ablation studies. The results can be observed by Table 8.
|Method||Architecture||epoch||Linear Top-1 (%)||k-NN Top-1 (%)|
Appendix F Impact of longer training
From Table 9, we observe that longer training improves the performance of our method with DeiT-S regardless of the kind of linear probing. This phenomenon is consistent with previous self-supervised learning methods.
Appendix G Implementation pseudo code
The complete algorithm of our method is shown as Alg. 2. Our model is optimized by AdamW loshchilov2018decoupled with learning rate and batch size 1024. The initial weight decay is set to be 0.04. After warmup goyal2017accurate in the first 10 epochs, the learning rate follows a cosine decay schedule loshchilov2016sgdr. The model uses multi-crop similar to caron2020unsupervised and data augmentations similar to grill2020bootstrap. The setting of momentum, temperature coefficient, and weight decay follows caron2021emerging. The coefficient of basic instance discrimination task is set as 1.0 while the restoration task is set as 0.6.
Appendix H Visualization of the attention maps
As shown in Figure 3, we provide the attention maps of supervised and our method. These images consist of original images, attention maps of supervised method, and attention maps of our method. We observe that the visualization of attention maps of our method is clearer than the supervised.