Groupified-VAE
Repo for Groupified VAE: "GroupifyVAE: from Group-based Definition to VAE-based Unsupervised Representation Disentanglement"
view repo
The key idea of the state-of-the-art VAE-based unsupervised representation disentanglement methods is to minimize the total correlation of the latent variable distributions. However, it has been proved that VAE-based unsupervised disentanglement can not be achieved without introducing other inductive bias. In this paper, we address VAE-based unsupervised disentanglement by leveraging the constraints derived from the Group Theory based definition as the non-probabilistic inductive bias. More specifically, inspired by the nth dihedral group (the permutation group for regular polygons), we propose a specific form of the definition and prove its two equivalent conditions: isomorphism and "the constancy of permutations". We further provide an implementation of isomorphism based on two Group constraints: the Abel constraint for the exchangeability and Order constraint for the cyclicity. We then convert them into a self-supervised training loss that can be incorporated into VAE-based models to bridge their gaps from the Group Theory based definition. We train 1800 models covering the most prominent VAE-based models on five datasets to verify the effectiveness of our method. Compared to the original models, the Groupidied VAEs consistently achieve better mean performance with smaller variances, and make meaningful dimensions controllable.
READ FULL TEXT VIEW PDFRepo for Groupified VAE: "GroupifyVAE: from Group-based Definition to VAE-based Unsupervised Representation Disentanglement"
Repo for Groupified VAE: "GroupifyVAE: from Group-based Definition to VAE-based Unsupervised Representation Disentanglement"
Learning independent and semantic representations of which individual dimension has interpretable meaning, usually referred to as disentangled representations learning, is important for artificial intelligence research. Such disentangled representations are useful for many tasks
(bengio2013representation): domain adaptation (li2019cross; zou2020joint), zero-shot learning (lake2017building), and adversarial attacks (alemi2016deep), etc. Intuitively, for a disentangled representation, each latent unit is only sensitive to changes of an individual generative factor. higgins2018towards propose a formal mathematical definition of disentangled representations from a perspective of Group Theory and Group Representation Theory, which is widely accepted (greff2019multi; mathieu2019disentangling; khemakhem2020variational).Recently, caselles2019symmetry; quessard2020learning propose to learn disentangled representation based on the group^{1}^{1}1The terms in Group Theory are marked in bold italics in this paper and explained in Appendix A-based definition of higgins2018towards. However, these methods can only work on toy datasets and require the ground truth (the World States in higgins2018towards) of the generative factors, as they rely on interaction with the environments. Subsequently, painter2020linear
propose to estimate the actions (e.g., interactions with the environment) by policy gradient to get rid of World States. However, interaction with the environment is still required.
Most of the state-of-the-art methods (higgins2016beta; burgess2018understanding; kim2018disentangling; chen2018isolating; kumar2017variational)
are based on Variational Autoencoders (VAEs)
(kingma2013auto), which learn disentangled representations from the perspective of probabilistic inference, i.e., by enforcing a factorized aggregated posterior. These methods are fully unsupervised and can be applied to a variety of complex datasets. However, based on Measure Theory, locatello2019challenging prove that VAE-based unsupervised disentanglement is fundamentally impossible without introducing inductive bias on both models and data. Therefore, it is necessary to find some non-probabilistic inductive bias and introduce it to VAE-based disentangled representation learning. Furthermore, these models do not consider the definition of higgins2018towards. Motivated by the above observations, we derive constraints from the group-based definition and introduce them into these VAE-based models as a non-probabilistic inductive bias to facilitate unsupervised disentanglement.In Group Theory, the nth dihedral group (judson2020abstract) is a set of all permutations of polygons vertices, forming a group under the operation of composition (miller1973symmetry). There are several basic generators in an nth dihedral group, e.g., flip and rotation, which can be regarded as the disentangled factors. Inspired by the nth dihedral group, which is a subgroup generated by some basic generators and also a permutation group, we assume the group in the definition of higgins2018towards (see Section 3.1) is a subgroup generated by disentangled factors; We further define a permutation group based on the existing VAE-based models using group actions (see Section 3.2 and Figure 2), which has the same form as the nth dihedral group. In this setting, as shown in Figure 1, first, we can find two conditions equivalent to the definition (higgins2018towards), as Theorem 1 in Section 3.2 shows: isomorphism (between and ) and the constancy of permutations (which requires World State). We then prove (in Section 3.4) that the isomorphism condition is equivalent to two Group constraints (which are self-supervision signals): Abel constraint (for the exchangeability) and Order constraint (for the cyclicity). We map the latent representation of VAE into an nth roots of unity group by applying the sine and cosine functions (see Figure 2 (a) and Section 3.3) , which makes the existing VAE-based models groupifiable (i.e., based on which can be defined) so that these two constraints can be applied (see Figure 2 (b)) to meet one of the necessary conditions (i.e., isomorphism) of disentangled representation. With these two Group constraints as the non-probabilistic inductive bias, our method consistently achieves statistically better performance in prominent metrics (higher means and lower variances for different VAE-based models on five datasets). Besides, we can even make the meaningful dimensions controllable for AnnealVAE (burgess2018understanding).
Our main contributions are summarized as:
For unsupervised representation disentanglement learning, to our best knowledge, we are the first to unify the formal group
-based mathematical definition with the existing VAE-based probability inference models by groupifying existing models.
We map the formal group-based mathematical definition into two specific conditions (isomorphism and constancy of permutations), where isomorphism acts as a non-probabilistic inductive bias, to facilitate unsupervised disentangled representation learning.
We propose to use the sine and cosine functions to make the existing VAE-based models groupifiable so that the Group constraints can be applied. We then convert the proposed isomorphism condition into a loss function specifically applicable to the VAE-based models.
Different definitions have been proposed for disentangled representation (bengio2013representation; higgins2018towards; suter2019robustly). However, only the group-based definition proposed by higgins2018towards focuses on the disentangled representation itself and is mathematically rigorous, which is well accepted (caselles2019symmetry; quessard2020learning; painter2020linear; TopologDefects). Nevertheless, higgins2018towards do not propose a specific method based on their definition. Before this rigorous definition was proposed, there have been some success in identifying generative factors in static datasets (without environment), such as -VAE (higgins2016beta), Anneal-VAE (burgess2018understanding), -TCVAE (chen2018isolating), and FactorVAE (kim2018disentangling). These VAE-based unsupervised methods are based on probabilistic inference. Recently, locatello2019challenging
proved by Measure Theory (which is the basis of Probability Theory) that these methods theoretically have infinitely many solutions. Therefore, introducing non-probabilistic inductive bias into them would help to reduce the solution space.
It is not straightforward to reconcile the probabilistic inference methods with the group-based definition framework (quessard2020learning). caselles2019symmetry; quessard2020learning; painter2020linear leverage the interaction with the environment (assuming it is available) as supervision instead of minimizing the total correlation as the VAE-based methods do. Consequently, the effectiveness of these methods is limited to the datasets with the environment available. pfau2020disentangling
propose a non-parametric method to unsupervisedly learn linear disentangled planes in data manifold under a metric by leveraging the
lie group. However, as pointed out by the authors, the method does not generalize to held-out data and performs poorly when trying to disentangle directly from pixels.To summarize, the probabilistic inference methods lack theoretical support and non-probabilistic inductive bias, while the application scope of existing methods based on the group-based mathematical definition (higgins2018towards) is limited. Therefore, unifying the probabilistic inference methods and the group-based definition framework is essential for learning disentangled representations. To the best of our knowledge, our work is the first to reconcile the probabilistic generative methods with the inherently deterministic group-based definition framework of higgins2018towards.
We first review the group-based mathematical definition of disentangled representation (higgins2018towards) in Section 3.1. We then map the definition into two equivalent conditions: one related to the World States, the other not (isomorphism). The isomorphism is used as a non-probabilistic inductive bias for the VAE-based models (Section 3.2), after making the VAE-based models groupifiable (Section 3.3). Section 3.4 describes how to convert the isomorphism condition into a specific loss to be incorporated into existing VAE-based models.
We assume some basic familiarity with the fundamentals of Group Theory and Group Representation Theory. Please refer to Appendix A for some basic concepts. In this section, we briefly introduce the group-based definition of disentangled representation (higgins2018towards). Some of the settings and mathematical symbols in this section are used later when introducing our method.
Let be a set of World States. We assume that the data is obtained through a generation process , which maps from the World States to observations (we focus on images in this paper). Let be a set of representations, and we have an inference process (done by the encoder) , which maps from observations to representations. Then, we consider the function composition .
Consider a group acting on and via group action and group action respectively. We state: the mapping is equivariant between the actions on and if
(1) |
Assume can be decomposed as . The set is disentangled with respect to if: the mapping is equivariant between the actions on and . There is a decomposition such that each is affected only by the corresponding .
It is not straightforward to apply the definition above to existing VAE-based unsupervised disentanglement models. A specific form for the definition is required to introduce the definition into those VAE-based models as a non-probabilistic inductive bias.
As mentioned in Section 1, we assume is a subgroup generated by disentangled factors, i.e., direct product of cyclic groups: , where is the factor size. To get a non-probabilistic inductive bias from the definition without requiring the World States, the key is to define a permutation group based on the VAE-based model, inspired by the nth dihedral group as discussed in Section 1. Considering an ideal VAE-based model, where the encoder and the decoder are inversion of each other, the group action of on is defined as:
(2) |
where is a permutation on , and “overlin” indicates congruence class (equivalence class under congruence relation). denotes the group action of on , which is defined as the addition on congruence class. Therefore, the set forms a permutation group under the operation of composition as shown in Figure 2, where , denotes the symmetry group on . In this setting, the group-based definition is equivalent to two conditions as shown in Theorem 1. The proof is provided in Appendix B.
For the group and a decompostion , if there exists an isomorphism between and , i.e., , and the constancy of permutations: for and , we have , where is an unkonwn constant, then is disentangled with respect to .
denotes that is generated by generators (see the toy examples in Figure 2 (a)). The isomorphism condition requires no World States. Therefore, we can apply it to the existing VAE-based models as a self-supervision signal. The isomorphism condition is the non-probabilistic inductive bias.
In this paper, “groupifying a VAE-based model” means forming a permutation group (isomorphic to ) based on the model as discussed above. In Section 3.3, we discuss how to introduce congruence class into existing VAE-based models to achieve group action of on , which is necessary for forming a permutations group . Then in Section 3.4 we convert the isomorphism condition into a loss function, which can be applied to VAE-based models, as a necessary condition to achieve disentangled representations.
The existing VAE-based models can not form the permutation group for lacking of congruence class. In this section, we map the representation of the VAE-based models to the n-th root unity group for implementing congruence class into the models. Based on this, the permutation group can be defined and the isomorphism condition can be established.
The congruence classes form a direct product of cyclic groups , where is the order of the groups. To implement congruence classes, we need a group that the representation can be mapped to with a differentiable function (to allow back-propagation) and that is isomorphic to . From Group Theory, we know there is an isomorphism between and the n-th root unity group:
(3) |
Therefore, the representation is mapped to by groupifiable function as (see Figure 2 (a)). However, can not be mapped to directly for it has complex numbers, but we can use Euler’s formula: to map
to its real and imaginary part, i.e., vector
and . The two vectors are fed to the decoder after concatenation. Then the congruence class is implemented for the VAE-based models. In practice, the functions should multiply a scaling coefficient for better optimization (sitzmann2020implicit). We refer to such updated VAE-based models as groupifiable models.The group is similar to the nth dihedral group: it can be generated by several generators, which is guaranteed by the isomorphism condition (i.e., and are isomorphic); the elements of the group are permutations. Inspired by this, we convert the isomorphism condition into two constraints on the generators. Furthermore, we convert these constraints into an Isomorphism Loss on the groupifiable models we build. The group is a direct product of the cyclic groups. Therefore, the group , being isomorphic to G, is expected to be commutative and cyclic. The constraints on the generators thus should include two parts: one is exchangeability, the other is cyclicity. Based on this observation, we derive two constraints that together are equivalent to the isomorphism condition, as shown in Theorem 2. Please refer to Appendix C for the proof.
The permutation group we defined is isomorphic to if and only if: for generators , we have , and , we have , where is the identity element of group .
The first condition in Theorem 2 suggests that: first performing permutation and then permutation on images should be equal to first performing permutation then permutation . Intuitively, a good disentanglement model should meet the constraint. The second suggests that: performing permutation on images times remain unchanged. We refer to these two conditions as Group constraints.
In the following, we convert the Group constraints into losses and apply them to the existing VAE-based groupifiable models. The first constraint is about the exchangeability, which requires the group to be an abelian group (judson2020abstract). Therefore, we name it as Abel constraint and the loss derived from it as the Abel Loss . The second is about cyclicity, a constraint on the order of elements. We thus name it as the Order constraint and the loss derived from it as the Order Loss .
Abel Loss. For a given VAE-based model, the generator of group is defined as , where is identical element of dimension in and denotes group actions of on , as shown in Figure 2 (a). We implement by adding the action scale on the -th dimension of , then mapping it to by the sine and cosine functions. Here we focus on the Abel constraint: , we have . We minimize to make the first condition satisfied.
Denote the set of factors learned by a VAE-based model as . The Abel Loss function needs to constrain any two factors learned for the exchangeability. Therefore, the Abel Loss is the sum of the losses on the combination of factors. Denote the set containing factors combinations as . Finally, the Abel Loss of the groupifiable VAE-based model is as follows:
(4) |
where represents the upper path of Figure 2 (b), and represents the lower path of Figure 2 (b). See Appendix E for details of and .
Order Loss. For the Order constraint: , we have , where is the identity element in group (identity mapping). is rewritten as , where the exponent operation does not mean multiplication but composition instead. Note that with times composition of , it is difficult for the gradient to back-propagate. Recall that the decoder is approximately equal to the inverse of the encoder , so can be approximated as , where . With this approximation, there are only two times of composition, which is easy to be optimized. Besides, because
(5) |
where is the inverse of . Therefore, we implement by adding on the -th dimension of (let ), and mapping it to by the sine and cosine functions. Similar to Abel Loss, we minimize to make the second condition satisfied. The whole process is illustrated in Figure 2 (b). However, this conversion is not symmetrical, which leads to bias in the optimization process. Therefore, we convert the Order constraint into a symmetrical form: . There is an Order Loss for each element of . Therefore, the Order Loss is as follows:
(6) |
see Appendix E for the details of implementation. With the above two loss functions optimized, the isomorphism condition is satisfied. This can be illustrated by Theorem 3. Please refer to Appendix D for the proof.
Since the Abel Loss and Order Loss are equally important for meeting the isomorphism condition, we assign equal weight for them. Thus, the Isomorphism Loss is . We name the groupifiable VAE optimized with this loss as groupified VAE.
We first verify the effectiveness of groupified models quantitatively in enhancing the ability to learn disentangled representations on several datasets and several models. Then, we show their effectiveness qualitatively on two typical datasets. After that, we perform a case study on the dSprites dataset to analyze the effectiveness and ablation studies on the losses and hyperparameters. Furthermore, we evaluate the performance of two downstream tasks trained on the representations learned by the groupified models: abstract reasoning
(van2019disentangled) and fairness evaluation (locatello2019fairness). For more comprehensive results, please see Appendix G.To evaluate our method, we consider several datasets: dSprites (higgins2016beta), Shapes3D (kim2018disentangling), Cars3D (car3d), and the variants of dSprites introduced by locatello2019challenging: Color-dSprites and Noisy-dSprites. Please refer to Appendix F for the details of the datasets.
We choose the following four baseline methods as representatives of the existing VAE-based models, which have broad impacts. After groupifying these models, we verify the effectiveness of our method. -VAE (higgins2016beta) introduces a hyperparameter in front of the KL regularizer of VAEs loss. It constrains the VAE information capacity to learn the most efficient representation. AnnealVAE (burgess2018understanding) progressively increases the bottleneck capacity so that the encoder learns new factors of variation while retaining disentanglement in previously learned factors. FactorVAE (burgess2018understanding) and -TCVAE (chen2018isolating) both penalize the total correlation (watanabe1960information), but estimate it with adversarial training (nguyen2010estimating; sugiyama2012density) and Monte-Carlo estimator respectively.
dSprits | DCI | BetaVAE | MIG | FactorVAE | ||||
---|---|---|---|---|---|---|---|---|
Original | Groupified | Original | Groupified | Original | Groupified | Original | Groupified | |
-VAE | ||||||||
AnnealVAE | ||||||||
FactorVAE | ||||||||
-TCVAE |
Cars3d | DCI | BetaVAE | MIG | FactorVAE | ||||
---|---|---|---|---|---|---|---|---|
Original | Groupified | Original | Groupified | Original | Groupified | Original | Groupified | |
-VAE | ||||||||
AnnealVAE | ||||||||
FactorVAE | ||||||||
-TCVAE |
Shapes3d | DCI | BetaVAE | MIG | FactorVAE | ||||
---|---|---|---|---|---|---|---|---|
Original | Groupified | Original | Groupified | Original | Groupified | Original | Groupified | |
-VAE | ||||||||
AnnealVAE | ||||||||
FactorVAE | ||||||||
-TCVAE |
Original | Group Order | Action Scale | Group Order , Action Scale | ||||||||
Groupifiable | w/o Abel | w/o Order | Groupified | ||||||||
DCI |
This section performs quantitative evaluations on the datasets and models introduced with different random seeds and different hyperparameters. Then, we evaluate the performance of the original and groupified models in terms of several popular metrics: BetaVAE score (higgins2016beta), DCI disentanglement (eastwood2018framework) (DCI in short), MIG (chen2018isolating), and FactorVAE score (kim2018disentangling). We assign three or four hyperparameter settings for each model on each dataset. We run it with random seeds for each hyperparameter setting to minimize the influence of random seeds. Therefore, we totally run models. We evaluate each metric’s mean and variance for each model on each dataset to demonstrate the effectiveness of our method. As shown in Table 1, these groupified models have better performance (numbers marked bold in Table 1) than the original models on almost all the cases. Please refer to Appendix G for more results.
On Shapes3d, the groupified models outperform the original ones on all the metrics except for BetaVAE scores, suggesting that there is some disagreement between BetaVAE scores and other metrics. Similar disagreement is also observed between the variances of MIG and other metrics on Cars3d. Note that the qualitative evaluation in Appendix H shows that the disentanglement ability of groupified VAEs is better on Shapes3d and Cars3d.
(a) Original | (b) Groupified |
By groupifying the VAEs, we qualitatively show they achieve better disentanglement than the originals. As shown in Figure 6, the traversal results of groupified -TCVAE on Shape3d and Car3d are less entangled. For more qualitative evaluation, please refer to Appendix H.
To demystify how the groupifying helps the VAE-based models to improve the disentanglement ability, we take dSprites as an example, visualize the learned latent space, and show the typical score distributions of the metrics. First, we visualize the space spanned by the three most dominant factors (x position, y position, and scale). Figure 4 is plot as follows: We traverse all the related World States to get the images; feed them into the encoder get the representations; take the three corresponding dimensions as the location of the points in the 3D space and use different colors to indicate different images. As shown in Figure 4 (for more results, please refers to Appendix J), the spaces learned by the original models collapses, while the spaces of the groupified models only bend a little bit. The main reason is that the Isomorphism Loss, serving as a self-supervision signal, suppresses the representation space distortion and encourages the disentanglement of the learned factors. As Figure 3 shows, the groupified models consistently achieve better mean performance with smaller variances. The Group constraints reduce the search space of the network so that the groupified model converges to the ideal disentanglement solution with higher probability.
An interesting observation is that meaningful dimensions can be controlled in groupified AnnealVAE. As shown in Figure 5, the KL divergence increases continuously on these assigned dimensions after the Isomorphism Loss is applied to them. Note that the KL divergence loss in AnnealVAE indicates the amount of information encoded. As Figure 5 (b) shows, the KL divergence of assigned dimensions increases at the beginning of training, which means Isomorphism Loss results in that the assigned dimensions become meaningful. Finally, the assigned first-five dimensions learn to encode the semantics of x position, y position, scale, and orientations. For more results, please refer to Appendix I. The underlying reason is that Isomorphism Loss plays a role of cycle consistency. Models are forced to encode information in the assigned dimensions to satisfy the cycle consistency in the optimization process. The latent factors are learned and disentangled in the assigned dimensions, due to the Isomorphism Loss and the total correlation punishment.
We perform ablation study on the action scale , Group order , Abel Loss , and Order Loss . We take the AnnealVAE trained on dSprites as an example. We only consider the DCI disentanglement metric here. To illustrate the influence of the action scale, we vary from to with set to , and the full Isomorphism Loss is applied. Similarly, we set the action scale to , and investigate the influence of Group size . Besides, to evaluate the effectiveness of the two constraints, the models with the Abel Loss alone or Order Loss alone are also evaluated. In this setting, we fix to 5 and to 10. We compute the mean and variance of the performance for 30 settings of hyperparameters and random seeds. Table 2 shows that a larger action scale leads to better performance, as a larger action scale makes the isomorphism condition harder to be satisfied, which requires better disentanglement. The isomorphism condition plays a role of cycle consistency in the latent space, leading to better disentanglement. The performance is robust to the Group order , as the models learn to adapt to different in the training process. The models with only the Abel Loss or Order Loss applied have improved performance compared to the originals. The former performs better than the latter, suggesting that exchangeability plays a more important role.
Abstract reasoning | Unfairness scores | |
---|---|---|
Original | ||
Groupified |
As pointed out by locatello2019challenging, the disentangled representation’s downstream tasks should also be verified. Therefore, we verify the effectiveness of the representations learned by the groupified VAE-based models on Shapes3d in two downstream tasks: abstract reasoning (van2019disentangled) and fairness evaluation (locatello2019fairness). As Table 3 shows, the performance of the abstract reasoning models fine-tuned on the representation learned by the groupified FactorVAEs is better than the original ones. In terms of fairness evaluation, we can observe that the unfairness scores of the representation learned by the groupified FactorVAEs are lower than the original ones.
In this paper, we unify the group-based mathematical definition of disentangled representation with the existing VAE-based probability inference models by groupifying them. Inspired by the nth dihedral group, we map the group-based mathematical definition of unsupervised representation disentanglement into two equivalent conditions: isomorphism and the constancy of permutations. We use the former as a self-supervision signal, which is converted into Group constraints. These constraints are converted to the Abel Loss for the exchangeability and Order Loss for the cyclicity, summation as Isomorphism Loss. After making the VAE-based models groupifiable by mapping the representation into the n-throot unity group, we form a VAE-based permutation group, and apply Isomorphism Loss to meet the isomorphism condition. We observe through our extensive experiments that the Group constraints narrows down the optimization space, and by incorporating this non-probabilistic inductive bias, the groupified VAEs achieve better average performance with smaller variances and controllable meaningful dimensions. For future work, extending GroupifyVAE by leveraging the lie group (hall2015lie) (which is also a manifold) is an interesting direction.