Metric learning aims to construct task-specific distance from supervised data, where learned distance metrics can be used to perform various tasks such as classification, clustering, and information retrieval. We can divide the extension of conventional metric learning or Mahalanobis metric learning algorithms into two directions: those models based on deep metric learning and those based on Bregman divergence learning. Deep metric learning uses neural networks to automatically learn discriminative features from samples and then compute the metric such as contrastive loss in Siamese network koch2015siamese or triplet loss in the triplet network hoffer2015deep. Bregman divergences learning aims to generalize measures such as Euclidean distance bregman1967relaxation and KL divergences painsky2019bregman. These methods learn the underlying generating function of the Bregman divergence using piecewise linear approximation siahkamari2020learning, or by adding quantization rates to the existing basis of functional Bregman for clustering liu2016clustering.
Recently, deep divergence learning cilingir2020deep; kampffmeyer2019deep
was introduced to learn and parameterize functional Bregman divergence using linear neural networks. This method measures divergences of data by approximating functional Bregman divergences by training deep neural networks where it aims to reduce the distance between feature vectors corresponding to the same class and increase the distance between the feature vectors corresponding to different classes A key advantage of this method is to shift the divergence learning focus from learning divergence between single points to capturing the divergence between the distribution of a vector of pointscilingir2020deep. This can potentially introduce new desirable properties to a learned feature space. Nevertheless, learning Bregman divergence has not yet seen such widespread adoption and it remains a challenging endeavor for the representation learning and self-supervision based on distance reduction (i.e. similarity maximization).
Self-supervised learning is amongst the most promising approaches for learning from limited labeled data. In contrast to supervised methods, these techniques learn a representation of the data without relying on human annotation. Early self-supervised methods focused on solving a pretext task
such as predicting the rotation of images, the relative position of patches, and colorizationlee2017unsupervised; noroozi2016unsupervised; pathak2016context; zhang2016colorful. Most recent successful self-supervised methods use contrastive loss and maximize the similarity of representation obtained from different distorted versions of a sample. Instance discrimination wu2018unsupervised, CPC henaff2019data; oord2018representation, Deep InfoMax hjelm2018learning, AMDIM bachman2019learning, CMC stojnic2021self, PIRL misra2020self, MoCo he2020momentum, and SimCLR chen2020simple; chen2020big are examples of such contrastive learning methods that produce representations that are competitive with supervised ones. However, existing deep contrastive learning methods are not directly amenable to compare the data distributions.
In this paper, we extend existing contrastive learning approaches in a more profound way using the contrastive divergence learning strategy which learns a generalized divergence of the data distribution by jointly adopting a functional Bregman divergence along with a contrastive learning strategy. In summary, we train our framework end-to-end in four stages: (1) Transformation: we sequentially apply random color distortions, random rotations, random cropping followed by resizing back to the original size, and random Gaussian blur. Then, we train our network with two distorted images only. (2) Base network: learns representations on the top of distorted samples using deep neural networks. These representations are projected onto a lower-dimensional space by the projection head. We call the output of the base network the representations and the output of the projector the embeddings. (3) Bregman divergence network: the different deep linear neural subnetworks create the convex generating function and are able to compute divergence between two distributions given embeddings. (4) Contrastive divergence learning: Our network is trained end-to-end by combining a traditional contrastive loss and a new divergence loss. As depicted in Fig. 1, the contrastive loss computed over embeddings and our novel divergence loss formulated over the output of subnetworks. Our main findings and contributions can be summarized as follows:
We propose a novel framework for self-supervised learning of Bregman divergence of visual representations. The proposed method learns deep Bregman divergences which are beyond Euclidean distance and capable of capturing divergence over distributions while also benefiting from contrastive learning representation.
We propose a contrastive divergence loss which encourages each subnetwork to focus on different attributes of an input. The combination of new contrastive divergence loss with recent modified contrastive loss (i.e. NT-Xent, NT-Logistic) improves the performance of contrastive learning for multiple tasks including image classification and object detection over multiple baselines.
We show empirical results that highlight the benefits of learning representation from our deep Bregman divergence and contrastive divergence loss.
Our method is comparable with methods involving measuring or learning divergence over distributions such as contrastive divergence carreira2005contrastive, stochastic Bregman divergence dragomir2021fast, maximum mean discrepancy xing2002distance, and most recently deep divergence learning cilingir2020deep; kampffmeyer2019deep; kong2020rankmax; kato2021non. Another related work is Rankmax kong2020rankmax which studies an adaptive projection alternative to the softmax function that is based on a projection on the simplex with application in multi-class classification. However, our proposed framework differs in representation learning strategy and applications.
Much of the early works on self-supervised learning focused on the problem of learning embeddings without labels such that the linear classifier operating on the learned embeddings from self-supervision could achieve high classification accuracydoersch2015unsupervised. Later, some models aim to learn the representation using auxiliary handcrafted prediction tasks. Examples are image jigsaw puzzle noroozi2016unsupervised, relative patch prediction doersch2015unsupervised; doersch2017multipathak2016context
, image super-resolutionledig2017photo; yuksel2021latentclr.
Contrastive learning of is amongst the most successful self-supervised method to achieve linear classification accuracy and outperforming supervised learning tasks by suitable architectures and loss caron2020unsupervised; chen2020simple; chen2020big; chen2020intriguing; zbontar2021barlow, using pretraining in a task-agnostic fashion kolesnikov2019revisiting; shen2020mix, and fine-tuning on the labeled subset in a task-specific fashion wu2018unsupervised; henaff2019data. However, Hjelm et al. hjelm2018learning showed the accuracy depends on a large number of negative samples in the training batch. Self-supervised training with a large batch size is computationally expensive for high-resolution images. BYOL grill2020bootstrap and SimSiam chen2021exploring mitigated this issue by an additional prediction head and learning the latent representations of positive samples only. Robinson et al. robinson2020contrastive proposed a technique for selecting negative samples and the whitening procedure of ermolov2021whitening despite success is still sensitive to batch size. MoCo he2020momentum; chen2020improved alleviates this problem using a memory-efficient queue of the last visited negatives, together with a momentum encoder that preserves the intra-queue representation consistency. Most recently, SSL-HSIC li2021self
proposed a method to maximize dependence between representations of transformed versions of an image and the image identity, while minimizing the kernelized variance of those features. VICRegbardes2021vicreg proposed a regularization term on the variance of the embeddings to explicitly avoid the collapse problem. In this paper, we developed and studied the impact of -different parameterized convex linear neural networks with functional Bregman divergence on top of a simple contrastive learning framework chen2020simple; chen2020improved.
Problem Formulation and Approach
Our goal is two fold: First, we aim to learn representations using contrastive loss on the top of embeddings via representation network. Second, we learn a deep Bregman divergence by minimizing divergence between samples from the same distribution and maximizing divergence for samples from different classes and different distributions. As depicted in Fig. 1, our proposed method includes two sequentially connected neural networks: the representation network and the deep Bregman divergence network. In the following sections, we first describe contrastive learning in the context of visual representation. Then we discuss Bergman divergence, functional Bregman, and deep Bregman divergence network. Next, we describe our framework and the proposed loss.
Given a randomly sampled mini-batch of images with samples, contrastive learning aims to learn an embedding function by contrasting positive pairs against negative pairs . First, we generate a positive pair sharing the same semantics from each sample in a mini-batch by applying standard image transformation techniques. For each positive pair, there exists negative examples in a mini-batch. The encoder network (e.g. ResNet-50 he2016deep) encodes distorted positive and negative samples to a set of corresponding features. These features are then transformed with a projection MLP head chen2020simple which results in and
. The contrastive estimation for a positive pair of examplesis defined as:
is cosine similarity between two vectors,is the number of samples in a mini-batch, and is a temperature scalar. Loss over all the pairs formulated as:
Bregman Divergence Learning
Bregman divergence parametrizes by a strictly convex function on convex set , where is continuously-differentiable on relative interior of . The Bregman divergence associated with for data point calculated by:
The well-known examples of Bregman divergence is the squared Euclidean distance parametrized by ; the KL-divergence parameterized by ; and the Itakura-Saito distance parametrized by
. Bregman divergences appear in various settings in machine learning and statistical learning. In optimization, Bregman et. albregman1967relaxation
proposed Bregman divergences as part of constrained optimization. In the unsupervised clustering, Bregman divergences provide a solution to extend the K-means algorithm beyond the convenience of the squared Euclidean distancebanerjee2005clustering. In this paper, we use an extension of standard Bregman divergences called functional Bregman divergences.
A functional Bregman divergence frigyik2008functional; ovcharov2018proper generalizes the standard Bregman divergence for vectors and it measures the divergence between two functions or distributions. Given two functions and , and a strictly convex functional defined on a convex set of functions which output in , the functional Bregman divergence formulated as:
Same as the vector Bregman divergence, the functional Bregman divergence holds several properties including convexity, non-negativity, linearity, equivalence classes, linear separation, dual divergences, and a generalized Pythagorean inequality.
Deep Bregman divergence cilingir2020deep parametrize the functional Bregman divergence by weight functions and biases with assumption that every generating convex functional can be expressed in terms of linear functional. For the set of linear functional , the defines as:
and based on the underlying generating convex functional , functional Bregman divergence can be expressed as:
with , are given by empirical distributions over input points; ; and is defined same as .
Therefore, we can train deep functional Bregman divergence if each of the weights and bias functions is given by separate linear neural networks (Fig. 2).
Contrastive Divergence Learning
Given a randomly sampled mini-batch of images , our method takes an augmentation set and draw two random augmentation and to produces two distorted images and for a single image . The distorted samples are encoded via base network to generate corresponding representations, and .
Next, we perform divergence learning by adopting a functional Bregman ( Fig. 2). Consider and as empirical distributions over and , respectively. We parametrize our deep divergence with weight function and biases . Each sub-network takes and produces a single output . Consider as the index of the maximum output and index of the maximum output across the subnetworks. Now, the divergence is the difference between the output of at and the output of at .
Considering each of the outputs corresponds to a different class: the divergence is zero when both points achieve a maximum value for the same class, and it is non-zero otherwise. The divergence increases as the two outputs become more separated.
Our method is trained with a combination of two losses; one based on discriminative features by representation network and another based on the Bregman divergence output by subnetworks. In this paper, we estimate noise contrastive loss on top of representation vectors similar to Eq. 2
. The output of the representation network results in well-separated inter-product differences while the deep features learned bydifferent Bregman divergences result in well discriminative features with compact intra-product variance. Therefore, the combination of these is key to have better visual search engines. In addition, learning such discriminative features enables the network to generalize well on unseen images.
We convert Bregman divergence (Eq. 6) to similarity using a Gaussian kernel (where is adjustable parameter). All the divergences obtained with the same network are viewed as positive pairs while all other divergences obtained with a different network are considered as negative pairs. We enforce each subnetwork to have a consistent but also orthogonal effect on the feature. The divergence loss between and for a mini-batch of the size of define as:
and the total loss calculated by:
One advantage of our framework is learning over distributions and we do not restrict ourselves only to divergences between single points. We can also capture divergences between distributions of points similar to the maximum-mean discrepancy and the Wasserstein distance. Example applications are data generation, semi-supervised learning, unsupervised clustering, information retrieval, and Ranking. To empirically compare our proposed framework to existing contrastive models, we follow standard protocols by self-supervised learning and evaluate the learned representation by linear classification and semi-supervised tasks as well as transfer learning to different datasets and different computer vision tasks.
Image augmentation. We define a random transformation function that applies a combination of crop, horizontal flip, color jitter, and grayscale. Similar to chen2020simple, we perform crops with a random size from to of the original area and a random aspect ratio from to
of the original aspect ratio. We also apply horizontal mirroring with a probability of. Then, we apply grayscale with probability , and color jittering with probability and with configuration
. However, for ImageNet, we define the stronger jittering, crop size from to , grayscale probability , and Gaussian blurring with probability and . In all the experiments, at the testing phase, we apply only resize and center crop (standard protocol).
Deep representation network architecture Our base encoder consists of a convolutional residual network he2016deep
with 18 layers and 50 layers with minor changes. The network is without the final classification layer instead it has two nonlinear multi-layer perceptrons (MLP) with rectified linear unit activation in between. This MLP consists of a linear layer with input size 1024 followed by batch normalization, rectified linear unit activation, and a final linear layer with output dimension 128 as embedding space. The embeddings are fed to the contrastive loss and used as an input for our deep divergence networks.
Deep divergence network architecture (subnetworks) We implemented
-adaptive subnetworks on top of the MLP projection head. Many possible architectures are suitable to capture this type of network; we consider a simple architecture where each network includes 2-layer MLPs with 128, 64, 1 hidden nodes, follow by batch normalization. We do not include activation between the layers to maintain Bregman properties and convex network. Each subnetwork own independent set of weights. We perform a Bayesian hyperparameter search to find the best.
Optimization We use the Adam optimizer kingma2014adam with a learning rate , , , and weight decay
without restarts, over 500 epochs (similar to MoCo configuration). We convert Bregman divergences to similarity score using a strictly monotone decreasing function. As explained in Method, we use Gaussian kernel and set for all experiments. Temperature
sets equal to 0.1. We train our model with a mini-batch size of 512 and 2 GPUs on all small datasets, considering our primary objective is to verify the impact of our proposed method rather than to suppress state-of-the-art results. For ImageNet, we use a mini-batch size of 256 in 4 GPUs (Tesla A-100), and an initial learning rate of 0.003. We train for 500 epochs with the learning rate multiplied by 0.1 at 120 and 360 epochs, taking around five days of training ResNet-50.
Datasets and tasks We use the following datasets in our experiments: CIFAR 10/100 krizhevsky2009learning are subsets of the tiny images dataset. Both datasets include 50,000 images for training and 10,000 validation images of size with 10 and 100 classes, respectively. STL 10 coates2011analysis consists of 5000 training images and 8000 test images in 10 classes with size of
. This dataset includes 100,000 unlabeled images for unsupervised learning task.ImageNet deng2009imagenet, aka ILSVRC 2012, contains 1000 classes, with 1.28 million training images and 50,000 validation images. ISCI-2018 codella2019skin is a challenge on the detection of seven different skin cancer and part of the MICCAI-2018 conference. The organizers released 10,015 dermatology scans with a size of pixels collected from different clinics.
Experiments and Results
The common evaluation protocol for self-supervised learning is based on freezing the base encoder after unsupervised pretraining, and then train a supervised linear classifier on top of it. The linear classifier is a fully connected layer followed by softmax, which is plugged on top of after removing the MLP’s head and divergence network . Our linear evaluation consists of studies on small and large datasets including STL-10, CIFAR-10/100, ISIC-7, ImageNet with different approaches to explore and compare the effectiveness of our proposed method. Top-1 accuracy in reported on the test set of ImageNet in Table 1 and Table 2 shows the comparison of our model against the self-supervised baselines on small datasets. Based on the performance reported in Table 1 and Table 2, our method obtains up to 1.5% improvement compared with baselines MoCo-V2 or SimCLR. We also achieves a top-1 accuracy of 71.3% on ImageNet which is comparable to the state-of-the-art methods.
|SwAV caron2020unsupervised (w/o multi-crop)||71.8|
|Barlow Twins zbontar2021barlow||73.2|
We evaluate the performance of our method’s representation on a semi-supervised image classification task. In this task, we pre-train the ResNet-50 model on unlabeled ImageNet examples and fine-tune a classification model using a subset of ImageNet examples with labels. We follow the semi-supervised protocol of chen2020big and use the same fixed splits of respectively 1% and 10% of ImageNet labeled training data. Table 3 shows the comparison of our performance against state-of-the-art methods and baselines. The result indicates using the proposed method we can outperform the baseline SimCLR and MoCo-v2. We also have a slightly better or comparable results compared to sate-of-the-art methods such as BYOL.
Transfer to Other Tasks
We further assess the generalization capacity of the learned representation on object detection. We train a Faster R-CNN faster2015towards
model on Pascal VOC 2007 and Pascal VOC 2012 and evaluate on the test set of Pascal VOC. Table4 provides a comparison of transfer learning performance of our self-supervised approach for the task of object detection. We use pre-trained ResNet-50 models on ImageNet and perform object detection on Pascal VOC07+12 dataset everingham2010pascal.
Ablation Studies and Discussions
To build intuition around the behavior and the observed performance of the proposed method, we further investigate the following aspects of our approach in multiple ablation studies: (i) the impact of the Bregman divergence network on the quality of representation; (ii) the number of
subnetworks; (iii) robustness of our algorithm in invariance to augmentations and transformations. Further analysis about the combination of loss functions and training behavior is available in the Appendix.
|Number of subnetworks||5||20||50||100||200||500||1000|
|Deep Divergence cilingir2020deep||71.9||77.8||79.4||80.0||77.4||74.1||70.8|
Impact of Bregman divergence network on quality of representations We visualize the representation features using t-SNE with the last convolution layer from ResNet-18 to explore the quality of learned features using our proposed method. Figure 5 compares the representation space learned by our proposed method (a) and SimCLR (b). As depicted in Fig. 5, our model shows better separation on clusters, especially for classes 2, 4, 6, and 7.
Number of -subnetworks We trained individual deep neural networks on top of the embedding space. The input of each network was similar but they parameterized with different weights and biases. Here, we provide more details regarding our classification experiments by considering different . Fig. 6 shows the performance in term of top-1 accuracy for CIFAR-10 and Tiny ImageNet le2015tiny. Based on quantitative results shown in Fig. 6, the performance improves in both CIFAR-10 and Tiny ImageNet datasets by increasing the number of subnetworks () until a certain point, then it starts dropping possibly due to over parameterization. With a small , the performance of our network is more similar to contrastive loss. For example in case of CIFAR-10 shown in Fig. 6, when our performance is almost 90% and the performance is increased to 92.9% for larger (). This shows training our network with a correct number of can learn a better representation of data.
Table 5 compares the performance of our method with Cilingir et al. cilingir2020deep. Based on reported results in Table 5, the performance of the divergence learning framework depends on the representation feature space while our method significantly outperforms deep divergence.
Since contrastive loss is sensitive to the choice of augmentation technique and learned representations can get controlled by the specific set of distortions grill2020bootstrap, we also examined how robust our method is to remove some of data augmentations. Figure 7 presents decrease in top-1 accuracy (in % points) of our method and SimCLR at 300 epochs and under linear evaluation on ImageNet (SimCLR numbers are extracted from grill2020bootstrap). This figure shows that the representations learned by our proposed contrastive divergence are more robust to removing certain augmentations in comparison to SimCLR. For instance, SimCLR does not work well when removing image crop from its transformation set.
In this paper, we proposed and examined deep divergence for contrastive learning of visual representation. Our framework is composed of the representation learning network followed by multiple divergence learning networks. We train functional Bregman divergence on top of the representation network using -adaptive convex neural networks. The similarity matrix is formulated according to Bregman distance output by an ensemble of the networks. Then networks are optimized end-to-end using our novel contrastive divergence loss. We successfully improve over previous methods for deep metric learning, deep divergence learning, self-supervised, semi-supervised, and transfer learning. Empirical experiments demonstrate the efficacy of the proposed method on standard benchmarks as well as recent clinical datasets on both classification and object detection tasks.
This work has been funded in part by the German Federal Ministry of Education and Research (BMBF) under Grant No.01IS18036A, Munich Center for Machine Learning (MCML). F. S., M. R., and B. B. were supported by the German Federal Ministry of Education and Research (BMBF) under Grant No. 01IS18036A. We would like to thank Ali Siahkamari for helpful discussions. We are also grateful to Simon Kornblith for valuable feedback on the manuscript.
Contrastive Divergence Loss
In this section, we provide more training details related to our proposed contrastive divergence loss. As it has been mentioned previously, the divergence loss between and for a mini-batch of the size of define as:
and the total loss calculated by:
where we convert Bregman divergence (Eq. 6) to similarity using . The conversion function , must be defined as an strictly monotone function and can have many forms. In general there are multiple options to define a function to convert divergences to similarity scores ranging from a simple inverse function to a more complex Gaussian kernel. Table A.1 shows the functions that has been examined in our experiments where we achieved the best performance when the Gaussian kernel is used. Our chosen conversion function, can be adjusted using a parameter. We set in our experiments to 0.9.
|Strictly monotone function|
Figure A.1 demonstrated approximating divergences between different classes. In this example, we show that the divergence of a random sample to a specific sample in the data is equal to zero when they are from same category. In general, the outputs of the underlying convex generating functional of the input samples that are from the same category will lay on the same hyper-plane. We use the same strategy as siahkamari2019learning to produce the graph.
We use Adam optimizer kingma2014adam with a learning rate , and weight decay without restarts, over 300 epochs for all of the pretraining. Figure A.2 shows a sample learning curve of our model using the ResNet-18 base network over 100 epochs and the corresponding test loss and top-1 and top-5 accuracy on CIFAR-10 dataset. For the pre-training we use with a mini-batch size of 512 and 2 GPUs.
The metric learning and divergence learning problems are the fundamental problem in machine learning, attracting considerable research and applications. These applications include (but not limited to) uncertainty quantification, density estimation, image retrieval, unsupervised image clustering, program debugging, image generation, music analysis, and ranking. Fundamental studies in these problems will help to improve results in these applications as well as direct to additional impact in new domains.