Controllable Invariance through Adversarial Feature Learning

by   Qizhe Xie, et al.
Carnegie Mellon University

Learning meaningful representations that maintain the content necessary for a particular task while filtering away detrimental variations is a problem of great interest in machine learning. In this paper, we tackle the problem of learning representations invariant to a specific factor or trait of data. The representation learning process is formulated as an adversarial minimax game. We analyze the optimal equilibrium of such a game and find that it amounts to maximizing the uncertainty of inferring the detrimental factor given the representation while maximizing the certainty of making task-specific predictions. On three benchmark tasks, namely fair and bias-free classification, language-independent generation, and lighting-independent image classification, we show that the proposed framework induces an invariant representation, and leads to better generalization evidenced by the improved performance.


Evading the Adversary in Invariant Representation

Representations of data that are invariant to changes in specified nuisa...

Learning Invariant Representations with Local Transformations

Learning invariant representations is an important problem in machine le...

Fundamental Limits and Tradeoffs in Invariant Representation Learning

Many machine learning applications involve learning representations that...

Adversarial Learned Fair Representations using Dampening and Stacking

As more decisions in our daily life become automated, the need to have m...

Adversarial Invariant Feature Learning with Accuracy Constraint for Domain Generalization

Learning domain-invariant representation is a dominant approach for doma...

Learning Invariant Representations for Sentiment Analysis: The Missing Material is Datasets

Learning representations which remain invariant to a nuisance factor has...

Representation Learning with Multisets

We study the problem of learning permutation invariant representations t...

1 Introduction

How to produce a data representation that maintains meaningful variations of data while eliminating noisy signals is a consistent theme of machine learning research. In the last few years, the dominant paradigm for finding such a representation has shifted from manual feature engineering based on specific domain knowledge to representation learning that is fully data-driven, and often powered by deep neural networks 

(Bengio et al., 2013). Being universal function approximators (Gybenko, 1989), deep neural networks can easily uncover the complicated variations in data (Zhang et al., 2017), leading to powerful representations. However, how to systematically incorporate a desired invariance into the learned representation in a controllable way remains an open problem.

A possible avenue towards the solution is to devise a dedicated neural architecture that by construction has the desired invariance property. As a typical example, the parameter sharing scheme and pooling mechanism in modern deep convolutional neural networks (CNN) 

(LeCun et al., 1998) take advantage of the spatial structure of image processing problems, allowing them to induce more generic feature representations than fully connected networks. Since the invariance we care about can vary greatly across tasks, this approach requires us to design a new architecture each time a new invariance desideratum shows up, which is time-consuming and inflexible.

When our belief of invariance is specific to some attribute of the input data, an alternative approach is to build a probabilistic model with a random variable corresponding to the attribute, and explicitly reason about the invariance. For instance, the variational fair auto-encoder (VFAE) 

(Louizos et al., 2016) employs the maximum mean discrepancy (MMD) to eliminate the negative influence of specific “nuisance variables”, such as removing the lighting conditions of images to predict the person’s identity. Similarly, under the setting of domain adaptation, standard binary adversarial cost (Ganin and Lempitsky, 2015; Ganin et al., 2016)

and central moment discrepancy (CMD) 

(Zellinger et al., 2017) have been utilized to learn features that are domain invariant. However, all these invariance inducing criteria suffer from a similar drawback, which is they are defined to measure the divergence between a pair of distributions. Consequently, they can only express the invariance belief w.r.t. a pair of values of the random variable at a time. When the attribute is a multinomial variable that takes more than two values, combinatorial number of pairs (specifically, ) have to be added to express the belief that the representation should be invariant to the attribute. The problem is even more dramatic when the attribute represents a structure that has exponentially many possible values (e.g. the parse tree of a sentence) or when the attribute is simply a continuous variable.

Motivated by the aforementioned drawbacks and difficulties, in this work, we consider the problem of learning a feature representation with the desired invariance. We aim at creating a unified framework that is (1) generic enough such that it can be easily plugged into different models, and (2) more flexible to express an invariance belief in quantities beyond discrete variables with limited value choices. Specifically, inspired by the recent advancement of adversarial learning (Goodfellow et al., 2014), we formulate the representation learning as a minimax game among three players: an encoder which maps the observed data deterministically into a feature space, a discriminator which looks at the representation and tries to identify a specific type of variation we hope to eliminate from the feature, and a predictor which makes use of the invariant representation to make predictions as in typical discriminative models. We provide theoretical analysis of the equilibrium condition of the minimax game, and give an intuitive interpretation. On three benchmark tasks from different domains, we show that the proposed approach not only improves upon vanilla discriminative approaches that do not encourage invariance, but also outperforms existing approaches that enforce invariant features.

2 Adversarial Invariant Feature Learning

In this section, we formulate our problem and then present the proposed framework of learning invariant features.

(a) and are marginally independent
(b) and are not marginally independent
Figure 1: Dependencies between , where is the observation and is the target to be predicted. is the attribute to which the prediction should be invariant.

Given observation/input , we are interested in the task of predicting the target based on the value of using a discriminative approach. In addition, we have access to some intrinsic attribute of as well as a prior belief that the prediction result should be invariant to .

There are two possible dependency scenarios of and here: (1) and can be marginally independent. For example, in image classifications, lighting conditions and identities of persons are independent. The data generation process is . (2) In some cases, and are not marginally independent. For example, in fairness classifications, are the sensitive factors such as age and gender. can be the saving, credit and health condition of a person. and are related due to the inherent bias within the data. Using a latent variable to model the dependency between and , the data generation process is . We show the corresponding dependency graphs in Figure 1.

Unlike vanilla discriminative models that outputs the conditional distribution , we model to make predictions invariant to . Our intuition is that, due to the explaining away effect, and are not independent when conditioned on although they can be marginally independent. Consequently,

is a more accurate estimation of

than . Intuitively, this can inform and guide the model to remove information about undesired variations. For example, if we want to learn a representation of image that is invariant to the lighting condition , the model can learn to “brighten” the input if it knows the original picture is dark, and vice versa. Also, in multi-lingual machine translation, a word with the same surface form may have different meanings in different languages. For instance, “gift” means “present” in English but means “poison” in German. Hence knowing the language of a source sentence helps inferring the meaning of the sentence and conducting translation.

As the input can have highly complicated structure, we employ a dedicated model or algorithm to extract an expressive representation from . Thus, when we extract the representation from , we want the representation to preserve variations that are necessary to predict while eliminating information of . To achieve the aforementioned goal, we employ a deterministic encoder to obtain the representation by encoding and into , namely, . It should be noted here that we are using as an additional input. Given the obtained representation , the target is predicted by a predictor , which effectively models the distribution . By construction, instead of modeling directly, the discriminative model we formulate captures the conditional distribution with additional information coming from .

Surely, feeding into the encoder by no means guarantees the induced feature will be invariant to . Thus, in order to enforce the desired invariance and eliminate variations of factor from , we set up an adversarial game by introducing a discriminator which inspects the representation and ensure that it is invariant to . Concretely, the discriminator is trained to predict based on the encoded representation , which effectively maximizes the likelihood . Simultaneously, the encoder fights to minimize the same likelihood of inferring the correct by the discriminator. Intuitively, the discriminator and the encoder form an adversarial game where the discriminator tries to detect an attribute of the data while the encoder learns to conceal it.

Note that under our framework, in theory, can be any type of data as long as it represents an attribute of . For example,

can be a real value scalar/vector, which may take many possible values, or a complex sub-structure such as the parse tree of a natural language sentence. But in this paper, we focus mainly on instances where

is a discrete label with multiple choices. We plan to extend our framework to deal with continuous and structured in the future.

Formally, , and jointly play the following minimax game:



where is a hyper-parameter to adjust the strength of the invariant constraint, and is the true underlying distribution that the empirical observations are drawn from.

Note that the problem of domain adaption can be seen as a special case of our problem, where is a Bernoulli variable representing the domain and the model only has access to the target when “source domain” during training.

3 Theoretical Analysis

In this section, we theoretically analyze, given enough capacity and training time, whether such a minimax game will converge to an equilibrium where variations of are preserved and variations of are removed. The theoretical analysis is done in a non-parametric limit, i.e., we assume a model with infinite capacity. In addition, we discuss the equilibriums of the minimax game when is independent/dependent to .

Since both the discriminator and the predictor only use which is transformed deterministically from and , we can substitute with

and define a joint distribution

of and as follows

Here, we have used the fact that the encoder is a deterministic transformation and thus the distribution is merely a delta function denoted by . Intuitively, absorbs the randomness in and has an implicit distribution of its own. Also, note that the joint distribution depends on the transformation defined by the encoder.

Thus, we can equivalently rewrite objective (1) as


To analyze the equilibrium condition of the new objective (2), we first deduce the optimal discriminator and the optimal predictor for a given encoder and then prove the global optimality of the minimax game.

Claim 1.

Given a fixed encoder , the optimal discriminator outputs and the optimal predictor corresponds to .


The proof uses the fact that the objective is functionally convex w.r.t. each distribution, and by taking the variations we can obtain the stationary point for and as a function of . The detailed proof is included in the supplementary material A. ∎

Note that the optimal and given in Claim 1 are both functions of the encoder . Thus, by plugging and into the original minimax objective (2), it can be simplified as a minimization problem only w.r.t. the encoder with the following form:


where is the conditional entropy of the distribution .

Equilibrium Analysis

As we can see, the objective (3) consists of two conditional entropies with different signs. Optimizing the first term amounts to maximizing the uncertainty of inferring based on , which is essentially filtering out any information of from the representation. On the contrary, optimizing the second term leads to increasing the certainty of predicting based on . Implicitly, the objective defines the equilibrium of the minimax game.

  • [leftmargin=*]

  • Win-win equilibrium: Firstly, for cases where the attribute is entirely irrelevant to the prediction task (corresponding to the dependency graph shown in Figure 0(a)

    ), the two terms can reach the optimum at the same time, leading to a win-win equilibrium. For example, with the lighting condition of an image removed, we can still/better classify the identity of the people in that image. With enough model capacity, the optimal equilibrium solution would be the same regardless of the value of


  • Competing equilibrium: However, there are cases where these two optimization objectives are competing. For example, in fair classifications, sensitive factors such as gender and age may help the overall prediction accuracies due to inherent biases within the data. In other words, knowing may help in predicting since and are not marginally independent (corresponding to the dependency graph shown in Figure 0(b)). Learning a fair/invariant representation is harmful to predictions. In this case, the optimality of these two entropies cannot be achieved simultaneously, and defines the relative strengths of the two objectives in the final equilibrium.

4 Parametric Instantiation of the Proposed Framework

4.1 Models

To show the general applicability of our framework, we experiment on three different tasks including sentence generation, image classification and fair classifications. Due to the different natures of data of and , here we present the specific model instantiations we use.

Sentence Generation

We use multi-lingual machine translation as the testbed for sentence generation. Concretely, we have translation pairs between several source languages and a target language. is the source sentence to be translated and is a scalar denoting which source language belongs to. is the translated sentence for the target language.

Recall that is used as an input of to obtain a language-invariant representation. To make full use of , we employ separate encoders for sentences in each language . In other words, where each is a different encoder. The representation of a sentence is captured by the hidden states of an LSTM encoder (Hochreiter and Schmidhuber, 1997) at each time step.

We employ a single LSTM predictor for different encoders. As often used in language generation, the probability

output by the predictor is parametrized by an autoregressive process, i.e.,

where we use an LSTM with attention model 

(Bahdanau et al., 2015) to compute .

The discriminator is also parameterized as an LSTM which gives it enough capacity to deal with input of multiple timesteps.

is instantiated with the multinomial distribution computed by a softmax layer on the last hidden state of the discriminator LSTM.


For our classification experiments, the input is either a picture or a feature vector. All of the three players in the minimax game are constructed by feedforward neural networks. We feed to the encoder as an embedding vector.

4.2 Optimization

There are two possible approaches to optimize our framework in an adversarial setting. The first one is similar to the alternating approach used in Generative Adversarial Nets (GANs) (Goodfellow et al., 2014). We can alternately train the two adversarial components while freezing the third one. This approach has more control in balancing the encoder and the discriminator, which effectively avoids saturation. Another method is to train all three components together with a gradient reversal layer  (Ganin and Lempitsky, 2015). In particular, the encoder admits gradients from both the discriminator and the predictor, with the gradient from the discriminator negated to push the encoder in the opposite direction desired by the discriminator.  Chen et al. (2016b) found the second approach easier to optimize since the discriminator and the encoder are fully in sync being optimized altogether. Hence we adopt the latter approach. In all of our experiments, we use Adam (Kingma and Ba, 2014) with a learning rate of .

5 Experiments

In this section, we perform empirical experiments to evaluate the effectiveness of proposed framework. We first introduce the tasks and corresponding datasets we consider. Then, we present the quantitative results showing the superior performance of our proposed framework, and discuss some qualitative analysis which verifies the learned representations have the desired invariance property.

5.1 Datasets

Our experiments include three tasks in different domains: (1) fair classification, in which predictions should be unaffected by nuisance factors; (2) language-independent generation which is conducted on the multi-lingual machine translation problem; (3) lighting-independent image classification.

Fair Classification

For fair classification, we use three datasets to predict the savings, credit ratings and health conditions of individuals with variables such as gender or age specified as “nuisance variable” that we would like to not consider in our decisions (Zemel et al., 2013; Louizos et al., 2016). The German dataset (Frank et al., 2010) is a small dataset with samples describing whether a person has a good credit rating. The sensitive nuisance variable to be factored out is gender. The Adult income dataset (Frank et al., 2010) has data points and the objective is to predict whether a person has savings of over dollars with the sensitive factor being age. The task of the health is to predict whether a person will spend any days in the hospital in the following year. The sensitive variable is also the age and the dataset contains entries. We follow the same -fold train/validation/test splits and feature preprocessing used in (Zemel et al., 2013; Louizos et al., 2016).

Both the encoder and the predictor are parameterized by single-layer neural networks. A three-layer neural network with batch normalization 

(Ioffe and Szegedy, 2015) is employed for the discriminator. We use a batch size of and the number of hidden units is set to . is set to in our experiments.

Multi-lingual Machine Translation

For the multi-lingual machine translation task we use French to English (fr-en) and German to English (de-en) pairs from IWSLT 2015 dataset (Cettolo et al., 2012). There are pairs of fr-en sentences and pairs of de-en sentences in the training set. In the test set, there are pairs of fr-en sentences and pairs of de-en sentences. We evaluate BLEU scores (Papineni et al., 2002) using the standard Moses multi-bleu.perl script. Here, indicates the language of the source sentence.

We use the OpenNMT (Klein et al., 2017) in our multi-lingual MT experiments222Our MT code is available at The encoder is a two-layer bidirectional LSTM with 256 units for each direction. The discriminator is a one-layer single-directional LSTM with 256 units. The predictor is a two-layer LSTM with 512 units and attention mechanism (Bahdanau et al., 2015). We follow Johnson et al. (2016) and use Byte Pair Encoding (BPE) subword units (Sennrich et al., 2016) as the cross-lingual input. Every model is run for epochs. is set to and the batch size is set to .

Image Classification

We use the Extended Yale B dataset (Georghiades et al., 2001) for our image classification task. It comprises face images of people under different lighting conditions: upper right, lower right, lower left, upper left, or the front. The variable to be purged is the lighting condition. The label is the identity of the person. We follow  Li et al. (2014); Louizos et al. (2016)’s train/test split and no validation is used: samples are used for training and all other data points are used for testing.

We use a one-layer neural network for the encoder and a one-layer neural network for prediction. is set to . The discriminator is a two-layer neural network with batch normalization. The batch size is set to and the hidden size is set to .

(a) Accuracy on predicting . The closer the result is to the majority line, the better the model is in eliminating the effect of nuisance variables.
(b) Accuracy on predicting . High accuracy in predicting is desireable.
(c) Overall performance and performance on biased categories. Fair representations lead to high accuracy on baised categories.
Figure 2: Fair classification results on different representations. denotes directly using the observation as the representation. The black lines in the first and the second row show the performance of predicting the majority label. “Biased categories” in the third row are explained in the fourth paragraph of Section 5.2.

5.2 Results

Fair Classification

The results on three fairness tasks are shown in Figure 2. We compare our model with two prior works on learning fair representations: Learning Fair Representations (LFR) (Zemel et al., 2013)

and Variational Fair Autoencoder (VFAE) 

(Louizos et al., 2016). Results of VAE and directly using as the representation are also shown.

We first study how much information about is retained in the learned representation

by using a logistic regression to predict factor

. In the top row, we see that cannot be recognized from the representations learned by three models targeting at fair representations. The accuracy of classifying is similar to the trivial baseline predicting the majority label shown by the black line.

The performance on predicting label is shown in the second row. We see that LFR and VFAE suffer on Adult and German datasets after removing information of . In comparison, our model’s performance does not suffer even when making fair predictions. Specifically, on German, our model’s accuracy is compared to and achieved by VFAE and LFR. On Adult, our model’s accuracy is while VFAE and LFR have accuracies of and respectively. On the health dataset, all models’ performances are barely better than the majority baseline. The unsatisfactory performances of all models may be due to the extreme imbalance of the dataset, in which of the data has the same label.

We also investigate how fair representations would alleviate biases of machine learning models. We measure the unbiasedness by evaluating models’ performances on identifying minority groups. For instance, suppose the task is to predict savings with the nuisance factor being age, with savings above a threshold of $ being adequate, otherwise being insufficient. If people of advanced age generally have fewer savings, then a biased model would tend to predict insufficient savings for those with an advanced age. In contrast, an unbiased model can better factor out age information and recognize people that do not fit into these stereotypes.

Concretely, for groups pooled by each possible value of , we seek for the minority in each of these groups and define the minority as the biased category for the group. Then we first calculate the accuracy on each biased category and report the average performance for all categories. We do not compute the instance-level average performance since one category may hold the dominant amount of data among all categories.

As shown in the third row of Figure 2, on German and Adult, we achieve higher accuracy on the biased categories, even though our overall accuracy is similar to or lower than the baseline which does not employ fairness constraints. Specifically, on Adult, our performance on the biased categories is while the baseline’s accuracy is . On German, our accuracy on biased categories is while the baseline achieves . The results show that our model is able to learn a more unbiased representation.

Multi-lingual Machine Translation

Model test (fr-en) test (de-en)
Bilingual Enc-Dec (Bahdanau et al., 2015) 35.2 27.3
Multi-lingual Enc-Dec (Johnson et al., 2016) 35.5 27.7
Our model 36.1 28.1
 w.o. discriminator 35.3 27.6
 w.o. separate encoders 35.4 27.7
Table 1: Results on multi-lingual machine translation.

The results of systems on multi-lingual machine translation are shown in Table 1. We compare our model with attention based encoder-decoder trained on bilingual data (Bahdanau et al., 2015) and multi-lingual data (Johnson et al., 2016). The encoder-decoder trained on multi-lingual data employs a single encoder for both source languages. Firstly, both multi-lingual systems outperform the bilingual encoder-decoder even though multi-lingual systems use similar number of parameters to translate two languages, which shows that learning invariant representation leads to better generalization in this case. The better generalization may be due to transferring statistical strength between data in two languages.

Comparing two multi-lingual systems, our model outperforms the baseline multi-lingual system on both languages, where the improvement on French-to-English is BLEU score. We also verify the design decisions in our framework by ablation studies. Firstly, without the discriminator, the model’s performance is worse than the standard multi-lingual system, which rules out the possibility that the gain of our model comes from more parameters of separating encoders. Secondly, when we do not employ separate encoders, the model’s performance deteriorates and it is more difficult to learn a cross-lingual representation, which

  • [leftmargin=*]

  • verifies the theoretical advantage of modeling instead of as mentioned in Section 2. Intuitively, German and French have different grammars and vocabulary, so it is hard to obtain a unified semantic representation by performing the same operations.

  • means that the encoder needs to have enough capacity to reach the equilibrium in the minimax game. We also observe that the discriminator needs enough capacity to provide faithful gradients towards the equilibrium. Specifically, instantiating the discriminator with feedforward neural network w./w.o. attention mechanism (Bahdanau et al., 2015) does not work in our experiments.

Image Classification

Method Accuracy of classifying Accuracy of classifying
Logistic regression 0.96 0.78
NN + MMD (Li et al., 2014) - 0.82
VFAE (Louizos et al., 2016) 0.57 0.85
Ours 0.57 0.89
Table 2: Results on Extended Yale B dataset. A better representation has lower accuracy of classifying factor and higher accuracy of classifying label
(a) Using the original image as the representation
(b) Representation learned by our model
Figure 3: t-SNE visualizations of images in the Extended Yale B. The original pictures are clustered by the lighting conditions, while the representation learned by our model is clustered by identities of individuals

We report the results in Table 2 with two baselines (Li et al., 2014; Louizos et al., 2016) that use MMD regularizations to remove lighting conditions. The advantage of factoring out lighting conditions is shown by the improved accuracy for classifying identities, while the best baseline achieves an accuracy of .

In terms of removing , our framework can filter the lighting conditions since the accuracy of classifying drops from to , as shown in Table 2. We also visualize the learned representation by t-SNE  (Maaten and Hinton, 2008) in comparison to the visualization of original pictures in Figure 3. We see that, without removing lighting conditions, the images are clustered based on the lighting conditions. After removing information of lighting conditions, images are clustered according to the identity of each person.

6 Related Work

As a specific case of our problem where takes two values, domain adaption has attracted a large amount of research interest. Domain adaptation aims to learn domain-invariant representations that are transferable to other domains. For example, in image classification, adversarial training has been shown to able to learn an invariant representation across domains (Ganin and Lempitsky, 2015; Ganin et al., 2016; Bousmalis et al., 2016; Tzeng et al., 2017) and enables classifiers trained on the source domain to be applicable to the target domain. Moment discrepancy regularizations can also effectively remove domain specific information (Zellinger et al., 2017; Bousmalis et al., 2016) for the same purpose. By learning language-invariant representations, classifiers trained on the source language can be applied to the target language (Chen et al., 2016b; Xu and Yang, 2017).

Works targeting the development of fair, bias-free classifiers also aim to learn representations invariant to “nuisance variables” that could induce bias and hence makes the predictions fair, as data-driven models trained using historical data easily inherit the bias exhibited in the data.  Zemel et al. (2013) proposes to regularize the distance between representation distributions for data with different nuisance variables to enforce fairness. The Variational Fair Autoencoder (Louizos et al., 2016) targets the problem with a Variational Autoencoder (Kingma and Welling, 2014; Rezende et al., 2014) approach with maximum mean discrepancy regularization.

Our work is also related to learning disentangled representations, where the aim is to separate different influencing factors of the input data into different parts of the representation. Ideally, each part of the learned representation can be marginally independent to the other. An early work by Tenenbaum and Freeman (1997) propose a bilinear model to learn a representation with the style and content disentangled. From information theory perspective, Chen et al. (2016a) augments standard generative adversarial networks with an inference network, whose objective is to infer part of the latent code that leads to the generated sample. This way, the information carried by the chosen part of the latent code can be retained in the generative sample, leading to disentangled representation.

As we have discussed in Section 1, these methods bear the same drawback that the cost used to regularize the representation is pairwise, which does not scale well as the number of values that the attribute can take could be large. Louppe et al. (2016) propose an adversarial training framework to learn representations independent to a categorical or continuous variable. A basic assumption in their theoretical analysis is that the attribute is irrelevant to the prediction, which limits its capabilities in analyzing the fairness classifications.

7 Conclusion

In sum, we propose a generic framework to learn representations invariant to a specified factor or trait. We cast the representation learning problem as an adversarial game among an encoder, a discriminator, and a predictor. We theoretically analyze the optimal equilibrium of the minimax game and evaluate the performance of our framework on three tasks from different domains empirically. We show that an invariant representation is learned, resulting in better generalization and improvements on the three tasks.


We thank Shi Feng, Di Wang and Zhilin Yang for insightful discussions. This research was supported in part by DARPA grant FA8750-12-2-0342 funded under the DEFT program.


Appendix A Supplementary Material: Proofs

The proof for Claim 1:


Given a fixed encoder , the optimal discriminator outputs . The optimal predictor corresponds to .


We first prove the optimal solution of the discriminator. With a fixed encoder, we have the following optimization problem


Then is the Lagrangian dual function of the above optimization problem where are the dual variables introduced for equality constraints.

The optimal satisfies the following equation


Summing w.r.t. on both sides of the last line of Eqn. (4) and using the fact that , we get


Substituting Eqn. 5 back into Eqn. 4, we can prove the optimal discriminator is

Similarly, taking derivation w.r.t. and setting it to 0, we can prove . ∎