Intuitively speaking, a disentangled representation can be defined as a (usually low-dimensional) encoding of of a data sample , where distinct components of are responsible for encoding a specific generative factor of the data. Despite the different attempts in the literature, coming up with a formal definition of what disentanglement actually is has proven more difficult than expected [do2019theory]. Several works just assume that a disentangled representation is a representation in which a single latent dimension responsible for encoding a single generative factor of the data. This definition, while easy to formalise in a mathematical way, has resulted to be too restrictive in general. Recently, [2022_valenti_leveraging] relaxed this definition by introducing weak disentangled representation, where each generative factor can be encoded in a different region of the latent space without imposing additional limitations on their dimensionality. Despite the advantages of this new approach, the initial implementation of [2022_valenti_leveraging] suffered from the fact that the number of annotations required for achieving weak disentanglement grew very quickly in the number of generative factors.
In this paper, we address this limitation by introducing modular representations for weak disentanglement. In a modular representation, each partition of the latent space encodes the respective generative factor in a different adaptive prior distribution, independent from the others. We show that models that use modular representations are able to accurately perform controlled manipulations of the learned generative factors of the data without the need of increasing the amount of supervised information.
2 Related Works
Early methods for disentanglement are mainly concerned with increasing prior regularisation of the loss function[2017_higgins_beta-vae, 2018_burgess_understanding]. Another line of work [2018_kim_factor_vae, 2018_chen_beta-TCVAE, 2019_zhao_infovae] penalises different terms of the same loss function in various ways. They define disentanglement using simple mathematical notions (e.g. total correlation of the latent dimensions). After the results of [2019_locatello_challenging], showing that pure unsupervised disentanglement is in general impossible to achieve, many works started using various degrees of supervised information [2017_lample_fader], either in the form of complete supervision on a small subset of training data [2019_locatello_few-labels] or partial annotations on a subset of generative factors [gabbay2021image]
. Some works use additional classifier networks on the latent space in order to separate different parts of the latent codes. While useful, these method are not practical when multiple factors of variations need to be disentangled at the same time. Other methods for introducing implicit supervision involve dropping the i.i.d. assumption by leveraging relational information between the samples. The relational information can be group-wise[locatello2020without_compromises], pair-wise [chen2020pairwise_similarity], or structural [bai2021contrastively]. Recently, [2022_valenti_leveraging] introduced the concept of weak disentanglement, overcoming many of the above limitations. However, their method requires an increasing amount of supervision when the number of generative factors increases.
3 Modular Representations for Weak Disentanglement
A general overview of the model’s architecture is illustrated in Fig. 1. We frame our representation learning problem as an auto-encoding task. Given a data sample , we want to output a faithful reconstruction . The encoder network , parameterised by , takes a data sample as input and produces latent codes, where is the number of generative factors of the data. Conversely, the decoder network , parameterised by , combines these partial latent codes to reconstruct the initial input. The resulting Modular AutoEncoder
Modular AutoEncoder(M-AE) model is then trained using the following maximum likelihood objective:
The first term of Eq.1 is directly responsible for ensuring a good reconstruction of the original input . The second term, a sum of KL divergences between the aggregate posteriors and the priors , encourages each partition of the latent space to follow a specific prior distribution. These priors are directly inspired from the data and are enforced in an adversarial way, similar to GANs [2014_goodfellow_gan]. In particular, this term is optimised via an additional discriminator network , parameterised by :
This adversarial loss allows us to choose the most suitable prior distribution for each partition. In particular, since our goal is to identify all possible real-world instances of a particular value of a generative factor , we model each as a mixture of normal distributions
mixture of normal distributions: where is the number of values that factor can take. We build a different mixture for each partition. Specifically, the parameters and
of each prior’s components are empirically estimated using a small subset of annotated samples:, where denotes the subset of (encoded) supervised samples where the factor takes value . Since each latent partition encodes a different generative factor, when can re-use the same annotated samples for computing the different parts of the prior. The second part of the model is the Relational Learner (ReL). During training, the ReL learns how to perform controlled changes to specific properties of the data sample by leveraging the representations learned by the M-AE. The ReL is composed of the relational sub-network , parameterised by . Assuming that the relation to be learned affects only the value of a single factor of variations, the relational objective becomes the following:
where is the output of the relational learner. The function defines the “connections” between the prior components that correspond to a specific relation. This can be easily extended for losses that affect multiple factors. This loss function encourages the partition affected by the relation to match the prior of the new value of that factor, while the other partitions remain unchanged. The correspondence between components of the prior and generative factor values is made possible by the representation learned by the M-AE. Finally, training is done end-to-end by combining the previous losses
We consider two disentanglement tasks based on the dSprites [dsprites17] and Shapes3D [3dshapes18] datasets, containing respectively 2D and 3D images of shapes that express different combinations of generative factors (shape, x/y-position, scale, and orientation for dSprites; floor-color, shape-color, background -color, and orientation for Shapes3D). We consider all the relations that affect the change of a single generative factor of the data (e.g. move-left, move-right, +hue, change-shape, etc.). No restriction is imposed on the nuisance factors, that are able to vary freely when applying relations on the latent codes. For each dataset, we construct three versions of increasing complexity, characterised by different choices of relevant and nuisance factors.
The M-AE encoder and decoder are implemented as a CNN111All the code of the models and the experiments is publicly available: https://github.com/Andrea-V/Weak-Disentanglement ., while the prior’s Discriminator and the ReL are 3-layers MLP with 1024 units each. We use 8-dimensional latent codes for each generative factors, for a maximum size of the latent space of . All tasks use a batch size of 1024 for the M-AE and 128 for the ReL. The parameter of Eq. 1 is set to . The optimiser used for all modules is Adam with a learning rate of . Training is divided in two stages. In the first stage, called warmup, only the M-AE is trained. The prior is set to . After epochs we enter the full training stage, where the prior of each latent partition is set to the adaptive prior described in Sec. 3. We construct a different prior for each generative factor, leveraging the annotations of the supervised subset. At the same time, the training of the ReL begins: the input data samples are constructed as triples , where and are respectively the encoded input and output samples for the relation . The latent codes are sampled from their respective components in the latent space. The concurrent training of the M-AE and the ReL is carried on during the full training phase for 5000 additional epochs.
|Previous Work [2022_valenti_leveraging]||This Work|
Latent Codes Manipulation.
In this first set of experiments, we are interested to analyse how well suited are the modular representations to perform controlled changes of generative factors in the latent codes. We compute the relation accuracy of the ReL by first sampling a latent code from the prior, then we apply a random relation and check the outcome. The results are reported in Table. 1, compared with the previous work of [2022_valenti_leveraging]. The results show that modular representations are beneficial for the accuracy of the ReL, while not requiring an increasing amount of supervised data when the number of factor value combinations increases.
|Locatello et al. [2019_locatello_few-labels]||0.533||0.01||0.01||0.48||0.05||0.08|
|Gabbay et al. [gabbay2021image]||0.8366||0.14||0.57||1.0||0.3||1.0|
|Valenti et al. [2022_valenti_leveraging]||0.9543||0.994||0.7728||0.6921||0.6897||0.5007|
We compare the SAP [2017_kumar_variational], DCI [2018_chen_beta-TCVAE] and MIG [2018_eastwook_framework] disentanglement scores against several models of the literature. Following the approach of [2022_valenti_leveraging], we convert our modular representations into its corresponding generative factor values before computing the scores. This step can be done at no additional computational cost. The results are reported in Table 2 showing that modular representations have a beneficial impact to all the scores, especially considering the challenging SAP score. This is a strong sign that the modular separation of weakly disentangled representations is indeed able to improve the disentanglement performance of generative models.
In this paper, we introduced a novel framework for learning modular weakly disentangled representations. Modular representations encode each generative factor into a separate partition of the latent space, thus overcoming the need of requiring additional supervision when the number of value combinations of the generative factors increases. The experiments show that modular representations allow to perform controlled manipulations to selected generative factors with high accuracy. This, in turn, results in high disentanglement scores. In the future, we wish to further enhance the expressivity of our methods by finding ways to encode continuous generative factors in a weakly disentangled way.