Recently, generative adversarial networks (GANs) have emerged as a scalable solution for the generation of a wide variety of data types from images to text to biological samples (zhu2016generative; huang2017manifold). GANs mainly aim to generate data distributed similarly as the training data, achieved with an alternating minimax game between a generator and a probabilistic discriminator. However, in many applications, simply replicating the density is not very useful (grun2015single). This is because the process of collecting data is often biased, and there is uneven coverage of the data state space, leaving important areas of the state space sparsely sampled (lindstrom2011miniaturization)
. This problem is exacerbated when aligning datasets from different batches, where the goal is to match the geometries of each dataset such that latent objects like cell types are aligned. GANs may skew the alignment of the datasets based on differences in density. To address this, we develop a geometry-based data generation and matching GAN framework.
Generating from a manifold geometry model of the data can be much more useful than replicating density in many applications (lindenbaum2018geometry; coifman2005geometric). Such generation can i) “fill” missing samples from parts of the state space while ii) creating additional samples in sparse areas with rare data points that may be of great interest to the application area, such as biology. Modeling the manifold geometry traditionally involves converting the data into a k-nearest neighbors graph (or some other density-independent graph) and then computing the eigen dimensions of the associated graph Laplacian (belkin2008towards). Other variations include turning the graph Laplacian into a random walk operator and diffusing over the graph (belkin2008towards). In either case, this can become computationally infeasible in cases of large datasets due to the complexity of eigendecomposition and graph computation. Furthermore, it is often not known how to compute meaningful distances when Euclidean distance is inappropriate. For example, a simple pixel-by-pixel L2 distance between the two images is not an accurate measure of how similar they are, and developing meaningful distances measures remain an ongoing challenge (frogner2015learning; wang2005euclidean).
Previous deep learning literature has shown that autoencoders are capable of learning the manifold geometry of a dataset without placing any assumptions on distance (vincent2010stacked; vincent2008extracting; amodio2017exploring; holden2015learning)
. Here we propose to augment the two network system of a GAN (generator and discriminator) with a third network (an autoencoder) to generate data from the manifold geometry. We call this new framework the Manifold Geometry GAN (MG GAN), which has a new manifold geometry loss function. We show that this loss is essential for aligning manifolds. Density-based alignment is very problematic and can skew the alignment between samples with density differences. We further show how to incorporate the geometric loss into a cycle-consistent GAN framework to form the Manifold Geometry Matching GAN (MGM GAN). The main contributions of our work are as follows:
a novel loss that facilitates using a GAN to sample from the manifold geometry
the cycle-consistent alignment framework of the MGM GAN
demonstration of the difference between density generation and geometry generation on many datasets
2 Previous Work
Much work has been devoted to the topic of modeling manifold geometry, largely focused on graph and distance based methods (tenenbaum2000global; lindenbaum2018geometry; wang2013manifold). These include Laplacian eigenmaps and diffusion maps (tu2012laplacian; belkin2003laplacian; ellingson2010validation; coifman2005geometric). The graph-based approach of (oztireli2010spectral) balances density through a resampling scheme of existing data points. With a model of the manifold geometry, two manifolds can be aligned through other, graph-based harmonic methods (bachmann2005exploiting; stanley2018manifold). Unlike this work, these methods all face the difficulties of forming a meaningful graph representation of the data discussed above.
Existing research has demonstrated autoencoders to be effective at learning the data manifold, both in theory and practice, without placing restrictive assumptions like a distance measure between points (vincent2010stacked; vincent2008extracting; holden2015learning; amodio2017exploring; bengio2013representation; goodfellow2016deep). Augmented forms of autoencoders, including those with added adversaries, are a subject of continuing research in order to further leverage the powerful representations that autoencoders are able to learn (makhzani2015adversarial; wang2014generalized; wang2016auto).
GANs have been used previously for domain mapping, with the earliest methods requiring supervision (isola2017image; van2015transfer). Unsupervised domain mapping came to the forefront with the introduction of cycle-consistent GANs. Cycle-consistent GANs address the problem of unsupervised domain mapping by simultaneously learning two generative functions: one that maps from the first dataset to the second dataset and vice versa. Many such pairs of generative functions exist, however, so the key assumption constraining these architectures is that the two generative functions should be each other’s inverse. In practice, cycle-consistent GANs have achieved impressive results on a wide range of applications (hoffman2017cycada; zhu2017unpaired; chu2017cyclegan). However, problems with the density-based loss for domain mapping, including model ambiguity, have been identified (amodio2018magan; dumoulin2016adversarially; perera2018in2i; li2017alice).
Importance sampling is a frequently used technique in statistical literature, used for improving Monte Carlo simulations, Bayesian inference, and surveying.(neal2001annealed; guo2002survey; gregoire2007sampling). In stratified survey sampling, known biases in the sampling process are corrected by weighting populations differently (nassiuma2000survey). It has only been rarely used in deep learning contexts, however, mostly with a focus on optimizing convergence (bengio2003quick).
3 MGM GAN Model
To model the geometry of a single manifold with a GAN we first let be a dataset domain with , . We seek a generator that takes points from a noise domain and maps them to . To guide the generator into creating realistic points, we also train a discriminator network
that tries to distinguish between real points and points mapped by the generator. Adversarial training leads the generator to try to fool the discriminator into classifying its points as real, with the following standard loss terms:
3.1 Importance Sampling
To make the GAN model the geometry instead of the density, we first obtain a representation of the data manifold by extracting a latent layer of a pre-trained autoencoder, letting be the representation of on the manifold (Figure 1a). We then create a Voronoi partition of
with k-means clustering, dividing the space intoregions . We assign weights to each point inversely proportional to the number of points in its dataset that are in that the region on the manifold (Figure 1b):
While these weights can be used to make a single generator sample from the geometry of a single domain manifold, we can also extend this to the case of unsupervised domain mapping where we have two datasets ( and ), and two generators ( and ) and discriminators ( and ), as well. Normally, the minimax game between the generator and the discriminator finds equilibrium when the discriminator’s probabilistic output calculates each point as being equally likely to be from the real sample and the generated sample. In order to accomplish this, if a shared region of the space is sparser in than it is in , the generator must take points in that are not in and map them to points in . In other words, the generators are not learning to align the manifolds of and , but are learning to match the density in the data space.
This alignment warps the manifold geometry, taking points that were originally not similar (points in and points not in ), and projecting them close to each other. Since is already represented in the dataset, this region need not be altered at all. Instead, we want the generator to only change points if they do not look like realistic points in the other dataset. To accomplish this, we adopt the following importance sampling technique.
Theorem Under the traditional GAN loss, a generated distribution that occupies the same regions as the real distribution but with different densities is not a minimum.
Without loss of generality, consider the discriminator loss for a small region r such that :
Using the notation from above where and are the number of points of and that are in , respectively:
Next we take the derivative, and assuming that it is a local optimum, set it equal to zero:
In general, the number of points in the real data and the generated data that are in the region will not be the same, and thus by contradiction this is not a local optimum. ∎
Theorem Under the importance sampling GAN loss, a generated distribution that occupies the same regions as the real distribution , even with different densities, is a minimum.
As before, we will consider the discriminator loss for a small region , except now with the importance sampling weights:
For any values of and greater than zero, this function is a constant and thus its derivative is zero. Thus, the importance sampling GAN loss in a region has a local optimum anywhere there are both real and generated points in that region, no matter their respective densities. ∎
Note that the modified loss function cannot be further lowered for any region that has points in both the real and generated data, but since the weights are constant with respect to the network, the function is still fully differentiable and receives signal to generate points in regions where there is real data but where it does not currently output any points. We further note that while this method requires choosing the number of partitions , we find that optimizing the Bayesian Information Criterion (BIC) over partitions of the manifold works well (Konishi2008).
Under the importance sampling loss function, we have altered the loss landscape such that mappings that match the geometry of the two manifolds are optima. Now that these mappings are local optima, the GAN framework could potentially succeed at matching manifold geometry. In the traditional framework, optimization would be guaranteed not to find these mappings, since they were not local optima. However, this loss alone does not explicitly enforce manifold alignment. Thus, we also introduce a global manifold alignment loss term for this purpose.
3.2 Manifold Geometry Loss
Preserving the manifold geometry requires preserving some notion of distance between points before and after transformation. However, the standard GAN loss function only looks at the data after transformation, and thus cannot enforce any relationship between points before and after. Thus, while the generated distribution will look like at the distribution level, the relationship between a pair might not match the relationship between the corresponding pair.
We address this by introducing a loss to explicitly preserve manifold geometry using the same manifold from the previous section. The manifold geometry loss is thus:
where is the representation of on the manifold and is a distance function, here chosen to be Euclidean distance on the manifold. We use a coefficient to control the emphasis placed on this term in relation to the importance sampling GAN loss, which we choose to be everywhere.
In this section, we start by demonstrating an example of generating from the geometry. Then, we experiment on mapping between domains on: simulated Gaussian mixture models, sampling from the canonical MNIST images, and mass cytometry on T cell development in the mouse thymus, which measure the abundance of various proteins in individual cells. We compare the performance of the MGM GAN to both a traditionalGAN and a cycle GAN
. To illustrate the importance of our specific technique for calculating the weights in importance sampling, we then also compare to our model, except using weights that are randomly generated from a uniform distribution (random weights) instead. Further implementation details are in the supplemental.
4.1 Simulated data
4.1.1 Geometry Generation
We first consider an experiment generating from the manifold geometry. The data in this experiment was simulated from a two-dimensional Gaussian mixture model consisting of three Gaussians sampled at different frequencies with a small number of points transitioning between them (Figure 3a). The traditional GAN penalty indeed teaches the generator to sample from the density, as can be seen in the kernel density plot in Figure 3b, which is dominated by the largest population and misses the transition points. The MGM GAN’s importance sampling upweights these points in the low-density region and downweights the points in the high-density region, allowing it to generate evenly over the geometry of the data (Figure 3c).
|F-score||Domain 1||Domain 2|
4.1.2 Unsupervised Domain Geometry Mapping
Next, we create two domains out of mixtures of three different Gaussians, but with one of the Gaussians having a minor shift needing alignment. This unsupervised domain mapping presents a significant challenge for traditional GANs, though, because the two domains sample each Gaussian at different frequencies: , , and respectively for the first domain, and , , and respectively for the second domain (Figure 4a).
The traditional GAN penalty prevents aligning the two domains such that the shared Gaussians are aligned together, since their densities along the manifold are different (Figure 4c-e). In contrast, the importance sampling weighting in the MGM GAN balances the densities, allowing the generator to converge to this alignment (Figure 4b). Furthermore, without the manifold geometry loss, points that are originally not part of the same Gaussian are mapped to the same Gaussian. With this loss, the MGM GAN preserves the relationships between points before and after mapping, keeping the two representations (one in each domain) of similar points similar and different points different.
We next consider an experiment on image data from the canonical MNIST dataset (lecun1998mnist). We form the first domain by taking a random sample of of each digit except for the digit zero, of which there are . For the second domain, we do the same with the digit one oversampled. Thus, even though the manifold for the two domains cover the same support, the density along the manifold is different in each domain. In fact, the ones in the first domain are exact elements in the second domain, so it would be desirable to align the two domains such that the class of the elements does not change.
As expected, the traditional GAN loss prevents these models from finding an alignment that preserves the digit identity across domains. Since the ones are oversampled in domain two, most of them get turned into other digits in the other domain (Figure 5c-e). The oversampling of the zero in the other domain also forces the GANs to create zeros out of other digits, to recreate their abundance in the target domain (last row in Figure 5c-e).
The MGM GAN importance sampling compensates for the differing densities, and allows the identity function to be a possible local optima for the GAN loss. The manifold geometry further encourages similar original images to be similar after mapping, and different images to be different after mapping. As Figure 5b shows, this allows the MGM GAN to preserve the identity of the digit through domain transfer.
To quantitatively assess the performance, we consider a slightly different version of the classical task of domain adaptation, which we term unsupervised domain adaptation
. In traditional domain adaptation, we have labels in one domain and wish to map points to another domain where we have no or few labels. The goal is to classify points in the target domain accurately. Unsupervised domain adaptation is harder because we do not presume to have labels in either domain. Instead, we wish to use unsupervised learning to align the data such that the class of a point is preserved by the mapping.
To evaluate performance at unsupervised domain adaptation, we use the ground truth labels in each original domain and a nearest neighbor classifier to assign labels to generated points in the target domain. We emphasize that these labels are used to score the models, but are not available to them during training. The number of each oversampled digit (zero for the first domain and one for the second domain) that gets mapped to each other digit is shown in Figure 6. In the top row, we see the other GANs have to change many of the oversampled zeros to other digits, since the second domain has fewer zeros. This is notable because these zeros are elements in the other domain, but are getting changed in the domain mapping anyway. The same happens for the oversampled ones in the other domain. The MGM GAN balances these different densities and consequently performs significantly better at our unsupervised domain adaptation task.
To measure the performance quantitatively, we use F-scores, which deal with the class imbalance by incorporating both precision and recall within each domain(van1979information). Scores are reported in Table 8, where we see as expected, the MGM significantly outperforms the traditional GANs in both domains with scores of and , respectively.
4.3 Biological Data
In this section, we highlight the importance of modeling the manifold geometry rather than the data density on a real dataset of biological measurements. The data consists of measurements of T cell development in mouse thymus from two individuals, downloaded from (setty2016wishbone). We would like to integrate measurements from both individuals together, so that further analysis can evaluate both of them together. However, there are multiple sources of variation between the samples that preclude naively combining them.
There are two categories of variation to consider: the first of which we want to correct (instrument error) and the second of which we want to be robust to (differential sampling). The first, instrument error, is inevitable when running complex machinery as in mass cytometry (shaham2017removal). Calibration, amount of reagent, and environmental conditions all can have an effect on measurements, so whenever two samples are compared, these differences need to be reconciled. Most often, this can seen in the existence of a part of the space with points from one sample but not the other (johnson2007adjusting). A desirable alignment of the two datasets would correct these differences in support.
The second, differential sampling, though, should not be corrected. While shifts in the support between samples are more likely to be instrument error, we fully expect cell types to be present at different frequencies in different samples. This expectation motivates our need to align manifolds without matching density along the manifolds. For example, we want Cell Type A in one domain to align to Cell Type A in the other domain. The traditional GAN loss would prevent this if Cell Type A is more abundant than Cell Type B in one domain, but the opposite is true in the other domain.
The two samples we wish to align consist of and cells, respectively, both having measurements of the abundance of different proteins. An embedding of each sample can be seen in Figure 7a. A difference in geometry between the samples can be seen by examining a particular cell type that is important in the study of mouse development (vosshenrich2006thymic). As a part of normal thymus development in mice, cells that are low in a protein called CD3 and high in a protein called CD8 express a protein called GATA3 (CD3-CD8+GATA3+ cells) (tai2013gata). In the second sample, these cells make up of all measured cells. In the first sample, they make up just of all cells. Moreover, their abundance of other proteins are completely different in the first sample. For example, of these cells are also low in a protein called BCL11b in the second sample, while there are none of the cells at all in the first sample. We would first expect the alignment of the two manifolds to better match this cell type across the samples.
In Figure 9, we compare the mean abundance of each protein for CD3-CD8+GATA3+ cells in the original first sample and the second sample. Then, we make this comparison between the transformed first sample and the original second sample. There we see that the transformation significantly improves the accuracy of the alignment, as we would desire, increasing the from to .
This first population confirms the MGM GAN’s ability to make changes to the geometry of the manifold in order to convincingly generate points in the opposite domain. However, these clear differences in the part of the data space covered by each of the two samples was corrected by all of the GANs, due to the presence of the adversary. The further challenge is to ensure that areas of the geometry that have no batch effect, but possibly have different densities, are not unnecessarily changed. An illustration of the importance of the MGM GAN in such cases is evident when looking at what the transformations did to second cell type population. In both samples, there exists a cell type high in the protein CD25 (CD25+ cells) (mousecdchart) with no difference in their expression between samples. We would like our alignment to preserve these cells after the transformation. However, they are present in different proportions in the two samples. This means the generator cannot learn an alignment with a one-to-one mapping of CD25+ cells between the samples, as the discriminator would be able to classify this part of the space as preferentially belonging to true samples from one domain or generated samples from the other domain.
|F-score||Domain 1||Domain 2|
The MGM GAN’s importance sampling balances the differential frequencies and allows a mapping that preserves the CD25+ cells to optimally fool the discriminator. Table 10 shows F-scores for CD25+ cells, and we can see the traditional GAN loss forces the other models to move cells around to match the densities along the manifold. As a result, these CD25+ cells are aligned with different cell types, introducing error into any later analysis that uses the aligned data.
In this work we have introduced a novel GAN framework for generating from the manifold geometry. We demonstrate it both in the context of a single GAN generating a single domain and in the context of unsupervised domain mapping. We contribute a re-casting of the traditional density formulation of domain mapping into one of manifold geometry alignment. We model the geometry with an importance sampling technique that weights points based on their density on the manifolds and a novel manifold geometry loss term. The ability to generate from the geometry of the manifold has widespread usage in biology, where sampling makes the density an unreliable represation of the data.
For the artificial dataset, the autoenecoder had three encoder layers and three decoder layers, with dimensions ofon all layers except the embedding and the output which had linear activation. The generator had the same structure as the autoencoder. The discriminator had three layers with dimension with leaky ReLU activation on all layers except the last layer, which had a sigmoid. For the MNIST dataset, convolutional layers were used both in the autoencoder and the generator. The layers had kernel size
, stride length
, and equal padding. The U-Net architecture(ronneberger2015u) of skip connections between the encoder and decoder was used with layers of size . The encoder layers had leaky ReLU activation and the decoder layers had ReLU activation. The last layer used the hyperbolic tangent activation. For the biological dataset, both the same autoencoder structure and generator structure was used as in the artificial dataset except with wider layers. Layer sizes were . The discriminator had layers of dimension . For all datasets, a learning rate of was used with minibatches of size . The coefficient for the cyle-consistency loss was and the identity loss was .