Semi-Supervised Semantic Segmentation with Cross-Consistency Training

03/19/2020 ∙ by Yassine Ouali, et al. ∙ 16

In this paper, we present a novel cross-consistency based semi-supervised approach for semantic segmentation. Consistency training has proven to be a powerful semi-supervised learning framework for leveraging unlabeled data under the cluster assumption, in which the decision boundary should lie in low-density regions. In this work, we first observe that for semantic segmentation, the low-density regions are more apparent within the hidden representations than within the inputs. We thus propose cross-consistency training, where an invariance of the predictions is enforced over different perturbations applied to the outputs of the encoder. Concretely, a shared encoder and a main decoder are trained in a supervised manner using the available labeled examples. To leverage the unlabeled examples, we enforce a consistency between the main decoder predictions and those of the auxiliary decoders, taking as inputs different perturbed versions of the encoder's output, and consequently, improving the encoder's representations. The proposed method is simple and can easily be extended to use additional training signal, such as image-level labels or pixel-level labels across different domains. We perform an ablation study to tease apart the effectiveness of each component, and conduct extensive experiments to demonstrate that our method achieves state-of-the-art results in several datasets.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 4

page 13

page 16

page 17

page 18

Code Repositories

CCT

:page_facing_up: Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CVPR 2020).


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In recent years, with the wide adoption of deep supervised learning within the computer vision community, significant strides were made across various visual tasks yielding impressive results. However, training deep learning models requires a large amount of labeled data which acquisition is often costly and time consuming. In semantic segmentation, given how expensive and laborious the acquisition of pixel-level labels is, with a cost that is 15 times and 60 times larger than that of region-level and image-level labels respectively

[COCO], the need for data efficient semantic segmentation methods is even more evident.

As a result, a growing attention is drown on deep Semi-Supervised learning (SSL) to take advantage of a large amount of unlabeled data and limit the need for labeled examples. The current dominant SSL methods in deep learning are consistency training [ladder_nets, TEnsembling, MeanTeachers, VAT], pseudo labeling [pseudo_labling], entropy minimization [entropy_mini] and bootstrapping [qiao2018deep]. Some newly introduced techniques are based on generative modeling [kumar2017semi, souly2017semi].

Figure 1: The proposed Cross-Consistency training (CCT). For the labeled examples, the encoder and the main decoder are trained in a supervised manner. For the unlabeled examples, a consistency between the main decoder’s predictions and those of the auxiliary decoders is enforced, over different types of perturbations applied to the inputs of the auxiliary decoders. Best viewed in color.

However, the recent progress in SSL was confined to classification tasks, and its application in semantic segmentation is still limited. Dominant approaches [DSRG, dilated_cam, erasing, lee2019ficklenet] focus on weakly-supervised learning which principle is to generate pseudo pixel-level labels by leveraging the weak labels, that can then be used, together with the limited strongly labeled examples, to train a segmentation network in a supervised manner. Generative Adversarial Networks (GANs) were also adapted for SSL setting [souly2017semi, hung2018adversarial] by extending the generic GAN framework to pixel-level predictions. The discriminator is then jointly trained with an adversarial loss and a supervised loss over the labeled examples.

Nevertheless, these approaches suffer from some limitations. Weakly-supervised approaches require weakly labeled examples along with pixel-level labels, hence, they do not exploit the unlabeled data to extract additional training signal. Methods based on adversarial training exploit the unlabeled data, but can be harder to train.

To address these limitations, we propose a simple consistency based semi-supervised method for semantic segmentation. The objective in consistency training is to enforce an invariance of the model’s predictions over small perturbations applied to the inputs. As a result, the learned model will be robust to such small changes. The effectiveness of consistency training depends heavily on the behavior of the data distribution, i.e., the cluster assumption, where the classes must be separated by low density regions. In semantic segmentation, we do not observe the presence of low density regions separating the classes within the inputs, but rather within the encoder’s outputs. Based on this observation, we propose to enforce the consistency over different forms of perturbations applied to the encoder’s output. Specifically, we consider a shared encoder and a main decoder that are trained using the labeled examples. To leverage unlabeled data, we then consider multiple auxiliary decoders whose inputs are perturbed versions of the output of the shared encoder. The consistency is imposed between the main decoder’s predictions and that of the auxiliary decoders (see Fig. 1). This way, the shared encoder’s representation is enhanced by using the additional training signal extracted from the unlabeled data. The added auxiliary decoders have a negligible amount of parameters compared to the encoder. Additionally, during inference, only the main decoder is used, reducing the computation overhead both in training and inference.

The proposed method is simple and efficient, it is also flexible since it can easily be extended to use additional weak labels and pixel-level labels across different domains in a semi-supervised domain adaption setting. With extensive experiments, we demonstrate the effectiveness of our approach on PASCAL VOC [pascalvoc] in a semi-supervised setting, and CityScapes, CamVid [camvid] and SUN [sun] in a semi-supervised domain adaption setting. We obtain competitive results across different datasets and training settings.

Concretely, our contributions are four-fold:

  • We propose a cross-consistency training (CCT) method for semi-supervised semantic segmentation, where the invariance of the predictions is enforced over different perturbations injected into the encoder’s output.

  • We propose and conduct an exhaustive study of various types of perturbations.

  • We extend our approach to use weakly-labeled data, and exploit pixel-level labels across different domains to jointly train the segmentation network.

  • We demonstrate the effectiveness of our approach with with an extensive and detailed experimental results, including a comparison with the state-of-the-art, as well as an in-depth analysis of our approach with a detailed ablation study.

2 Related Work

Semi-Supervised Learning. Recently, many efforts have been made to adapt classic SSL methods to deep learning, such as pseudo labeling [pseudo_labling], entropy minimization [entropy_mini] and graph based methods [liu2018deep, kipf2016semi] in order to overcome this weakness. In this work, we focus mainly on consistency training. We refer the reader to [chapelle2009semi] for a detailed overview of the field. Consistency training methods are based on the assumption that, if a realistic form of perturbation was applied to the unlabeled examples, the predictions should not change significantly. Favoring models with decision boundaries that reside in low density regions, giving consistent predictions for similar inputs. For example, -Model [TEnsembling] enforces a consistency over two perturbed versions of the inputs under different data augmentations and dropout. A weighted moving average of either the previous predictions (i.e., Temporal Ensembling [TEnsembling]), or the model’s parameters (i.e., Mean Teacher [MeanTeachers]), can be used to obtain more stable predictions over the unlabeled examples. Instead of relying on random perturbations, Virtual Adversarial Training (VAT) [VAT] approximates the perturbations that will alter the model’s predictions the most.

Similarly, the proposed method enforces a consistency of predictions between the main decoder and the auxiliary decoders over different perturbations, that are applied to the encoder’s outputs rather than the inputs. Our work is also loosely related to Multi-View learning [multi-view-leaning] and Cross-View training [clark2018semi], where each input to the auxiliary decoders can be view as an alternate, but corrupt representation of the unlabeled examples.

Semi-Supervised Semantic Segmentation. A significant number of approaches use a limited number pixel-level labels together with a larger number of inexact annotations, e.g., region-level [song2019boxdriven, Dai_2015] or image-level labels [lee2019ficklenet, CAM, dilated_cam, wheretolook]. For image-level based weak-supervision, primary localization maps are generated using class activation mapping (CAM) [CAM]. After refining the generated maps, they can then be used to train a segmentation network together with the available pixel-level labels in a SSL setting.

Generative modeling can also be used for semi-supervised semantic segmentation [souly2017semi, hung2018adversarial] to take advantage of the unlabeled examples. Under a GAN framework, the discriminator’s predictions are extended over pixel classes, and can then be jointly trained with a Cross-Entropy loss over the labeled examples and an adversarial loss over the whole dataset.

In comparison, the proposed method exploits the unlabeled examples by enforcing a consistency over multiple perturbations on the hidden representations level. Enhancing the encoder’s representation and the overall performance, with a small additional cost in terms of computation and memory requirements.

Recently, CowMix [french2019consistency], a concurrent method was introduced. CowMix, using MixUp [zhang2017mixup], enforces a consistency between the mixed outputs and the prediction over the mixed inputs. In this context, CCT differs as follows: (1) CowMix, as traditional consistency regularization methods, applies the perturbations to the inputs, but uses MixUp as a high-dimensional perturbation to overcome the absence of the cluster assumption. (2) Requires multiple forward passes though the network for one training iteration. (3) Adapting CowMix to other settings (e.g., over multiple domains, using weak labels) may require significant changes. CCT is efficient and can easily be extended to other settings.

Domain Adaptation. In many real world cases, the existing discrepancy between the distribution of training data and and that of testing data will often hinder the performances. Domain adaptation aims to rectify this mismatch and tune the models for a better generalization at test time [patel2015visual]. Various generative and discriminative domain adaptation methods have been proposed for classification [geng2011daml, ganin2014unsupervised, ganin2016domain, cao2018partial] and semantic segmentation [hoffman2016fcns, zhang2017curriculum, saleh2018effective, kalluri2018universal] tasks.

In this work, we show that enforcing a consistency across different domains can push the model toward better generalization, even in the extreme case of non-overlapping label spaces.

3 Method

3.1 The cluster assumption in semantic segmentation

We start with our observation and analysis of the cluster assumption in semantic segmentation, motivating the proposal of our cross-consistency training approach. A simple way to examine it is to estimate the local smoothness by measuring the local variations between the value of each pixel and its local neighbors. To this end, we compute the average euclidean distance at each spatial location and its 8 intermediate neighbors, for both the inputs and the hidden representations (

i.e., the ResNet’s [resnet] outputs of a DeepLab v3 [deeplabv3] trained on COCO [COCO]). For the inputs, following [french2019consistency], we compute the average distance of a patch centered at a given spatial location and its neighbors to simulate a realistic receptive field. For the hidden representations, we first upsample the feature map to the input size, and then compute the average distance between the neighboring activations (

-dimensional feature vectors). The results are illustrated in

Fig. 7. We observe that the cluster assumption is violated at the input level, given that the low density regions do not align with the class boundaries. On the contrary, for the encoder’s outputs, the cluster assumption is maintained where the class boundaries have high average distance, thus corresponding to low density regions. This observation motivates the following approach, in which the perturbations are applied to the encoder’s outputs rather than the inputs.

3.2 Cross-Consistency Training for semantic segmentation

3.2.1 Problem Definition

In SSL, we are provided with a small set of labeled training examples and a larger set of unlabeled training examples. Let represent the labeled examples and represent the unlabeled examples, with as the -th unlabeled input image, and as the -th labeled input image with spatial dimensions and its corresponding pixel-level label , where is the number of classes.

As discussed in the introduction, the objective is to exploit the larger number of unlabeled examples () to train a segmentation network , to perform well on the test data drawn from the same distribution as the training data. In this work, our architecture (see Fig. 3) is composed of a shared encoder and a main decoder , which constitute the segmentation network . We also introduce a set of auxiliary decoders , with . While the segmentation network is trained on the labeled set in a traditional supervised manner, the auxiliary networks are trained on the unlabeled set by enforcing a consistency of predictions between the main decoder and the auxiliary decoders. Each auxiliary decoder takes as input a perturbed version of the encoder’s output, and the main encoder is fed the uncorrupted intermediate representation. This way, the representation learning of the encoder is further enhanced using the unlabeled examples, and subsequently, that of the segmentation network .

Figure 2: The cluster assumption in semantic segmentation. (a) Examples from PASCAL VOC 2012 train set. (b) Pixel-level labels. (c) Input level. The average euclidean distance between each patch of size centered at a given spatial location extracted from the input images, and its 8 neighboring patches. (d) Hidden representations level. The average euclidean distance between a given -dimensional activation at each spatial location and its 8 neighbors. Darkest regions indicate high average distance.

3.2.2 Cross-Consistency Training

As stated above, to extract additional training signal from the unlabeled set , we rely on enforcing a consistency between the outputs of the main decoder and those of auxiliary decoders . Formally, for a labeled training example , and its pixel-level label , the segmentation network is trained using a Cross-Entropy (CE) based supervised loss :

(1)

with as the CE. For an unlabeled example , an intermediate representation of the input is computed using the shared encoder 111 Throughout the paper, always refers to the output of the encoder corresponding to an unlabeled input image .. Let us consider stochastic perturbations functions, denoted as with , where one perturbation function can be assigned to multiple auxiliary decoders. With various perturbation settings, we generate perturbed versions of the intermediate representation , so that the -th perturbed version is to be fed to the -th auxiliary decoder. For consistency, we consider the perturbation function as part of the auxiliary decoder, (i.e., can be seen as ). The training objective is then to minimize the unsupervised loss , which measures the discrepancy between the main decoder’s output and that of the auxiliary decoders:

(2)

with

as a distance measure between two output probability distributions (

i.e., the outputs of a function applied over the channel dimension). In this work, we choose to use mean squared error (MSE) as a distance measure.

The combined loss for consistency based SSL is then computed as:

(3)

where is an unsupervised loss weighting function. Following [TEnsembling], to avoid using the initial noisy predictions of the main encoder, ramps up starting from zero along a Gaussian curve up to a fixed weight . Concretely, at each training iteration, an equal number of examples are sampled from the labeled and unlabeled sets. The supervised loss is computed using the main encoder’s output and pixel-level labels. For the unlabeled examples, we compute the MSE between the prediction of each auxiliary decoder and that of the main decoder. The total loss is then compute and back-propagated to train the segmentation network and the auxiliary networks . Note that the unsupervised loss is not back-propagated through the main-decoder , only the labeled examples are used to train .

Figure 3: Illustration of our approach. For one training iteration, we sample a labeled input image and its pixel-level label together with an unlabeled image . We pass both images through the encoder and main decoder, obtaining two main predictions and . We compute the supervised loss using the pixel-level label and . We apply various perturbations to , the output of the encoder for , and generate auxiliary predictions using the perturbed versions . The unsupervised loss is then computed between the outputs of the auxiliary decoders and that of the main decoder.

3.2.3 Perturbation functions

An important factor in consistency training is the perturbations to apply to the hidden representation, i.e., the encoder’s output . We propose three types of perturbation functions : feature based, prediction based and random.

Feature based perturbations. They consist of either injecting noise into or dropping some of the activations of encoder’s output feature map .

  • : we uniformly sample a noise tensor

    of the same size as . After adjusting its amplitude by multiplying it with , the noise is then injected into the encoder’s output to get . This way, the injected noise is proportional to each activation.

  • : we first uniformly sample a threshold . After summing over the channel dimension and normalizing the feature map to get , we generate a mask 222 is a boolean function outputting 1 if the condition is true, 0 otherwise., which is then used to obtain the perturbed version . This way, we mask to of the most active regions in the feature map.

Prediction based perturbations. They consist of adding perturbations based on the main decoder’s prediction or that of the auxiliary decoders. We consider masking based perturbations (, and ) in addition to adversarial perturbations ().

  • Guided Masking: Given the importance of context relationships for complex scene understanding

    [oliva2007role], the network might be too reliant on these relationships. To limit them, we create two perturbed versions of by masking the detected objects () and the context (). Using , we generate an object mask to mask the detected foreground objects and a context mask , which are then applied to .

  • Guided Cutout (): in order to reduce the reliance on specific parts of the objects, and inspired by Cutout [cutout] that randomly masks some parts of the input image, we first find the possible spatial extent (i.e., bounding box) of each detected object using . We then zero-out a random crop within each object’s bounding box from the corresponding feature map .

  • Intermediate VAT (): to further push the output distribution to be isotropically smooth around each data point, we investigate using VAT [VAT] as a perturbation function to be applied to instead of the unlabeled inputs. For a given auxiliary decoder, we find the adversarial perturbation that will alter its prediction the most. The noise is then injected into to obtain the perturbed version .

Random perturbations. () Spatial dropout [spatial_dropout] is also applied to as a random perturbation.

3.2.4 Practical considerations

A each training iteration, we sample an equal number of labeled and unlabeled samples. As a consequence, we iterate on the set more times than on its unlabeled counterpart , thus risking an overfitting of the labeled set .

Avoiding Overfitting. Motivated by [pohlen2017full] who observed improved results by sampling only of the hardest pixels, and [xie2019unsupervised] who showed an improvement when gradually releasing the supervised training signal in a SSL setting, we propose an annealed version of the bootstrapped-CE () in [pohlen2017full]. With an output in the form of a probability distribution over the pixels, we only compute the supervised loss over the pixels with a probability less than a threshold :

(4)

To release the supervised training signal, the threshold parameter is gradually increased from to during the beginning of training, with as the number of output classes.

3.3 Exploiting weak-labels

In some cases, we might be provided with additional training data that is less expensive to acquire compared to pixel-level labels, e.g., image-level labels. Formally, instead of an unlabeled set , we are provided with a weakly labeled set alongside a pixel-level labeled set , with is the -th image-level label corresponding to the -th weakly labeled input image . The objective is to extract additional information from the weak labeled set to further enhance the representations of the encoder . To this end, we add a classification branch consisting of a global average pooling layer followed by a classification layer, and pretrain the encoder for a classification task using binary CE loss.

Following previous works [ahn2018learning, lee2019ficklenet, DSRG], the pretrained encoder and the added classification branch can then be exploited to generate pseudo pixel-level labels . We start by generating the CAMs as in [CAM]. Using , we can then generate pseudo labels , with a background and a foreground thresholds. The pixels with attention scores less than (e.g., ) are considered as background. For the pixels with an attention score larger than (e.g., ), they are assigned the class with the maximal attention score, and the rest of the pixels are ignored. After generating , we conduct a final refinement step using dense CRF [krahenbuhl2011efficient].

In addition to considering as an unlabeled set and imposing a consistency over its examples, the pseudo-labels are used to train the auxiliary networks using a weakly supervised loss . In this case, the loss in Eq. 3 becomes:

(5)

With

(6)

3.4 Cross-Consistency Training on Multiple Domains

In this section, we extend the propose framework to a semi-supervised domain adaption setting. We consider the case of two datasets with partially or fully non-overlapping label spaces, each one contains a set of labeled and unlabeled examples . The objective is to simultaneously train a segmentation network to do well on the test data of both datasets, which is drown from the different distributions.

Our assumption is that enforcing a consistency over both unlabeled sets and might impose an invariance of the encoder’s representations across the two domains. To this end, on top of the shared encoder , we add domain specific main decoder and auxiliary decoders . Specifically, as illustrated in Fig. 4, we add two main decoders and auxiliary decoders on top of the encoder . During training, we alternate between the two datasets, at each iteration, sampling an equal number of labeled and unlabeled examples from each one, computing the loss in Eq. 3 and training the shared encoder and the corresponding main and auxiliary decoders.

Figure 4: CCT on multiple domains. On top of a shared encoder, we add domain specific main decoder and auxiliary decoders. During training, we alternate between the two domains, sampling labeled and unlabeled examples and training the corresponding decoders and the shared encoder at each iteration.

4 Experiments

Figure 5: Ablation Studies on CamVid with 20, 50 and 100 labeled images. With different types of perturbations and a variable number of auxiliary decoders , we compare the individual and the combined effectiveness of the perturbations to the baseline in which the model is trained only on the labeled examples. CCT full refers to using all of the 7 perturbations, i.e. the number of auxiliary decoder is .
Figure 6: Ablation study on PASCAL VOC. Ablation study results with 1000 labeled examples using different perturbations and various numbers of auxiliary decoders .

To evaluate the proposed method and investigate its effectiveness in different settings, we carry out detailed experiments. In Section 4.4, we present an extensive ablation study to highlight the contribution of each component within the proposed framework, and compare it to state-of-the-art methods in a semi-supervised setting. Additionally, in Section 4.5 we apply the proposed method in a semi-supervised domain adaptation setting and show performance above baseline methods.

4.1 Network Architecture

Encoder. For the following experiments, the encoder is based on a ResNet-50 [resnet]

pretrained on ImageNet

[imagenet] provided by [you2019torchcv] and a PSP module [PSPnet]. Following previous works [PSPnet, DSRG, ahn2018learning], the last two strided convolutions of ResNet are replaced with dilated convolutions.

Decoders. For the decoders, taking the efficiency and the number of parameters into consideration, we choose to only use convolutions. After an initial convolution to adapt the depth to the number of classes , we apply a series of three sub-pixel convolutions [shi2016real]

with ReLU non-linearities to upsample the outputs to original input size.

4.2 Datasets and Evaluation Metrics

Datasets. In a semi-supervised setting, we evaluate the proposed method on PASCAL VOC [pascalvoc], consisting of 21 classes (with the background included) and three splits, training, validation and testing, with of , and images respectively. Following the common practice [DSRG, PSPnet], we augment the training set with additional images from [hariharan2011semantic]. Note that the pixel-level labels are only extracted from the original training set.

For semi-supervised domain adaption, for partially overlapping label spaces, we train on both Cityscapes [cityscapes] and CamVid [camvid]. Cityscapes is a finely annotated autonomous driving dataset with classes. We are provided with three splits, training, validation and testing with , and images respectively. CamVid contains 367 training, 101 validation and 233 testing images. Although originally the dataset is labeled with classes, we use the classes version [segnet]. For experiments over non-overlapping labels spaces, we train on Cityscapes and SUN RGB-D [sun]. SUN RGB-D is an indoor segmentation dataset with classes containing two splits, training and validation, with and images respectively. Similar to [kalluri2018universal], we train on the classes version [handa2016understanding].

Evaluation Metrics. We report the results using mIoU (i.e., mean of class-wise intersection over union) for all the datasets.

4.3 Implementation Details

Training Settings.

The implementation is based on the PyTorch 1.1

[pytorch] framework. For optimization, we train for epochs using SGD with a learning rate of and a momentum of . During training, the learning rate is annealed following the poly learning rate policy, where at each iteration, the base learning rate is multiplied by with .

For PASCAL VOC, we take crops of size and apply random rescaling in the range of and random horizontal flip. For Cityscapes, Cam-Vid and SUN RGB-D, following [kalluri2018universal, hung2018adversarial], we resize the input images to , and respectively, without any data-augmentation.

Reproducibility All the experiments are conducted on a V-100 GPUs.

Inference Settings. For PASCAL VOC, Cityscapes and SUN RGB-D, we report the results obtained on the validation set, and on the test set of CamVid dataset.

4.4 Semi-Supervised Setting

4.4.1 Ablation Studies

The proposed method consists of several types of perturbations and a variable number of auxiliary decoders. We thus start by studying the effect of the perturbation functions with different numbers of auxiliary decoders, in order to provide additional insight into their individual performance and their combined effectiveness. Specifically, we measure the effect of different numbers of auxiliary decoders (i.e., , , and ) of a given perturbation type. We refer to this setting of our method as “CCT {perturbation type}”, with seven possible perturbations. We also measure the combined effect of all perturbations resulting in auxiliary decoders in total, and refer to it as “CCT full”. Additionally, “CCT full+

” indicates the usage of the annealed-bootstrapped CE as a supervised loss function. We compare them to the baseline, in which the model is trained only using the labeled examples.

Method Pixel-level Labeled Examples Image-level Labeled Examples Val
WSSL [papandreou2015weakly] 1.5k 9k 64.6
GAIN [wheretolook] 1.5k 9k 60.5
MDC [dilated_cam] 1.5k 9k 65.7
DSRG [DSRG] 1.5k 9k 64.3
Souly et al[souly2017semi] 1.5k 9k 65.8
FickleNet [lee2019ficklenet] 1.5k 9k 65.8
Souly et al[souly2017semi] 1.5k - 64.1
Hung et al[hung2018adversarial] 1.5k - 68.4
CCT 1k - 64.0
CCT 1.5k - 69.4
CCT 1.5k 9k 73.2
Table 1: Comparison with the-state-of-the-art. CCT performance on PASCAL VOC compared to other semi-supervised approaches.

CamVid. We carried out the ablation on CamVid with 20, 50 and 100 labels; the results are shown in Fig. 5. We find that each perturbation outperforms the baseline, with the most dramatic differences in the 20-label setting with up to points. We also surprisingly observe an insignificant overall performance gap among different perturbations, confirming the effectiveness of enforcing the consistency over the hidden representations for semantic segmentation, and highlighting the versatility of CCT and its success with numerous perturbations. Increasing results in a modest improvement overall, with the smallest change for and due to their lack of stochasticity. Interestingly, we also observe a slight improvement when combining all of the perturbations, indicating that the encoder is able to generate representations that are consistent over many perturbations, and subsequently, improving the overall performance. Additionally, gradually releasing the training signal using helps increase the performance with up to , which confirms that overfitting of the labeled examples can cause a significant drop in performance.

PASCAL VOC. In order to investigate the success of CCT on larger datasets, we conduct additional ablation experiments on PASCAL VOC using 1000 labeled examples, The results are summarized in Fig. 6. We see similar results, where the proposed method makes further improvement compared to the baseline with different perturbations, from to points. The combined perturbations yield a small increase in the performance, with the biggest difference with . Furthermore, similar to CamVid, when using the loss, we see a significant gain with up to points compared to CCT full.

Based on the conducted ablation studies, for the rest of the experiments, we use the setting of “CCT full” with for and due to their lack of stochasticity, for given its high computational cost, and for the rest of the perturbations, and refer to it as “CCT”.

4.4.2 Comparison to Previous Work

To further explore the effectiveness of our framework, we quantitatively compare it with previous semi-supervised semantic segmentation methods on PASCAL VOC. Table 1 compares CCT with other semi-supervised approaches. Our approach outperforms previous works relying on the same level of supervision and even methods which exploit image-level labels. We also observe an increase of points when using additional image-level labels, affirming the flexibility of CCT, and the possibility of using it with different types of labels without any learning conflicts.

4.5 Semi-Supervised Domain Adaptation Setting

In real world applications, we are often provided with pixel-level labels collected from various sources, thus distinct data distributions. To examine the effectiveness of CCT when applied to multiple domains with a variable degree of labels overlap, we train our model simultaneously on two datasets, Cityscapes (CS) + CamVid (CVD) for partially overlapping labels, and Cityscapes + SUN RGB-D (SUN) for the disjoint case.

Method n=50 n=100
CS CVD Avg. CS CVD Avg.
Kalluri, et al[kalluri2018universal] 34.0 53.2 43.6 41.0 54.6 47.8
Baseline 31.2 40.0 35.6 37.3 34.4 35.9
CCT 35.0 53.7 44.4 40.1 55.7 47.9
Table 2: CCT applied to CS+CVD. CCT performance when simultaneously trained on two datasets with overlapping label spaces, which are Cityscapes (CS) and CamVid (CVD).

Cityscapes + CamVid. The results for CCT on Cityscapes and CamVid datasets with 50 and 100 labeled examples are given in Table 7. Similar to the SSL setting, CCT outperforms the baseline significantly, where the model is iteratively trained using only on the labeled examples, with up to 12 points for , we even see a modest increase compared to previous work. This confirms our hypothesis that enforcing a consistency over different datasets does indeed push the encoder to produce invariant representation across different domains, and consequently, increases the performance over the baseline while delivering similar results on each domain individually.

Method Labeled Examples CS SUN Avg.
SceneNet [mccormac2017scenenet] Full (5.3k) - 49.8 -
Kalluri, et al[kalluri2018universal] 1.5k 58.0 31.5 44.8
Baseline 1.5k 54.3 38.1 46.2
CCT 1.5k 58.8 45.5 52.1
Table 3: CCT applied to CS+CVD. CCT performance when trained on both datasets Cityscapes (CS) and SUN RGB-D (SUN) datasets, for the case of non-overlapping label spaces.

Cityscapes + SUN RGB-D. For cross domain experiments, where the two domains have distinct labels spaces, we train on both Cityscapes and SUN RGB-D to demonstrate the capability of CCT to extract useful visual relationships and perform knowledge transfer between dissimilar domains, even in completely different settings. The results are shown in Table 7. Interestingly, despite the distribution mismatch between the datasets, and the high number of labeled examples (), CCT still provides a meaningful boost over the baseline with 5.9 points difference and 7.3 points compared to previous work. Showing that, by enforcing a consistency of predictions on the unlabeled sets of the two datasets over different perturbations, we can extract additional training signal and enhance the representation learning of the encoder, even in the extreme case with non-overlapping label spaces, without any performance drop when an invariance of representations across both datasets is enforced at the level of encoder’s outputs.

5 Conclusion

In this work, we present cross-consistency training (CCT), a simple, efficient and flexible method for a consistency based semi-supervised semantic segmentation, yielding state-of-the-art results. For future works, a possible direction is exploring the usage of other perturbations to be applied at different levels within the segmentation network. It would also be interesting to adapt and examine the effectiveness of CCT in other visual tasks and learning settings, such as unsupervised domain adaptation.

Acknowledgements. This work was supported by Ranstad corporate research chair. We would also like to thank Saclay-IA plateform of Université Paris-Saclay and Mésocentre computing center of CentraleSupélec and École Normale Supérieure Paris-Saclay for providing the computational resources.

References

A Comparison with Traditional Consistency Training Methods

In this section, we present the experiments to validate the observation that for semantic segmentation, enforcing a consistency over different perturbations applied to the encoder’s outputs rather than the inputs is more aligned with the cluster assumption. To this end, we compare the proposed method with traditional consistency based SSL methods. Specifically, we conduct experiments using VAT [VAT] and Mean Teachers [MeanTeachers]. In VAT, at each training iteration, the unsupervised loss is computed as the KL-divergence between the model’s predictions of the input and its perturbed version . For Mean Teachers, the discrepancy is measured using Mean Squared Error (MSE) between the prediction of the model and the prediction using an exponential weighted version of it. In this case, the noise is sampled at each training step with SGD.

Splits n=500 n=1000
Baseline 51.4 59.2
Mean Teachers 51.3 59.4
VAT 50.0 57.9
CCT 58.6 64.4
Table 4: CCT compared to traditional consistency methods. We conduct an ablation study on PASCAL VOC, where we compare the performance of the baseline to the proposed method CCT, VAT and Mean Teachers. represents the number of labeled examples.

The results are presented in Table 4. We see that applying the adversarial noise to inputs with VAT results in lower performance compared to the baseline. When using Mean Teachers, in which the noise is not implicitly added to the inputs, we obtain similar performance to the baseline. These results confirm our observation that enforcing a consistency over perturbations applied to the hidden representations is more aligned with the cluster assumption, thus yielding better results.

B Additional Results and Evaluations

b.1 Distance Measures

In the experiments presented in Section 4.4, MSE was used as a distance measure for the unsupervised loss , to measure the discrepancy between the main and auxiliary predictions. In this section, we investigate the effectiveness of other distance measures between the output probability distributions. Specifically, we compare the performance of MSE to the KL-divergence and the JS-divergence. For an unlabeled example , we obtain a main prediction with the main decoder and an auxiliary prediction with a given auxiliary decoder . We compare the following distance measures:

(7)
(8)
(9)

where and refers to the output probability distribution at a given spatial location . The results of the comparison are shown in Table 5.

Splits n=500 n=1000
Baseline 51.4 59.2
CCT KL 54.0 62.5
CCT JS 58.4 64.3
CCT MSE 58.6 64.4
Table 5: CCT with different distance measures. We compare the performance of MSE to the KL-divergence and the JS-divergence on PASCAL VOC dataset.

We observe similar performance with and , while we only obtain 2.6 and 3.3 points gain for and respectively over the baseline when using . The low performance of might be due to its non-symmetric nature. With , the auxiliary decoders are heavily penalized over sharp but wrong predictions, thus pushing them to produce uniform and uncertain outputs, and reducing the amount of training signal that can be extracted from the unlabeled examples. However, with , which is a symmetrized and smoothed version of , we can bypass the zero avoidance nature of the KL-divergence. Similarly, can be seen as a multi-class Brier score [MixMatch] which is less sensitive to completely incorrect predictions, giving it similar properties to with a lower computational cost.

b.2 Confidence Masking and Pairwise Loss

Confidence Masking. ()

When training on the unlabeled examples, we use the main predictions as the source for consistency training, which may result in a corrupted training signal when based on uncertain predictions. A possible way to avoid this is masking the uncertain predictions. Given a main prediction in the form of a probability distribution over the classes at different spatial locations . We compute the unsupervised loss only over the pixels with probability greater than a fixed threshold (e.g., 0.5).

Pairwise Loss. ()

In CCT, we enforce the consistency of predictions only between the main and auxiliary decoders, without any pairwise consistency in between the auxiliary predictions. To investigate the effectiveness of enforcing such an additional pairwise consistency, we add the following an additional loss term to the total loss in Eq. 3

to penalize the auxiliary predictive variance:

(10)

with as the mean of the auxiliary predictions . Given auxiliary decoders, the computation of is in the order of . To reduce it, at each training iteration, we only compute over a randomly chosen subset of the auxiliary predictions (e.g., out of ).

Table 6 shows the results of the experiments when using CCT with and . Interestingly, we do not observe any gain over CCT when using , indicating that using the uncertain main predictions to enforce the consistency does not hinder the performance. Additionally, adding a pairwise loss term results in lower performance compared to CCT, with and points difference in both settings, indicating that adding can potentially compel the auxiliary decoders to produce similar predictions regardless of the applied perturbation, thus diminishing the representation learning of the encoder, and the performance of the segmentation network as a whole.

Splits n=500 n=1000
Baseline 51.4 59.2
CCT + 58.4 63.3
CCT + 55.6 61.2
CCT 58.6 64.4
Table 6: CCT with and . The results of the effect of adding a confidence masking over unsupervised loss and a pairwise loss between the auxiliary predictions on PASCAL VOC val set.

C Algorithm

The proposed Cross-Consistency training method can be summarized by the following Algorithm:

Input: Labeled image , its pixel-level label and unlabeled image
Require: Shared encoder , main decoder , auxiliary decoders
1) Forward through the encoder and main decoder:
2) Forward the unlabeled input through the shared encoder:
3) Generate the main decoder’s prediction for :
4) Generate the aux. decoders predictions for :
for  in  do
      - Apply a given perturbation
      - Forward through the aux. decoder :
     
      end for
     5) Training the network.
     
     
Update network by
Algorithm 1 Cross-Consistency Training (CCT).

D Further Investigation of The Cluster Assumption

The learned feature of a CNNs are generally more homogeneous, and at higher layers, the network learns to compose low level features into semantically meaningful representations while discarding high-frequency information (e.g., texture). However, the leaned features in a segmentation network seem to have a unique property; the class boundaries correspond to low density regions, which are not observed in networks trained on other visual tasks (e.g., classification, object detection). See Fig. 7 for an illustration of this difference.

Figure 7: The smoothness of CNNs features trained on different tasks. (a) Examples from PASCAL VOC 2012 train set. (b) Results for a segmentation network. (c) Results for an object detection network. (b) Results for a classification network.

E Adversarial Distribution Alignment

When applying CCT over multiple domains, and to further reduce the discrepancy between the encoder’s representations of the two domains (i.e., the empirical distribution mismatch measured by the -Divergence [ben2010theory]), we investigate the addition of a discriminator branch , which takes as input the encoder’s representation , and predict 0 for examples from and 1 for examples from . Hence, we add the following adversarial loss to the total loss in Eq. 1: L_adv = 1—D(1)— ∑_x_i ∈D^(1) log(g_d(z_i)) + 1—D(2)— ∑_x_i ∈D^(2) (1 - log(g_d(z_i)))

The encoder and the discriminator branch are competitors within a min-max framework, i.e., the training objective is , which can be directly optimized using a gradient reversal layer as in [ganin2014unsupervised]. The total loss in this case is:

(11)
Method n=50 n=100
CS CVD Avg. CS CVD Avg.
Baseline 31.2 40.0 35.6 37.3 34.4 35.9
CCT 35.0 53.7 44.4 40.1 55.7 47.9
CCT + 35.3 49.2 42.2 37.7 52.8 45.2
Table 7: CCT applied to CS+CVD.

For the discriminator branch, similar to [hung2018adversarial], we use a fully convolutional discriminator, with a series of convolutions and Leaky ReLU non-linearities as shown in Table 8. The outputs are of the same size as the encoder outputs (i.e. with an input image of spatial dimensions , the outputs of are of size ).

Description Resolution channels
Conv
LeakyReLU
Conv
LeakyReLU
Conv
LeakyReLU
Conv
LeakyReLU
Conv
Table 8: Discriminator Branch. The added discriminator branch on top of the encoder, in order to further push towards an invariance of the encoder’s representations between the different domains.

The results are shown in Table 8. Surprisingly, adding a discriminator branch diminishes the performance of the segmentation network, hinting to possible learning conflicts between CCT and the adversarial loss.

F Multi-scale Inference

To further enhance the predictions of our segmentation network, we conduct additional evaluations on PASCAL VOC using multi-scale to simulate a similar situation to training where we apply random scaling between and , random croping and random horizontal flip. We apply the same augmentations during test. For a given test image, we create 5 versions using 5 scales: , , , and , each image is also flipped horizontally, resulting in versions of the test image. The model’s prediction are computed for each image, rescaled to the original size, and are then aggregated by pixel-wise average pooling. The final result is obtain by taking the over the classes for each spatial location.

In Table 9, we report the results obtained with multi-scale inference.

n mIoU
CCT 1000 67.3 (+3.3)
CCT 1500 73.4 (+4)
CCT + Image-level labels 1500 75.1 (+2.9)
Table 9: CCT results with multi-scale inference. The mIoU when we apply multi-scale inference on PASCAL VOC val set.

G Virtual Adversarial Training (VAT)

Without the label information in a semi-supervised setting, VAT [VAT] lends itself as a consistency regularization technique. It trains the output distribution to be isotropically smooth around each data point by selectively smoothing the model in its most anisotropic direction. In our case, we apply the adversarial perturbation to the encoder output . For a given auxiliary decoder , we would like to compute the adversarial perturbation that will alter its predictions the most. We start by sampling a Gaussian noise of the same size as , compute its gradients with respect the loss between the two predictions, with and without the injections of the noise (i.e., KL-divergence is used as a distance measure ). can then be obtained by normalizing and scaling

by a hyperparameter

. This can be written as follows:

(12)
(13)
(14)

Finally, the perturbed input to is . The main drawback of such method is requiring multiple forward and backward passes for each training iteration to compute . In our case, the amount of computations needed are reduced given the small size of the auxiliary decoders.

H Dataset sizes

For the size of each split of the datasets used in our experiments, see Table 10.

Splits Train Val Test
PASCAL VOC 10582 1449 1456
Cityscapes 2975 500 1525
CamVid 367 101 233
SUN RGB-D 5285 - 5050
Table 10: Semantic Segmentation Datasets. The size of each split of the datasets used in the experiments.

I Further Experimental Details

For the experiments throughout the paper, we used a ResNet 50 and a PSP module [PSPnet] for the encoder. As for the decoders, we used an initial convolutions to adapt the depth to the number of classes , followed by a series of sub-pixel convolutions [shi2016real] (i.e., PixelShuffle) to upsample the feature maps to the original size. For details see Table 11.

Encoder Decoder
Description Resolution channels Description Resolution channels
ResNet 50 Conv
PSPModule [PSPnet] Conv
PixelShuffle
Conv
PixelShuffle
Conv
PixelShuffle
Table 11: Encoder-Decoder architecture. Showing the layer type, the number of the outputs channels and the spatial resolution.

Inference Settings. For PASCAL VOC, during the ablation studies reported in Fig. 6, in order to reduce the training time, we trained on smaller size image. Specifically, we resize the bigger side to and randomly take crops of size . For the comparisons with state-of-the-art we resize the bigger side to and take crops of size and conduct the inference on the original sized images. For the rest of the datasets, the evaluation is conducted on the same sizes as the ones used during training.

J Hyperparameters

In order to present a realistic evaluation of the proposed method, and following the practical considerations mentioned in [oliver2018realistic]. We avoid any form of intensive hyperparameter search, be it that of the perturbation functions, model architecture or training settings. We choose the hyperparameters that resulted in stable training by hand, we do expect however that better performances can be achieved with a comprehensive search. The hyperparameters settings used in the experiments are summarized in Table 12.

Training
SGD
Learning rate
Momentum
Weight Decay
Number of training epochs
PASCAL VOC 50
CamVid 50
Cityscapes & CamVid 50
Cityscapes & SUN RGB-D 100
Losses
Unsupervised loss
Rampup periode for 0.1
weight 30
Weakly-supervised loss
Rampup periode for 0.1
weight 0.4
Annealed Cross-Entropy loss
Rampup periode 0.5
Final threshold 0.9
Adversarial loss
Weight
Perturbation Functions
VAT 2.0
VAT
Dropout rate
Area of the dropped region 0.4
Drop threshold range
The uniform noise range
Table 12: Hyperparameters. The hyperparameter settings used in our experiments.

K Ramp-up functions

For the unsupervised loss in Eq. 2, the weighting function is gradually increased from up to a fixed final weight . The rate of increase can follow many possible rates depending on the schedule used. Fig. 8 shows different ramp-up schedules. For our experiments, following [TEnsembling], ramps-up following an exp-schedule:

(15)

with as the current training iteration and as the desired ramp-up length (e.g., the first of training time). Similarly, the threshold in the loss (Eq. 4) is gradually increased starting from , with as the number of classes, up to a final threshold (e.g., ) within a ramp-up period (e.g., the first of training time). For , we use a log-schedule to quickly increase in the beginning of training:

(16)
Figure 8: Different ramp-up schedules.

L Computational Overhead

Decoders Input size GPU memory (MB) GPU time (ms)
Main Decoder 139 2.0
139 2.6
157 3.0
175 45.7
463 82.3
463 2.7
463 2.4
463 3.3
Main Decoder 457 4.0
520 4.7
520 5.2
584 149.5
1592 176.0
1592 4.7
1592 4.6
1592 7.1
Table 13: Computation and memory statistics. Comparisons between the main and auxiliary decoders with different perturbation functions. The channel numbers of the input feature maps is . The lower the values, the better.

In order to present a comparison between the computational overhead of the different types of auxiliary decoders, we present various computation and memory statistics in Table 13. We observe that for the majority of the auxiliary decoders, the GPU time is similar to that of the main decoder. However, we see a significant increase for given the multiple forward and backward passes required to compute the adversarial perturbation. also results in high GPU time due to the sampling procedure. To this end we reduce the number of decoders (e.g., for our experiments). For , for an input tensor of size with as the batch size, instead of sampling a noise tensor of the same size, we sample a tensor of size and apply it over the whole batch. Significantly reduces the computation the computation time without impacting the performance.

M Change Log

v1 Initial Release. Paper accepted to CVPR 2020, the implementation will soon be made available.

N Qualitative Results

Pseudo-Labels

Fig. 9 shows some qualitative results of the generated pseudo pixel-level labels using the available image-level labels. We observe that when considering regions with high attention scores (i.e. ), the assigned classes do correspond in most cases to true positives.

Figure 9: The generated pseudo pixel-level labels. Instances of the generated pseudo pixel-level labels from PASCAL VOC train set. The white regions correspond to the ignored pixels.

Predictions

Qualitative results of CCT on PASCAL VOC val images with different values of are presented in Fig. 10.

Figure 10: CCT results. Semantic Segmentation Results on the PASCAL VOC val images.