Robust conditional GANs under missing or uncertain labels

06/09/2019 ∙ by Kiran Koshy Thekumparampil, et al. ∙ 9

Matching the performance of conditional Generative Adversarial Networks with little supervision is an important task, especially in venturing into new domains. We design a new training algorithm, which is robust to missing or ambiguous labels. The main idea is to intentionally corrupt the labels of generated examples to match the statistics of the real data, and have a discriminator process the real and generated examples with corrupted labels. We showcase the robustness of this proposed approach both theoretically and empirically. We show that minimizing the proposed loss is equivalent to minimizing true divergence between real and generated data up to a multiplicative factor, and characterize this multiplicative factor as a function of the statistics of the uncertain labels. Experiments on MNIST dataset demonstrates that proposed architecture is able to achieve high accuracy in generating examples faithful to the class even with only a few examples per class.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Conditional GAN (cGAN) has been applied to several domains for various tasks, such as improving image quality, reinforcement learning, and category transformation

(Mirza & Osindero, 2014; Ledig et al., 2016; Zhu et al., 2017; Odena et al., 2016). As opposed to a standard GAN, a conditional GAN is trained using labeled samples which provide additional useful information, which could be utilized to generate better quality samples (Brock et al., 2018). However, it is costly to obtain accurate class labels for all the samples. Instead, we might choose to collect accurate labels for a few examples, and either leave most examples without labels or find cheaper ways to collect less accurate labels. In this paper, we consider a class of such economically collected labels, which we call uncertain labels. We provide a robust cGAN architecture with finite sample performance guarantees and empirically verify the its performance for the case of missing labels.

Notation. ,

is the all ones vector,

is the -th standard basis vector (with appropriate dimensions),

is the identity matrix,

denotes a diagonal matrix with as the diagonal, and for we define .

Uncertainty model. Let be a data point having a true label

drawn from a joint distribution

. We consider a semi-supervised setting, where we observe only a few examples with correct labels. The remaining examples have labels that are corrupted by uncertainty. Concretely, there is an additional set of labels . Having an example with observed label , for example, means we are uncertain about the true label , but we have some information about it according to the observed label . A common example is the standard semi-supervised setting where , and the class indicates that the label is missing. Another example is when the crowd is asked to give a membership, instead of a definite class, where a label might mean that the example has one of three labels but we are uncertain about which one. We refer to the set of true labels as class labels and the set of corrupted labels as uncertain labels.

We assume that each data point is corrupted independently and with a certain probability conditioned on the true label by an erasure channel. Formally, each

is drawn according to a confusion matrix

where . Unlike the standard noisy label setting, we only consider uncertain labels; if you observe one of the class labels, then you are certain that it is the correct label. Otherwise, each uncertain label has an uncertainty set that the label could have been generated from. Formally, an uncertain label is parameterized by a vector , where if and if . The zeros follow from the fact that the true label cannot be an uncertain label. It immediately follows that . Under such an uncertainty model, the confusion matrix can be written as

(1)

This captures a variety of label corruption models:

  1. [label=()]

  2. Missing labels: If portion of the samples have their labels missing, then we can can incorporate the missing labels into our model as the uncertain class , with .

  3. Complementary labels (Ishida et al., 2017): A complementary label specifies that a sample does not belong to a particular class. Let all samples from each class are assigned a complimentary label uniformly at random from . Then the complimentary label which specifies the exclusion from class could be denoted by the uncertain label with .

  4. Group (membership) labels: Group label specify if a sample belongs to a subset of classes or not. For example, if the original classes are: car, bus, horse, cat, then we could divide them into two super group labels: automobile, animal. It can easily be shown that this is a special case of our uncertainty model.

Contribution. In this paper, we design a new adversarial training of deep generative models, which is robust against uncertainty models discussed above. The main idea is to intentionally corrupt the label of generated examples, and have a discriminator distinguish the real and generated : data example and corrupted label , jointly. We showcase the robustness of this proposed approach both theoretically and empirically. First, we show that minimizing the proposed loss is equivalent to minimizing true divergence between real and generated up to a multiplicative factor (Theorems 1 and 2). This multiplicative factor characterizes how the performance depends on the uncertainty parameters ’s. We further provide sample complexity of achieving the same guarantee in Theorem 3. Experiments on MNIST dataset demonstrates that proposed architecture is able to achieve 97% accuracy in generating examples faithful to the class even with only a few labeled examples per digit.

Related work.

As semi-supervised learning was one of the initial motivations of training deep generative models, training a GAN with a few labeled examples has been an important topic of interest.

Salimans et al. (2016)

used (unconditional) GAN as a proxy for training a semi-supervised classifier.

Sricharan et al. (2017) proposed training conditional GANs, but using two discriminators: one for distinguishing real and generated and another for distinguishing real and generated . Lucic et al. (2019)

proposed training a conditional GAN by first training a classifier using off-the-shelf semi-supervised techniques, and then using this to complete the missing labels with the help of an additional self-supervised discriminator. They get high-fidelity images, trained on ImageNet data.

Xu et al. (2019) studied training classifiers under complementary labels.

For the rest of the manuscript, if is the distribution of the true labeled data, then denotes the distribution of the corrupt labeled data corrupted by the the uncertainty model represented by in eq. (1).

2 Robust cGAN (RCGAN) architecture

We suppose that we know the confusion matrix

. It is easy to estimate, for example, when the only uncertain label is the missing label (assuming known marginal

as usual for cGANs). We propose the robust conditional GAN (RCGAN) architecture, inspired from the RCGAN for noisy labeled data (Thekumparampil et al., 2018). RCGAN uses the following adversarial loss :

(2)

where is the conditional discriminator, is the conditional generator, is the domain of input latent , and and

are some loss functions. The discriminator and generator update steps (in order) are given by:

where is the family of conditional discriminators, and is the family of conditional generators. Note, that the generated sample is a function of latent vector with distribution and is conditioned on the true label generated according true marginal .

The first expectation is estimated with the corrupted real labeled samples, whose distribution is . The second expectation is taken over the generator input latent () distribution , the true class marginal , and the distribution, (-th row of the confusion matrix), of the corrupted label given the true label . That is, the true label , of the generator samples are artificially corrupted to , by the same uncertainty model which corrupted the real data. Thus the discriminator computes a distance between the corrupted real labeled distribution and the corrupted generated labeled distribution, denoted by and in Section 2.1 we reason why minimizing this distance would minimize the distance between the true real and generated distributions . For this loss we use the projection discriminator (Miyato & Koyama, 2018) of the form discribed in Section 2.1.

Figure 1: RCGAN: The output of the generator is paired with an uncertain label , which is corrupted by the same uncertainty model, , which corrupted the uncertain real label . The discriminator estimates whether a given labeled sample is coming from the real data or the generated data .

2.1 Theoretical Analysis of RCGAN

We see that our proposed RCGAN loss (2) minimizes a divergence, between the distribution, , of the given corrupt real samples and distribution, , of the generated samples whose labels are artificially corrupted by the same uncertainty model, , which corrupted the real data, where,

(3)

When is the set of all functions with range , this divergence reduces to the standard GAN losses: (a) the total variation distance when (up to some scaling and shifting) and (b) the Jensen-Shannon divergence when (

is the Kullback-Leibler divergence). Next, we provide some approximation guarantees on these divergences to motivate our proposed architecture which corrupts the generated samples.

Theorem 1.

Let and be two distributions on . Let and be the corresponding distributions when samples from are passed through the erasure channel given by the confusion matrix (eq. (1)). If is full-rank (), and , we get,

(4)
(5)

A proof is provided in Appendix A.1.1. These bounds imply that minimizing the divergences between the corrupt distributions will minimize the divergence between the true distributions . However, these divergences do not generalize under finite sample assumptions, therefore we study a more practical GAN loss, called the neural network distance which could generalize (Arora et al., 2017). We say that the divergence

is a neural network distance when the class of discriminators

is parameterized by a finite set of variables (like in a neural network). For simplicity, we assume that .

For deriving similar approximation bounds as in Theorem 1, we make the simple Assumption 1 (Appendix A.1.2) on the discriminator function class (Thekumparampil et al., 2018). It is easy to show that the state-of-the-art projection discriminator (Miyato & Koyama, 2018), will satisfy the assumption, when it has the following form:

where , are any neural networks parameterized by , , and such that (Thekumparampil et al., 2018). This constraint on can be easily implemented through weight clipping. Next we show that, the neural network distance satisfies similar guarantees as the total variation distance.

Theorem 2.

Under the same assumptions as in Theorem 1, if a class of functions satisfies Assumption 1, then

(6)

where .

Similar to that of Theorem 1, a proof of the above theorem follows from Thekumparampil et al. (2018, Theorem 2 ). This justifies the proposed RCGAN architecture to learn the true conditional distribution from corrupted labels. However, in practice, we observe only samples from each of the distributions , , and we minimize the empirical divergence between the empirical distributions, , of these samples (Thekumparampil et al., 2018). Using recent generalization results (Arora et al., 2017), we can show that minimizing this empirical neural network distance would minimize the distance between the true distributions up to an additive error which vanishes with , as follows.

Theorem 3.

Under the same assumptions as in Theorem 2, for any class of bounded functions, which is parameterized by and is -Lipschitz in , satisfying Assumption 1, there exists a universal constant such that

with probability at least for any and large enough, , where .

A proof of this result directly follows from Thekumparampil et al. (2018, Theorem 3) and Theorem 2. For more details and discussion of thes results see Thekumparampil et al. (2018). Next we study some special cases of uncertainties.

2.2 Learning from few labels

Assume that the true label of a sample is erased by an erasure channel with probability . As mentioned in Section 1, these missing labels could be captured by an uncertainty model with a single uncertain label , defined by the vector , and confusion matrix given by

(7)

From Theorems 1 and 2 we can get the following corollary.

Corollary 1.

Under the same assumptions as in Theorems 1 and 2 with given by eq. (7), if , we get,

(8)
(9)
(10)

If for all classes , , then RHS becomes , which is expected since for this case labels are independent of the samples and recovery of true distribution is infeasible. As a special case, when the fraction of the labels are missing uniformly at random, we have .

2.3 Complementary labels

Here, we assume that fraction of the real class labels are changed to one of their corresponding complementary labels at random, i. e. for a real sample , with probability its label is changed to an uncertain label saying ‘ is not from the class ’ where is selected uniformly at random from . As discussed in Section 1, we can capture this corruption by an uncertainty model with a set of uncertain classes, , such that , and a confusion matrix,

(11)

Again using Theorems 1 and 2, we get the following guarantee.

Corollary 2.

Under the same assumptions as in Theorems 1 and 2 with given by eq. (11), if , and and , we get,

(12)
(13)
(14)

The multiplicative factor can be tightened further with additional simple assumptions on the discriminator architecture.

3 Experiments

#labels () RCGAN S3-GAN
80 0.977 0.001 0.851 0.014
60 0.974 0.001 0.795 0.018
40 0.978 0.000 0.758 0.031
30 0.971 0.004 0.726 0.025
20 0.918 0.029 0.596 0.031
10 0.838 0.044 0.414 0.027
ClusterGAN (permutation corrected): 0.901 0.014
(a) Generated label accuracy
#labels () RCGAN S3-GAN
80 0.916 0.005 0.880 0.006
60 0.908 0.005 0.842 0.013
40 0.913 0.007 0.799 0.023
30 0.910 0.009 0.769 0.019
20 0.874 0.024 0.644 0.040
10 0.791 0.042 0.474 0.023
ClusterGAN (permutation corrected): 0.855 0.015
(b) Label recovery accuracy
Table 1: Average metrics (standard error) for RCGAN & S3-GAN trained with MNIST dataset with very few number of labels ().

For evaluating the empirical performance of RCGAN we consider the case of uniformly missing true class labels (Section 2.2) in MNIST dataset of handwritten digits (LeCun, 1998). For training we use all the

k samples of MNIST, however only a fraction of these are labeled. We use two different metrics to evaluate the trained conditional generators: (a) generated label accuracy; and (b) label recovery accuracy. For more details on the architectures, training hyper-parameters and evaluation metrics, and more results please refer Appendix

A.2.

As a proof of concept, first, we show that RCGAN learns the true conditional distribution when only a significantly small fraction () of the samples have labels. We see that RCGAN gets 99% accuracy on both metrics even when only 20% of the samples are labeled (Table 2). However, when is below 5% we get poor performance, which we address in the next section.

Fraction
labeled ()
Generated
label accuracy
Label recovery
accuracy
1.0 0.992 0.924
0.8 0.993 0.926
0.6 0.991 0.908
0.4 0.994 0.916
0.2 0.988 0.926
0.1 0.983 0.910
0.05 0.162 0.420
0.025 0.122 0.234
Table 2: Generated label accuracy and Label recovery accuracy of RCGAN trained on MNIST dataset with only an fraction of samples being labeled (1 trial for each setting).

3.1 Learning from extremely few labels

In this section we look at the case when only a very few number, , of samples are labeled. Since the fraction of labeled samples are extremely small we use the following modified loss function, RCGAN(), to boost the signal from the labeled samples.

(15)

where . It is easy to show that, in expectation, this loss is equivalent to the RCGAN loss when fraction of the labels are missing. Therefore, with sufficient number of samples, the above loss can recover the true conditional distributions. In our experiments, we use , and the first two expectations are computed with all the available real and generated samples, and the latter two expectations are computed with only the labeled real and generated sample. Note that, all the terms use the same discriminator network.

As a baseline, we consider the recently proposed S3-GAN (Lucic et al., 2019), which uses self(-semi)-supervised learning techniques and projection discriminator to achieve state-of-the-art image quality metrics from few labels in ImageNet dataset. We also provide the permutation corrected metrics achieved by the unsupervised ClusterGAN (Mukherjee et al., 2018) which learns conditional GAN from unlabeled data. We see that RCGAN consistently out performs S3-GAN on both the metrics (Tables 0(a) and 0(b)). We also note that RCGAN is easier to implement than S3-GAN due to latter’s pre-processing step, and S3-GAN is slower to converge.

In Figure 2 (in Appendix A.2), we provide the samples generated by the RCGAN and S3-GAN architectures for . In each setting, each row corresponds to a class learned by the corresponding conditional generator. We see that RCGAN produces more number of higher quality samples from the correct classes than S3-GAN which produces more number of lower quality samples from the wrong classes.

We hypothesize that this gain of RCGAN over the baselines would be more pronounced on more complex datasets such as CIFAR (Krizhevsky & Hinton, 2009) and ImageNet (Russakovsky et al., 2015).

4 Conclusion

We proposed a robust conditional GAN (RCGAN) architecture which was theoretically shown to be robust to a general class of uncertain labels. This class of uncertain labels can capture a variety of label corruption models such as missing labels, complementary labels, and group memberships label. Further, we empirically verified its robustness on MNIST dataset when only a few labels are given. RCGAN was able to achieve 97% accuracy even with a few labeled examples per class.

References

Appendix A Appendix

a.1 Additional theoretical results and proofs

a.1.1 Proof of Theorem 1

Proof.

From Thekumparampil et al. (2018, Theorem 1), we get that, . Next, using Woodbury matrix inversion identity (Henderson & Searle, 1981) on (1), we can show that , which implies that . We can further tighten the upper-bound by noting that . Inequalities for Jensen-Shannon divergence also follow from the same reasoning. ∎

a.1.2 Invariance Assumption

For deriving similar approximation bounds as in Theorem 1, we make the following simple assumptions on the discriminator function class (Thekumparampil et al., 2018). First, we define an operation over a matrix and a class of functions of the form as

(16)
Assumption 1.

The class of discriminator functions can be decomposed into three parts such that is any constant and

  • , for all ,

  • there exists a class of functions over such that,

a.2 Experimental details and additional results

For the experiments in Section 3, with only fraction of the samples labeled, we generate the corrupted dataset by independently labeling each sample with probability . We only report results from 1 trial for each of the settings. Assuming that the prior of the true classes are known, it is easy to estimate the confusion matrix (7), which will be .

For the experiments in Section 3.1 with very small number of labeled samples, we allocate the labeled samples equally across the classes and within each class the labeled samples are selected uniformly at random (). For each setting we provide mean and standard error over 5 trials, except for RCGAN when , for which we ran 10 trials.

For RCGAN, S3-GAN (Lucic et al., 2019), and ClusterGAN (Mukherjee et al., 2018) we use the same underlying discriminator and generator architectures as Thekumparampil et al. (2018). For the modified loss (15) we use after a simple parameter search. For S3-GAN we use (Lucic et al., 2019). S3-GAN uses self(-semi)-supervised pre-processing step to estimate the true labels, for which we used (Lucic et al., 2019). For the pre-processing step, we use a standard CNN classifier architecture which can get 99+% accuracy on fully labeled MNIST dataset. For ClusterGAN, we use (Mukherjee et al., 2018)

. We train the RCGAN and ClusterGAN for 30 epochs, and S3-GAN for 100 epochs since it was slow to converge.

The two metrics were proposed by Thekumparampil et al. (2018). Generated label accuracy is the accuracy of the generated labels, as per a pre-trained classifier with a high accuracy (99.2%) as mentioned in Thekumparampil et al. (2018). We use this classifier to predict the labels of the generated images, which are then compared with the generated labels to compute this accuracy. This is a measure of correctness of the class label () conditioning in the generator output. Label recovery accuracy is the accuracy with which the learned generator can be used to recover the true class labels of the unlabeled samples in the training data, using simple back-propagation on the conditional generator (Thekumparampil et al., 2018). This is a measure of the quality and coverage of the generated samples (given the generated label accuracy is high).

Since ClusterGAN is trained without any labels in an unsupervised fashion, for it we report the same metrics but after permutation correction. That is, we report the minimum metric values possible over all possible permutations of the classes learned by the conditional generator.

#labels () S3-GAN
100 0.725 0.012
80 0.673 0.009
60 0.625 0.010
40 0.580 0.017
30 0.544 0.018
20 0.439 0.019
10 0.305 0.019
Table 3: Average accuracy ( standard error) of the self(-semi)-supervised classifier used in the pre-processing step of S3-GAN trained with MNIST dataset with very few number of labels ().

Finally we report the accuracy of the self(-semi)-supervised classifier from the pre-processing step of S3-GAN as a measure of the its ability to understand the true classes of the unlabeled training data. We see that the classifier has low accuracy when very few samples are labeled (Table 3), which could explain the low performance of S3-GAN when compared to RCGAN.

Figure 2: Samples generated by RCGAN and S3-GAN when trained on MNIST dataset with labels. Each row is one class as learned by the corresponding conditional generator.