TzK: Flow-Based Conditional Generative Model

by   Micha Livne, et al.

We formulate a new class of conditional generative models based on probability flows. Trained with maximum likelihood, it provides efficient inference and sampling from class-conditionals or the joint distribution, and does not require a priori knowledge of the number of classes or the relationships between classes. This allows one to train generative models from multiple, heterogeneous datasets, while retaining strong prior models over subsets of the data (e.g., from a single dataset, class label, or attribute). In this paper, in addition to end-to-end learning, we show how one can learn a single model from multiple datasets with a relatively weak Glow architecture, and then extend it by conditioning on different knowledge types (e.g., a single dataset). This yields log likelihood comparable to state-of-the-art, compelling samples from conditional priors.



There are no comments yet.


page 5

page 6

page 7

page 8


Conditional Generative Models are not Robust

Class-conditional generative models are an increasingly popular approach...

Multimodal Controller for Generative Models

Class-conditional generative models are crucial tools for data generatio...

The Kanerva Machine: A Generative Distributed Memory

We present an end-to-end trained memory system that quickly adapts to ne...

Structured Output Learning with Conditional Generative Flows

Traditional structured prediction models try to learn the conditional li...

Multi-Resolution Continuous Normalizing Flows

Recent work has shown that Neural Ordinary Differential Equations (ODEs)...

Normalizing Flow based Hidden Markov Models for Classification of Speech Phones with Explainability

In pursuit of explainability, we develop generative models for sequentia...

TzK Flow - Conditional Generative Model

We introduce TzK (pronounced "task"), a conditional flow-based encoder/d...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The goal of representation learning is to learn structured, meaningful latent representations given large-scale unlabelled datasets. It is widely assumed that such unsupervised learning will support myriad downstream tasks, some of which may not be known

a priori (Bengio et al., 2013). To that end it is useful to be able to train on large amounts of hetergeneous data, but then use conditional priors that isolate specific sub-spaces or manifolds from the broader data distribution over the observation domain.

Building on probability flows (Dinh et al., 2014, 2016; Kingma & Dhariwal, 2018), this paper introduces a flexible form of conditional generative model. It is compositional in nature, without requiring a priori knowledge of the number of classes or the relationships between classes. Trained with maximum likelihood, the framework allows one to learn from heterogeneous datasets in an unsupervised fashion, with concurrent or subsequent specialization to sub-spaces or manifolds of the observation domain, e.g., conditioning on class labels or attributes. The resulting model thereby supports myriad downstream tasks, while providing efficient inference and sampling from the joint or conditional priors.

1.1 Background

There has been significant interest in learning generative models in recent years. Prominent models include variational auto-encoders (VAE), which maximize a variational lower bound on the data log likelihood (Rezende et al., 2014; Kingma & Welling, 2013; van den Berg et al., 2018; Papamakarios et al., 2017; Kingma et al., 2016), and generative adversarial networks (GAN), which use an adversarial discriminator to enforce a non-parametric data distribution on a parametric decoder or encoder (Goodfellow et al., 2014; Makhzani, 2018; Makhzani et al., 2015; Chen et al., 2016)

. Inference, however, remains challenging for VAEs and GANs as neither model includes a probability density estimator

(Schmah et al., 2009; Papamakarios et al., 2017; Dinh et al., 2016, 2014).

Auto-regressive models (Germain et al., 2015; Bengio & Bengio, 1999; Larochelle & Murray, 2011) and normalizing flows (Dinh et al., 2014, 2016; Rezende & Mohamed, 2015; Kingma & Dhariwal, 2018)

train with maximum likelihood (ML), avoiding approximations by choosing a tractable parameterization of probability density. Auto-regressive models assume a conditional factorization of the density function, yielding a tractable joint probability model. Normalizing flows represent the joint distribution with a series of invertible transformations of a known base distribution, but are somewhat problematic in terms of the memory and computational costs associated with large volumes of high-dimensional data (

e.g.images). While invertibility can be used to trade memory with compute requirements (Chen et al., 2018; Gomez et al., 2017), training powerful density estimators remains challenging.

The attraction of unsupervised learning stems from a desire to exploit vast amounts of data, especially when downstream tasks are either unknown a priori, or when one lacks ample task-specific training data. And while samples from models trained on heterogeneous data may not resemble one’s task domain per se, conditional models can be used to isolate manifolds or sub-spaces associated with particular classes or attributes. The TzK framework incorporates task-specific conditioning in a flexible manner. It supports end-to-end training of the full model. Or one to train a powerful density estimator once, retaining the ability to later extend it to new domains, or specialize it to sub-domains of interest. We get the advantages of large heterogenous datasets, while retaining fidelity of such specialized conditional models.

Existing conditional generative models allow one to sample from sub-domains of interest (e.g., (Makhzani, 2018; Chen et al., 2016; Dupont, 2018)), but they often require that the structure of the data and latent representation be known a priori and embedded in the network architecture. For example, (Chen et al., 2016; Makhzani, 2018)

allow unsupervised learning but assume the number of (disjoint) categories is given. In doing so they fix the structure of the latent representation to include a 1-hot vector over categories at the time of training. Such models are therefore re-trained from scratch if labels change, or if new labels are added, augmenting the training data.

Kingma & Dhariwal (2018) train a conditional prior post hoc, given an existing Glow model. This allows them to condition an existing model on semantic attributes, but lacks the corresponding inference mechanism. A complementary formulation, augmenting a generative model with a post hoc discriminator, is shown in (Oliver et al., 2018).

Inspired by (Kingma & Dhariwal, 2018; Oliver et al., 2018), TzK incorporates conditional models with discriminators and generators, all trained jointly. The proposed framework can be trained unsupervised on large volumes of data, yielding a generic representation of the observation domain (e.g.

, images), while explicitly supporting the semi-supervised learning of new classes in an online fashion. Such conditional models are formulated to be compositional, without a prior knowledge of all classes, and exploiting similarity among classes with a joint latent representation.

Finally, the formulation below exhibits an interesting connection between the use of mutual information (MI) and ML in representation learning. The use of MI is prevalent in learning latent representations (Belghazi et al., 2018; Chen et al., 2016; Dupont, 2018; Klys et al., 2018)

, as it provides a measure of the dependence between random variables. Unfortunately, MI is hard to compute; it is typically approximated or estimated with non-parametric approaches. A detailed analysis is presented in

(Belghazi et al., 2018), which offers scalability with data dimensionality and sample size. While it is intuitive and easy to justify the use of MI to enforce a relationship between random variables (e.g., dependency (Chen et al., 2016) or independence (Klys et al., 2018)), MI is often used as a regularizer to extend an existing model. The TzK formulation offers another perspective, showing how MI arises naturally with the ML objective, following the assumption that a target distribution can be factored into (equally plausible) encoder and decoder models. We exploit a lower bound that allows indirect optimization of MI, without estimating MI directly.

Contributions: We introduce a conditional generative model based on probability density normalizing flows, which is flexible and extendable. It does not require that the number of classes be known a priori, or that classes are mutually exclusive. One can train a powerful generative model on unlabeled samples from multiple datasets, and then adapt the structure of the latent representation as a function of specific types of knowledge in an online fashion. The proposed model allows high parallelism when training multiple tasks while maintaining a joint distribution over the latent representation and observations, all with efficient inference and sampling.

2 TzK Framework

We model a joint distribution over an observation domain (e.g., images) and latent codes (e.g., attributes or class labels). Let observation be a random variable associated through a probability flow with a latent state variable (Dinh et al., 2014, 2016; Rezende & Mohamed, 2015; Kingma & Dhariwal, 2018). In particular, is mapped to through a smooth invertible mapping , i.e., . As such, transforms a base distribution (e.g., Normal) to a distribution over the observation domain. Normalizing flows can be formulated as a joint distribution , but for notational simplicity we can omit or

from probability distributions by trivial marginalization of one or the other.

For conditional generative models within the TzK framework, the latent state is conditioned on a latent code (see Fig. 1b). As such, they capture distributions within the observation domain associated with subsets of the training data, or subsequent labelled data. To this end, let

be a hybrid discrete/continuous random variable

, where and , similar to (Chen et al., 2016; Dupont, 2018). We refer to as knowledge of type , while is the latent code of knowledge , a structured latent representation of . We call the existence of knowledge

, a binary variable that serves to indicate whether or not

can be generated by .

To handle multiple types of knowledge, let denote the set of latent codes associated with knowledge types. Importantly, we do not assume that knowledge types correspond to mutually exclusive class labels. Rather, we allow varying levels of interaction between knowledge classes under the TzK framework. This avoids the assumption of mutually exclusive classes and allows a TzK model to share a learned representation between similar classes, while still being able to represent distinct classes.

(a) encoder

(b) decoder
Figure 1: TzK framework models , a joint distribution over task domain and multiple latent codes with a dual encoder/decoder. The framework offers explicit representation of sub-domains of interest in by conditioning on the latent codes which comprise a single compositional model.

2.1 Formulation

Our goal is to learn a probability density estimator of the joint distribution . In terms of an encoder-decoder, for effective inference and sample generation, we model in terms of two factorizations, i.e.,


The encoder factorization in Eq. (1) makes explicit, which is used for inference of the latent code given . The decoder in Eq. (2) makes explicit for generation of samples of given a latent code .

As noted by (Kingma & Welling, 2013; Agakov & Barber, 2003; Rezende & Mohamed, 2015; Chen et al., 2016), inference with the general form of the posterior is challenging. Common approaches resort to variational approximations (Kingma & Welling, 2013; Rezende & Mohamed, 2015; Kingma et al., 2016). A common relaxation in the case of discrete latent codes is the assumption of independence (e.g., ). Alternatively, one can assume that such binary codes represent mutually exclusive classes, e.g., with a single categorical random variable. But this makes it difficult to model attributes, for which the presence or absence of one attribute may be independent of other attributes, or to allow for the fact that one image may belong to two different classes (e.g., it might be present in more than one database).

Here we design TzK to avoid the need for mutual exclusivity, or the need to specify the number of classes a priori, instead allowing the model to be extended with new classes, and to learn and exploit some degree of similarity between classes. To that end we assume that knowledge types exhibit statistical independence, expressed in terms of the following encoder factorization,


and the corresponding decoder factorization


It is by virtue of this particular factorization that a TzK model is easily extendable with different knowledge types (and conditional models) in an online fashion.

Taking the hybrid form of knowledge codes into account, as in Fig. 1, the model is further factored as follows:


Here, and act as discriminators for binary variable , conditioned on and respectively.

Finally, the factors of the encoder and decoder in (3) - (7

) are parametrized in terms of neural networks. Accordingly, denoting the parameters of the encoder and decoder by


, in what follows we write the parametrized model encoder and decoder as

and . (In what follows we use this more concise notation for the encoder and decoder, except where we need the explicit factorization in terms of , and .) Details of our implementation are described in Sec. 3.

2.2 Learning

We would like to train a parametric model of the joint distribution with the dual encoder/decoder factorization defined in Eqs. (3) - (7). Following the success of (Dinh et al., 2014, 2016; Kingma & Dhariwal, 2018) with high-dimensional distributions, we opt to estimate the model parameters using maximum likelihood.

We aim to learn a single probabilistic model of , comprising a consistent encoder-decoder, with the factorization given above, and a shared flow . To do so, we define a joint distribution with parameters . Expressed as a linear mixture, is randomly selected to be or with equal probability, i.e.,


Choosing the mixing coefficients to be equal reflects our assumption of a dual encoder/decoder parametrization of the same underlying joint distribution. There are other ways to combine and into a single model; we chose this particular formulation because it yields a very effective learning algorithm.

Learning maximum likelihood parameters entails maximizing with respect to ; equivalently,


where is the usual cross-entropy. Instead of optimizing the negative cross entropy directly, which can be numerically challenging, here we advocate the optimization of a lower bound on . Using Jensen’s inequality it is straightforward to show that , and as a consequence,


The lower bound turns out to be very useful because, among other things, it encourages consistency between the encoder and decoder. To see this, we examine the bound in greater detail. With some algebraic manipulation, ignoring expectation in Eq. (10), one can derive the following:


This implies that maximization of the lower bound (the expectation of the LHS) entails maximization of the expectation of the two terms on the RHS, the first of which is . The expectation of the second term on the RHS of Eq. (11) can be viewed as a regularizer that encourages the encoder and decoder to assign similar probability density to each datapoint. Importantly, it obtains its upper bound of zero when , in which case the inequality in Eq. (10) becomes equality. In practice, we find the bound is tight.

It is also interesting to note that itself is a lower bound on , since for any distributions and . If satisfies the factorization of the TzK model in Eqs. (3) - (7) then the entropy of the joint distribution can be expressed as


where and denote entropy of marginal distributions, and and denote mutual information, for which all expectations are with respect to . (The derivation of Eq. (12) is given in the supplemental material.) Eq. (12) suggests that maximizing the MI between observations and latent codes here follows from a design choice, for a model that can equally well "understand" (encode) and "express" (decode) an independent set of latent codes (as in Eqs. (3) and (4)), within a shared observation domain.

We claim that the assumption of independent latent codes is a relatively mild assumption, and has little affect on the ability of the model to represent for a random variable over the same domain as . A sufficiently expressive flow will allow for for arbitrary , and (Dinh et al., 2014). Effectively, we approximate the relationship between factors of by learning the relation between conditional distribution of independent factors over the same observation domain. Although such an approximation may not exist for priors, it is effective when dealing with conditional distributions. As we demonstrate below, TzK can learn meaningful representations of the joint knowledge .

3 Implementation

The TzK model comprises probability distributions defined in Eqs. (3) - (7). Each can be treated as a black box object with the functionality of a probability density estimator, returning given , and a sampler, returning given . The specific implementation choices outlined were made for the ease and efficiency of training.

We adopt a Glow-based architecture for probability density estimators, using reparametrization (Papaspiliopoulos et al., 2003; Williams, 1992) and back-propagation with Monte Carlo (Rezende et al., 2014) for efficient gradient-based optimization (e.g., see (Rezende & Mohamed, 2015)). Our flow architecture used fixed shuffle permutation rather than invertible convolution used in (Kingma & Dhariwal, 2018)

as we found it to suffer from accumulated numerical error. We implemented TzK in Pytorch, using

non-linearity (Ramachandran et al., 2018)

instead of ReLU as the activation function. We found that the ReLU-based implementation converged more slowly because of the truncation of gradients for negative values.

We implemented separated and for with regressors from and to parameters of distributions over and . In practice, we regress to the mean and diagonal covariance of a multi-dimensional Gaussian density. We implemented and , discriminators for binary variable conditioned on and respectively, with regressors from and followed by sigmoid to normalize the output value to be in . We refer to the prior flow as the -flow, and the flows in each conditional prior as a -flow.

All experiments were executed on a single NVIDIA TITAN Xp GPU with 12GB, and optimized with Pytorch ADAM optimizer (Kingma & Lei Ba, 2014), with default parameters and , a warm up scheduler (Vaswani et al., 2017) , and mini-batch size of 10. Further details are included in the supplemental material.

4 Experiments

To demonstrate the versatility of TzK we train on up to six image datasets (Table 1), in unsupervised and semi-supervised settings. All images were resized to

as needed, MNIST images were centered and padded to

. When using grayscale (GS) images in an RGB setting, the GS channel was duplicated in R, G, and B channels.

In all experiments below the images, , and class labels, , for different tasks are given. The latent codes, , are not. In this semi-supervised context we sample the missing according to the model, . Specifically, at every mini-batch, we randomly choose or with equal probability. When is chosen we sample from , and for we return the marginal over with respect to the observed binary variable .

We chose CIFAR10 and MNIST as targets for conditional model learning. Each comprises just 3.2% of the entire multi-data training set of 1,892,916 images. Table 2 gives performance benchmarks in terms of negative log-likelihood in bits per dimension (NLL) for existing flow-based models.

# Images
train / val
% Classes
CIFAR10 RGB 50,000 / 10,000 3.2 10
MNIST GS 60,000 / 10,000 3.2 10
Omniglot † RGB 19,280 / 13,180 1.7 NA
SVHN † RGB 73,257 / 26,032 5.3 10
ImageNet Varying RGB 1,281,167 / 150,000 75.8 1000
Celeba † RGB 200,000 / NA 10.8 NA
Table 1: Datasets marked with †  were used in unsupervised settings only. GS denotes grayscale images. The multi-data training set consists of all six datasets, namely, CIFAR10 (Krizhevsky et al., 2009), MNIST (LeCun et al., 1998), Omniglot (Lake et al., 2015), SVHN (Netzer et al., 2011), ImageNet (Russakovsky et al., 2015), Celeba (Liu et al., 2015). There are 1,892,916 images in total. % gives each dataset’s fraction of the entire multi-data training set.
Glow FFJORD RealNVP TzK Prior TzK Cond.
CIFAR10 3.35 3.4 3.49 3.54 2.99 *
MNIST 1.05 †† 0.99 1.06 †† 1.11 1.02 * †
Table 2: Comparison of negative log-likelihood bits per dimension (NLL) on test data (lower is better). *Results of dataset conditional prior. †Model was trained on all 6 datasets (Table 1). We compare to Glow (Kingma & Dhariwal, 2018), FFJORD (Grathwohl et al., 2018), and RealNVP (Dinh et al., 2016). Results marked with  ††  are taken from (Grathwohl et al., 2018).

All learning occurred in an online fashion, adding new conditional knowledge types as needed. When training begins, we start with a model with no knowledge, i.e., , which is just a Glow probability density estimator. As data are sampled for learning, new knowledge types are added only when observed, in a semi-supervised manner, i.e., the class label is given, the latent code is not. In most of the experiments below the only class label used is the identity of the dataset from which the image was drawn.

Figure 2: Random training samples from CIFAR10 and MNIST.
(a) CIFAR10 only
(b) MNIST only
Figure 3: Random samples from two baseline models, each trained with a single dataset (CIFAR10 and MNIST). The NLL for the CIFAR10 model is 3.54. The NLL for the MNIST model is 1.11.
(a) multi-data
(b) MNIST+Omniglot
Figure 4: Samples from a model trained on all six datasets (4), and from one trained on MNIST+Omniglot (4). Sample quality is similar to models trained solvely on CIFAR10 and MNIST (Fig. 3), despite slightly higher NLL (3.6 for multi-data model and 1.28 for MNIST+Omniglot model). Samples are more diverse, however, reflecting the greater heterogeneity of the training data.

4.1 Baselines

Two baseline models are trained on CIFAR10 and MNIST, training samples for which are shown in Fig. 2. Each used a Glow architecture for the -flow, with 512 channels, 32 steps, and 3 layers. (See (Kingma & Dhariwal, 2018) for more details.) These models give test NLL values of 3.54 and 1.11, comparable to the state-of-the-art with flow-based models. Differences between our NLL numbers and those reported for Glow by others in Table 2 are presumably due to implementation and optimization details. Fig. 3 shows random samples from the two models, the quality of which compare well with training samples (Fig. 2).

When we train the same architecture on all 6 datasets (i.e., multi-data), we obtain NLL of 3.6 when testing on CIFAR10. Random samples from this model are shown in Fig. 4. One can clearly see the greater diversity of the training data, with images resembling faces and grayscale characters for example. When the same architecture is trained on the union of MNIST and Omniglot, and tested on MNIST, the NLL is 1.28. Random samples of this model (Fig. 4) again show greater diversity. Although the NLL numbers with these models, both learned from larger training sets, are slightly worse, the image quality remains similar to models trained on a single dataset (Fig. 3).

(a) CIFAR10
(b) multi-data
Figure 5: Given a model trained solely on CIFAR10, 5

depicts interpolation in

between random samples from CIFAR10, MNIST, and SVHN. Interpolation reveals regions of that correspond to relatively poor quality images. This occurs even when the interpolated images are visually similar, and reflects relatively sparse coverage of the high-dimensional image space. Given a model trained on all six datasets (multi-data), the interpolation results in 5 are much better than those above in 5. With more training data we obtain a denser model with visually better interpolation.

4.2 Interpolation - Visualizing Flow Expressiveness

Insight into the nature of the generative model can be gleaned from latent space interpolation. Here, given four images (observations ), we obtain latent space coordinates, . We then linearly interpolate in before mapping back to for visualization. In a flow-based generative model with a Gaussian prior on , we expect interpolated points to have probability density as high or higher than the end points, and at least as high as one of the two endpoints.

Despite Glow being a powerful model, the results in Fig. 5 reveal deficiencies. Training on CIFAR10 data produces a model that yields interpolated images that are not always characteristic of CIFAR10 (e.g., the darkened images in Fig. 5 ). Even with color images (i.e., SVHN), which are expected to be represented reasonably well by a CIFAR10 model, there are regions of low quality interpolants.

One would suspect that a model trained on the entire multi-data training set, rather than just CIFAR10, would yield a better probability flow, exhibiting denser coverage of image space. Consistent with this, Fig. 5 shows superior interpolation in .

4.3 Specializing a -Flow

In this section we further explore the benefits of unsupervised training over large heterogeneous datasets and the use of TzK for learning conditional models in an online manner. To that end, we assume a -flow has been learned and then remains fixed while we learn one or more conditional models, as one might with unknown downstream tasks. The Glow-like architecture used for the -flow (i.e., for ) had 512 channels, 20 steps, and 3 layers, a weaker model than those in (Kingma & Dhariwal, 2018) and the baseline models above with 3 layers of 32 steps. The architecture used for the -flow, for each of the conditional models (i.e., for ), had one layer with just 4 steps.

In the first experiment the -flow is trained solely on CIFAR10 data, entirely unsupervised. The -flow was then frozen, and conditional models were learned, one for CIFAR10 and one for MNIST. Doing so exploits just one bit of supervisory information, namely, whether each training image originated from CIFAR10 or MNIST. Although this is a relatively weak form of supervision, the benefits are significant. The MNIST images serve as negative samples for the conditional CIFAR10 model, and vice versa. This allows the discriminators of the respective conditional models to learn tight conditional distributions.

Indeed, the resulting CIFAR10 conditional model exhibits a significant performance gain, with a NLL of 2.99 when evaluated on the CIFAR10 test set, at or better than state-of-the-art for CIFAR10, and a great improvement over the baseline -flow (with 20 steps per layer), the NLL for which was 3.71 on the same test set. Fig. 6 shows random samples from the conditional model.

(a) CIFAR10 conditional
(b) MNIST conditional
Figure 6: The ability of TzK to learn tight conditional priors is demonstrated here by freezing a -flow trained on CIFAR10 only, then learning conditional priors using CIFAR10 and MNIST. Random samples from the CIFAR10 conditional are shown in 6. When tested on CIFAR10, the NLL for this model is just 2.99. Random samples from the MNIST conditional, in 6, are surprisingly good given that MNIST data was not used to learn the flow. The NLL for the MNIST conditional, tested on MNIST, is 1.33.
(a) CIFAR10 conditional
(b) MNIST conditional
Figure 7: TzK offers a powerful framework to specialize a generative flow model trained in an unsupervised fashion on a large heterogeneous dataset. By learning tight conditional priors, these models are comparable to those trained end-to-end on a single dataset. Here, we train two conditional priors concurrently. Although trained concurrently, samples share the same latent representation . The NLL for CIFAR10 (7) is 3.1. The NLL for MNIST (7) is 1.02.

Just as surprising is the performance of the MNIST conditional, even though the CIFAR10 data on which the -flow was trained did not contain images resembling the grayscale data of MNIST. Despite this, the conditional model was able to isolate regions of the latent space representing predominantly grayscale MNIST-like images, random samples of which are shown in Fig. 6. When evaluated on MNIST data, the conditional model produced a NLL 1.33. While these results are impressive, one would not expect a flow trained on CIFAR10 to provide a good latent representation for many different image domains, like MNIST.

In the next experiment we train a much richer -flow from the entire multi-data training set of 1,892,916 images, again unsupervised. Once frozen, we again learn conditional models for CIFAR10 and MNIST. Despite MNIST and Omniglot representing a small fraction of the training set, the MNIST conditional model exhibits state-of-the-art performance, with NLL of 1.02 on the MNIST test set. Random samples of the model are shown in Fig. 7. Similarly, the CIFAR10 conditional model exhibits state-of-the-art performance, with NLL 3.1. While slightly worse than the model trained from CIFAR10, it is still much better than our benchmark -flow, with 3 layers of 32 steps, and NLL of 3.54. Random samples from this CIFAR10 conditional model are shown in Fig. 7.

In terms of cost, the time required to train the conditional models is roughly half the time needed to train our baseline -flow model (or equivalently Glow). Freezing the -flow allows for asynchronous optimization of all conditional priors, resulting in significant gains in training time, while still maintaining a model of the joint probability. That is, conditional models can be trained in parallel so the training does not scale with the number of knowledge types. End-to-end training also benefits from this parallelism, but does require synchronization for the shared -flow. Finally, training a weaker -flow with 20 steps and 3 layers is marginally faster than training a more expressive flow with 32 steps.

4.4 End-to-End Hierarchical training

We next consider a hierarchical extension to TzK for learning larger models. Suppose, for example, one wanted a TzK model with 10 conditional priors, one for each MNIST digit. Conditioning on 10 classes in TzK would require 10 independent discriminators, and 40 independent regressors (2 for , 2 for , per knowledge type ). This does not scale well to large numbers of conditional priors.

As an alternative, one can compose TzK hierarchically. For example, the first TzK model could learn a conditional prior for MNIST images in general, while the second model provides further specialization to digit-specific priors. In particular, as depicted in Fig. 8, the second TzK model takes as input observations the latent codes from the MNIST conditional model, and then learns a second TzK model comprising a new latent space, on which the 10 digit-specific priors are learned. The key advantage of this hierarchical TzK model is that the latent code space for the generic MNIST prior in the first TzK model is low-dimensional, so training the second TzK model with 10 conditional priors is much more efficient in terms of both training time and memory.

(a) encoder

(b) decoder
Figure 8: The modular nature of TzK allows to build hierarchical model, where of one TzK model is serving as of a domain-specific TzK model. The joint objective is simple summation of the multiple objectives. This results in a model that supports fine grain control through sub-division of likelihood manifolds.
(a) MNIST "0"
(b) MNIST "1"
Figure 9: A -flow was first learned with all datasets, and then frozen, after which an MNIST conditional over was learned (the task) jointly with MNIST digit conditional over (the task domain). This procedure allows efficient joint training for 10 conditional priors over a low-dimensional , as depicted in Fig. 8. The NLL for the MNIST conditional prior (over ) is 1.17. The NLL for the digit class conditional prior (over ) is 1.06.

To implement this idea, the -flow in the first TzK model comprised 3 layers with 10 steps, and 512 channels, a weaker flow than those used in previous experiments with 20 or 32 steps. The -flow for the MNIST conditional model consisted of 512 channels and 1 layer of 4 steps, with a 10D latent code space for . The second stage TzK model then maps the 10D vectors to a latent space with a probability flow comprising 1 layer of 4 steps with 64 channels, on which 10 conditional priors are learned, each with a -flow comprising one more layer of 4 steps and 64 channels, with 2D latent codes . Fig. 8 depicts the model.

Training this two-stage TzK model entails optimization of the sum of two TzK losses, defined in Eq. (10). The first -flow was learned from the multi-data traing set, and then fixed, as this flow is somewhat expensive to train. Everything else (i.e., the -flow in the first TzK model, and all components of the second TzK model) was trained jointly end-to-end. The trained model had NLL 1.17 for the MNIST prior (first TzK model), and 1.06 for digit-specific priors; class-conditional samples are shown in Fig. 9.

This experiment also provides empirical evidence of consistency of the encoder and decoder within a TzK model. Quantitatively, the digit-conditional discriminative model had 0.87 classification accuracy over . Although far from state-of-the-art in classification accuracy, we note that we allowed more than one class to high probability, rather than choosing a single category with highest probability. Instead, the model learned a joint distribution over 10 independent classes. Qualitatively, the consistency can be observed in the samples in Fig. 9, which are strongly correlated with the classification accuracy. In other words, samples are roughly aligned with 0.87 classification accuracy. In effect, TzK dual structure results in consistency between the generative and discriminative components. More importantly, this allows for multiple evaluation criteria of a TzK generative model (i.e., per conditional prior), in addition to NLL, which is not necessarily a meaningful quantity to measure the performance of a generative model.

We perform one additional experiment with CIFAR10 conditional. This time the model predicted a binary representation of the label (i.e., label 3 = ) with 4 bits. The trained model had NLL of 3.64 over the CIFAR10 conditional, and 0.74 classification of the domain task model. The accuracy was independent per bit acknowledging similarity between classes, as a result of the arbitrary division of classes according the label binary representation. This last experiment again demonstrates the ability for TzK to learn compositional structure that represents a joint distribution.

5 Conclusions

This paper introduces a versatile conditional generative model based on probability flows. It supports compositionality without a priori knowledge of the number of classes or the relationships between classes. Trained with maximum likelihood, it provides efficient inference and sampling from class-conditionals or the joint distribution. This allows one to train generative models from multiple heterogeneous datasets, while retaining strong prior models over subsets of the data (e.g., from a single dataset, class label, or attribute).

The resulting model is efficient to train, either end-to-end, in two phases (unsupervised flow followed by conditional models), or hierarchically. In addition, TzK offers an alternative motivation for the use of MI in ML models, as a natural term that arises given the assumption that the joint distributions over observation and multiple latent codes has two equally plausibly factorization of encoder and decoder. Our experiments focus on models learned from six different image datasets, with a relatively weak Glow architecture, conditioning on various types of knowledge, including the identity of the source dataset, or class labels. This yields log likelihood comparable to state-of-the-art, with compelling samples from conditional priors.


We thank Ethan Fetaya, James Lucas, Alireza Makhzani, Leonid Sigal, and Kevin Swersky for helpful comments on this work. We also thank the Canadian Institute for Advanced Research and NSERC Canada for financial support.


  • Agakov & Barber (2003) Agakov, F. and Barber, D. The IM algorithm: a variational approach to information maximization. In NIPS, pp. 201–208, 2003.
  • Belghazi et al. (2018) Belghazi, I., Rajeswar, S., Baratin, A., Hjelm, R. D., and Courville, A. MINE: Mutual information neural estimation. In ICML, 2018.
  • Bengio & Bengio (1999) Bengio, Y. and Bengio, S. modeling high dimensional discrete data with multi-layer neural networks. In NIPS, pp. 400–406, 1999.
  • Bengio et al. (2013) Bengio, Y., Courville, A., and Vincent, P. Representation learning: A review and new perspectives. IEEE TPAMI, 35(8):1798–1828, 2013.
  • Chen et al. (2018) Chen, T. Q., Rubanova, Y., Bettencourt, J., and Duvenaud, D.

    Neural ordinary differential equations.

    In NIPS, 2018.
  • Chen et al. (2016) Chen, X., Duan, Y., Houthooft, R., Schulman, J., Sutskever, I., and Abbeel, P. InfoGAN: Interpretable representation learning by information maximizing generative adversarial nets. In NIPS, 2016.
  • Dinh et al. (2014) Dinh, L., Krueger, D., and Bengio, Y. NICE: Non-linear independent components estimation. arXiv:1410.8516, 2014.
  • Dinh et al. (2016) Dinh, L., Sohl-Dickstein, J., and Bengio, S. Density estimation using Real NVP. arXiv:1605.08803, 2016.
  • Dupont (2018) Dupont, E. Learning disentangled joint continuous and discrete representations. arXiv:1804.00104, 2018.
  • Germain et al. (2015) Germain, M., Gregor, K., Murray, I., and Larochelle, H.

    MADE: Masked autoencoder for distribution estimation.

    In ICML, 2015.
  • Gomez et al. (2017) Gomez, A. N., Ren, M., Urtasun, R., and Grosse, R. B.

    The Reversible Residual Network: Backpropagation Without Storing Activations.

    In NIPS, 2017.
  • Goodfellow et al. (2014) Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., and Bengio, Y. Generative adversarial nets. In NIPS, pp. 2672–2680, 2014.
  • Grathwohl et al. (2018) Grathwohl, W., Chen, R. T. Q., Bettencourt, J., Sutskever, I., and Duvenaud, D. K. FFJORD: free-form continuous dynamics for scalable reversible generative models. CoRR, abs/1810.01367, 2018.
  • Kingma & Dhariwal (2018) Kingma, D. P. and Dhariwal, P. Glow: Generative Flow with Invertible 1x1 Convolutions. In NIPS, 2018.
  • Kingma & Lei Ba (2014) Kingma, D. P. and Lei Ba, J. ADAM: A method for stochastic optimization. In ICLR, 2014.
  • Kingma & Welling (2013) Kingma, D. P. and Welling, M. Auto-Encoding Variational Bayes. In ICLR, 2013.
  • Kingma et al. (2016) Kingma, D. P., Salimans, T., Jozefowicz, R., Chen, X., Sutskever, I., and Welling, M. Improving variational inference with inverse autoregressive flow. In NIPS, 2016.
  • Klys et al. (2018) Klys, J., Snell, J., and Zemel, R. Learning latent subspaces in variational autoencoders. In NIPS, 2018.
  • Krizhevsky et al. (2009) Krizhevsky, A., Nair, V., and Hinton, G. CIFAR10. Technical report, University of Toronto, 2009. URL
  • Lake et al. (2015) Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. Human-level concept learning through probabilistic program induction. Science, 350(6266):1332–8, 2015.
  • Larochelle & Murray (2011) Larochelle, H. and Murray, I. The Neural Autoregressive Distribution Estimator. In AISTATS, pp. 29–37, 2011.
  • LeCun et al. (1998) LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. Gradient-based learning applied to document recognition. Proc. IEEE, 86(11):2278–2324, 1998.
  • Liu et al. (2015) Liu, Z., Luo, P., Wang, X., and Tang, X. Deep learning face attributes in the wild. In ICCV, pp. 3730–3738, 2015.
  • Makhzani (2018) Makhzani, A. Implicit autoencoders. CoRR, abs/1805.09804, 2018.
  • Makhzani et al. (2015) Makhzani, A., Shlens, J., Jaitly, N., Goodfellow, I., and Frey, B. Adversarial autoencoders. In ICLR Workshop, 2015.
  • Netzer et al. (2011) Netzer, Y., Wang, T., Coates, A., Bissacco, A., Wu, B., and Ng, A. Reading digits in natural images with unsupervised feature learning. In NIPS Workshop on Deep Learning and Unsupervised Feature Learning, 2011.
  • Oliver et al. (2018) Oliver, A., Odena, A., Raffel, C., Cubuk, E. D., and Goodfellow, I. J. Realistic evaluation of deep semi-supervised learning algorithms. arXiv:1804.09170, 2018.
  • Papamakarios et al. (2017) Papamakarios, G., Pavlakou, T., and Murray, I. Masked autoregressive flow for density estimation. In NIPS, pp. 2335–2344, 2017.
  • Papaspiliopoulos et al. (2003) Papaspiliopoulos, O., Roberts, G. O., and Skold, M. Non-centered parameterisations for hierarchical models and data augmentation. In Bayesian Statistics, pp. 307–326. Oxford University Press, 2003.
  • Ramachandran et al. (2018) Ramachandran, P., Zoph, B., and Le, Q. V. Searching for activation functions. In ICLR, 2018.
  • Rezende & Mohamed (2015) Rezende, D. J. and Mohamed, S. Variational inference with normalizing flows. In ICML, 2015.
  • Rezende et al. (2014) Rezende, D. J., Mohamed, S., and Wierstra, D. Stochastic backpropagation and approximate inference in deep generative Models. In ICML, 2014.
  • Russakovsky et al. (2015) Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., Huang, Z., Karpathy, A., Khosla, A., Bernstein, M., Berg, A. C., and Fei-Fei, L. ImageNet large scale visual recognition challenge.

    Int. J. Computer Vision

    , 115(3):211–252, 2015.
  • Schmah et al. (2009) Schmah, T., Hinton, G. E., Small, S. L., Strother, S., and Zemel, R. S. Generative versus discriminative training of RBMs for classification of fMRI images. In NIPS, pp. 1409–1416, 2009.
  • van den Berg et al. (2018) van den Berg, R., Hasenclever, L., Tomczak, J. M., and Welling, M. Sylvester normalizing flows for variational inference. arXiv:1803.05649, 2018.
  • Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., and Polosukhin, I. Attention is all you need. In NIPS, pp. 5998–6008, 2017.
  • Williams (1992) Williams, R. J.

    Simple statistical gradient-following algorithms for connectionist reinforcement learning.

    Machine Learning, 8(3-4):229–256, 1992.

Appendix A Formulation Details

a.1 Encoder and Decoder Consistency

Here we discuss in greater detail how the learning algorithm encourages consistency between the encoder and decoder of a TzK model, beyond the fact that they are fit to the same data, and have a consistent factorization (see Eqs. (3) - (7)). To this end we expand on several properties of the model and the optimization procedure.

One important property of the optimization follows directly from the difference between the gradient of the lower bound in Eq. (10) and the gradient of the cross-entropy objective. By moving gradient operator into the expectation using reparametrization, one can express the gradient of lower bound in terms of the gradient of the and the regularization term in Eq. (11). That is, with some manipulation one obtains


Consistent with the regularization term in Eq. (11), this shows that for any data point where a gap exists, the gradient applied to grows with the gap, while placing correspondingly less weight on the gradient applied to . The opposite is true when . In both case this behaviour encourages consistency between the encoder and decoder. Empirically, we find that the encoder and decoder become reasonably consistent very early in the optimization process.

Numerical Stability:

Instead of using a lower bound, one might consider direct optimization of Eq. (9). To do so, one must convert and to and . Unfortunately, this is likely to produce numerical errors, especially with 32-bit floating-point precision on GPUs. While various tricks can reduce numerical instability, we find that using the lower bound eliminates the problem while providing the additional benefits outlined above.


As mentioned earlier, there may be several ways to combine the encoder and decoder into a single probability model. One possibility we considered, as an alternative to in Eq. (8), is


where is the partition function. One could then define the objective to be the cross-entropy as above with a regularizer to encourage to be close to 1, and hence consistency between the encoder and decoder. This, however, requires a good approximation to the partition function. Our choice of

avoids the need for a good value approximation by using reparameterization, which results in unbiased low-variance gradient, independent of the accuracy of the approximation of the loss value.

a.2 TzK Entropy and Mutual Information

This section provides some context and a derivation for Eq. (12).

As discussed in Sec. 2.2, probability density normalizing flows allow for efficient learning of arbitrary distributions by learning a mapping from independent components to a joint target distribution (Dinh et al., 2014). It follows that for a sufficiently expressive -flow the TzK model factorization assumed in (3) - (7) does not pose a fundamental limitation when learning joint, conditional distributions. In other words, it is likely that there exist distribution flows with which can be factored according to the encoder and decoder factorizations in TzK .

To that end, we can assume that with a sufficiently expressive model one can assume that the dual encoder/decoder model can fit to the true data distribution reasonably well. In the ideal case, where the encoder and decoder are consistent and equal to the underlying data distribution, i.e. , we obtain the following result, which relates the entropy of the data distribution to the mutual information between the data distribution and the latent space representation:


Eq. (15) illustrates an interesting connection between ML and MI, assuming TzK to be the true underlying model. One can interpret ML learning of as a lower bound for the sum of the negative entropy of observations , the negative entropy of latent codes , and the MI between the observations and the latent codes , and between the latent state and the latent codes . This formulation arises naturally from the TzK representation of the data distribution, as opposed to several existing models that use MI as a regularizer (Belghazi et al., 2018; Chen et al., 2016; Dupont, 2018; Klys et al., 2018).

An important property of the TzK formulation is the lack of variational approximations where an auxiliary distribution is used to approximate . As a consequence, it is hoped that a more expressive will be better able to approximate , leading to a tighter lower bound; since, compared to variational inference (VI), a more expressive does not necessarily guarantees a tighter lower bound as it is restricted to tractable families. In addition, since is unknown, VI does not offer a method to measure that gap.

Appendix B Architecture Details

The components of a TzK model, i.e., the factors in Eqs. (3) - (7), have been implemented in terms of parameterized deep networks. In somewhat more detail, the prior over was implemented with:

  • is our Glow-based implementation (Kingma & Dhariwal, 2018). Flow details are included with each experiment.

  • is parameterized in terms of a probability normalizing flow

    from a multivariate standard normal distribution


All density probability flow had 3 layers (multi-scale) and steps defined in each experiment, with 512 channels for regressors in affine coupling transforms (Kingma & Dhariwal, 2018).

The priors over latent codes were implemented with:

  • , for , is parameterized in terms of a probability normalizing flow from a multivariate standard normal distribution.

For , we use a Glow architecture with 1 layer, 4 steps, and channels, where , unless specified otherwise in an experiment. was shaped to have all dimensions in a single channel .

The discriminators associated with different knowledge types , conditioned on observation or a latent code , were implemented with:

All discriminators from and had 3 layers of convolution with channels (unless specified otherwise in experiment details), followed by linear mapping to the target dimensionality of 1, and a sigmoid mapping to normalize the output to be .

The conditional priors over and are modeled with regressors from the corresponding and to the distribution parameters (e.g., mean and variance for Gaussian), similar to VAE. More explicitly, we implemented the priors with regressors to the mean and diagonal covariance matrix of a Gaussian, as described below:

  • is composed of a regressor to the mean and diagonal covariance of a Gaussian base distribution, and , an invertible function that serves as a probability flow. The two components comprise a single parametric representation of a generic probability distribution. We condition the density on by learning two separate sets of weights, i.e., for . The flow uses the same Glow architecture as , for which the details are given below. All had 4 flow steps, with dimensionality of , unless specified otherwise.

  • is composed of a regressor to the mean and diagonal covariance of a Gaussian base distribution, and . We condition the probability density on by learning two separated set of weights for .

All Gaussian regressors were implemented with a linear mapping , 3 layers of convolution layers with 512 channels, and final layer with 192 channels, resulting in . All Gaussian regressors were implemented with convolutional layer with 80 channel, followed by 3 layers of alternating squeeze (Kingma & Dhariwal, 2018) and convolutional layers with 80 channels, followed by linear layer to . All regressors had ActNorm to initialize inputs to be mean-zero with unit variance, and the weights of the last layer were initialized to 0, as in (Kingma & Dhariwal, 2018).

Appendix C Model Evaluation

Here we explain in detail how we evaluate the model for the experiments. To that end, we consider the evaluation of the negative log likelihood of a data sample under the model, and the process for drawing a random sample from the model.

Given a set of test samples, the NLL is defined as the average negative log likelihood of the individual samples (i.e., assuming IID samples). Evaluating the NLL for is straightforward in terms of the flow and the latent Gaussian prior .

To evaluate the NLL for a conditional distribution, given , we first draw a random sample . We then use that sample to build . with which We evaluate the log probability of .

We next explain the procedure for sampling from a class conditional hierarchical model. To sample from a the digit "1" over a domain we first sample from , followed by , and finally .