Log In Sign Up

Group-disentangled Representation Learning with Weakly-Supervised Regularization

by   Linh Tran, et al.
Autodesk Inc.

Learning interpretable and human-controllable representations that uncover factors of variation in data remains an ongoing key challenge in representation learning. We investigate learning group-disentangled representations for groups of factors with weak supervision. Existing techniques to address this challenge merely constrain the approximate posterior by averaging over observations of a shared group. As a result, observations with a common set of variations are encoded to distinct latent representations, reducing their capacity to disentangle and generalize to downstream tasks. In contrast to previous works, we propose GroupVAE, a simple yet effective Kullback-Leibler (KL) divergence-based regularization across shared latent representations to enforce consistent and disentangled representations. We conduct a thorough evaluation and demonstrate that our GroupVAE significantly improves group disentanglement. Further, we demonstrate that learning group-disentangled representations improve upon downstream tasks, including fair classification and 3D shape-related tasks such as reconstruction, classification, and transfer learning, and is competitive to supervised methods.


Leveraging Relational Information for Learning Weakly Disentangled Representations

Disentanglement is a difficult property to enforce in neural representat...

Encouraging Disentangled and Convex Representation with Controllable Interpolation Regularization

We focus on controllable disentangled representation learning (C-Dis-RL)...

Transformation Coding: Simple Objectives for Equivariant Representations

We present a simple non-generative approach to deep representation learn...

Weakly-Supervised Disentanglement Without Compromises

Intelligent agents should be able to learn useful representations by obs...

Physically Disentangled Representations

State-of-the-art methods in generative representation learning yield sem...

Learning Disentangled Representations for Time Series

Time-series representation learning is a fundamental task for time-serie...

Learning Controllable Disentangled Representations with Decorrelation Regularization

A crucial problem in learning disentangled image representations is cont...

1 Introduction

Decomposing data into disjoint independent factors of variations, i.e., learning disentangled representations, is essential for interpretable and controllable machine learning

(bengio2013representation). Recent works have shown that disentangled representation is useful for abstract reasoning (SteenkisteLSB19), fairness (locatello2019fairness; creager2019flexibly)

, reinforcement learning

(HigginsPRMBPBBL17) and general predictive performance (locatello2019challenging). While there is no consensus on the definition of disentanglement, existing works define it as learning to separate all factors of variation in the data (bengio2013representation). According to this definition, altering a single underlying factor of variation should only affect a single factor in the learned representation. However, works in learning disentangled representations higgins2016beta,ChenLGD18,locatello2019challenging have shown that this setting comes with a trade-off between the precision of the representation and the fidelity of the samples. Therefore, learning precise representations for finer factors, i.e., each factor of variation, may not be practical or desirable. We deviate from this stringent assumption to learn group-disentangled representations, in which a group might include several factors of variation that can co-variate. For instance, groups of interest may be content, style, or background. As a result, a change in one component might affect other variables in a group but not on other groups.

We present GroupVAE, a vae based framework that leverages weak supervision to learn group-disentangled representations. In particular, we use paired observations that always share a group of factors. Existing group-disentangled approaches (bouchacourt2018multi; hosoya2019group) enforce disentangled group representations by using an average or product of approximate group posteriors. However, as group representations are dependent on the observations used for the average or product, observations belonging to the same group may not be encoded to the same latent representations. We address this inconsistency challenge by incorporating a simple but effective regularization based on the kl divergence. Our idea builds on maximizing the elbo of the vae while minimizing the kl divergence between the latent variables that correspond to the group shared by the paired observations.

In summary, we make the following contributions:

  1. We propose a way of learning disentangled representations from paired observations that employs kl regularization between the corresponding groups of latent variables.

  2. We propose groupmig, a mutual information-based metric for evaluating the effectiveness of group disentanglement methods.

  3. Through extensive evaluation, we show that our GroupVAE’s effectiveness on a wide range of applications. Our evaluation shows significant improvement for group disentanglement, fair facial attribute classification, and 3D shape-related tasks, including generation, classification, and transfer learning.

2 Background & Notation

(a) Model

(b) Generative


(c) Inference
Figure 1: GroupVAE’s architecture visualization and graphical model. We visualize the complete model, including model weights in (a) as well as show the (b) generative and (c) inference parts as graphical models. The model visualization shows two paired inputs, one pair sharing “style” and the other sharing “content”. The kl minimization depends on the group that is shared. For instance, for input , GroupVAE objective only minimizes the kl between the style latent variables. Shaded nodes denote observed quantities in (b) and (c), and unshaded nodes represent unobserved (latent) variables. Dotted arrows represent minimizing the kl divergence between variables during inference.

Consider observations sampled i.i.d. from distribution and latent variables

. A vae learns the joint distribution

where is the likelihood function of observations given , are the model parameters of and is the prior of the latent variable . vae are trained to maximize the evidence lower bound (ELBO) on the log-likelihood . This objective averaged over the empirical distribution is given as


where denotes the learned approximate posterior, the variational parameters of and KL denotes the Kullback-Leibler (KL) divergence. VAEs KingmaW13 are frequently used for learning disentangled representations and serve as the basis of our approach.

Weakly-supervised group disentanglement.

We assume the observations and the data generating process can be described by distinct groups . Each group splits into disjoint partitions with arbitrary sizes. Each group consists of non-overlapping sets of factors of variations. For example, images of 3D shapes (3dshapes18)111Samples are shown in Figure 1. can be described through three groups: shape222The group shape contains factors such as shape category, shape size and shape color., background333The group background contains factors such as floor color, wall color. and view. Without loss of generality, we define two groups (content) and (style independent of content) to describe the generative and inference process. We assume having paired observations for training in a weakly-supervised setting. Each pair of observations shares the same group, i.e., in our case either content or style . During inference, the exact values for content and style are unknown, but only that share a certain group is known. For each observation , we define two latent variables: for content and for style. The goal for group-based disentanglement is that the representation for the same group as close to each other to ensure consistency.

3 Learning Group-Disentangled Representations

In the following, we introduce GroupVAE, a deep generative model which learns disentangled representations for each group of factors. For simplicity, we limit the formulation of GroupVAE to two groups, content and style, although GroupVAE can be applied to any number of groups. This section first describes the generative and inference model and then introduces our main contributions – the kl regularization and inference scheme. We visualized the generative and inference model in Figures 0(b) and 0(c).

Inference and generative model.

Our model uses paired observations in a weakly-supervised setting. We sample from the empirical data distribution and conditionally sample in an i.i.d. manner, so that and belong to the same group , i.e.,


Given , we define two latent variables, as content and as style variables. The data is explained by the generative process:


Both and

are assumed to be independent of each other and are sampled from a Normal distribution with zero mean and diagonal unit variance.

is a suitable likelihood function444Suitable likelihood functions are, e.g., a Bernoulli likelihood for binary values or a Gaussian likelihood for continuous values. which is parameterized by a deep neural network. The generative model shown in Figure 0(b) is also known as the decoding part seen in Figure 0(a).

To perform inference, we approximate the true posterior with the factorized approximate posterior that uses a neural network to amortize the the variational parameters. We specify the inference model as


where both approximate posteriors are assume to be a factorized Normal distributions with mean and diagonal covariance . The inference model is visualized as a graphical model in Figure 0(c) and as the encoding part in Figure 0(a). The generative and inference models visualized in Figure 1 apply to as well.

VAE objective for paired observation.

Given paired observations , the VAE framework maximizes the elbo


which consists of the reconstruction losses of the observations and (first two terms) and kl divergence between approximate posterior and prior of the latent variables and (third and fourth term). This is a straightforward application of the original elbo in (1) to two sets of observations, and .

kl regularization for group similarity.

Rather than defining an average representation for groups as in (bouchacourt2018multi; hosoya2019group), we propose to enforce consistency between the latent variables by minimizing kl divergence between the latent variables and . Here, denotes the group shared between observations and . and denote the corresponding group variable, e.g., if and share group then the corresponding latent variables are and . Given paired observations from the same group , our objective is to minimize


The kl divergence has analytical solutions for Gaussian and Categorical approximate posteriors and is unaffected by the number of shared observations. The analytical solutions can be found in Appendix A.2.

GroupVAE objective and inference.
1:  while  training()  do
5:      # encode
7:      # encode
9:      # calculate loss according to (7)
11:      # update gradient and parameters
14:  end while
Algorithm 1 GroupVAE Inference

Given a paired observation in the sharing group , we combine the elbo in (5) and our proposed kl regularization in (6). Our proposed model, GroupVAE, has the following minimization objective


where we treat the degree of regularization

as a hyperparameter. We propose an alternating inference strategy to encourage variation in both of the latent variables. If we only utilize observations that belong to one group, e.g., paired observations that always share content, we can obtain a trivial solution for the content latent variable by encoding constant latent variables. We overcome this collapse by alternating the group that the observations belong to during training. In particular, during inference we randomly sample a group

and the paired observation according to group g. We then minimize the kl divergence of the corresponding latent variable. The inference’s pseudo code is shown in Algorithm 1.

3.1 Related Work

Unsupervised learning of disentangled representations.

Various regularization methods for unsupervised disentangled representation learning have been presented in existing works (higgins2016beta; kim2018disentangling; ChenLGD18). Even though unsupervised methods have shown promising results to learn disentangled representations, locatello2019challenging showed in a rigorous study that it is impossible to disentangle factors of variations without any supervision or inductive bias. Since then, there has been a shift towards weakly-supervised disentanglement learning. Our work follows this stream of works and focuses on the weakly-supervised regime instead of an unsupervised one.

Weakly-supervised learning of disentangled representations.

shu2019weakly investigated different types of weak supervision and provided a theoretical framework to evaluate disentangled representations. locatello2020weakly proposed to disentangle groups of variations with only knowing the number of shared groups which can be considered as a complementary component to our method. Similar to our method, both these works follow a weakly-supervised setup. However, both approaches focus on the disentanglement of fine-grained factors, whereas our focus is to disentangle groups. Before the concept of paired observations was coined by shu2019weakly as “match pairing”, it was already used for geometry and appearance disentanglement (KossaifiTPP18; tran2019disentangling) and group-based disentanglement (bouchacourt2018multi; hosoya2019group). Closest to our work is MLVAE (bouchacourt2018multi) and GVAE (hosoya2019group). For group-disentangled representations, MLVAE uses a product of approximate posteriors, whereas GVAE uses an empirical average of the parameters of the approximate posteriors. A thorough analysis of both works is in Appendix B. In contrast, we employ a simple and effective KL regularization that has no dependency on the batch size.

Alignment between factors of variations and learned representations.

Closely related to our work and group-based disentanglement concepts are studies that learn specific latent variables corresponding to one or several factors of variations (or labels). Dupont18 used both continuous and discrete latent variables to improve unsupervised disentanglement of mixed-type latent factors. creager2019flexibly proposed to minimize the mutual information between the sensitive latent variable and sensitive labels. Similarly, KlysSZ18 proposed to minimize mi between the latent variable and a conditional subspace. Both works (creager2019flexibly; KlysSZ18)

require either supervision, sensitive labels, or conditions to estimate the mutual information, whereas we only use weak supervision for learning disentangled group representations. Concurrent to our work,

sinha2021consistency proposed to use a kl regularization for learning a vae with representation that is consistent with augmented data. While sinha2021consistency use the KL regularization to enforce the encoding to be consistent with changes in the input, our goal is to split the representation into subspaces that correspond to the different groups of variations.

4 Evaluation

Here, we evaluate our GroupVAE and compare it to existing approaches. We show that our approach outperforms existing approaches for group-disentanglement and disentanglement on existing disentanglement benchmarks. Within the context of evaluating group disentanglement, we propose a mi-based evaluation metric to assess the degree of group disentanglement. Further, we demonstrate that our approach is generic and can be applied to various applications, including fair classification and 3D shape-related tasks (reconstruction, classification, and transfer learning).

4.1 Weakly-supervised group-disentanglement

Figure 2: Example of failed content-style disentanglement with high MIG. The heatmap shows the MI of each pair of factors and latent dimensions. Although content and style have not been separated in the corresponding latent dimensions, the MIG is still very high (). In contrast, group-MIG considers where the groups are captured, and thus, the group-MIG is much lower ().
Experimental settings.

We used three standard datasets on disentangled representation learning: 3D Cars (reed2014learning), 3D Shapes (3dshapes18) and dSprites (dsprites17). Despite the fact that these image datasets are synthetic, disentangling the factors of variation remains a difficult and unresolved task (locatello2019challenging; locatello2020weakly). We use mig (ChenLGD18) and our proposed metric groupmig for quantitative evaluation different approaches. We compare our model, GroupVAE, to unsupervised methods (-VAE (higgins2016beta) and FactorVAE (kim2018disentangling)) as well as weakly-supervised methods (AdaGVAE (locatello2019challenging), MLVAE (bouchacourt2018multi), and GVAE (hosoya2019group)). For all methods, we ran a hyperparameter sweep varying regularization strength for five different seeds. We report the median groupmig and mig.

groupmig for evaluating group disentanglement.

The mig (ChenLGD18) is a commonly used evaluation metric for disentanglement. This metric measures the normalized difference between the latent variable dimensions with highest and second-highest mi values. The higher the mig, the greater the degree of disentanglement is. However, mig can still be high if the style latent variable disentangles all factors of variation whereas the content variable collapse to a constant value. An example of a failure in group disentanglement is shown in Figure 2. Therefore, we introduce groupmig, a metric based on mig, which addresses this issue and quantitatively estimates the mutual information between groups and corresponding latent variables. We define groupmig as


where is the number of groups, is the ground truth group, and is an empirical estimate of the mi between continuous variable and

. The values of groupmig is small if the group factors are not represented in the corresponding latent vectors, even though the factor is disentangled within the other variables.

Group labeling.

We define the following groups based on the fine-grained factors for each dataset:

  • dSprite:s

  • 3D Shapes:

  • 3D Cars:


We consistently outperform weakly-supervised disentanglement models w.r.t. median groupmig over five hyperparameter sweeps of different seeds by at least 25%

(3D Shapes). Further, we also improve on disentanglement w.r.t. mig for two out of three datasets (3D Cars, dSprites). In addition, we show interpolation samples of MLVAE, GVAE, and GroupVAE

555We selected models with median groupmig over five hyperparameter sweeps of different seeds. for 3D Shapes in Figure 3. Both MLVAE and GVAE are not able to capture azimuth in the latent representations. Moreover, GVAE encodes almost all factors into the style part and collapses to a constant representation in the content part. The interpolations of GroupVAE show content and style disentanglement, although some factors such as object size and type for 3D Shapes remain entangled. As we assume that factors in a group can co-variate, this result is expected as object size and type are in the same group.

3D Cars 3D Shapes dSprites
Type Model
group-MIG MIG group-MIG MIG group-MIG MIG
unsup. -VAE 0.08 0.22 0.10
unsup. FactorVAE 0.10 0.27 0.14
weakly-sup. AdaGVAE 0.15 0.56 0.26
weakly-sup. MLVAE 0.24 0.07 0.47 0.32 0.11 0.22
weakly-sup. GVAE 0.27 0.08 0.45 0.31 0.14 0.21
weakly-sup. GroupVAE (ours) 0.48 0.18 0.60 0.31 0.54 0.27
Table 1: Quantitative disentanglement results. We report median groupmig and median mig over five hyperparameter sweeps of different seeds (higher is better). Since the unsupervised approaches and AdaGVAE do not learn group disentangled representations, we cannot report groupmig for these groups and denote it with . We highlight in bold the best results.
(b) GVAE
(c) GroupVAE (ours)
Figure 3: Interpolations of 3D Shapes. We show samples from our model GroupVAE and the baseline models (MLVAE and GVAE) with median groupmig over five hyperparameter sweeps. For each subplot, we show random inputs (first column), its reconstructions (second column) and reconstruction when interpolating the latent variables (remaining columns) of each latent dimension (row-wise). The factors annotated on the right side are those with a high level of mutual information (). For all three models to is supposed to capture style (non-content) while to is supposed to capture content.

4.2 Application to fair classification

Demographic parity (DP)
Fair learning Model Test acc. “shape” “scale”
MLP 99.07 0.007 0.008
CNN 99.04 0.002 0.002
✓ (supervised) FFVAE 98.60 0.004 0.004
✓ (weakly-superv.) GroupVAE 99.18 0.002 0.002
(a) Results for dSpritesUnfair predicting “x-position”.
Demographic parity (DP)
Fair learning Model Test acc. “Male” “Young”
MLP 97.89 0.99 0.99
CNN 98.46 0.95 0.93
✓ (supervised) FFVAE 97.79 0.04 0.04
✓ (weakly-superv.) GroupVAE 98.23 0.01 0.02
(b) Results for CelebA predicting “bald”.
Demographic parity (DP)
Fair learning Model Test acc. ‘BigNose’ ‘HeavyMakeup’ ‘Male’ ‘WearingLipstick’
MLP 77.24 0.09 0.15 0.06 0.04
CNN 79.90 0.11 0.15 0.03 0.06
✓ (supervised) FFVAE 97.75 0.03 0.02 0.03 0.03
✓ (weakly-superv.) GroupVAE 97.88 0.01 0.02 0.02 0.01
(c) Results for CelebA predicting “attractive”.
Table 2: Fair classification results on the test set of dSpriteUnfair and CelebA.

We report test accuracy and dp for each sensitive attribute with an average of five experiments. We report the standard error for all test accuracies, but leave out the standard error for all DP results as they were

. We highlight in
bold the best results. The column Fair learning refers to whether a model uses any supervision during the fair representation learning phase. For the final classification, all models use full supervision.

We examine the problem of learning fair representations for classification problems as an application of our method. In particular, we want to learn fair group representation in which members of any (demographic) groups have an equal probability of being assigned to the positive predicted class. Deep learning algorithms have been proven to be biased against specific demographic groups or populations

(mehrabi2021survey). It is critical that classification models can produce accurate predictions without discriminating against certain groups in high-stakes and safe-related applications. In this context, we propose to learn fair representations by learning two distinct groups of representations: a predictive representation for evaluating the downstream task and a representation to account for the sensitive factors, e.g., gender- or age-specific attributes. The latter representation is solely utilized for training and not for downstream tasks.

Learning fair representations consist of a two-step optimization scheme. First, we train GroupVAE with pairs of observations sharing either sensitive and non-sensitive attributes. Second, we train a simple MLP for attribute classification using the non-sensitive mean representation. We measure classification accuracy and dp. dp measures whether the predictive outcome is independent of a sensitive attribute. A completely fair model would attain a dp value of 0.0, whereas a biased model can have a dp up to 1.0. We compare against MLP and CNN baselines, and FFVAE creager2019flexibly which learns fair representations by using a supervised loss on the sensitive attributes and a total correlation loss. We used two datasets: dSpritesUnfair creager2019flexibly,trauble2020independence and CelebA liu2015deep. dSpritesUnfair is a modified image dataset based on dSprites with binarized factors of variations and is sampled with shape and x-position being highly correlated. For CelebA, an image dataset of celebrity faces with 40 binary attribute labels, we predict “bald” and “attractive” in two separate experiments. For predicting “bald”, we use the attributes “male” and “young” as sensitive attributes whereas we use the attributes “BigNose”, “HeavyMakeUp”, “Male” and “WearingLipstick” as sensitive attributes for predicting “attractive”. We argue that these attributes have a weak correlation but a strong correlation with the predictive attribute. However, several CelebA attributes significantly correlate, making this a difficult dataset for fairness classification. We refer to the Appendix 

C.2 for the detailed experimental settings.


We report the fair classification results in Table 2. Overall, the results in Table 2 show that weakly-supervised fair representation learning (GroupVAE) outperforms supervised fair representation learning (FFVAE). Further, we either get competitive or even outperform the supervised baselines (MLP, CNN). Surprisingly, when evaluating dSpritesUnfair the demographic parity for all models is relatively low, and the strong correlation between shape and x-position does not seem to affect the classification. The test accuracy and dp of the sensitive attributes of all the competitive models are very close to each other. Nevertheless, among all models, our method achieves the highest test accuracy and lowest dp. For predicting “bald” in CelebA, even though both MLP and CNN baselines achieve high test accuracy, the dp shows an extremely biased classification towards gender-specific and male-specific attributes. In contrast, our GroupVAE achieves the lowest dp but still attain competitive classification accuracy, i.e., second highest test accuracy after the CNN performance. When predicting “attractive”, GroupVAE decreases the bias of all sensitive attributes and increases the test accuracy compared to all other models.

4.3 Application to 3D point cloud tasks

In addition to evaluating image datasets, we show experiments on 3D point clouds for reconstruction and classification. We experimented with FoldingNet yang2018foldingnet, a deep autoencoder that learns to reconstruct 3D point clouds in an unsupervised way. Unlike VAEs, the FoldingNet autoencoder is deterministic and does not optimize the representation to be a probabilistic distribution. Instead of converting the autoencoder into a VAE, we use a similar approach as ghosh2019variational

. We assume the embedding of autoencoder to be Normal distributed with constant variance. Given this assumption, the KL divergence between the corresponding embeddings reduces to a simple L2 regularization, and we can inject noise to regularize the decoding. We evaluate three tasks, 3D point cloud reconstruction, classification, and transfer learning. We measure the Chamfer Distance (CD) and the Earth Mover’s Distance (EMD) to assess reconstruction quality and report accuracy to assess classification and transfer learning performance. We compare to FoldingNet (unsupervised) and DGCNN (supervised) wang2019dynamic, a dynamic graph-based classification approach. For assessing the transfer learning capability, we use a linear SVM classifier on the extracted representation. We used two datasets for training: FG3D liu2021fine and ShapeNetV2 chang2015shapenet. FG3D contains 24,730 shapes with annotations of basic categories (Airplane, Car, and Chair) and fine-grained sub-categories. ShapeNetV2 contains 51,127 shapes with annotations of 55 categories. For transfer learning, we also use ModelNet40 wu20153d.


Table 3(a) shows that weakly-supervised training improves upon 3D point cloud reconstruction for both FG3D and ShapeNetV2. Table 3(b) shows the classification and transfer results. Our approach GroupFoldingNet improves point cloud classification compared to the original FoldingNet and is competitive with the supervised approach when training with FG3D. We outperform both supervised and unsupervised transfer learning performances when training with FG3D and evaluating ShapeNetV2 and ModelNet40. We are competitive to the supervised approach when training with ShapeNetV2 and evaluating on ModelNet40. In particular, the transfer learning performance with FG3D as the training set highlights the capabilities of weakly-supervised group disentanglement as it can learn 3D point clouds of three classes and transfer it to ShapeNetV2, a large-scale dataset with 55 classes. We also visualize point cloud reconstructions and interpolations of three different classes using our approach in Figure 4. The reconstructions show that our approach is better than FoldingNet in reconstructing finer details. Further, the interpolations show that our approach can learn an interpretable representation.

FG3D ShapeNetV2
Type Model CD EMD CD EMD
unsupervised FoldingNet 0.9539 0.9340 2.9867 1.5576
weakly-superv. GroupFoldingNet (ours) 0.7519 0.8191 2.6891 1.3009
(a) Reconstruction results for FG3D and ShapeNetV2.
Linear SVM ACC
Type Model Training dataset Test dataset #classes Test ACC ShapeNetV2 ModelNet40
supervised DGCNN FG3D FG3D 3 99.26 50.53 74.25
unsupervised FoldingNet FG3D FG3D 3 98.27 85.45 80.04
weakly-superv. ours FG3D FG3D 3 98.57 87.24 81.39
supervised DGCNN ShapeNetV2 ShapeNetV2 55 94.4 90.02
unsupervised FoldingNet ShapeNetV2 ShapeNetV2 55 81.51 87.40
weakly-superv. ours ShapeNetV2 ShapeNetV2 55 82.62 89.97
(b) Classification and transfer learning of representations.
Table 3: Evaluation of 3D point cloud reconstruction, classification, and transfer learning. We report Chamfer Distance (CD) and Earth Mover Distance (EMD) for quality of reconstruction and accuracy for classification and transfer learning tasks. Best results without full supervision are highlighted in bold.
(a) Reconstructions.
(b) Interpolations between two different samples.
Figure 4: Qualitative samples of ShapeNetV2. We show reconstructions of FoldingNet and our approach in (a) and show interpolations of our approach in (b).

5 Conclusion & discussion

We proposed a simple KL regularization for VAEs to enforce group disentanglement through weak supervision. We empirically showed that our model outperforms existing approaches in group disentanglement. Further, we demonstrated that learning group-disentangled representations outperforms performance on fair image classification and 3D shape-related tasks (reconstruction, classification, and transfer learning) and is even competitive to supervised approaches.

There are several possible directions for future work. In comparison to unsupervised representation learning, weakly-supervised learning, by definition, requires some weak form of supervision. Although we only need knowledge of whether two observations share a specific group, this limits the approach. Further, we require group labels for the entire dataset for training and evaluation. For real-life applications, datasets may not be fully labeled, and performance may suffer under this setting. Future investigation of group disentanglement in a low data or a “semi” weakly-supervised regime can allow group disentanglement learning to transfer to large-scale and more realistic settings. Another promising direction is investigating models with more than two groups. Even though we chose to focus on applications with two groups in this work, our method can generalize to more than two groups, which is a promising direction for future work.


We thank Hooman Shayani and Tonya Custis for useful discussions and comments on the paper.


Appendix A GroupVAE

a.1 Joint Learning of Continuous and Discrete Groups

The generative model defined in the main Section 4 assumes both content and style representations to be Gaussian distributed. However, many data-generating processes rely on discrete factors which is usually difficult to capture with continuous variables. In these cases, we can define the generative model as


For inference, we use a Gumbel-Softmax reparameterization JangGP17; MaddisonMT17, a continuous distribution on the simplex that can approximate categorical samples for . Similar to the kl divergence between two Normal distributions, the kl divergence between two Categorical distributions can also be computed in closed form.

a.2 Closed-form Solutions for the KL Regularization

In the case of both and being factorized Gaussian distributions, the KL regularization has the analytical solution


In the case of and , the KL has the analytical solution


Appendix B Analysis of Existing Group-Disentanglement Approaches

In this Section, we give further details about the content approximate posterior proposed by Bouchacourt et al. bouchacourt2018multi, and Hosoya hosoya2019group. Further, we analyze the proposed approaches and show its limitations.

b.1 MLVAE and GVAE

As described in Subsection 3, we restrict to two groups and define corresponding latent variables and given observation 666In similar fashion, we define two latent variables and for observation .. However, both works also apply to any number of groups. For paired observations with shared group factor , the loss objectives for MLVAE bouchacourt2018multi and GVAE hosoya2019group are


The loss objectives and are very similar. The only exceptions are the group approximate posteriors, for and for .

bouchacourt2018multi assume the group approximate posterior to be a product of the individual approximate posteriors sharing the same group


The product of two or more Normal distributions is Normal distributed, and thus the kl term can still be calculated in closed-form.

hosoya2019group uses an empirical average over the parameters of the individual approximate posteriors. The group approximate posterior is defined as


b.2 Analysis

Both MLVAE and GVAE enforce disentanglement through the -regularization in the last two terms of eq:mlvaeeq:gvae. This regularization was also used in -VAE higgins2016beta which regularizes a trade-off between disentanglement and reconstruction. The two kl terms in eq:mlvaeeq:gvae can be decomposed similar to the elbo and kl decomposition in hoffman2016elbo; ChenLGD18. We consider the objective in (14) averaged over the empirical distribution

. Each training sample denoted by a unique index and treated as random variable

. We simplify and refer to as the aggregated posterior hoffman2016elbo. We can decompose the first kl in (14)777We can decompose the kl of GVAE in (15) similarly. as


where . We show the full derivation in the next Subsection B.3. Minimizing the averaged kl between the content and style latent variables () and the prior also leads to minimization of the total correlation of content variables and style variables (the last two terms in (18)). The total correlation quantifies the amount of information shared between multiple random variables, i.e., low total correlation indicates high independence between the variables. Even though this objective motivates disentangled content and style representations, the group representation depends on the number of samples used for the averaging. Further, both bouchacourt2018multi and hosoya2019group only average over the content group. There are no structural nor optimization constraints that prevent the style latent variable from encoding all factors of variation.

(a) 3DShapes: MI between latent dimensions and factors of variation of a trained GVAE model with and .
(b) dSprites: group-MIG of content and style information for all hyperparameter runs.
(c) dSprites: MIG w.r.t. different number of shared observations for MLVAE and GVAE.
Figure 5: Collapse and sensitivity of existing weakly supervised group disentanglement models. (a) shows mutual information (MI, higher is better) for a GVAE model trained on 3DShapes. (b) plots both group-MIG (higher is better) w.r.t. content and style information trained on dSprites. (c) plots MIG (higher is better) w.r.t. number of shared observations.
Sensitivity to group batch size.

MLVAE and GVAE use different types of averaging over group latent variables. In realistic settings, always having a certain number of observations that share the same group variations can be difficult. For instance, when training MLVAE and GVAE with dSprites, the performance and its variance is correlated with the number of shared observations. We visualized these findings in Figure 5 (c).

Visualization of collapse.

We visualize such behavior in Figure 5 (a) on a GVAE model trained on 3DShapes with two groups of variations {object color, object size and object type} and {floor color, wall color, azimuth}. Ideally, contains high mutual information with group factors and contains high mutual information with group factors . However, most information is captured in , whereas only a little information about object type is contained in .

b.3 KL Decomposition

Here, we show the full derivation for  (18). For a given group the KL decompose as follows:


where denote the empirical data distribution.

Appendix C Experimental Setup

c.1 Disentanglement Study

All hyperparameters for optimization and model architectures are listed in Table 4. We compare our approach, GroupVAE, to four different models: -VAE higgins2016beta, AdaGVAE locatello2020weakly, MLVAE bouchacourt2018multi and GVAE hosoya2019group. To fairly compare all models, we used the same architecture and optimization settings for all models and only varied the range of the regularization. We ran five experiments for every hyperparameter set with different random seeds (). In total, we ran 240 experiments. Each experiment ran on GPU clusters consisting of Nvidia V100 or RTX 6000 for approximately 2-3 hours.

Datasets and group sampling.

We evaluated our approach on three datasets: 3D Cars reed2014learning, 3D Shapes 3dshapes18 and dSprites dsprites17. All datasets contain images of size with pixels normalized between 0 and 1. For training, given observations and groups , we sample uniformly from all groups and the observation uniform from all observations which share the same group values as .

Evaluating disentanglement.

In addition to comparing group disentanglement, we also used mig ChenLGD18 to compare the models’ ability to disentangle all factors of variation. ChenLGD18 proposed MIG as an unbiased and hyperparameter-dependent evaluation metric to measure the mutual information between each ground truth factor and each dimension in the computed representation. The MIG is calculated as the average difference between the highest and second-highest normalized mutual information of each factor. The score is computed as


where and is the number of known factors.

Parameters Values
Batch size 64
Latent dimension 10
Optimizer Adam
Adam: beta1 0.9
Adam: beta2 0.999
Adam: epsilon 1-8
Adam: learning rate 5-4
Training iterations 300,000
(c) Common hyperparameters.

(Stride 2), ReLU act.,

Conv (Stride 2), ReLU act.,
Conv (Stride 2), ReLU act.,
Conv (Stride 2), ReLU act.,
FC 256, ReLU act., FC 2 10
FC 1024, ReLU act., Reshape (64, 4, 4),
TransposeConv (Stride 2), ReLU act.,
TransposeConv (Stride 2), ReLU act.,
TransposeConv (Stride 2), ReLU act.,
TransposeConv (Stride 2)
(d) Common model architectures.
Model Parameter Values
-VAE higgins2016beta
AdaGVAE locatello2020weakly
MLVAE bouchacourt2018multi
GVAE hosoya2019group
(e) Model hyperparameters.
Table 4: Experimental setup for the disentanglement study. We list hyperparameters, model architectures and hyperparamter common to the disentanglement study.

c.2 Fairness

FFVAE discriminator FC 1000, LeakyReLU(0.2) act., FC 1000, LeakyReLU(0.2) act.,
FC 1000, LeakyReLU(0.2) act., FC 1000, LeakyReLU(0.2) act.,
FC 1000, LeakyReLU(0.2) act., FC 2
FC 128, ReLU act., FC 128, ReLU act., FC 128, FC 2
Conv , ReLU act., MaxPool
Conv , ReLU act., MaxPool, FC 120,
ReLU act., FC 84, ReLU act., FC 2
(a) Additional model architecture.
Model Parameter Values
(b) Additional model hyperparameter.
Dataset Parameters Values
CelebA Latent dimensions
[sensitive, non-sensitive]
dSpritesUnfair Latent dimensions
[sensitive, non-sensitive]
(c) Dataset-specific hyperparameters.
Table 5: Experimental settings for fair classification. We list hyperparameters of FFVAE and the MLP and CNN baselines.

We ran five experiments for every hyperparameter set with different random seeds (). In total, we ran 550 experiments. Each experiment ran on GPU clusters consisting of Nvidia V100 or RTX 6000 for approximately 2-3 hours.


For the fair classification experiments we used the same common hyperparameters and model architecture as in the disentanglement studies (Table 4 (a) and (b)) for GroupVAE, GVAE and MLVAE. In addition, we implemented two simple baselines, an mlp and a cnn. The architecture for these two models are described in Table 5. For the supervised fair classification, we implemented FFVAE creager2019flexibly with the same encoder and decoder networks as in Table 4 (b) and the FFVAE discriminator as in Table 5

. The baselines are trained with a cross-entropy loss between the logits of the network and the binary label “HeavyMakeup”. We used different number of latent dimensions which is shown in Table 

5 (c).

Sensitive and non-sensitive latent variables.

Similar to the content and style disentanglement setup, we define two groups, sensitive and non-sensitive. GroupVAE can be optimized to learn from weakly supervised observations sharing either sensitive or non-sensitive group values. FFVAE creager2019flexibly can be seen as the supervised approach of learning sensitive and non-sensitive representations. FFVAE maximizes the ELBO objective (reconstruction loss and KL divergence between approximate posterior and prior). In addition, the objective regularizes the discriminative ability of the sensitive latent variable with in a supervised manner (how well can the model classify sensitive labels from sensitive latent variable?) and the disentanglement with (how well is the sensitive latent variable disentangled from the non-sensitive latent variable?).


For comparability with FFVAE creager2019flexibly, we used similar dataset settings for CelebA li2018deep and dSpritesUnfair. Both datasets contain images with pixels normalized between 0 and 1. We used the pre-defined train, validation, and test split of CelebA li2018deep, whereas in dSpritesUnfair we use a random split of 80% train, 5% validation, and 15% test.


dSpritesUnfair is a modified version of dSprites dsprites17. The two components are the binarization of the factors of variation and biased sampling. dSprites contains images which are described by five factors of variation. We binarized the factors of variations following these criterion creager2019flexibly:

  • Shape

  • Scale

  • Rotation

  • X-position

  • Y-position

Similar to trauble2020independence, we enforce correlations between shape and x-position through a biased sampling. In the training set, we sample these two factors from a joint distribution


where determines the strength of the.correlation and is set to in our experiments. The smaller , the higher the correlation between the two factors.

Model selection.

As shown in creager2019flexibly, there is a trade-off between classification accuracy and demographic parity. Thus, model selection based on only one of these metrics compromises the other. We propose to use the difference between the two metrics as a way to do model selection. We coin this metric FairGap (FG) and define it as


FG is high if accuracy is high and the average demographic is low, resulting in a fair classifier. We select the model on the test set of CelebA and dSprites based on the FG of the validation set.