Recently, semi-supervised learning (SSL) has been extensively studied to improve the generalization ability of deep neural networks for visual recognition by utilizing limited labelled images and massive unlabelled images. For leveraging the unlabelled data, current SSL methods are usually based on a common density-based cluster assumption, samples lying in the same high-density region are likely to belong to the same class[chapelle2005semi, ICT2019]. Its equivalent one is low-density separation assumption, which states that the decision boundary should not cross high density regions, but instead lies in low density regions. Current methods mainly focus on enforcing the low-density separation assumption by encouraging invariant prediction for perturbations around each unlabelled data point or same predictions for nearby samples, including consistency-regularization based methods [laine2016temporal, tarvainen2017mean, ICT2019, berthelot2019mixmatch] and pseudo-label based methods [lee2013pseudo, shi2018transductive, iscen2019label].
While density-peak assumption [Rodriguez1492] states that high-density samples are more likely to be the center of a cluster, thus samples with high density can encode more representative information of the cluster, which are valuable clues for semi-supervised learning. However, current methods have not consider such density information explicitly or exploited it in depth. Moreover, as we know, the performance of current SSL frameworks is mainly determined by two aspects: feature learning and pseudo label generation for unlabelled samples. Nevertheless, when learning the feature embedding, current methods mainly leverage the single data sample and ignore the abundant neighborhood information which is helpful to learn more discriminative features. And for pseudo label generation, existing methods often directly choose current model predictions[lee2013pseudo] or perform label propagation[iscen2019label] to generate labels for unlabeled samples, which induce either inaccurate pseudo labels, or high matrix computation cost which is difficult to be end-to-end trained with the feature learning part.
Motivated by above observations, this paper proposes a novel density-aware graph based SSL framework. By building a density-aware graph, the neighborhood information can be easily utilized for each data sample. More importantly, the feature learning part and the label propagation part can also be end-to-end trained in the newly proposed framework. Further more, to better leverage the density information, we explicitly incorporate it into these two parts respectively.
Specifically, given the labelled and unlabelled data samples, a density-aware graph will be first built and we define the density for each node in the graph. Then for feature embedding learning, rather than only based on each single data sample, we propose to aggregate the neighborhood features to enhance the target features instead, which is demonstrated to be very useful in modern Graph Neural Networks(GNNs) [scarselli2008graph, kipf2016semi, mne_iccv19]. However for the current aggregation schemes, the aggregation weights are often only characterized by the feature similarity between target node and its neighbors. In such case, when the target node has equal similarity with two neighbors which belong to two different clusters, it will give the same weight to these two neighbors. Motivated by the aforementioned density-peak assumption, we propose a novel Density aware Neighborhood Aggregation(DNA) scheme. Concretely, besides considering the feature similarity, we take the neighbor density information into account as well when calculating the aggregating weights. Intuitively, we want higher density neighbors to have higher importance. One simple explanation of this strategy is illustrated in Fig. 1(a).
To generate pseudo-labels for unlabelled samples more efficiently, we follow the basic label-propagation scheme which propagates the labels from labelled samples to unlabelled samples. However, we are not going to perform label propagation through a linear system solver used in [iscen2019label, zhu2002learning], which induces high matrix computational cost and works offline while training. Inspired by the aforementioned density assumption again, given one unlabelled sample, we argue the pseudo-label generated from neighboring samples with higher density is more possibly precise than that from neighbors with lower density. Based on this insight, we further propose a novel Density-ascending Path-based Label Propagation (DPLP) module. Specifically, for each sample, we will construct a density-ascending path where densities of samples are characterized by ascending-order, and perform label propagation within this path, which is efficient and can be online trained with feature learning in an end-to-end fashion. A graphical illustration can be found in Fig. 1(b).
In summary, the main contributions of this work include: (1) We propose a novel density-aware graph based SSL framework. To the best of our knowledge, this is the first work which exploits the density information explicitly for deep semi-supervised visual recognition; (2) A new Density-aware Neighborhood Aggregation module and Density-ascending Path-based Label Propagation module are designed for better feature learning and pseudo label generation respectively. The two modules are integrated into a unified framework and can be end-to-end trained; (3) Extensive experiments demonstrate the effectiveness of the proposed framework, which significantly outperforms the current state-of-the-art methods.
2 Related Works
Consistency-regularization for SSL. These methods usually apply a consistency loss on the unlabeled data, which enforce invariant predictions for perturbations of unlabelled data. For example, -model [laine2016temporal] proposes to use a consistency loss between the outputs of a network on random perturbations of the same image, while Laine [laine2016temporal] apply consistency constraints between the output of the current network and the temporal average of outputs during training. The mean teacher (MT) method [tarvainen2017mean] replaces output averaging by averaging of network parameters. To utilize the structural information among unlabeled data points, [shi2018transductive] applies a Min-Max Feature regularization loss to encourage networks to learn features with better between-class separability and within-class compactness. Similarly, Luo [luo2018smooth]
utilize the contrastive loss to enforce neighboring points to have consistent predictions while the non-neighbors are pushed apart from each other. Although these methods have exploited neighborhood and density information, they are in the form of regularization terms or loss functions. By contrast, our method proposes to aggregate neighborhood features to enhance the target feature in a more explicit density-aware manner.
Pseudo-labeling for SSL. To leverage unlabelled data, pseudo-label based methods try to assign pseudo labels to the unlabeled samples based on labelled samples, then train the network in a fully supervised way. To generate precise pseudo labels, Lee [lee2013pseudo] use the current network predictions with high confidence as pseudo-labels for unlabeled examples. Shi [shi2018transductive] use the network class prediction as hard labels for the unlabeled samples and introduce an uncertainty weight. Recently, Iscen [iscen2019label] employ graph-based label propagation to infer labels for unlabeled samples. However, they perform label propagation through a linear system solver on the training set offline with high computational cost, thus cannot be trained in an end-to-end way. In this work, we propose to construct a density-ascending path and perform label propagation within this path, which is much more efficient and can be end-to-end trained.
Neighborhood Aggregation in GNNs.
Modern GNNs broadly follow a neighborhood aggregation scheme, where each node aggregates feature vectors of its neighbors to get more representative feature vector[xu2018how]. Different GNNs can vary in how they perform neighborhood aggregation. For example, Kipf [kipf2016semi] use mean-pooling based neighborhood aggregation and Hamilton [hamilton2017inductive]
propose three aggregator functions: Mean aggregator, Max-Pooling aggregator and LSTM aggregator. Recently, inspired by self-attention mechanism, Petar[velivckovic2017graph] propose an attention-based aggregation architecture by learning adaptive aggregation weights. Li [mne_iccv19] extend the attention-based aggregation by supervising the attention weights with node-wise class relationship. After careful study, we find most of these neighborhood aggregation methods only consider the feature similarity between target sample and its neighbors when defining the aggregating weights. However, density information is shown to be a very important clue for SSL. Therefore, besides feature similarity, this paper also takes the neighborhood density into consideration and proposes a novel density-aware neighborhood aggregation scheme.
In semi-supervised learning, a small amount of labelled training samples and a large set of unlabelled training samples are often given, where and are number of labelled and unlabelled samples respectively and usually . Then the goal of SSL is to leverage both and to train a better and generalized recognition model. Formally, let be the total number of training samples, be the feature extractor and
be the classifier, current deep SSL methods adopt a similar optimization formulation:
where is the loss function like cross-entropy loss. For labelled data, is the ground-truth label, while for unlabelled data, can be pseudo-label. is the regularization term, which encourages the model to generalize better to unseen data. Inspired by [grandvalet2005semi, NIPS1991_440], we add regularization term as follow:
where is the number of classes, represents the mean softmax predictions of the model for category across current training batch. The first term is the entropy minimization objective defined in [grandvalet2005semi], which simply encourages the model output to have low entropy; while the second term encourages the model to predict each class with equal frequency on average [NIPS1991_440].
Overview. In this work, we introduce a unified framework for joint feature learning and label propagation on a density-aware graph for semi-supervised visual recognition, which can be trained in an end-to-end fashion. A graphical overview of the proposed framework is depicted in Fig.2. First, we construct a -nearest neighbor graph and define the density for each node in the graph. Then based on the density-aware graph, we propose to learn the feature embedding and pseudo label generation simultaneously for each node in the graph. Specifically, for each target node, we will sample its neighborhood sub-graph and learn feature embedding on this sub-graph by incorporating the neighborhood information with Density-aware Neighborhood Aggregation(DNA). For pseudo label generation, we propose Density-ascending Path-based Label Propagation (DPLP), , build a density-ascending path for each node in the graph and propagate the labels from labelled nodes to unlabelled nodes within this path.
4.1 Density-Aware Graph
Given a pre-trained feature extractor and classifier, we first extract the feature vectors and label predictions for all training samples, and organize them as Feature Bank and Label Bank respectively, which can be accessed through index later. Based on the features in the feature bank, we construct the global -nearest neighbor affinity graph ( = 64 in this work). We then define the density for each node in the graph as:
where is the -nearest neighbors of node , and , are the L2-normalized feature embedding of node and . Intuitively, this formula expresses the density of each node as the average of the similarities between and its neighbors. Need to note that other definitions of the density can also be considered, such as the number of neighbors whose similarity with target node is greater than a predefined threshold [Rodriguez1492]. But it is not the focus of this work. We refer to the graph equipped with the density as Density-Aware Graph;
4.2 Density-Aware Neighborhood Aggregation
Current methods mainly focus on regularizing the output to be smooth near a data sample locally but learn the feature embedding only based on each single data sample. That is to say, they have not fully explored the important neighborhood information for feature learning, which is demonstrated to be very useful in other tasks [Zhuang_2019_ICCV, mne_iccv19, Sabokrou_2019_ICCV, han2019once, zhou2019dup]. Motivated by this observation, we propose to enhance the feature embedding of each target sample by aggregating the neighborhood features. Specifically, for each sample in the current training batch, we first pass it through the backbone to get the feature vector, and get corresponding node in the global density-aware graph (referred as target node), then we sample the neighborhood nodes of the target node from the global graph to obtain a sub-graph.
Sub-Graph Construction For construction of the sub-graph, we follow [mne_iccv19] and organize the neighbor nodes as a Tree-Graph. In particular, we take the target node as the root node, then build the tree in an iterative fashion. Each time, we extend all the leaf nodes by adding their nearest neighbors from global-graph as the new leaf nodes. The tree graph grows until it reaches a predefined depth . Based on Tree-Graph, we then iteratively perform feature aggregation among connected nodes and gradually propagates information within the sub-graph from leaf nodes to the target node. In the experimental part, we will study the effect of the number of sampling neighbors and graph depth .
Density-aware Neighborhood Aggregation(DNA) After sampling of sub-graph, we propose to improve target feature embedding by aggregating its neighbors embedding in the sub-graph. General aggregation strategies like mean-pooling and max-pooling cannot determine which neighbors are more important. Recently, to adaptively aggregate the features, [velivckovic2017graph, mne_iccv19] proposed an attention-based architecture to perform neighboring aggregation, whose aggregation weights are characterized by the feature similarity between target node and its neighbors. Formally, the adaptive aggregation can be denoted as:
where are extra parameters for feature transformation. And is the aggregation weight denoting how much neighbor node contributes to the target node :
In details, we first perform feature L2-normalization, then define the similarity as the inner product of L2-normalized feature, and get the final aggregation weights by normalizing the similarity with the softmax operation.
However in SSL, we find only considering the aggregation weight with feature similarity is sub-optimal. In this way, if the target node has equal similarity with two neighbors that belong to two different clusters, the same aggregation weights will be assigned to these two neighbors. In fact, based on the density-peak assumption, the nodes with higher density are more closer to the cluster center and more discriminative. Therefore, besides the feature similarity information, we propose to incorporate the density information of each neighbor as well when calculating the aggregation weights. By default, we simply combine feature similarity and density with element-wise summation and rewrite Eq. 5 as follows:
4.3 Density-ascending Path for Label Propagation
Density-Ascending Path Construction. We construct the density-ascending path for each node in the global density-aware graph. More specifically, for node , we initialize the density-path as one-element set . Then we add one new nearest neighbor node , whose density is greater than the previous added node. We iteratively perform this process until the distance between the candidate node and the last added node is greater than a predefined threshold.
For notation clarity, we define the Density-Ascending Path as , where is the node added to the path at - step. Supposing the added node at - step is , then node to be added is the neighbor node with higher density:
where is a distance metric function, and we choose the L2-Euclidean distance metric in this work by default, i.e., To alleviate the influence of irrelevant neighbors, we define a threshold , to terminate the growth of the density-ascending path, i.e., for each node pair (,) in , it satisfies: .
Before entering the Density-ascending Path-based Label Propagation, we first introduce the following assumption.
Assumption: The labelled nodes with higher density in the density-ascending path are more possible to provide correct pseudo labels than the ones with lower density.
Explanation: As stated in the cluster assumption, samples with the same label are more likely to lie in the high-density region. Meanwhile, the density-peak assumption [Rodriguez1492] shows that high-density nodes are more likely to be the center of a cluster. Thus for one unlabelled node, the labelled nodes with higher density are more representative and more likely to provide correct pseudo labels than the ones having lower density in the same density-ascending path.
Based on density-ascending path, now we introduce our Density-ascending Path-based Label Propagation(DPLP) algorithm. Specifically, we first sort all the labelled nodes based on the density in descending order, then for each labelled node, we construct a density-ascending path and use the label of the max-density labelled node to update the entries of Label Bank corresponding to all the unlabelled nodes in this path. For remaining unlabelled nodes, we also construct a density-ascending path for each of them and update the corresponding entry of Label Bank using the label of labeled node with highest density in the path. The detailed procedures are summarized in Alg. 1.
4.4 Density Aware Graph-based SSL Framework
The above Density-aware Neighborhood Aggregation and Density-ascending Path-based Label Propagation are integrated into a unified Density Aware Graph-based SSL framework (dubbed as “DAG”) and can be trained in an end-to-end fashion. We summarize the whole training process in Alg. 2. Specifically, at the beginning of each epoch, we update the Feature Bank and Label Bank with the latest feature extractor and classifier. Then construct a new global density-aware graph based on current Feature Bank and perform density-ascending path label propagation based on DAG and current Label Bank. Need to note that we always use ground-truth labels for the Label Bank entries of labeled samples. After the above training preparation, we then start to train the framework by sampling batch images and labels. In details, for each batch of images, we feed them into the feature extractor and enhance their output features by density-aware neighborhood aggregation on sub-graphs before feeding them into the classifier.
|No. of labelled images||1000||4000|
|NA w/o density||9.46||40.76|
|NA with density||9.18||40.33|
5.1 Experimental Setup
Dataset Setup. To verify the effectiveness, we conduct experiments on three popular datasets, namely CIFAR10 [krizhevsky2009learning], CIFAR100 [krizhevsky2009learning]
and Mini-ImageNet[vinyals2016matching]. In details, CIFAR10 and CIFAR100 both contain 50k images for training and 10k images for testing with resolution , but coming from 10 and 100 classes respectively. Following the standard SSL setting, for CIFAR10, we perform experiments with 50, 100, 200, and 400 labelled images per classes. And for CIFAR100, we experiment with 40 and 100 labelled images per class. For Mini-ImageNet, it is a subset of ImageNet[deng2009imagenet] and consists of 100 classes with 600 images per class of resolution . With the same setting as [iscen2019label], we randomly assign 500 images from each class to the training set, and 100 images to the test set. The train and test sets therefore contain 50k and 10k images. We then experiment with 40 and 100 labelled images per class for SSL. We perform ablation study on CIFAR10 and CIFAR100 with an independent validation set of 5K instances sampled from the training set as [athiwaratkun2018there, oliver2018realistic] and compare with state-of-the-art methods on the standard test set.
Implementations. For model architecture, we use the same 13-layer CNN network as in [laine2016temporal, tarvainen2017mean, iscen2019label] for CIFAR10/100 and ResNet-18 [he2016deep] for Mini-ImageNet as [iscen2019label]. The SGD optimizer is used with the initial learning rate 0.1 and 0.2 for CIFAR10/100 and Mini-ImageNet respectively. We decay the learning rates by 0.1 at 250 and 350 epochs and obtain the final model after 400 epochs. We augment training images with random cropping and horizontal flipping as [laine2016temporal, tarvainen2017mean]. Inspired by [zhang2018mixup, ICT2019, berthelot2019mixmatch], we also employ Mixup strategy[zhang2018mixup] to augment the training samples, which give us a stronger baseline. To build the Feature Bank and Label Bank at the first epoch, we use the model pre-trained only on labelled training samples. And during the testing stage, we directly construct neighborhood sub-graph by retrieving neighbors from the Feature Bank built in the training stage for each test sample.
5.2 Ablation Studies
5.2.1 Density-Aware Neighboring Aggregation
|No. of labelled images||1000||4000|
|h = 0(baseline)||10.96||43.34|
|h = 1||9.18||40.33|
|h = 2||9.20||39.60|
|h = 3||9.26||40.28|
Effectiveness of Neighborhood Aggregation. We propose to aggregate the neighboring features to enhance the feature of the target instance, thus improving the performance of the semi-supervised image classification. To show the effectiveness of neighborhood aggregation, we conduct a baseline experiment without neighborhood aggregation and provide the comparison results in Tab.1. It shows that neighborhood aggregation (“NA w/o density” and “NA with density”) can significantly improve the baseline without neighborhood aggregation (“Baseline”). To have a deeper analysis, we further visualize the learned feature embedding of neighborhood aggregation and that of the baseline on the CIFAR10 validation set in Fig. 3, which clearly shows that incorporating neighbourhood features can help learn more discriminative feature embeddings.
Is density improving Neighborhood Aggregation? When learning the aggregation weights, besides considering the feature similarity between target node and its neighbors, we believe incorporating the density information of each neighbor for aggregation weight learning is also very important based on the density-peak assumption. To verify it, we compare the proposed density-aware neighbor aggregation (“NA with density”) with the version without considering the density (“NA w/o density”). The results in Tab.1 show that incorporating the density information of neighbors into the learning of aggregation weight can generally achieve superior performance.
Study about Sub-Graph size. By experiments, we find selecting a good neighbor number and sub-graph depth (“hop”) is crucial to get the best performance. First to study the influence of , we only consider different number of neighbors in the first hop (). The results in Fig 4 show that a too large or too small number of neighbors will both result in inferior results. This is because a too small number of neighbors will not get sufficient neighbouring information while a too large number of neighbors will introduce unrelated neighbours which may weaken the effectiveness of neighborhood aggregation, which is consistent with the results in [mne_iccv19]. We then study if incorporate multi-hop neighbors () can bring performance gain and show the results in Tab. 2. It can seen that incorporating two hops of neighbors can bring addition gain on CIFAR100, while yield no addition gain on CIFAR10. On the other hand, sampling the third hop of neighbors will degrade the performance on both datasets. We think it is because more unrelated samples may also be introduced as the neighbors hop increases, thus impairing the target feature.
5.2.2 Density-Ascending Path Label Propagation
|No. of labelled images||1000||4000||4000||10000|
|No. of labelled images||1000||4000|
|No. of labelled images||500||1000||2000||4000||4000||10000|
Comparison with state-of-the-art methods on CIFAR10 and CIFAR100. Average error rate and standard deviation of 5 runs with different labelled/unlabelled splits are reported.
|LP + MeanTeacher [iscen2019label]||72.78||57.35|
Effectiveness of Density-Ascending Path. In this part, we study the effectiveness of our proposed Density-ascending Path-based Label Propagation(DPLP) which consists of two main sequential steps: first construct a density-ascending path from each labelled sample(denoted as LP-L), then construct a density-ascending path from each remaining unlabelled sample(denoted as LP-LU). Here we study these two steps respectively in Tab. 3. It shows that: (i) Density-Ascending Path-based Label Propagation can significantly improve the classification accuracy; (ii) Constructing the density-ascending path only from labelled samples(LP-L) has already significantly improved the baseline; (iii) Constructing the density-ascending path from remaining unlabelled samples(LP-LU) can further bring additional gain.
Distribution of the length of Density-Ascending Path. As elaborated before, our density-ascending path-based label propagation constructs the path based on the ascending density constraint and terminates according to the feature similarity constraint. Though we have already demonstrated its effectiveness, we are still curious about the distribution of the length of the density-ascending path. In Fig. 5, we show the distribution of the density-ascending path length on CIFAR10(1k labelled samples) and CIFAR100(4k labelled samples) respectively. We can observe that the length of density-ascending path on CIFAR10(1k labelled samples) and CIFAR100(4k labelled samples) have a similar distribution, and is mostly between in 5 and 25. Therefore, it is efficient to online perform label propagation on this density-ascending path during the training stage.
5.2.3 Density Aware Graph-based SSL Framework
In the previous subsections, we have demonstrated the effectiveness of each individual component, now we will study the effectiveness of the overall framework(DAG), , the combination of Density-aware Neighborhood Aggregation(DNA) and Density-ascending Path-based Label Propagation(DPLP). The results on CIFAR10 and CIFAR100 are displayed in Tab. 4. It can be seen that DNA and DPLP are two complementary modules, and combining them can outperform the baseline by a large margin. For example, DAN and DPLP can reduce the error rate to and on CIFAR100(4k labelled samples) respectively, and their combination further reduces the error rate to .
Discussions about computational complexity. The main computation of our framework comes from global graph construction which involving kNNs retrieval, yet there already exists many highly efficient nearest neighbour searching algorithms and tools. The default tool we used is Faiss[JDH17], which can perform efficient billion-scale similarity search with GPU. And for testing, we found the overhead of kNNs search is negligible compared to the feature extraction part, our test time is almost identical to the baselines.
5.3 Comparison with state-of-the-arts
We report the results with the state-of-the-art approaches in Tab. 5 and Tab. 6. To show our superiority, we consider both state-of-the-art consistency-regularization based methods [laine2016temporal, tarvainen2017mean, luo2018smooth, athiwaratkun2018there, ICT2019, ke2019dual, berthelot2019mixmatch] and pseudo-label based methods [shi2018transductive, iscen2019label] in Tab. 5 for CIFAR10 and CIFAR100. Among them, ICT [ICT2019] and MixMatch [berthelot2019mixmatch] both leveraged the Mixup data augmentation strategy [zhang2018mixup], which is also used in this work. It can be seen that our method outperforms most state-of-the-art methods on CIFAR10 and CIFAR100 in terms of different numbers of labelled samples. On the more challenging Mini-ImageNet benchmark, our method achieves the best performance and records a new state-of-the-art for 4k labelled samples and for 10k labelled samples respectively, which beats latest best results [iscen2019label] by 14.32% and 10.07%.
Although existing SSL methods are based on the common density-based cluster assumption and achieve impressive results, we find three limitations exist: 1) They have not exploited density information explicitly; 2) Neighborhood information is not considered when learning the feature; 3) Existing label propagation scheme can only be done offline and difficult to be end-to-end trained; In this paper, we propose a novel and unified density-aware graph based framework for semi-supervised visual recognition. Specifically, we propose two novel density-aware modules targeting at the two key SSL components respectively, i.e., Density-aware Neighborhood Aggregation and Density-ascending Path-based Label Propagation. These two modules can be jointly trained and work in a complementary way. Experiments demonstrate our superior performance, which beats current state-of-the-art methods by a large margin.
This work is supported by the Fundamental Research Funds for the Central Universities (WK2100330002, WK3480000005), National Key Research and Development Program of China(2018YFB0804100).