Automatic Shortcut Removal for Self-Supervised Representation Learning

02/20/2020 ∙ by Matthias Minderer, et al. ∙ 0

In self-supervised visual representation learning, a feature extractor is trained on a "pretext task" for which labels can be generated cheaply. A central challenge in this approach is that the feature extractor quickly learns to exploit low-level visual features such as color aberrations or watermarks and then fails to learn useful semantic representations. Much work has gone into identifying such "shortcut" features and hand-designing schemes to reduce their effect. Here, we propose a general framework for removing shortcut features automatically. Our key assumption is that those features which are the first to be exploited for solving the pretext task may also be the most vulnerable to an adversary trained to make the task harder. We show that this assumption holds across common pretext tasks and datasets by training a "lens" network to make small image changes that maximally reduce performance in the pretext task. Representations learned with the modified images outperform those learned without in all tested cases. Additionally, the modifications made by the lens reveal how the choice of pretext task and dataset affects the features learned by self-supervision.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 6

page 8

page 13

page 14

page 15

page 16

page 17

page 18

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Figure 1: Example of automatic shortcut removal for the Rotation prediction pretext task. The lens learns to remove features that make it easy to solve the pretext task (concretely, it conceals watermarks in this example). Shortcut removal forces the network to learn higher-level features to solve the pretext task and improves representation quality.

In self-supervised visual representation learning, a neural network is trained using labels that can be generated cheaply from the data rather than requiring human labeling. These artificial labels are used to create a “pretext task” that ideally requires learning abstract, semantic features useful for a large variety of vision tasks. A network pre-trained on the pretext task can then be transferred to other vision tasks for which labels are more expensive to obtain, e.g. by learning a head or fine-tuning the network for the target task.

Self-supervision has led to significant advances in unsupervised visual representation learning recently, with the first self-supervised methods outperforming supervised ImageNet pre-training on selected vision benchmarks

(Hénaff et al., 2019; He et al., 2019; Misra and van der Maaten, 2019).

Yet, defining sensible pretext tasks remains a major challenge because neural networks are biased towards exploiting the simplest features that allow solving the pretext task. This bias works against the goal of learning semantically meaningful representations that transfer well to a wide range of target tasks. Simple solutions to pretext tasks can be unintuitive and surprising. For example, in the self-supervised task of predicting the orientation of rotated training images (Gidaris et al., 2018), logos and watermarks allow the network to predict the orientation of the input image by learning simple text features, rather than transferable object representations (Figure 1). Similarly, nearly imperceptible color fringes introduced by chromatic aberrations of camera lenses provide a signal for context-based self-supervised methods that is strong enough to significantly reduce the quality of representations learned from these tasks, unless they are specifically addressed by augmentation schemes (Doersch et al., 2015; Noroozi and Favaro, 2016).

Many such data augmentation procedures have been proposed, but have relied on the intuition and creativity of researchers to identify shortcut features. We aim to break this pattern and propose a simple method to remove shortcuts automatically. The key insight underlying our method is that visual features based on which a network can easily solve a pretext task may also be features which an adversary can easily exploit to make the task harder. We therefore propose to process images with a lightweight image-to-image translation network, called “lens” (borrowing the terminology from

Sajjadi et al., 2018), which is trained adversarially to reduce performance on the pretext task without deviating much from the original image.111We address the unwanted removal of potentially useful features in Section 3.1. Once trained, the lens can be applied to unseen images, so it can be used downstream when transferring the network to a new task. We show that the lens leads to significant improvements across a variety of pretext tasks and datasets. Furthermore, the lens can be used to visualize the shortcuts by inspecting the difference image between the input and the output images. The changes made to the image by the lens provide insights into how shortcuts differ across tasks and datasets, which we use to make suggestions for future task design.

In summary, we make the following contributions:

  • [leftmargin=12pt, topsep=0pt, partopsep=0pt, itemsep=0pt, parsep=0pt]

  • We propose a simple and general method for automated removal of shortcuts which can be used with virtually any pretext task.

  • We validate the proposed method on a wide range of pretext tasks and on two different training datasets (ImageNet and YouTube-8M frames), showing consistent improvements across all methods, upstream training datasets, and two downstream/evaluation datasets (ImageNet and Places205). In particular, our method can replace preprocessing procedures that were hand-engineered to remove shortcuts.

  • We use the lens to compare shortcuts across different pretext tasks and data sets. This analysis provides intriguing insights into data set and pretext-specific shortcuts.

2 Related work

Self-supervised learning.

Self-supervised learning (SSL) has attracted more and more interest in the computer vision and machine learning community over the past few years. Early approaches (pretext tasks) involve exemplar classification

(Dosovitskiy et al., 2014), predicting the relative location of image patches (Doersch et al., 2015), solving jigsaw puzzles of image patches (Noroozi and Favaro, 2016)

, image colorization

(Zhang et al., 2016), object counting (Noroozi et al., 2017), predicting the orientation of images (Gidaris et al., 2018), and clustering (Caron et al., 2018). These methods are typically trained on ImageNet and their performance is evaluated by training a linear classification head on the frozen representation.

New pretext tasks are usually proposed along with procedures to mitigate the effect of shortcuts. Common shortcuts include color aberrations (Doersch et al., 2015), re-sampling artifacts (Noroozi et al., 2017), compression artifacts (Wei et al., 2018), and salient lines or grid patterns (Doersch et al., 2015; Noroozi and Favaro, 2016). To counter the effects of these shortcuts, various pre-processing strategies have been developed, for example channel dropping (Doersch et al., 2015; Doersch and Zisserman, 2017), channel replication (Lee et al., 2017), conversion to gray scale (Noroozi et al., 2017), chroma blurring (Mundhenk et al., 2018), and spatial jittering (Doersch et al., 2015; Mundhenk et al., 2018).

One of the premises of SSL is that it can be applied to huge data sets for which human labeling would be too expensive. Goyal et al. (2019) explore this aspect and find that large-scale SSL can outperform supervised pretraining. Another research direction is to combine several pretext tasks into one, which often improves performance (Doersch and Zisserman, 2017; Feng et al., 2019; Misra and van der Maaten, 2019). Using multiple pretext tasks could reduce the effect of shortcuts that are not shared across tasks. These efforts may benefit from our comparison of shortcuts across tasks.

More recently, contrastive methods gained popularity (Oord et al., 2018; Hjelm et al., 2019; Bachman et al., 2019; Tian et al., 2019; Hénaff et al., 2019; He et al., 2019). These methods are based on the principle of learning a representation which associates different views (e.g. different augmentations) of the same training example with similar embeddings, and views of different training examples with dissimilar embeddings (Tschannen et al., 2020). There are numerous parallels between pretext task-based and contrastive methods (He et al., 2019), and our method in principle applies to both types of approaches.

Adversarial training.

Our method is related to adversarial training and we give a brief overview of works relevant for this paper (see Akhtar and Mian (2018)

for a survey). Adversarial examples are small, imperceptible perturbations to the input of a classifier that lead to a highly confident misclassification of the input

(Szegedy et al., 2014). Deep neural networks are particularly susceptible to adversarial perturbations. A plethora of adversarial training techniques were proposed to make them more robust, e.g. the fast gradient sign method (FGSM; Goodfellow et al., 2015) or the projected gradient descent defense (Madry et al., 2018).

Adversarial training can significantly improve semi-supervised learning when combined with label propagation

(Miyato et al., 2018). However, only very recently Xie et al. (2019) succeeded in developing an adversarial training procedure that substantially improves classification accuracy on unperturbed images in the context of (fully) supervised learning. Somewhat related, Santurkar et al. (2019) present evidence that adversarially trained classifiers learn more abstract semantic features. We emphasize that all these works use hand-annotated ground-truth labels carrying much more information than pretext labels, and we believe that adversarial training has an even higher potential for SSL.

3 Method

Figure 2: Model schematic. For the experiments in this paper, we use the U-Net architecture for the lens and the ResNet50 v2 architecture for the feature extractior . We use an loss as the reconstruction loss for simplicity, but other choices are possible.

We propose to improve self-supervised visual representations by processing images with a lightweight image-to-image translation network, or “lens”, that is trained adversarially to reduce performance of the feature extractor network on the pretext tasks. In this section, we outline our approach in more detail. We start by defining the notion of “shortcut” visual features for the purpose of this study.

3.1 Which visual features are shortcuts?

Intuitively, shortcut features may be defined in terms of both the pretext task and the downstream application of the learned representations: Shortcuts (i) allow solving the pretext task quickly and accurately by exploiting low-level visual features, but (ii) are useless for downstream applications and prevent learning of more useful representations.

Item (i) of the definition is easy to address: We can leverage an adversarial loss to learn to remove visual features that allow solving the pretext task quickly and accurately.

However, in a typical transfer learning scenario the downstream tasks are assumed unknown. It is therefore impossible to know

a priori whether a feature is a “true” shortcut (such as the watermark in Figure 1, if the downstream task is ImageNet classification), or is useful downstream despite being an easy solution to the pretext task (such as the eyes of the dogs in Figure 4). In fact, for any potential shortcut, a downstream task could be conceived that depends on this feature (e.g. watermark detection). To address this issue (item (ii) above), we always combine representations obtained with and without shortcut removal, as described below. This ensures that the automatic shortcut removal never reduces the information present in the representations. We can then define “shortcut” more leniently as any feature that allows the pretext task to be solved easily.

3.2 Automatic adversarial shortcut removal

We start by formalizing the common setup for pretext task-based SSL, and then describe how we modify this setup to prevent shortcuts.

In pretext task-based SSL, a neural network (sometimes also called “encoder” or “feature extractor”) is trained to predict machine-generated pretext targets from inputs

. The pretext task is learned by minimizing a loss function

which usually takes the form , where is the number of training examples and is often a cross-entropy loss.

To remove shortcuts, we introduce a lens network that (slightly) modifies its inputs and maps them back to the input space, before feeding them to the representation network . When using the lens, the pretext task loss becomes and is trained to minimize this loss, as before. As motivated in Section 3.1, we train the lens adversarially against to increase the difficulty of the pretext-task. We consider two loss variants that were previously considered in the adversarial training literature (Kurakin et al., 2016): full and least likely.

The full adversarial loss is simply the negative task loss: . This type of adversarial training is applicable to any pretext task loss.

For classification pretext tasks, we can alternatively train the lens to bias the predicted class probabilities towards the

least likely class. The loss hence becomes:

The lens is also trained with a reconstruction loss to avoid trivial solutions: , where is a pixel-wise reconstruction loss and

is a hyperparameter that trades off the strength of the adversarial attack against reconstruction quality.

3.3 Hyperparameters and design choices

Before presenting a comprehensive experimental evaluation of the proposed method, we first discuss major hyperparameters and design choices, and compare our method to standard adversarial training.

Reconstruction loss.

We use an loss for due to its simplicity and stability. Other choices are possible and interesting, in particular losses going beyond pixel-wise similarity, measuring semantic similarity such as feature matching (Salimans et al., 2016) or perceptual losses (Zhang et al., 2018). However, note that these losses themselves often require supervised pretraining, hence defeating the purpose of (unsupervised) SSL. We discuss the effect of in the context of different pretext tasks in Section 4.3.

Selection of .

The reconstruction loss scale is the most important hyperparameter for lens performance and the only one that we vary between tasks. The optimal value for depends primarily on the scale of the pretext task loss, but also on the dataset and data augmentation applied prior to feeding the examples to the lens.

Network architectures.

For the feature extraction network

, we use the default ResNet50 v2 architecture (He et al., 2016) (i.e. with a channel widening factor of 4 and a representation size of 2048 at the

pre-logits

layer) unless otherwise noted. For the lens , we use a variant of the U-Net architecture (Ronneberger et al., 2015). The encoder and decoder are each a stack of residual blocks (ResNet50 v2 residual blocks) with down-sampling and up-sampling, respectively, after each block (a complete description can be found in the supplementary material). This lens architecture has 3.87M parameters, less than one sixth of the ResNet50 v2 network used to compute the representation (23.51M).

We emphasizes that, besides the choice of the reconstruction loss, the structure of the lens architecture is also important for the type of visual features removed. Its capacity impacts the type and abstractness of visual features the lens can manipulate. We deliberately chose an architecture with skip connections to simplify learning an identity mapping that minimizes the reconstruction loss.

Comparison with standard adversarial training.

One might wonder why we do not just use standard adversarial training methods such as the FGSM (Goodfellow et al., 2015) instead of the lens.222Iterative techniques such as projected gradient descent (Madry et al., 2018) are too expensive for our purposes. Besides outperforming the FGSM empirically (see Section 4), the lens has several other advantages. First, standard adversarial training requires a loss to generate a perturbation whereas our lens can be used independently of the pretext task used during training. Hence, the lens can be deployed downstream even when the pretext task loss and/or the feature extraction network are unavailable. Second, training the lens is of similar complexity as one-step adversarial training techniques, but deploying it downstream is much cheaper as its application only requires a single forward pass through the shallow lens network (and not a forward and backward pass through the representation network as in adversarial training). The same advantages apply for visualization of the lens action. Finally, we believe that lens learns to exploit similarities and structure that is shared between shortcuts as it accumulates signal from all training examples.

4 Experiments

4.1 Proof of concept: Removing synthetic shortcuts

Figure 3: Top: Example images from CIFAR-10 with two different synthetic shortcuts for the Rotation task (best viewed on screen). The Arrow shortcut adds directional information to a small region of the image (note the arrow in the top left of the image). The Chromatic aberration shortcut shifts the color channels, adding directional information globally. The lens learns to remove the shortcuts. Bottom:

Downstream testing accuracy (in %) of a logistic regression model trained for

CIFAR-10 classification on the frozen representations. Accuracy on clean data is .

Our approach rests on the hypothesis that features which are easy to exploit for solving self-supervised tasks are also easy to learn for the adversarial lens to remove. We initially test this hypothesis in a controlled experimental setup, by adding synthetic shortcuts to the data.

We use the CIFAR-10 dataset and the Rotation pretext task (Gidaris et al., 2018). In this task, each input image is fed to the network in four copies, rotated by 0°, 90°, 180° and 270°, respectively. The network is trained to solve the four-way classification task of predicting the correct orientation of each image. The task is motivated by the hypothesis that the network needs to learn high-level object representations to predict image rotations.

Dataset Method Pretext task
Rotation Exemplar Rel. patch loc. Jigsaw
ImageNet Baseline
FGSM
Lens
Places205 Baseline
FGSM
Lens
Table 1: Evaluation of representations from models trained on ImageNet with different self-supervised pretext tasks. The scores are accuracies (in %) of a logistic regression model trained on representations obtained from the frozen models. Mean s.e.m over three random initializations. Values in bold are better than the next-best method at a significance level of 0.05. Training images are preprocessed as suggested by the respective original works.

To test the lens, we add shortcuts to the input images that are designed to contain directional information and allow solving the Rotation task without the need to learn object-level features (Figure 3, top). Representations learned by the baseline network (without lens) from data with synthetic shortcuts perform poorly downstream (Figure 3, bottom). In contrast, feature extractors learned with the lens perform dramatically better. The lens output images reveal that the lens learns to remove the synthetic shortcuts. These results confirm our hypothesis that an adversarially trained lens can learn to remove shortcut features from the data.

4.2 Large-scale evaluation of the lens performance

To test the value of the lens as a general framework for improving self-supervised representation learning, we next evaluate the method on two large-scale datasets and four common pretext tasks for which open-source reference implementations are available

(Kolesnikov et al., 2019)333https://github.com/google/revisiting-self-supervised.

4.2.1 Pretext tasks

In addition to the Rotation task described above, we use the following pretext tasks:

Exemplar (Dosovitskiy et al., 2014): In the Exemplar task, eight copies are created of each image and augmented separately using random translation, scaling, brightness and saturation, including conversion to grayscale with probability 0.66. The network is trained using a triplet loss (Schroff et al., 2015) to minimize the embedding distance between copies of the same image, while maximizing their distance to the other images in the batch.

Relative patch location (Noroozi and Favaro, 2016): Here, the task is to predict the relative location of an image patch in the 8-connected neighborhood of a reference patch from the same image (e.g. “below”, “upper left” etc.).

Jigsaw (Doersch et al., 2015): For the Jigsaw

task, the image is divided into a three-by-three grid of patches. The patch order is randomly permuted and the patches are fed through the network. The representations produced by the network are then passed through a two-layer perceptron, which predicts the patch permutation.

For the patch-based tasks, we obtain representations for evaluation by averaging the representations of nine non-augmented patches created by dividing the input image into a regular three-by-three grid (Kolesnikov et al., 2019).

4.2.2 Pretraining datasets and preprocessing

Self-supervised training is performed on ImageNet, which contains 1.3 million images, each belonging to one of 1000 object categories. Unless stated otherwise, we use the same preprocessing operations and batch size as Kolesnikov et al. (2019) for the respective tasks. To mitigate distribution shift between raw and lens-processed images, we feed both the batch of lens-processed and the raw images to the feature extraction network (Kurakin et al. (2016) similarly feed processed and raw images for adversarial training). This is done for all tasks except for Exemplar, for which memory constraints did not allow inclusion of unprocessed images. Training was performed on 128 TPU v3 cores for Rotation and Exemplar and 32 TPU v3 cores for Relative patch location and Jigsaw.

4.2.3 Pretraining hyperparameters

Feature extractor and lens are trained synchronously using the Adam optimizer with , and

for 35 epochs. The learning rate is linearly ramped up from zero to

in the first epoch, stays at until the end of the 32nd epoch, and is then linearly decayed to zero.

For each pretext task, we initially roughly determined the appropriate magnitude of based on the magnitude of the pretext task loss. We then sweep over five values of , centered on the previously determined value, and report the accuracy for the best .

4.2.4 Evaluation protocol

To evaluate the quality of the representations learned by the feature extractor, we follow the common linear evaluation protocol (Dosovitskiy et al., 2014; Doersch et al., 2015; Noroozi and Favaro, 2016; Gidaris et al., 2018): We obtain image representations from the frozen feature extractor and use them to train a logistic regression model to solve multi-class image classification tasks. For networks trained with the lens, we concatenate features from the pre-logits layer obtained with and without applying the lens to the input image (see Section 3.1). To ensure a fair comparison, for baseline models (trained without lens), we concatenate two copies of the pre-logits

features to match the representation size of the lens networks. Note that the representation size determines the number of free parameters of the logistic regression model and we observed benefits in logistic regression accuracy (possibly due to a change in optimization dynamics) even though no new features are added by copying them. The logistic regression model is optimized by stochastic gradient descent (see supplementary material).

We report top-1 classification accuracy on the ImageNet validation set. In addition, to measure how well the learned representations transfer to unseen data, we also report downstream top-1 accuracy on the Places205 dataset. This dataset contains 2.5M images from 205 scene classes such as airfield, kitchen, coast etc.

4.2.5 Results

Table 1 shows evaluation results across tasks and datasets. For all tested tasks, adding the lens leads to a significant improvement over the baseline. The lens also outperforms adversarial training using the fast gradient sign method (FGSM; Goodfellow et al., 2015; details in the supplementary material). In particular, the lens outperforms the FGSM by a large margin when ImageNet-trained representations are transferred to the Places205-dataset (Table 1, bottom). The improved transfer performance suggests that the features modified by the lens are more general than those attacked by the FGSM.

Our results show that the benefit of automatic shortcut removal using an adversarially trained lens generalizes across pretext tasks and across datasets. Furthermore, we find that gains can be observed across a wide range of feature extractor capacities (model widths; see supplementary material).

Figure 4: Lens outputs for Rotation models trained with different reconstruction loss scales (best viewed on screen). Decreasing allows the lens to make increasingly large changes to the image and reveals the relative importance of features for the pretext task. For example, eyes and nose are affected at the highest (2560), while the legs are affected only at lower values (320).

To understand what features the lens removes, we visualize the input image, processed image, and their difference. The lens network produces visually interpretable modifications, in contrast to single-step adversarial attacks such as FGSM (Goodfellow et al., 2015). Figure 4 illustrates how an image is modified by lens networks trained with different values for the reconstruction loss scale on the Rotation task. The progression of image changes with decreasing reveal what features are used by the feature extraction network to solve the Rotation task. The highest-quality representations in terms of logistic regression accuracy on ImageNet are obtained at intermediate , suggesting that the feature extractor learns richer representations when it is discouraged by the lens from using shortcut features.

4.3 Comparing lens features across pretext tasks

Figure 5: Top: Three example images from ImageNet, processed by lenses trained on different pretext tasks (best viewed on screen). The dashed square on the input image shows the region used for the patch-based tasks (Relative patch location and Jigsaw). Bottom: Mean rec. loss across 1280 images randomly chosen from the test set. For display, values were clipped at the 95th percentile.
Qualitative assessment of shortcut features.

The trained lens represents a new tool for visualizing and comparing the features learned by different pretext tasks. Inspection of lens outputs confirms existing heuristics and provides new insights (Figure 

5), as discussed in the following.

The features removed by the lens are the most semantically meaningful for the Rotation task, potentially explaining why its representations outperform the other tasks. The lens removes features such as head and feet of animals, and is generally biased towards the image center (see mean reconstruction loss images, Figure 5, bottom). Text and watermarks provide a strong orientation signal and are also concealed, which is reflected by high mean reconstruction loss values in the corners of the image, where logos are often found. These results confirm expectations and support the validity of lens outputs as an interpretability tool.

For the Exemplar task, the lens introduces full-field color changes, suggesting that this task is easily solved by matching images based on their average/dominant color. In contrast to Rotation, the mean reconstruction loss is biased away from the image center.

Figure 6: Downstream accuracy on ImageNet for tasks that use conversion to grayscale for augmentation in their reference implementations. The lens always outperforms this manual augmentation. For Relative patch location, we also ablated all other manual augmentations. The full augmentation pipeline for Relative patch location involves: (1) conversion to grayscale with probability 0.66, (2) independent jittering of patch locations by up to 21 pixels, and (3) independent color standardization of patches. The number above the bars indicates the optimal reconstruction loss scale based on a sweep over .

For patch-based tasks such as Relative patch location and Jigsaw, much effort has gone into identifying trivial solutions and designing augmentation schemes to reduce their effect. These shortcuts can be hard to identify. For example, in the paper introducing the Relative patch location task, Doersch et al. (2015)

express their surprise at finding that convolutional neural networks learn to exploit chromatic aberrations to solve the task. To mitigate the shortcut, they drop color channels at random during training. Similar augmentations are proposed for the

Jigsaw task (Noroozi and Favaro, 2016). More recently, refined color augmentation heuristics such as chroma blurring have been developed for patch-based pretext tasks (Mundhenk et al., 2018).

Our approach learns similar augmentations automatically, and provides additional insights. Specifically, the lens output (Figure 5) suggests that that chroma blurring improves Relative patch location and Jigsaw for different reasons: For Relative patch location, color fringes at high-contrast edges (as caused by chromatic aberrations) are the most prominent visual feature modified by the lens. In contrast, the lens effect for Jigsaw involves diffuse color changes towards the image borders, suggesting that color matching across patch borders is a major shortcut. The difference images also suggest that Jigsaw, but not Relative patch location, additionally benefits from luminance blurring, because luminance edges are prominent in the Jigsaw difference images.

Figure 7: Downstream ImageNet classification accuracy for models trained on ImageNet or YouTube1M. The lens recovers much of the accuracy lost when training on the less curated YouTube1M dataset. Error bars: means.e.m. over three random initializations.
Ablation of preprocessing operations.

Quantitatively, we find that random color dropping becomes less relevant when using the lens (Figure 6): In all cases, the lens can at least compensate for the missing color dropping operation, and for Exemplar even performs better without color dropping than with. For Relative patch location, we additionally perform an ablation analysis of the whole augmentation pipeline (Figure 6). We find that the lens can replace random color dropping and, to some degree, random patch jitter. However, color standardization remains important, likely because full-field color shifts are expensive under the reconstruction loss.

Figure 9: Comparison of lens outputs for models trained on ImageNet and YouTube1M (best viewed on screen). Images are ordered based on the difference in reconstruction loss between the ImageNet-trained and YouTube1M-trained models. The left block shows the top six images (i.e. higher reconstruction loss when trained on ImageNet); the right block shows the bottom six images (i.e. higher reconstruction loss when trained on YouTube1M). The modifications made by the lens can thus be used to identify dataset bias. The three images on the far right were hand-selected to contain text and illustrate that YouTube1M-trained models are more sensitive to this shortcut.
Figure 8: Pearson correlation of the per-image reconstruction loss between pretext tasks for models trained on ImageNet.
Similarity of shortcuts across tasks.

The lens output allows for a quantitative evaluation of shortcut similarity. Specifically, the correlations between the per-image reconstruction loss across tasks suggest that patch-based tasks are vulnerable to similar shortcuts, whereas Rotation and Exemplar show anti-correlated reconstruction losses (Figure 8). This suggests that the Rotation and Exemplar, but less so Jigsaw and Relative patch location, may combine synergistically in multi-task learning. Empirical gains through training multiple pretext task jointly have been observed by Doersch and Zisserman (2017).

4.4 Comparing lens features across datasets

Beyond choice of pretext task, the pretraining dataset is another major factor influencing the representations learned by self-supervision. Since self-supervision does not require human-annotated labels, it opens new possibilities for mining large datasets from abundant unlabeled sources such as videos from the web. However, such data sources may contain more shortcuts (and less informative images) than highly curated pure image data sets, e.g. logos of TV stations, text from credits, black frames, etc. As a result, certain pretext tasks suffer a significant performance drop when applied to frames mined from videos (Tschannen et al., 2019). We therefore tested the effectiveness of the lens in self-supervised learning on video frames, using the Rotation task. For training, we used 1 million randomly sampled frames from the YouTube-8M dataset (Abu-El-Haija et al., 2016) as in (Tschannen et al., 2019) (referred to as YouTube1M). As expected, performance drops compared to training on ImageNet, but the performance reduction can be compensated at least partially by the lens (Figure 7). The lens recovers about half of the gap to ImageNet-training when evaluating on ImageNet classification downstream. In the transfer setting, when evaluating on Places205, the lens performs even better, such that the YouTube1M-trained model with lens performs similarly to the ImageNet-trained baseline.

Inspecting the lens outputs for models trained on ImageNet and YouTube1M allows us to compare the shortcut features across these datasets. In Figure 9, we show the images with the largest difference in reconstruction loss when trained on ImageNet or YouTube1M. The images strikingly reveal the biases of the respective datasets: For ImageNet, all of the top images contain dogs on grass. Dogs are overrepresented in ImageNet, so this results suggests that the lens targets overrepresented classes and thereby counteracts datasets biases. For YouTube1M, the top images predominantly show high-contrast edges oriented along the cardinal directions. We speculate that this is because images with black bars (for aspect ratio conversion) may be common in this video-derived dataset, and such bars are a strong shortcut for the Rotation task. We also found that the YouTube1M-trained lens is more sensitive to overlaid text than the ImageNet-trained lens (Figure 9, right). Overlaid text is common in YouTube1M (e.g. video credits), but less so in ImageNet.

Together, our quantitative and qualitative results show that the lens can be used to identify and remove pretext task-specific biases and shortcut features from datasets. The lens is therefore a promising tool for exploiting large non-curated data sources.

5 Conclusion

We proposed a method to improve self-supervised visual representations by learning to remove shortcut features from the data with an adversarially trained lens network. Training a feature extractor on data from which shortcuts have been removed forces it to learn higher-level features that transfer and generalize better, as shown in our experiments. By combining features obtained with and without shortcut removal into a single image representation, we ensure that our approach improves representation quality even if the removed features are relevant for the downstream task.

Apart from improved representations, our approach allows us to visualize, quantify and compare the features learned by self-supervision. We confirm that our approach detects and mitigates shortcuts observed in prior work and also sheds light on issues that were less known.

Future research should explore design choices such as the lens architecture and image reconstruction loss. Furthermore, it would be great to see whether and how the proposed technique can be applied to improve and/or visualize supervised learning algorithms.

References

  • S. Abu-El-Haija, N. Kothari, J. Lee, P. Natsev, G. Toderici, B. Varadarajan, and S. Vijayanarasimhan (2016) Youtube-8m: a large-scale video classification benchmark. arXiv:1609.08675. Cited by: §4.4.
  • N. Akhtar and A. Mian (2018)

    Threat of adversarial attacks on deep learning in computer vision: a survey

    .
    IEEE Access 6, pp. 14410–14430. Cited by: §2.
  • P. Bachman, R. D. Hjelm, and W. Buchwalter (2019) Learning representations by maximizing mutual information across views. In NeurIPS, Cited by: §2.
  • M. Caron, P. Bojanowski, A. Joulin, and M. Douze (2018)

    Deep clustering for unsupervised learning of visual features

    .
    Proc. ECCV. Cited by: §2.
  • C. Doersch, A. Gupta, and A. A. Efros (2015) Unsupervised visual representation learning by context prediction. In Proc. ICCV, pp. 1422–1430. Cited by: §1, §2, §2, §4.2.1, §4.2.4, §4.3.
  • C. Doersch and A. Zisserman (2017) Multi-task self-supervised visual learning. In Proc. ICCV, Cited by: §2, §2, §4.3.
  • A. Dosovitskiy, J. T. Springenberg, M. Riedmiller, and T. Brox (2014) Discriminative unsupervised feature learning with convolutional neural networks. In NeurIPS, pp. 766–774. Cited by: §2, §4.2.1, §4.2.4.
  • Z. Feng, C. Xu, and D. Tao (2019) Self-supervised representation learning by rotation feature decoupling. In Proc. CVPR, pp. 10364–10374. Cited by: §2.
  • S. Gidaris, P. Singh, and N. Komodakis (2018) Unsupervised representation learning by predicting image rotations. Proc. ICLR. Cited by: §1, §2, §4.1, §4.2.4.
  • I. J. Goodfellow, J. Shlens, and C. Szegedy (2015) Explaining and harnessing adversarial examples. Proc. ICLR. Cited by: §2, §3.3, §4.2.5, §4.2.5.
  • P. Goyal, D. Mahajan, A. Gupta, and I. Misra (2019) Scaling and benchmarking self-supervised visual representation learning. In Proc. ICCV, pp. 6391–6400. Cited by: §2.
  • K. He, H. Fan, Y. Wu, S. Xie, and R. Girshick (2019) Momentum contrast for unsupervised visual representation learning. arXiv:1911.05722. Cited by: §1, §2.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Identity mappings in deep residual networks. In Proc. ECCV, Cited by: Appendix A, §3.3.
  • O. J. Hénaff, A. Razavi, C. Doersch, S. Eslami, and A. v. d. Oord (2019) Data-efficient image recognition with contrastive predictive coding. arXiv:1905.09272. Cited by: §1, §2.
  • R. D. Hjelm, A. Fedorov, S. Lavoie-Marchildon, K. Grewal, P. Bachman, A. Trischler, and Y. Bengio (2019)

    Learning deep representations by mutual information estimation and maximization

    .
    In Proc. ICLR, Cited by: §2.
  • A. Kolesnikov, X. Zhai, and L. Beyer (2019) Revisiting self-supervised visual representation learning. In Proc. CVPR, pp. 1920–1929. Cited by: Appendix A, Appendix C, §4.2.1, §4.2.2, §4.2.
  • A. Kurakin, I. Goodfellow, and S. Bengio (2016) Adversarial machine learning at scale. In Proc. ICLR, Cited by: Appendix D, §3.2, §4.2.2.
  • H. Lee, J. Huang, M. Singh, and M. Yang (2017) Unsupervised representation learning by sorting sequences. In Proc. ICCV, pp. 667–676. Cited by: §2.
  • A. Madry, A. Makelov, L. Schmidt, D. Tsipras, and A. Vladu (2018) Towards deep learning models resistant to adversarial attacks. In Proc. ICLR, Cited by: §2, footnote 2.
  • I. Misra and L. van der Maaten (2019) Self-supervised learning of pretext-invariant representations. arXiv:1912.01991. Cited by: §1, §2.
  • T. Miyato, S. Maeda, M. Koyama, and S. Ishii (2018) Virtual adversarial training: a regularization method for supervised and semi-supervised learning. IEEE Transactions on Pattern Analysis and Machine Intelligence 41 (8), pp. 1979–1993. Cited by: §2.
  • T. N. Mundhenk, D. Ho, and B. Y. Chen (2018) Improvements to context based self-supervised learning. In Proc. CVPR, Cited by: §2, §4.3.
  • M. Noroozi and P. Favaro (2016) Unsupervised learning of visual representations by solving jigsaw puzzles. In Proc. ECCV, pp. 69–84. Cited by: §1, §2, §2, §4.2.1, §4.2.4, §4.3.
  • M. Noroozi, H. Pirsiavash, and P. Favaro (2017) Representation learning by learning to count. In Proc. ICCV, Cited by: §2, §2.
  • A. v. d. Oord, Y. Li, and O. Vinyals (2018) Representation learning with contrastive predictive coding. arXiv:1807.03748. Cited by: §2.
  • O. Ronneberger, P. Fischer, and T. Brox (2015) U-net: convolutional networks for biomedical image segmentation. Med. Image Comput. Comput. Assist. Interv.. Cited by: Figure 10, Appendix A, §3.3.
  • M. S. Sajjadi, G. Parascandolo, A. Mehrjou, and B. Schölkopf (2018) Tempered adversarial networks. In Proc. ICML, pp. 4451–4459. Cited by: §1.
  • T. Salimans, I. Goodfellow, W. Zaremba, V. Cheung, A. Radford, and X. Chen (2016) Improved techniques for training gans. In NeurIPS, Cited by: §3.3.
  • S. Santurkar, A. Ilyas, D. Tsipras, L. Engstrom, B. Tran, and A. Madry (2019) Image synthesis with a single (robust) classifier. In NeurIPS, pp. 1260–1271. Cited by: §2.
  • F. Schroff, D. Kalenichenko, and J. Philbin (2015)

    Facenet: a unified embedding for face recognition and clustering

    .
    In Proc. CVPR, Cited by: §4.2.1.
  • C. Szegedy, W. Zaremba, I. Sutskever, J. Bruna, D. Erhan, I. Goodfellow, and R. Fergus (2014) Intriguing properties of neural networks. In Proc. ICLR, Cited by: §2.
  • Y. Tian, D. Krishnan, and P. Isola (2019) Contrastive multiview coding. arXiv:1906.05849. Cited by: §2.
  • M. Tschannen, J. Djolonga, M. Ritter, A. Mahendran, N. Houlsby, S. Gelly, and M. Lucic (2019) Self-supervised learning of video-induced visual invariances. arXiv:1912.02783. Cited by: §4.4.
  • M. Tschannen, J. Djolonga, P. K. Rubenstein, S. Gelly, and M. Lucic (2020) On mutual information maximization for representation learning. In Proc. ICLR, Cited by: §2.
  • D. Wei, J. J. Lim, A. Zisserman, and W. T. Freeman (2018) Learning and using the arrow of time. In Proc. CVPR, pp. 8052–8060. Cited by: §2.
  • C. Xie, M. Tan, B. Gong, J. Wang, A. Yuille, and Q. V. Le (2019) Adversarial examples improve image recognition. arXiv:1911.09665. Cited by: §2.
  • R. Zhang, P. Isola, A. A. Efros, E. Shechtman, and O. Wang (2018)

    The unreasonable effectiveness of deep features as a perceptual metric

    .
    In Proc. CVPR, pp. 586–595. Cited by: §3.3.
  • R. Zhang, P. Isola, and A. A. Efros (2016) Colorful image colorization. In Proc. ECCV, pp. 649–666. Cited by: §2.

Appendix A Architecture

For the feature extractor , we use the ResNet50 v2 architecture (He et al., 2016; Kolesnikov et al., 2019) with the standard channel widening factor of 4 (i.e. channels in the first convolutional layer) and a representation size of 2048 at the pre-logits layer unless otherwise noted.

For the lens , we use a variant of the U-Net architecture (Figure 10; Ronneberger et al. 2015). The lens consists of a convolutional encoder and decoder. The encoder and decoder are each a stack of residual units (same unit architecture as used for the feature extractor), with channels for the first unit of the encoder. We use and for all experiments. Two additional residual units form the bottleneck between encoder and decoder (see Figure 10

). After each unit in the encoder, the number of channels is doubled and the resolution is halved by max-pooling with a

kernel and stride 2. Conversely, after each decoder unit, the number of channels is halved and the resolution is doubled using bilinear interpolation. At each resolution level, skip connections are created between the encoder and decoder by concatenating the encoder representation channel-wise with the decoder representation before applying the next decoder unit. The output of the decoder is of the same resolution as the input image, and reduced to three channels by a

convolutional layer. This map is combined by element-wise addition with the input image to produce the lens output.

We choose the U-Net architecture because it efficiently combines a large receptive field with a high output resolution. For example, for input images of size , the maps at the bottleneck of the U-Net are of size , such that a convolution at that size corresponds to pixels at the input resolution and is able to capture large-scale image context. Furthermore, the skip connections of the U-Net make it trivial for the lens to reconstruct the input image by setting all internal weights to zero. This is important to ensure that the changes made by the lens to the image are not simply due to a lack of capacity.

We find that a lens with and yields good results in general, although initial experiments suggested that tuning the lens capacity individually for each pretext task and dataset may provide further gains.

For the experiments using CIFAR-10 (Figure 3), we used a smaller lens architecture consisting of a stack of five ResNet50 v2 residual units without down or up-sampling.

Figure 9: Downstream accuracy for Rotation models trained on ImageNet with different feature extraction network widening factors. The performance gain remains large across model sizes.
Figure 10: Lens architecture. The number of channels is indicated above each block. Based on (Ronneberger et al., 2015).

Appendix B Robustness to model size

In addition to tasks and datasets (see main text), we also tested how the performance of the lens varies with the capacity of the feature extraction networks. For the Rotation task and ImageNet, we trained models with different widening factors (channel number multiplier). As expected, wider networks perform better (Figure 9). We find that the lens improves accuracy across all model widths. The accuracy gain of applying the lens to a feature extraction network with a width factor of 4 is equivalent to the gain obtained by widening the network by 2–.

Appendix C Downstream evaluation

For downstream evaluation of learned representations, we follow the linear evaluation protocol with SGD from Kolesnikov et al. (2019). A logistic regression model for ImageNet or Places205 classification was trained using SGD on the representations obtained from the pre-trained self-supervised models.

For training the logistic regression, we preprocessed input images in the same way for all models: Images were resized to , randomly cropped to , and the color values were scaled to . For evaluation, the random crop was replaced by a central crop.

Representations were then obtained by passing the images through the pre-trained models and extracting the pre-logits activations. For patch-based models, we obtained representations of the full image by averaging the representations of nine patches created from the full image. To create the patches, the the central section of the input image was divided into a grid of patches. Each patch was passed through the feature extraction network and the representations were averaged.

The logistic regression model was trained with a batch size of 2048 and an initial learning rate of 0.8. We trained for 90 epochs and reduced the learning rate by a factor of 10 after epoch 50 and epoch 70. For both ImageNet and Places205, training was performed on the full training set and the performance is reported for the public validation set.

Appendix D Adversarial training with FGSM

For the comparison to adversarial training (Table 1), we used the fast gradient-sign method (FGSM) as described by Kurakin et al. (2016). Analogously to our sweeps over for the lens models, we swept over the perturbation scale and report the accuracy for the best in Table 1. As suggested by Kurakin et al. (2016)

, we randomized the perturbation scale for each image by using the absolute value of a sample from a truncated normal distribution with mean 0 and standard deviation

. Since this randomization already includes nearly unprocessed images, we do not include further unprocessed images during training.

Figure 11: Further example lens outputs for models trained on ImageNet with the Rotation task. Images were randomly sampled from the ImageNet validation set.
Figure 12: Further example lens outputs for models trained on ImageNet with the Exemplar task. Images were randomly sampled from the ImageNet validation set.
Figure 13: Further example lens outputs for models trained on ImageNet with the Relative patch location task. Images were randomly sampled from the ImageNet validation set.
Figure 14: Further example lens outputs for models trained on ImageNet with the Jigsaw task. Images were randomly sampled from the ImageNet validation set.
Figure 15: Further example lens outputs for models trained on YouTube1M with the Rotation task. Outputs from ImageNet-trained models are provided for comparison. Images were randomly sampled from the ImageNet validation set.
Figure 16: Further example lens outputs for images containing text, comparing models trained on YouTube1M and ImageNet with the Rotation task. Images containing artificially overlaid text (logos, watermarks, etc.) were manually selected from a random sample of 1000 ImageNet validation images, before inspecting lens outputs. A random sample of these images is shown.