Importance Weighted Adversarial Variational Autoencoders for Spike Inference from Calcium Imaging Data

06/07/2019 ∙ by Daniel Jiwoong Im, et al. ∙ Howard Hughes Medical Institute 0

The Importance Weighted Auto Encoder (IWAE) objective has been shown to improve the training of generative models over the standard Variational Auto Encoder (VAE) objective. Here, we derive importance weighted extensions to AVB and AAE. These latent variable models use implicitly defined inference networks whose approximate posterior density q_ϕ(z|x) cannot be directly evaluated, an essential ingredient for importance weighting. We show improved training and inference in latent variable models with our adversarially trained importance weighting method, and derive new theoretical connections between adversarial generative model training criteria and marginal likelihood based methods. We apply these methods to the important problem of inferring spiking neural activity from calcium imaging data, a challenging posterior inference problem in neuroscience, and show that posterior samples from the adversarial methods outperform factorized posteriors used in VAEs.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 7

page 19

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The variational autoencoder (VAE) Kingma2014vae ; Rezende2014 has been used to train deep latent variable based generative models which model a distribution over observations by latent variables such that

using a deep neural network

which transforms samples from into samples from . This model trains the latent variable based generative model using approximate posterior samples from a simultaneously trained recognition network or inference network to maximize the evidence lower bound (ELBO).

There are two ways to improve the quality of the learned deep generative model. The multi-sample objective used by the importance weighted autoencoder (IWAE) Burda2015 has been used to derive a tighter lower bound to the model evidence , leading to superior generative models. Optimizing this objective corresponds to implicitly reweighting the samples from the approximate posterior. A second way to improve the quality of the generative model is to explicitly improve the approximate posterior samples generated by the recognition network.

In the VAE framework, the recognition network is restricted to approximate posterior distributions under which the log probability of a sample and its derivatives can be evaluated in close form. The adversarial autoencoder (AAE)

Makhzani2016 , and the adversarial variational Bayes (AVB) Mescheder2017 show how this constraint can be relaxed, leading to more flexible posterior distributions which are implicitly represented by the recognition network. In this paper, we derive importance weighted adversarial autoencoders of IW-AVB and IW-AAE, thus combining both adversarial and importance weighting techniques for improving probabilistic modeling.

Spike inference is an important Bayesian inference problem in neuroscience

berens2018community

. Calcium imaging methods enable the indirect measurement of neural activity of large populations of neurons in the living brain in a minimally invasive manner. The intracellular calcium concentration measured by fluorescence microscopy of a genetically encoded calcium sensor such as GCaMP6

chen2013ultrasensitive is an indirect measure of the spiking activity of the neuron. VAEs have previously been used speiser2017fast ; Aitchison:2017wz to perform Bayesian inference of spiking activity by training inference networks to invert the known biophysically described generative process which converts unobserved spikes into observed fluorescence time series.

The accuracy of a VAE-based spike inference method depends strongly on the quality of the posterior approximation used by the inference network. The posterior distribution over the binary latent spike train given the fluorescence time series has previously been approximated speiser2017fast

using either a factorized Bernoulli distribution (VIMCO-FAC) where

, or as an autoregressive Bernoulli distribution (VIMCO-CORR). As we show, the correlated autoregressive posterior is more accurate, but slow to sample from. In contrast, the factorized posterior allows for fast parallel sampling, especially on a GPU, but ignores correlations in the posterior. Fast inference networks which sample from correlated posteriors over discrete binary spike trains would be a significant advance for VAE-based spike inference.

Fast correlated distributions over time series can be constructed using normalizing flows for continuous random variables

Rezende2015

, but this is considerably harder for discrete random variables

Aitchison:2018vq . Thus an adversarial approach where an inference network which transforms noise samples into samples from the posterior can be trained without the need to evaluate the posterior likelihood is particularly appealing for modeling correlated distributions over discrete random variables. Here, we show that our adversarially trained inference networks produce correlated samples which outperform the factorized posterior trained in the conventional way as in speiser2017fast .

In addition to these practical advances, we derive theoretical results connecting the objective functions optimized by the importance weighted variants of the AVB, AAE, and VAE. The relationship between the AAE objective and data log likelihood is not fully understood. The AAE has been shown to be a special case of the Wasserstein autoencoder (WAE) under certain restricted conditions Bousquet2017 . However, we also do not understand the tradeoffs between the standard log-likelihood and penalized optimal transport objectives, and thus further theoretical insight is necessary to fully understand the tradeoffs between the VAE and AAE.

The main contributions of the paper are following:

  1. We propose IW-AVB and IW-AAE that yield tighter lower bounds on log-likelihood compared to AVB and AAE, and the global solution for maximizes likelihood.

  2. We provide theoretical insights into the importance weighted adversarial objective functions. In particular, we relate AAE and IW-AAE objectives to log-likelihoods and Wasserstein autoencoder objectives.

  3. We develop standard and importance weighted adversarial neural spike inference for calcium imaging data, and show that adversarially trained inference networks outperform existing VAEs using factorized posteriors.

2 Background

The maximum likelihood estimation of the parameter

with model defined as , where is a latent variable is in general intractable. Variational methods maximize a lower bound of the log likelihood. This lower bound is based on approximating the intractable distribution by a tractable distribution parameterized by variational parameter . VAEs maximize the following lower bound of : To make the relationship with proposed methods clear, we write this as

(1)

We do this for all criteria going forward.

To efficiently optimize this criterion with gradient descent, VAEs Kingma2014vae ; Rezende2014 define the approximate posterior such that the is a differentiable transformation of an noise variable . It is common to assume and , and for to be a deep network with weights .

Requiring that can be analytically evaluated restricts the class and is a limitation to such approaches. Adversarial variational Bayes (AVB) Mescheder2017 maximizes the variational lower bound by implicitly approximating KL divergence between approximate posterior and the prior distribution by introducing third neural network, . This neural network, known as the discriminator, implicitly estimates .

(2)
(3)

The three parametric models

, and are jointly optimized using adversarial training. Unlike VAE and IWAE, in this framework, we can make arbitrarily flexible approximate distributions .

The adversarial autoencoder (AAE) Makhzani2016 is similar, except that the discriminative network depends only on , instead of on and . AAE objective minimizes the following objective:

(4)
(5)

AAE replaces the KL divergence between the approximate posterior and prior distribution in with an adversarial loss that tries to minimize the divergence between the aggregated posterior and the prior distribution .

3 Importance Weighted Adversarial Training

The importance weighted autoencoder (IWAE) Burda2015 provides a tighter lower bound to ,

(6)

Burda et al. Burda2015 show that , and approaches as .

3.1 IW-AVB and IW-AAE

In AVB, generative adversarial training on joint distributions between data and latent variables is applied to the variational lower bound

. In this work, we propose applying it to the importance weighted lower bound of ,

(7)

where is defined as in Equation 3. We call this Importance Weighted Adversarial Variational Bayes bound (IW-AVB). As , as .

The main advantage of IW-AVB over AVB is that, when the true posterior distribution is not in the class of approximate posterior functions (as is generally the case), IW-AVB uses a tighter lower bound than AVB Burda2015 .

Similarly, we can apply importance weighting to improve AAE:

(8)

where is defined as in Equation 5.

IW-AVB and IW-AAE objectives can be described as a framework of minimax adversarial game between three neural networks, the generative network , inference network , and discriminative network . The inference network maps input to latent space , and the generative network maps latent samples to the the data space . Both inference and generative networks are jointly trained to minimize the reconstruction error and KL divergence term in . The discriminator network differentiates samples from the joint distribution between data and approximate posterior distribution (positive samples) versus the samples that are from the joint over data and prior latent distribution (negative samples).

Recent work Rainforth2018 has shown that optimizing the importance weighted bound can degrade the overall learning process of the inference network because the signal-to-noise ratio of the gradient estimates converges at the rate of and for generative and inference networks, respectively ( is the gradient estimate of ). The converges to 0 for inference network as , and the gradient estimates of become completely random. To mediate this, we apply the importance weighted bound for updating the parameter of generative network and variational lower bound for updating the parameters of inference network . Hence, we maximize the following:

(9)

We do this for IW-AAE as well.

We alternate between updating inference-generative pair, and adversarial discriminator . The training procedures for IW-AVB and IW-AAE are shown in Algorithm 1 and 2.

1:Initialize , , and .
2:while  has not converged do
3:     Sample .
4:     Sample .
5:     Sample .
6:     Sample .
7:     Compute gradient w.r.t in Eq. 9 :
8:     
9:     
10:     Compute gradient w.r.t in Eq. 9.
11:     Compute gradient w.r.t in Eq. 3:
12:     
13:     
Algorithm 1 IW-AVB

3.2 Analysis

An important reason to maximize w.r.t the variational lower bound in Equation 9 is that it guarantees for the optimal discriminator network Mescheder2017 . Since deriving indirectly depends on , we want the gradient w.r.t in to be disentangled from calculating the gradients of Equation 9. Thus, we are only using the importance weighted bound on generative model. Empirically, we find that this still improves performance (Section 5).

The following proposition shows that the global Nash equilibria of IW-AVB’s adversarial game yield global optima of the objective function in .

Proposition 1.

Assume can represent any function of two variables. If is a Nash Equilibrium of the two-player game for IW-AVB, then and is a global optimum of the importance weighted lower bound in Equation 9.

See the Appendix for proof. This proposition tells us that the solution to Equation 9 gives the solution to importance weighted bound, in which becomes the maximum likelihood assignment.

A similar property holds for AAE and IW-AAE with the discriminator .

Proposition 2.

Assume can represent any function of two variables. If is a Nash Equilibrium of two-player game for IW-AAE, then and is the global optimum of the following objective,

(10)

where .

The steps of the proof are the same as for Proposition 1.

In the next section, we provide theoretical insights into the relationship between the optima of Equations 9 and 10 and the log-likelihood.

4 Relationship of IW-AVB and IW-AAE to other objectives

Bousquet et al. Bousquet2017 showed adversarial objectives with equivalent solutions to and . In a similar manner, we show that the adversarial objective with equivalent solutions to is

(11)

where is the generative adversarial network objective Goodfellow2014 with discriminative network , and and are data and model distributions. can be viewed as (pseudo-) divergence between the data and model distribution, where for all .

Similarly, the the adversarial objective for IW-AAE becomes

(12)

Bousquet et al. also show that minimizing is a special case of minimizing a penalized optimal transport (POT) with -Wasserstein distance.

These adversarial objectives bound becomes a tighter upper-bound as the number of samples increases:

Proposition 3.

For any distribution and , and for samples:

The proof follows the steps from Theorem 1 in Burda2015 .

The relationships between and , , and are

Proposition 4.

For any distribution and :

The proof is shown in the Appendix. The is tighter than (Proposition 3), and the is tighter than due to tighter adversarial approximation (i.e., since is convex). However, the relationship between and is unknown, because the trade-off between importance weighting bound versus the more flexible adversarial objective is unclear.

4.1 Relationship between Wasserstein Autoencoders and log-likelihood

We would like to understand the relationship between AAE (IW-AAE) and log-likelihood. Previously, it was shown that converges to Wasserstein autoencoder objective function under certain circumstances Bousquet2017 . We observe that converges to new Wasserstein autoencoder objective which gives a tighter bound on the autoencoder log-likelihood . The quantity can be understood as likelihood of reconstructed data from probabilistic encoding model. Further in Corollary in Appendix, in a special case, we were able to relate and .

Wasserstein distance

is a distance function defined between probability distribution on a a metric space. Bousquet

et al. Bousquet2017 showed that the penalized optimal transportation objective is relaxed version of Wasserstein autoencoder objective 111Given that the generative network is probabilistic function, we have where

(13)

and is a distance function. is used for the choice of convex divergence between the prior and the aggregated posterior Bousquet2017 . As , converges to . It turns out that is a special case of . This happens when the cost function is squared Euclidean distance and is Gaussian .

We can also observe that converges to :

Proposition 5.

Assume that , . Then, where and are

Moreover, converges to and converges to as .

where is the log-likelihood of an autoencoder222We abuse the notation by writing as a .. The bound is derived by applying Jensen’s inequality (see the proof in the Appendix). We observe that is the lower bound of under the condition that . The tighter bound is achieve using compare to . Lastly, we observe that approximates and approximates .

The following theorem shows the relationship between AAE objective and .

Theorem 1.

Maximizing AAE objective is equivalent to jointly maximizing , mutual information with respect to , and the negative of KL divergence between joint distribution and ,

(14)

The proof is in the Appendix. This illustrate the trade of between the mutual information and the relative information between and . In order for the gap between and to be small, need to become close to .

5 Experiments

We conducted our experiments with two main objectives where we want to i. compare the performance between AVB, IW-AVB, AAE, and IW-AAE; ii. check whether the adversarial training objectives can benefit neural spike inference in general. For such reasons, we measure their performance in two experimental setups. First, we experiment on generative modeling task on MNIST dataset. Second, we apply adversarial training on neuron spike activity inference dataset with both amortized and non-amortized inference settings.

(a) AVB
(b) IW-AVB
(c) AAE
(d) IW-AAE

Figure 1: Samples of generative models from training MNIST dataset.

5.1 Generative modeling

We follow the same experimental procedure as Mescheder2017

for learning generative models on binarized MNIST dataset. We trained AVB, IW-AVB, AAE, and IW-AAE on 50,000 train examples with 10,000 validation examples, and measured log-likelihood on 10,000 test examples. We applied the same architecture from

Mescheder2017 333We followed the experiment and the code from https://github.com/LMescheder/AdversarialVariationalBayes.. See the details of Mescheder2017 in Supplementary Materials.

We considered three following metrics. The log-likelihood was computed using Annealed Importance Sampling (AIS) Neal2001 ; Wu2016 with 1000 intermediate distribution and 5 parallel chains. We also applied the Frechet Inception Distance (FID) Heusel2017 . It compares the mean and covariance of the Inception-based representation of samples generated by the GAN to the mean and covariance of the same representation for training samples:

(15)

Lastly, we also considered GAN metric proposed by Im2018 that measure the quality of generator by estimating divergence between the true data distribution and for different choices of divergence measure. In our setting we considered least-square measure (LS).

Log-Likelihood FID LS
VAE -90.69 0.88 259.87 3.8e-5
IWAE -91.64 0.71 255.513 3.6e-5
AVB -90.42 0.78 256.13 4.1e-5
IW-AVB -85.12 0.20 251.20 3.3e-5
AAE -101.78 0.62 266.76 3.8e-5
IW-AAE -101.38 0.19 249.12 3.2e-5
Table 1: Log-likelihood, FID, and LS metric on binarized MNIST. IW-AVB performs best for log-likelihood and IW-AAE performs best under FID and LS metric.

Table 1 presents the results. We observe that IW-AVB gets the best test log-likelihood for both MNIST and FashionMNIST dataset444Note that our results are slightly lower than the reported results in Mescheder2017 . However, we used same codebase for all models (the results for FashionMNIST is shown in Appendix). On the other hand, IW-AAE gets the best FID and LS metric. We speculate that the reason is because AVB and IW-AVB directly maximizes the lower bound of the log-likelihood , whereas AAE and IW-AAE does not. AAE and IW-AAE maximizes the distance between data and model distribution directly. The MNIST and FashionMNIST samples are shown in Figure 1 and  7 in Appendix.

(a) Non-amortized spike inference
(b) Amortized spike inference
(c) Inference time.
Figure 2:

Pairwise performance comparison using t-test

5.2 Neural activity inference from calcium imaging data

We consider a challenging and important problem in neuroscience – spike inference from calcium imaging. Here, the unobserved binary spike train is the latent variable which is transformed by a generative model whose functional form is derived from biophysics into fluorescence measurements of the intracellular calcium concentration.

We use a publicly available spike inference dataset, cai-1555The dataset is available at https://crcns.org/data-sets/methods/cai-1/. We use the data from five layer 2/3 pyramidal neurons in mouse visual cortex 666We excluded neurons that has clear artifacts and mislabels in the dataset.. The neurons are imaged at 60 Hz using GCaMP6f – a genetically encoded calcium indicator chen2013ultrasensitive . The ground truth spikes were measured electrophysiologically using cell-attached recordings.

When we train AVB, AAE, IW-AVB, IW-AAE to model fluorescence data, we use a biophysical generative model and a convolutional neural network as our inference network. Thus, the process is to generate (reconstruct) the fluorescence traces with inferred spikes using encoders. We ran five folds on every experiments in neural spike inference dataset. The details of architectures, biophysical model, and datasets can be found in the Appendix 

Neural Spike Modeling

We experimented under two settings: Non-amortized spike inference, and amortized spike inference settings. Non-amortized spike inference corresponds to training a new inference network for each neuron. This is expensive but it provides an estimate of the best possible performance achievable. Amortized spike inference setup corresponds to the more useful setting where a “training” dataset of neurons is used to train an inference network (without ground truth), and the trained inference network is tested on a new “test” neuron. This is the more practically useful setting for spike inference – once the inference network is trained, spike inference is extremely fast and only requires prediction by the inference network.

We use two variants of VIMCO Mnih2016 as a baseline, VIMCO-FACT and VIMCO-CORRspeiser2017fast . VIMCO-FACT uses a fast factorized posterior distribution which can be sampled in parallel over time, same as the adversarially trained networks. VIMCO-CORR uses an autoregressive posterior that produces correlated samples which must be sampled sequentially in time (see the details in the Appendix Neural Spike Modeling).

Following the neuroscience community, we evaluated the quality of our posterior inference networks by computing the correlation between predicted spikes and labels as the performance metric. We used a paired t-test goulden1956 compare the improvement of all pairs of inference networks across five neurons (see Figure 2). The full table of correlations scores for all neurons and methods in both amortized and non-amortized settings are shown in Appendix Table 3. We observe that AVB, AAE, IW-AAE, and IW-AVB performances lie in between VIMCO-FACT and VIMCO-CORR. Overall, we observe that IW-AVB, AAE, and IW-AAE performs similarly across given neuron datasets. Figure 3 illustrates the VIMCO-FACT and IW-AVB posterior approximation on neuron 1 dataset. From the figure, we observe that VIMCO-FACT tend to have high false negatives while IW-AVB tend to have high false positives. The results are similar for amortized experiments as shown in Table 4. Interestingly, the performance of IW-AVB, AAE, and IW-AAE were better than non-amortized experiments. This suggests that neural spike influencing can be generalized over multiple neurons. Note that this is the first time that adversarial training has been applied to neural spike inference.

Figure 3: Trace reconstructions along with spike samples

Moreover, VIMCO-CORR generates correlated posterior samples, whereas the samples from VIMCO-FACT are independent. Nevertheless, the inference is slower at test time compared to VIMCO-FACT since the spike inferences are done sequentially rather than in parallel. This is huge disadvantage to VIMCO-CORR, because spike records can be hourly long. We emphasize the adversarial training, such as IW-AVB and IW-AAE, because they generate correlated posterior samples in parallel. Figure 2 demonstrates the time advantage of adversarial training over VIMCO-CORR. The total florescence data duration was 1 hour at a 60 Hz sampling rate and ran on NVIDIA GeForce RTX 2080 Ti.

6 Conclusions

Motivated by two ways of improving the variational bound: importance weighting Burda2015 and better posterior approximation Rezende:2015vu ; Mescheder2017 ; Makhzani2016 , we propose importance weighted adversarial variational Bayes (IW-AVB) and importance weighted adversarial autoencoder (IW-AAE). Our theoretical analysis provides better understanding of adversarial autoencoder objectives, and bridges the gap between log-likelihood of an autoencoder and generator.

Adversarially trained inference networks are particularly effective at learning correlated posterior distributions over discrete latent variables which can be sampled efficiently in parallel. We exploit this finding to apply both standard and importance weighted variants of AVB and AAE to the important yet challenging problem of inferring spiking neural activity from calcium imaging data. We have empirically shown that the correlated posteriors trained adversarially in general outperform existing VAEs with factorized posteriors. Moreover, we get tremendous speed gain during the spike inference compare to existing VAEs work with autoregressive correlated posteriors speiser2017fast .

References

Appendix

Proof of Things

Proposition 1.

Assume that can represent any function of two variables. If is a Nash Equilibrium of two-player game, then and is a global optimum of the importance weighted lower bound.

Proof.

Suppose that is a Nash Equilibrium. It was previously shown by [8] that

Now, we substitute into Equation 7 and show that maximizes the following formula as a function of and :

Define the implicit distribution :

where

is an importance weight.

Now, following the steps of turning in terms of with implicit [7], we have

Thus, maximizes

Proof by contradiction, suppose that does not maximize the variational lower bound in . So there exist such that

However, substituting in is greater than , which contradicts the assumption. Hence, is a global optimum of .

Since we can express in terms of with implicit distribution [7],

() is also a global optimum of the importance weighted lower bound.

Proposition 4.

For any distribution and :

Proof.

First, we show .