Semantic segmentation is an essential pixel-wise classification task that has reached remarkable success in recent years. However, the training of such a task is known to be data-hungry, where the labelling process is particularly costly and time-consuming [ouali2020semi]. To tackle this limitation, semi-supervised semantic segmentation has become an important research direction that has drawn a growing attention recently [ouali2020semi, chen2021semi, ke2020guided]. This problem relies on a small set of pixel-level labelled images and a large set of unlabelled images, where both types of images are drawn from the same data distribution. The challenge is how to extract additional and useful training signal from the unlabelled images to allow the training of the model to generalise beyond the small labelled set.
Current state-of-the-art (SOTA) semi-supervised semantic segmentation models are based on consistency learning, which enforces the agreement between the outputs from different views of unlabelled images [ouali2020semi, zou2020pseudoseg, french2019semi, chen2021semi]. These different views can be obtained via perturbations applied to the input image with data augmentation [zou2020pseudoseg] or to the feature space with noise injection [ouali2020semi]. Another way of obtaining different views is with network perturbation, which encourages similar predictions between multiple models trained from different initialization, and has been shown to enable better consistency regularization than input image and feature perturbations [chen2021semi, ke2020guided]. One potential weakness of consistency learning is that it assumes accurate predictions for unlabelled images, such that the perturbation does not push the image feature to the wrong side of the true (hidden) classification decision boundary. Unfortunately, in practice this assumption is not always met by SOTA methods, making the training signal of consistency learning methods potentially incorrect. This problem is exacerbated for consistency learning based on network perturbation because incorrect predictions from one model will deteriorate the training for the other model, and vice versa. Another consequence of these inaccurate predictions is that consistency learning methods that rely on a “strict” cross-entropy (CE) loss can easily overfit prediction mistakes, which can lead to confirmation bias.
In this paper, we address the prediction accuracy problem of consistency based methods by extending the mean teacher (MT) model [tarvainen2017mean, french2019semi, chen2021semi, ke2020guided] with a new auxiliary teacher, and the replacement of MT’s means square error (MSE) loss by a stricter confidence-weighted CE loss (Conf-CE) that has better training convergence. These accurate predictions enable the use of more challenging perturbations, combining input image, feature and network perturbations to improve the generalisation of consistency learning. Furthermore, we propose a new type of adversarial feature perturbation that learns the perturbation to be applied to the student model using virtual adversarial training [miyato2018virtual] from the teachers (T-VAT), instead of injecting different types of noise in the image features [ouali2020semi]. To summarise, our contributions are:
New consistency based semi-supervised semantic segmentation MT model designed to improve the segmentation accuracy of unlabelled training images with a new auxiliary teacher and a replacement of MT’s MSE loss by a stricter confidence-weighted CE loss (Conf-CE) that allows stronger convergence and overall better training accuracy;
A new challenging combination of input data, feature and network perturbations to improve model generalisation; and
A new type of feature perturbation, called T-VAT, based on an adversarial noise learned from the both teachers of our MT model and applied to the student model, which results in the generation of challenging noise to promote an effective training of the student model.
Our experimental evaluation shows that our approach achieves the best results on Pascal VOC 2012 [everingham2015pascal]
. Our approach also shows the best performance on Cityscapes[cordts2016Cityscapes].
2 Related Work
Below, we first discuss supervised semantic segmentation, then semi-supervised learning, and then we describe pseudo-labelling and consistency-based SSL methods.
Supervised semantic segmentationlong2015fully, badrinarayanan2017segnet, minaee2021image] and extensions that explore: 1) multi-scale aspects of the image[dai2015convolutional, lin2016efficient], 2) pyramidal feature maps [zhao2017pyramid, ghiasi2016laplacian, he2019adaptive], 3) dilated convolutions [yu2015multi, chen2018encoder, chen2017rethinking], and 4) attention mechanisms [chen2016attention, li2019expectation]. SOTA semi-supervised semantic segmentation models rely on the supervised semantic segmentation models DeeplabV3+ [chen2017deeplab] and PSPNet [zhao2017pyramid] as backbone architectures.
Semi-supervised learning (SSL)
Semi-supervised learning (SSL)trains a model using labelled and unlabelled images. Current SSL solutions are formulated based on three assumptions [van2020survey]: 1) smoothness: similar images have similar labels; 2) low-density: decision boundary does not pass through high-density areas of the feature space; and 3) manifold: samples on the same low-dimensional manifold embedded in the feature space have the same label. SSL methods can be loosely classified into pseudo-label based SSL [berthelot2019mixmatch, sohn2020fixmatch, 9207304] and consistency based SSL [laine2016temporal, tarvainen2017mean, polyak1992acceleration], with the former generally presenting worse accuracy than the latter. We believe that this is due to the fact that pseudo-label methods disregard part of the unlabelled training set during training, which can reduce their generalisation ability. Below, we focus on consistency based SSL given its superior accuracy on public benchmarks.
Consistency-based SSL methods aims to enforce the agreement between the predictions of perturbed unlabelled images, where perturbations can be applied to the input image, to the feature representation, or to the network. The effectiveness of consistency-based SSL depends on the prediction accuracy for the unlabelled images and the perturbations to challenge the model training. In general, more challenging perturbations target an improved generalisation, but if this perturbation is applied to inaccurate predictions, it can cause the consistency-based method to learn from wrong labels. Prediction accuracy can be improved in many ways, but a simple model ensemble strategies [laine2016temporal, tarvainen2017mean]. Perturbations can be applied to the input image [zou2020pseudoseg], feature representation [ouali2020semi] or network [chen2021semi, ke2020guided]. Independently of how they are applied, perturbations tend to be more effective when they challenge the classification process by, for example, moving perturbed feature closer to true (but hidden) classification boundaries, such as with virtual adversarial training (VAT) [miyato2018virtual].
Consistency-based semi-supervised semantic segmentation methods have shown more competitive results [french2019semi, ouali2020semi, chen2021semi] than pseudo-label approaches. Among the SOTA consistency-based semi-supervised semantic segmentation methods, PseudoSeg [zou2020pseudoseg] relies on a new pseudo-labelling strategy and data augmentation consistency training to calibrate pseudo-labels, but its dependence on the generally inaccurate class activation maps can lead to poor training performance. Cross-consistency training (CCT) [ouali2020semi] applies different types of feature perturbations to enforce the agreement between their semantic segmentation results and the segmentation from the non-perturbed feature. Although the feature perturbations used in CCT are effective, more targeted and accurate adversarial noise can be more helpful for the consistency regularization. Other methods have explored network perturbation [feng2020semi, ke2020guided, mendel2020semi, chen2021semi], where consistency is enforced between the responses of differently initialised models. Perturbation models depend on the ability of the models to produce accurate segmentation results, and as mentioned above, such ability can be improved with the use of model ensembling [tarvainen2017mean]. French et al. [french2019semi] explore model ensembling [tarvainen2017mean] together with network perturbation and input image perturbation [yun2019cutmix]. This is one of the closest methods to our proposal, but we add a more effective model ensembling with multiple mean teachers, and a new adversarial feature perturbation with VAT [miyato2018virtual] and challenging input image perturbation with CutMix [yun2019cutmix] and Zoom In/Out [lin2018multi, chen2016attention] data augmentation. Also comparing with [french2019semi], the more accurate segmentation results produced by our multiple mean teachers allows us to train the model for unlabelled images with the CE loss instead of the MSE used in [french2019semi], providing better training convergence and accuracy.
Before we describe our model and training process, we introduce our dataset for semi-supervised semantic segmentation. We have a small labelled training set , where is the input image of size with colour channels, and is the segmentation map, with the number of visual classes denoted by . We also have a large unlabelled training set is provided, with . These datasets are used to train our proposed MT model with an auxiliary teacher, described in Sec. 3.1. The training of our new MT model, exploring network, feature and input image perturbations, and a strict Conf-CE loss is described in Sec. 3.2.
3.1 Multiple Mean Teachers and Student Models
As explained in Sec. 1, we aim to improve the accuracy of the segmentation of unlabelled training images. To achieve that, we propose the inclusion of an auxiliary teacher, exploring the idea of a double ensembling process to improve segmentation accuracy, namely the ensemble of teacher models, each representing a temporal ensemble of the student model [tarvainen2017mean]. The teachers and student models have the same network structure denoted by , where is the model parameter. This model is decomposed into an encoder and a decoder , where represents the feature space of dimension . Hence, , where
. The probability output of the network is achieved by applying the pixel-wise softmax functionover the classes, as in . The multiple mean teacher-student model is represented with the respective parameter superscripts: and for the teachers, and for the student.
3.2 Training with Multiple Perturbations and a Strict Confidence-weighted CE Loss
In this section, we present the training process of our new MT model using a confidence-weighted CE loss, which is optimised with perturbations to the network, feature representations and input images.
Training. The full training loss for the student model is
where the first loss is the supervised segmentation loss, defined as:
where is the image lattice of size , and denotes the CE loss between the annotation and segmentation prediction from at pixel address . The second term in (1) is the consistency loss, denoted by the confidence weighted CE loss (Conf-CE), with weighting its contribution and its definition being as follows:
where represents the CE loss, denotes the pixel address of the output lattice of the segmentation map, is the segmentation prediction from the teacher models at , is the student model segmentation prediction at , and represents the segmentation prediction confidence from the teacher models at , defined below in (4).
The network perturbation
is obtained from the predictions of the mean teacher models and the student model. The soft segmentation map produced by the ensemble of the mean teachers is estimated as:
where denotes the softmax function. The hard segmentation prediction by the ensemble of teachers, denoted by in (3), is obtained from the one-hot representation computed from in (4). The segmentation prediction confidence in (3) is computed with , where denotes an indicator function and is a minimum confidence to enable a value larger than zero for . Following the MT framework [tarvainen2017mean]
, while the student model is trained via stochastic gradient descent (SGD) to minimise the cost function in (3), both teacher models are trained with exponential moving average (EMA) [tarvainen2017mean] of the student model and batch norm (BN) parameters [cai2021exponential], with:
where , and
controls the transfer weight between epochs. For the training of teacher models, we update the parameters of only one of the two teachers at each training epoch.
The feature perturbation consists of a challenging adversarial feature perturbation that is designed to violate the cluster, or low-density, assumption [ouali2020semi, van2020survey] by pushing the image features, computed from the model encoder, toward the classification boundaries in the feature space. One effective way to produce such adversarial feature noise is with virtual adversarial training (VAT) [miyato2018virtual]
that optimises a perturbation vector to maximise the divergence between correct and adversarial classification. Current methods estimate the adversarial noise using the same single network where the consistency loss will be applied[ouali2020semi]. However, the use of VAT to perturb the training of MT in semi-supervised semantic segmentation is new, to the best of our knowledge. In an MT model, the feature perturbation can be applied to the student model, but given that it has less accurate predictions than the teacher model, this approach may not be conducive to effective training. Hence we propose to estimate the adversarial noise using the more accurate teachers, and then apply this estimated noise to the feature of the student model – we call this feature perturbation T-VAT. The student output to be used in the loss in (3) is , where the adversarial feature perturbation is estimated from the response of the ensemble of teacher models with:
where , is the the sum of the pixel-wise Kullback-Leibler (KL) divergence between the original and perturbed pixel predictions.
The input image perturbation is based on the weak-strong augmentation pairs [lee2013pseudo], where weak augmentations (image flipping, cropping and scaling) are applied to the images to be processed by the teacher models, and in addition to those weak augmentation, strong augmentations [chen2021semi, ke2020guided] (e.g., colour jitter, randomise grayscale and blur) are applied to the images fed to the student model to improve the overall generalization capability.
On top of the strong augmentations, we also apply the CutMix [yun2019cutmix] and Zoom In/Out [lin2018multi, chen2016attention] data augmentations to the student model images. As defined in [french2019semi], the CutMix augmentation is achieved by applying a binary mask that combines two images using the function . We can apply CutMix by combining two input images and minimise the consistency loss (3) with the prediction from (4) [chen2021semi] (referred to as CutMix before prediction), or we can minimise the consistency loss using the CutMix combination of the images and their predictions, as in
with defined in (4). The perturbation used in (7) is referred to as CutMix after prediction, which we argue to produce a cleaner prediction for the consistency loss than CutMix before prediction because its prediction does not contain the artifacts introduced by the prediction from the CutMix images. The Zoom In/Out augmentation [lin2018multi, chen2016attention] is defined by the function that zooms in or out the image using the parameter . The input image consistency loss in (3) for the zoom in/out augmentation for the ensemble results of teacher models is defined by
and is defined in (4).
The segmentation loss in (3) for previous consistency-based semi-supervised semantic segmentation models [french2019semi, chen2021semi] is usually based on the L2 loss. Even though L2 loss is known to be robust, which is advantageous when dealing with the noisy predictions produced by consistency-based methods, it is also known to have poor converge and to possibly lead to vanishing gradients. Given the reliability of the segmentation predictions produced by our extended MT model, we instead use the more effective cross entropy (CE) loss, constrained to be computed at regions of high-confidence segmentation results, represented by in (3), following the strategy applied in self-training approaches [yuan2021simple, he2021re, yang2021st++].
Inference. The semantic segmentation of a test image is obtained from the teachers, as described in (4).
We firstly introduce the experimental setting that we used to evaluate our approach. In Sec. 4.2 we demonstrate our approach for both datasets under different partition protocols by comparing them with the supervised baselines and other previous SOTA approaches. Moreover, we also carry out detailed results based on the few supervision studies in Sec. 4.3. Lastly, we perform the ablation study in Sec. 4.4 and demonstrate an extension experiment based on the exploring of the image-level data in Sec. 4.6.
4.1 Experimental Setup
|French et al.* [french2019semi]||68.90||70.70||72.46||74.49||72.56||72.69||74.25||75.89|
|Ours (sliding eval.)||ResNet50||75.76||76.92||77.64|
Datasets. Pascal VOC 2012 [everingham2015pascal] is the standard dataset used for evaluating the performance of the semi-supervised segmentation models. The dataset contains more than images with classes, providing images with pixel-level labels for training, images for validation and for testing. Following previous papers [ouali2020semi, chen2021semi], we adopt the additional labels from [hariharan2011semantic] and our entire training set contains images. Note that the labels from [hariharan2011semantic] are of low-quality, and may contain noise. Cityscapes [cordts2016Cityscapes] is the urban driving scene dataset, which consists of images for training, for validation and for testing. Each image in the dataset has resolution , and there are classes in total.
Following [ke2020guided, chen2021semi], we random split the full labelled set in Pascal VOC 2012 and Cityscapes with different ratios. We also provide the results based on the official labelled set (with images for Pascal VOC 2012) [zou2020pseudoseg, yuan2021simple].
Implementation details. Most results are based on using our method to train the DeeplabV3+ [chen2017deeplab]
model. We load the ImageNet pre-trained checkpoint, and the segmentation heads are initialized randomly. Following previous papers[ouali2020semi, he2021re, chen2021semi], we utilise the following polynomial learning-rate decay: . We also test our method on PSPNet [ouali2020semi, lai2021semi] to show the generalization of our approach.
During training, we apply data augmentation with random scaling in and random flipping of both labelled and unlabelled images. On Pascal VOC 2012, we crop images to pixels for DeeplabV3+, train for epochs with base learning rate set to , using batch size , for both labelled and unlabelled images, following [chen2021semi]. For PSPNet, we follow [he2021re] and crop images to pixels and use batch size . On Cityscapes, due to hardware limitation, we crop images to pixels, train for epochs with base learning rate set to with batch size for both architectures. Because the teacher’s predictions are unstable at the early stage of the training, we apply the Gaussian ramp-up to the consistency loss weight in (1). For both datasets, the supervised loss is the cross-entropy loss.
Evaluation metrics. Following previous papers [chen2021semi, ke2020guided], we report the mean Intersection-over-Union (mIoU) for validation set for both datasets. All the results are based on the single scale inference.
4.2 Results on Different Partition Protocols
In this section, following [ke2020guided, chen2021semi], we evaluate our method based on sub-sampling the datasets with ratio for the labelled set and unlabelled set. Specifically, in the Pascal VOC 2012 dataset, we split the entire training set (with images) with the ratios of , , , for be labelled set. In the Cityscapes, we similarly utilise the ratios , , for the labelled set. All the partition protocols are from [chen2021semi].
Improvements over Supervised Baselines. We first compare our results with fully supervised learning (trained with the same ratio of labelled set) using DeepLabV3+ architecture, and illustrate the improvements in Fig. 3. This figure demonstrates that our approach successfully exploits unlabelled data, with a dramatic performance boost. On Pascal VOC 2012, Fig. 3-(a) shows that our approach outperforms the supervised baseline by a large gap, especially for small labelled partitions. Specifically, in the ratio (with labelled images), our approach surpasses the fully supervised baseline with and for the ResNet50 and ResNet101, respectively. On the other settings, our approach also shows consistent improvements between and for ratios , , and . On Cityscapes, we use the sliding evaluation111Sliding evaluation is a well known mechanism to deal with high resolution images, where the sliding window moves through the input image, and the extracted image patches are fed into the network separately. to evaluate our final results following [chen2021semi]. Fig. 3-(b) shows that our approach surpasses the supervised baseline by and for ResNet50 and ResNet101 for all protocols.
Comparison to SOTA. For Pascal VOC 2012, Tab. 1 shows that our approach is the best for all partition protocols, using DeepLabV3+ and ResNet50 and ResNet101 backbones. Comparing to French et al. [french2019semi], which is considered a strong baseline, our approach improves by to in all cases. Our approach also provides a significant boost for the original MT in all cases. The results also show that our approach is better than the current SOTA CPS [chen2021semi] by around to for all cases. In some partition protocols, our approach is better than the CPS [chen2021semi] with fewer labelled samples. For example, our approach trained with labelled images outperforms CPS [chen2021semi] trained with labelled images using both backbones. This demonstrates that our perturbed and strict mean teachers yield more accurate results than any other method in the field. On Cityscapes, we use the settings from [lai2021semi]
and show results that use similar settings (in terms of image resolution, batch size, and supervised loss function) for fair comparison. Our approach outperforms CAC[lai2021semi] by nearly and for the and partition protocols. The sliding evaluation also boosts our performance by approximately for all the ratios. This shows that the the sliding process improves the performance of our approach in large resolution images.
4.3 Results on Official Labelled Set of Pascal VOC 2012
In this section, we report the results based on the official high quality labelled data of the Pascal VOC 2012. We compare our approach for both PSPNet and DeepLabV3+ architectures. Table 3 shows that our approach outperforms the SOTA methods for each architecture and backbone settings. For PSPNet, our approach outperforms DARS [he2021re] by mIoU and CCT [ouali2020semi] by . In the experiments, our approach outperforms other SOTA approaches by a large gap. Yuan et. al [yuan2021simple] only utilize the single network to produce the pseudo labels in a self-training manner. For instance, our approach outperforms Yuan et. al [yuan2021simple] by mIoU, which shows the value of our teachers and student network and several perturbation strategies, compared with their self-training single network approach.
Few-supervision study. We subsample with partitions , , and using the official labelled images. The remaining data are combined with the augmented set [hariharan2011semantic] (around images) to be the unlabelled data in the experiments. Table 4 shows that our approach yields the best mIoU results for all cases. For example, our approach outperforms CPS [chen2021semi] by for labelled images. We argue that our effective perturbations allowed the better generalisation of our model under this limited labelled data conditions.
|Yuan et al. [yuan2021simple]||DeeplabV3+||ResNet-101||75.00|
|French et al. [french2019semi]||69.84||68.36||63.20||55.58|
4.4 Ablation Study
In this section, we study the roles of the confidence weighted CE loss (conf-CE), T-VAT perturbation and auxiliary teacher (AT) of our approach. All the experiments are run on Pascal VOC 2012 under ratio, and we use DeeplabV3+ to evaluate our results.
Table 5 demonstrates the improvements of each component mentioned above, where we use MT [tarvainen2017mean] trained with the input image perturbations from Sec. 3.2 and MSE loss as baseline. We note that by replacing MSE by our conf-CE increases mIoU by and for the ResNet50 and ResNet101. T-VAT perturbation yields nearly improvements, showing the effectiveness of our proposed feature perturbation. The more accurate predictions by the auxiliary teacher allows a further improvement of and for the two backbones.
T-VAT perturbation. Fig. 4
-(a) shows the performance under different types of feature perturbations, namely: original (no feature perturbation), uniform (feature noise randomly sampled from uniform distribution), vat (VAT noise learned from the student model), t-vat (T-VAT noise learned from the teacher model as in (6)). Our proposed T-VAT outperforms uniform and VAT perturbations by and , respectively. Additionally, it also surpasses original by .
Empirical results of the CutMix before or after prediction, as described in (7). In Fig 4-(b), we show the mIoU results on the validation set during training epochs. Applying CutMix before predictions may introduce extra semantic complexity and yield inaccurate pseudo-labels, leading to ineffective optimisation. In contrast, the result indicates that applying the CutMix after the prediction improves mIoU by around .
Average gradient magnitudes difference between MSE and Conf-CE. Fig. 5 shows the average gradient magnitudes per layer of the student model after being trained with MSE and Conf-CE losses to optimise consistency in (3), at the last stages of training ( out of 80 epochs). Note that our Conf-CE shows larger gradient magnitudes than MSE [ke2020guided, chen2021semi, french2019semi], suggesting that it can lead to stronger convergence than MSE.
4.5 Qualitative Results
Figure 6 shows the supervised, student and mean teachers results on Pascal VOC 2012 images. The supervised results display the worst accuracy in column (c), caused by the limited labelled training samples. Our final results in column (e) significantly improves the baseline performance, which demonstrates the effectiveness of our approach. Moreover, our final results in column(e) are also more accurate than the student results in column (d).
4.6 Combining Pseudo-label and Consistency-based losses
Current consistency-based methods [ouali2020semi, huang2018weakly, lee2019ficklenet, lai2021semi] also include a pseudo-labelling loss, involving the use of class activation maps (CAM) from the model to generate pseudo-labels . We follow a similar strategy as in CCT [ouali2020semi] and add the CAM loss below to the cost function in (1) to train the student model:
We run experiments based on the official labelled images on Pascal VOC 2012 with the additional images used to minimise the CAM loss in (9). Results on Tab. 6 show that our method outperforms all previous works that use a similar strategy [ouali2020semi, lai2021semi, zou2020pseudoseg]. Moreover, the addition of this CAM loss in (9) to our cost function boosts the performance by and for the two backbones under DeeplabV3+ architecture, compared with our previous results in Tab. 3.
In this paper, we proposed a new consistency-based semi-supervised semantic segmentation method. Among our contributions, we introduced a new MT model, based on multiple mean teachers and a student network, which shows more accurate predictions for unlabelled images that facilitate consistency learning, allowing us to use a stricter confidence-based CE loss than the original MT’s MSE loss. This more accurate predictions also allowed us to use a challenging combination of network, feature and input image perturbations that showed better generalisation. Furthermore, we proposed a new adversarial feature perturbation, called T-VAT, that further improved the generalisation of our approach. Our method outperforms previous methods on Pascal VOC 2012 and Cityscapes, becoming the new SOTA for the semi-supervised semantic segmentation field.
Regarding the limitations of our model, it can be argued that the strict Conf-CE loss has the potential to overfit the remaining prediction mistakes, so we will focus on improving the robustness of the Conf-CE loss. Another limitation that we plan to address is to work on an approach that can handle high-resolution images without using the time-consuming sliding evaluation.