Forced to Learn: Discovering Disentangled Representations Without Exhaustive Labels

05/01/2017 ∙ by Alexey Romanov, et al. ∙ 0

Learning a better representation with neural networks is a challenging problem, which was tackled extensively from different prospectives in the past few years. In this work, we focus on learning a representation that could be used for a clustering task and introduce two novel loss components that substantially improve the quality of produced clusters, are simple to apply to an arbitrary model and cost function, and do not require a complicated training procedure. We evaluate them on two most common types of models, Recurrent Neural Networks and Convolutional Neural Networks, showing that the approach we propose consistently improves the quality of KMeans clustering in terms of Adjusted Mutual Information score and outperforms previously proposed methods.



There are no comments yet.


page 6

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In the past few years, a substantial amount of work has been dedicated to learning a better representation of the input data that can be either used in downstream tasks, such as KMeans clustering, or to improve generalizability or perfromance of the model. In general, these works can be divided into two categories:

  1. approaches that require a complicated training procedure;

  2. approaches that introduce a new loss component that can be easily applied to an arbitrary cost function;

For example, approaches by Liao et al. (2016) and Xie et al. (2016) can be assigned to the first category, as they propose to iteratively refine the clusters during the training. In contrast, approaches by Cogswell et al. (2015) and Cheung et al. Cheung et al. (2014) introduce new loss components that can be added to the cost function while training the model with a standard gradient descent algorithm. Our work belongs to the second category and focuses on a challenging problem of learning disentangled representations while having access to labels that do not fully reflect the underlying partitioning of the data, but still separate it into distinguishable groups.

For example, consider a case of predicting in-hospital mortality using multivariate physiological time series. This is a binary classification problem which can be solved using an Recurrent Neural Network model such as the one depicted on Figure 1

. During a regular training procedure with a sigmoid cross-entropy loss, the model tends to learn the weights that lead to a strong activation of one of the neurons in the penultimate layer (

) for the instances that belong to the positive class and a strong activation of another neuron for the instances that belong to the negative class, whereas all other neurons tend to be not active for both classes (see the (a)). However, we would like to separate patients into more than two groups by applying a clustering algorithm to the learned representations of the patients. Thus, we would need the model to learn a disentangled representation that can not only differentiate between the patients with different outcomes, but also between the patients with the same outcome using latent characteristics of the time series (see the (b)).

Figure 1: An RNN model with two fully-connected layers for binary classification of time series
(a) Without the proposed loss component
(b) With the proposed loss component
Figure 2: Number of samples for which the neurons on the axis were active the most in a binary classification task on MNIST strokes sequences dataset. The classes 0-4 have the label 0, and the classes 5-9 have the label 1. See the subsection 4.1 for details.

In order to force the network to learn such disentangled representations, we propose two novel loss components that can be applied to an arbitrary cost function. Although it can be used in any type of model, including autoencoders, this paper focuses on a task of learning disentangled representations during the binary classification problem.

2 Related Work

In the past few years we witnessed an astonishing success of neural networks. Starting with ILSVRC in 2012 (Russakovsky et al., 2015), where a deep convolutional neural network model won the challenge with a shocking gap (Krizhevsky et al., 2012), neural networks has achieved remarkable success in nearly every classification task. For a long time, it was unclear why deep neural networks, and convolutional neural networks (CNN) in particular, work so well and what is happening inside this “black box”, until the work of Zeiler & Fergus (2014), which showed that it is possible to visualize which features and representations specifically a network learns during the training, which answered this question for CNNs, but not for Recurrent Neural Networks (RNN). More recently, there have been a few works that sought to perform a similar analysis to RNNs. In particular, the work of Karpathy et al. (2015) showed what is happening inside an RNN cell during the inference and the areas of responsibility of neurons inside the cell. After that, Li et al. (2016) plotted the representations learned by an RNN trained on a task of sentiment classification and showed that the network is able to learn local compositionality, embedding the negated expressions (such as “not good”, “not nice”) into the space near words with a negative polarity (such as “bad”).

More relevant to the goal of this paper, Cheung et al. (2014) proposed a cross-covariance penalty (XCov) to force the network to produce representations with disentangled factors. The proposed penalty is, essentially, cross-covariance between the predicted labels and the activations of samples in a batch. Their experiments showed that the network can produce a representation, with components that are responsible to different characteristics of the input data. For example, in case of the MNIST dataset, there was a class-invariant factor that was responsible for the style of the digit, and in case of the Toronto Faces Dataset (Susskind et al., 2010), there was a factor responsible for the subject’s identity. Similarly, but with a different goal in mind, Cogswell et al. (2015) proposed a new regularizer (DeCov) which minimizes cross-covariance of hidden activations, leading to non-redundant representations and, consequently, less overfitting and better generalization. DeCov loss is trying to minimize the Frobenius norm of the covariance matrix between all pairs of activations in the given layer. The authors’ experiments showed that the proposed loss significantly reduced overfitting and led to a better classification performance on a variety of datasets.

(a) Without the proposed loss component
(b) With the proposed loss component
Figure 3: Number of samples for which the neurons on the axis were active the most in a binary classification task on the CIFAR-10 dataset. See the subsection 4.2 for details.

While the aforementioned approaches do not require a complicated training procedure, there were also works that proposed more convoluted algorithms in order to obtain better representations that could be used in downstream tasks, such as clusterization. Liao et al. (2016) proposed a method to learn parsimonious representations. Essentially, the proposed algorithm iteratively calculates cluster centroids, which are updated every

iterations and used in the cost function. The authors’ experiments showed that such algorithm leads to a better generalization and a higher test performance of the model in case of supervised learning, as well as unsupervised and even zero-shot learning. Similarly,

Xie et al. (2016) proposed an iterative algorithm that first calculates soft cluster assignments, then updates the weights of the network and cluster centroids. This process is repeated until convergence. In contrast to Liao et al. (2016), the authors specifically focused on the task of learning better representations for clustering, and showed that the proposed algorithm gives a significant improvement in clustering accuracy.

In contrast to the last two methods, our proposed loss components do not require such complicated training procedures and can be conveniently used in any cost function in a straightforward manner.

3 The proposed method

Inspired by the work of Cheung et al. (2014) and Cogswell et al. (2015), we propose two novel loss components which despite their simplicity, significantly improve the quality of the clustering over the representation produced by the model. The first loss component works on a single layer and does not affect the other layers in the network, which may be a desirable behaviour in some cases. The second loss component affects the entire network behind the target layer and forces it to produce disentangled representations in more complex and deep networks in which the first loss may not give the desired improvements.

3.1 Single layer loss

Consider the model on the Figure 1. The layer has the output size of 1 and produces a binary classification decision. The output of the layer is used to perform KMeans clustering. Recall from the example in the introduction that we want to force the model to produce divergent representations for the samples that belong to the same class, but are in fact substantively different from each other. One way to do it would be to force the rows of the weight matrix of the layer be different from each other, leading to different patterns of activations in the output of the layer.

Formally, it can be expressed as follows:


where are normalized weights of the row of the weights matrix of the given layer:


and is a component of the loss between the rows and :



is a hyperparameter that defines the desired margin of the loss component and

is the Kullback-Leibler divergence between the probability distributions

and .

Figure 4: A CNN model used in the CIFAR-10 experiments

3.2 Multilayer loss

Note that the loss component affects only the weights of the specific layer, as it operates not on the outputs of the layer but directly on its weights, similar to, for example,

regularization. Therefore, this loss component may help to learn a better representation only if the input to the target layer still contains the information about latent characteristics of the input data. This might be the case in simple shallow networks, but in case of very deep complex networks the input data is non-linearly transformed so many times that only the information that is needed for binary classification left, and all the remaining latent characteristics of the input data were lost as not important for binary classification (see the 

(a)). Indeed, as we can see from the experiments in Section 4, the loss component described above substantially improves the quality of the clustering in a simple baseline case. However, in case of a more complex model, this improvement is much less impressive. Therefore, we are also proposing a loss component that can influence not only one specific layer, but all layers before it, in order to force the network to produce a better representation.

Recall again that we want to force the model to produce disentangled representations of the input data. Namely, that these representations should be sufficiently different from each other even if two samples have the same label. We propose the following loss component in order to produce such properties:


where is a normalized output of the target layer for the sample :


is its the ground truth label, is the number of samples in the batch, is number of samples that have the same label, and is the function defined in Equation 3. Note that this loss component works on the outputs of the target layer, and therefore, it affects the whole network behind the layer on which it is applied, overcoming the local properties of the loss.

3.3 Unsupervised learning

Although our main focus in the presented experiments is on a binary classification task, both of our proposed loss components can be used in unsupervised learning as well. The loss component

does not require any labels so it can be used without modifications. The loss component can be applied to unlabeled data by just taking the summations without consideration of labels of the samples as follows:


For example, as autoencoder models is a common choice to learn representations to use in a downstream task, the proposed loss components can be easily applied to its cost function as follows:


where the first part is a standard reconstruction cost for autoencoder, the second is the proposed loss component, and is a hyperparameter reflecting how much importance is given to it.

Baseline 0.467 0.477
Baseline + DeCov 0.287 0.313
Baseline + XCov 0.525 0.547
Baseline + 0.544 0.553
Baseline + 0.502 0.523
Table 1: Adjusted Mutual Information (AMI) and Normalized Mutual Information (NMI) scores for the MNIST strokes sequences experiments

Model Validation set Test set
Baseline 0.198 0.204 0.194 0.199
Baseline + DeCov 0.175 0.191 0.176 0.191
Baseline + XCov 0.320 0.327 0.321 0.326
Baseline + 0.238 0.245 0.239 0.246
Baseline + 0.376 0.384 0.376 0.385
Table 2: Adjusted Mutual Information (AMI) and Normalized Mutual Information (NMI) scores for the CIFAR-10 experiments

3.4 Hyperparameter

One important choice to be made while using the proposed loss components is the value of the hyperparameter . A larger value of corresponds to a larger margin between the rows of the weights matrix in case of and a larger margin between the activations of the target layer in case of . The smaller the value of , the less influence the proposed loss components have.

In our experiments, we found that the proposed loss component is relatively stable with respect to the choice of , and generally performs better with larger values (in the range 5-10). In case of the loss component , we found that even a small value of the margin (0.1 - 1) disentangles the learned representations better and consequently leads to substantial improvements in the AMI score.

In all of the reported experiments, we found that the proposed loss component with a reasonably chosen does not hurt the model’s performance in the classification task.

(a) Without the proposed loss component, colored by binary labels
(b) Without the proposed loss component, colored by classes
(c) With the proposed loss component , colored by binary labels
(d) With the proposed loss component , colored by classes
Figure 5: PCA visualizations of the learned representations on the MNIST strokes sequences dataset. See the section 6 for details.

4 Experiments

We performed experiments on the MNIST strokes sequences dataset (de Jong, 2016)111 to validate our hypotheses in case of an RNN model. This dataset contains pen strokes, automatically generated from the original MNIST dataset (LeCun et al., 1998). Although the generated sequences do not always reflect a choice a human would made in order to write a digit, the strokes are consistent across the dataset.

To validate our hypothesis on a more complex dataset, we also experimented with the CIFAR-10 dataset (Krizhevsky & Hinton, 2009) and a more complex CNN model based on the VGG-16 architecture (Simonyan & Zisserman, 2014).

Our experiments therefore cover two types of neural networks most commonly used in modern research: Recurrent Neural Networks and Convolutional Neural Networks, which are used as baseline models in our experiments.

We implemented the models used in all experiments with TensorFlow 

(Abadi et al., 2016) and used Adam optimizer (Kingma & Ba, 2014) to train the them. Hyperparameters for the baseline models were chosen based on the Adjusted Mutual Information (AMI) score on the validation set and then used in the subsequent models. Hence, the performance of the baseline system is close to an empirical maximum that these models are able to achieve.

4.1 MNIST strokes sequences experiments

For this experiment, we split the examples into two groups: samples belonging to the classes from 0 to 4 were assigned to the first group, and samples belonging to the classes from 5 to 9 were assigned to the second group. The model is trained to predict the group of a given sample and does not have any access to the underlying classes. This experiment is a simplified version of the example we discussed in the introduction where we wanted to cluster the patients into meaningful groups, while having only binary mortality outcomes, rather than the more fine-grained labels.

We used the model depicted in the Figure 1 for this experiment. After the models were trained on the binary classification task, we used the output of the penultimate layer to perform KMeans clustering and evaluated the quality of the produced clustering using the original class labels as ground truth assignments.

We compared our loss components and with the DeCov regularizer (Cogswell et al., 2015) and XCov penalty (Cheung et al., 2014) as these losses use similar ideas, even though they do not directly target the task of improving the quality of clustering. We did not do comparison with the work of Liao et al. (2016) and Xie et al. (2016) as they belong to the second group of methods which requires a more complicated training procedure, whereas the loss components proposed here are simple to apply to an arbitrary cost function and do not require any changes in the training procedure.

We report the average of the Adjusted Mutual Information (AMI) and Normalized Mutual Information (NMI) scores (Vinh et al., 2010) across three runs in the Table 1.

4.2 CIFAR-10 experiments

As in the MNIST strokes sequences experiments, we split the examples in two groups: samples belonging to the classes “airplan”, “automobile”, “bird”, “cat”, and “deer” were assigned to the first group, and samples belonging to the classes “dog”, “frog”, “horse”, “ship”, “truck” were assigned to the second group. Note that this assignment is quite arbitrary as it simply reflects the order of the labels of the classes in the dataset (namely, the labels 0-4 for the first group and the labels 4-9 for the second group). All groups contain rather different types of objects, both natural and human-made.

For the experiments on the CIFAR-10 dataset, we used a CNN model based on the VGG-16 architecture, depicted on the Figure 4. We discarded the bottom fully connected and convolutional layers as, perhaps, they are too big for this dataset. Instead, we appended three convolutional layers to the output of pool3 layer with number of filters 256, 128 and 8 correspondingly. The first two layers use 3x3 convolutions, and the last layer uses 1x1 convolutions. After that, we pass the output trough a fully-connected layer of size 15 (), which will be producing the representations to be used in clustering, and a fully connected layer of size 1 (

) with the sigmoid activation function to produce a binary classification decision.

As in the subsection 4.1, we compare the proposed losses with DeCov and XCov and report the AMI and NMI scores in Table 2.

5 Implementation details

Despite the fact the the proposed loss components can be directly implemented using two nested for

loops, such implementation will not be computationally efficient, as it will lead to a big computational graph operating on separate vectors without using full advantages of highly optimized parallel matrix computations on GPU. Therefore, it is desirable to have an efficient implementation that can use full advantage of modern GPUs. We have developed such an efficient implementation that significantly accelerates the computation of the loss component in return for a higher memory consumption by creating two matrices that contain all combinations of

and from the summations in the Equation 1 and performing the operations to calculate the loss on them. We have made our implementation for TensorFlow (Abadi et al., 2016) publicly available on GitHub222 alongside with aforementioned models from the subsection 4.1 and the subsection 4.2.

It is worth noting that since the loss component operates directly on the weights of the target layer, its computational complexity does not depend on the size of the batch. Instead, it depends on the size of that layer. In contrast, the operates on the activations of the target layer on all samples in the batch, and its computational complexity depends on the number of samples in the batch. In practice, using the implementation described above, we were able to train models with batch size of 512 and higher without exhausting the GPU’s memory.

6 Discussion

As we can see from Figure 2 and Figure 3, during the binary classification task on both datasets without the proposed loss component the models tend to learn representations that is specific to the target binary label, even though the samples within one group come from different classes. The model learns to use mostly just two neurons to discriminate between the target groups and hardly uses the rest of the neurons in the layer. We observe this behaviour across different types of models and datasets: an RNN model applied to a timeseries dataset and an CNN model applied to an image classification dataset behave in the exactly the same way. Both proposed loss components and force the model to produce disentangled representations, and we can see how it changes the patterns of activations in the target layer. It is easy to see in (b)

that the patterns of activations learned by the networks roughly correspond to underlying classes, despite the fact that the network did not have access to them during the training. This pattern is not as easy to see in case of CIFAR-10 dataset (see the 

(b)), but we can observe that the proposed loss component nevertheless forced the network to activate different neurons for different classes, leading to a better AMI score on the clustering task.

In order to further investigate the representations learned by the model, we visualized the representations of samples from the MNIST strokes sequences dataset in Figure 5 using TensorBoard. (a) and (b) in the top row depict the representations learned by the baseline model, colored according to the binary label and the underlying classes, respectively. (c) and (d) in the bottom row depict the representations of the same samples, learned by the model with the loss component, colored in the same way. It is easy to see that the indeed forced the model to learn disentangled representations of the input data. Note how the baseline model learned dense clusters of objects, with samples from the same group (but different classes) compactly packed in the same area. In contrast, the model with the proposed loss component learned considerably better representations which disentangle samples belonging to different classes and placed the them more uniformly in the space.

Figure 6: Number of clusters and the corresponding AMI score on the CIFAR-10 dataset

In the real world, the number of clusters is rarely known beforehand. To systematically examine the stability of the proposed loss component, we plotted the Adjusted Mutual Information scores for the baselines methods and loss component with respect to the number of clusters in Figure 6, using the CIFAR-10 dataset. As can be seen from Figure 6, our loss component consistently outperforms the previously proposed methods regardless the number of clusters.

7 Conclusion

In this paper, we proposed two novel loss components that substantially improve the quality of KMeans clustering using the representations of the input data learned by the model. We performed a comprehensive set of experiments using two important model types (RNNs and CNNs) and different datasets, demonstrating that the proposed loss components consistently increase the Adjusted Mutual Information score by a significant margin and outperform previously proposed methods. In addition, we analyzed the representations learned by the network by visualizing the activation patterns and relative positions of the samples in the learned space, and the visualizations show that the proposed loss components indeed force the network to learn disentangled representations.