A Tutorial on the Mathematical Model of Single Cell Variational Inference

01/03/2021 ∙ by Songting Shi, et al. ∙ 0

As the large amount of sequencing data accumulated in past decades and it is still accumulating, we need to handle the more and more sequencing data. As the fast development of the computing technologies, we now can handle a large amount of data by a reasonable of time using the neural network based model. This tutorial will introduce the the mathematical model of the single cell variational inference (scVI), which use the variational auto-encoder (building on the neural networks) to learn the distribution of the data to gain insights. It was written for beginners in the simple and intuitive way with many deduction details to encourage more researchers into this field.



There are no comments yet.


page 16

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Auto-Encoder

To understand the scVI model, we should first understand how the variational auto-encoder works. And to make the understanding the variational auto-encoder easier, we first introduce the auto-encoder(Bengio (2009)) which is a similar but simple model. Now, let we think a simple example to get the ideas of auto-encoder. Suppose that the hidden code and the data is generated by . We have


formally, where the variance matrix of


is not invertible, and

is a degenerated normal distribution. Suppose that we see a set of samples of



from the above generation process. While now suppose that we only see the set of examples, we do not know the underline generation mechanism. We want to learn an encode (contraction) function such that we can represent in a compact form by and also a decode function such that we can recover form its code , i.e, . How can we do this?

By a simple linear regression, we can easily get the relation

. This means that lines on a one-dimensional manifold, we can easily find the the contraction function and recovery the from code by the generation function . For this simple example, a simple guess solves this problem. Can we find an algorithm from the above process to formulate a general method to solve this kind of problem but with more complicated data? Yes! Auto-encoder is one of such a method. It is a framework to learn the generation function and encode (contraction) function . Now we apply the auto-encoder method to go through this simple example to gain the basic ideas. We now suppose that and

comes from the linear transform function, and

lies on the one-dimension manifold, we can parametrize function and . Then the auto-encoder will output . The objective function of auto-encoder is given by


And then using the SGD methods or its variants to train the model on the training data to minimize the objective function (5 ).

Question: Can we learn out the optimal solution ? or some other reasonable solution? yes!

Since this objection is differentiable, we can set the first order of the objective function to get the stable condition.


when we fix the , we can get the following linear system about .


Note that , we can simplify it to the follow equation


Note that in the above equation, the second equation is a double times of the first equation, we simply get


When we fix , we can get the following solution of ,


Note that , we can simplify it to the follow equation


Bring them together, we get the following necessary condition of the stable point.


It can be simplifed to the following equation


When we restrict that , we can get


Obviously, satisfy the stable condition (14). Also note that there are infinite solution of equation (14), e,g. , , , and so on. And which solution arrived is depends on the which algorithm been used. As in the variational auto-encoder, we restrict that approaches the standard normal distribution, which restrict that the has zero mean and unit variance, then we can get . Note that even in this case, there is a freedom in , but it do not influence the output .

When there is a freedom of the optimal parameters in the function, it usually will cause the optimization algorithm unstable since it can jump between the many optimums. If there are many parameters of function than which need to fit the true solution, it will cause overfitting of the training data which we learning the noise information in the training data in the function which will deviate from the true solution. A general principle is to add a penalty on the objective function to avoid it, and the penalty can the l2/l1 norm of the parameters of the function. We now add the l2 norm penalty on the parameters with multiplier

, it will give the following loss function:


We carry out the same analysis above. First, we use the first order condition to get the following condition which the parameter must obey when it arrives at a local minimum of the objective function.


when we fix the , we can get the following linear system about .


Note that , and denoting , we can simplify it to the follow equation


We now assume which is true usually, so that the coefficient matrix is invertible. We have the following solution


Note that if

is the sample estimation of

, and if and , then we have


, which results in , this is what we needed. Note that when we add the l2 norm, we focus a optimum point form the original on a line to a point, the reason is that the l2 norm add a local convexity on the loss landscape.

When we fix , we can get the following solution of ,


Note that , we can simplify it to the follow equation


Note that , and if and , then we have .

Bring them together, we get the following necessary condition of the stable point.


Under the condition that , we can solve the above equation to get the solutions




Now we focus on the positive solution, when , we have


For , we have , and . This verifies the correctness of the solution. But in this case, we have . In generally, if we only have , then . and . In this case, when and , we will recovery correctly. Under the l2 norm penalty, we reduce the infinite solution of original encoder and decoder to two solutions, and this will make the algorithm more stable, if we choose that a small , the optimum of the objective with l2 penalty will approximate one of the optimums of the original solution.

What will happen if we apply the l1 penalty to the original objective, we left the exploration to you.

Note if we want to constrain the distribution of to a standard normal distribution, this will meet a obstacle, since the distribution of depends on the distribution of , i.e. if we known the distribution of is , then we can get the distribution of is when the is invertible and the determinant not equal zero almost surely. But we do not known the probability distribution of , even we know the , when the is hard to compute so that we can not use the KL divergence between the distribution of and the normal distribution to get a penalty. This yields the need of the variational auto-encoder. Before we give the story of it, we first summarize the above simple formulation of auto-encoder to the general auto-encoder.

In the general form of auto-encoder, such as use in the image processing, it consists of encoder function and decoder function , they are represented by the neural network with parameters , respectively. And general form of the neural networks can be represented in the form , where

is an element-wise non-linear activation function(e.g., sigmoid, ReLU),

is the linear projection matrix and is the intercept term, it has hidden layers and final -th layer is the output layer, and parmameters . The loss function is given by


And it is optimized by the SGD algorithm or its variants, and these methods only need the computation of the gradient of loss function on a mini-batch of samples essentially.

2 Variational Auto-Encoder

Now, we begin to tell the general story of the variational auto-encoder(Kingma and Welling (2014), Doersch (2016)) with general symbol. After that we begin to introduce the scVI model, which is a variational encoder designed for the scRNA-seq data.

We now use the same symbols in the Kingma and Welling (2014) to make it more easy to understand. To tackle the uncomputable probability distribution of , the variational auto-encoder assume that data point comes from the hidden continuous variable . is generated from the probability distribution , and then comes from the conditional distribution . And this is represented in the Fig 1 (Kingma and Welling (2014)) with the solid arrow.

Figure 1: The type of directed graphical model under consideration. Solid lines denote the generative model , dashed lines denote the variational approximation to the intractable posterior . The variational parameters are learned jointly with the generative model parameters .

The probability distribution of is given by . We hope that we can find a computable distribution to concisely represented information from the data points , such that we can use these probability distribution to do the downstream analysis. As we known from the bayesian approach, we can use a probability class to represent the distribution , but the marginal distribution is hard to obtain in general, so does the conditional distribution . The variational inference tackles this problem by using the computable distribution from the distribution class of to approximate the posterior distribution , which is represented by the dashed lines in Fig 1. To achieve this goal, we need find a computable algorithm to extract information form sample points into the parametrized distribution . This can finished by take the maximum likelihood method and do some approximation, i.e., use the variational lower bound. Now, we begin to give the fundamental deduction of the variational lower bound. Firstly, in the classical maximum likelihood method, we seek the optimal which maximize the log-likelihood . The variational lower bound on the marginal likelihood of datapoint is defined by


The is KL divergence between two distribution , which is nonnegative. The second RHS term basically measure the divergence of approximate from the true posterior. And since it is non-negative, we call it a lower bound. We can rewrite the variational lower bound into the known quantities .


So we get the classical representation of the variational lower bound.


The first RHS term is the KL divergence between the approximate posterior and the prior distribution of the hidden continuous variable . When , we have a tight bound.


So if we fix the parameter , the maximum of the variational lower bound will equal the log-likelihood , which is achieved by when . Now suppose that we always achieve such a state, i.e. the variational lower bound equals the marginal log-likelihood, by the maximum likelihood optimization, if we have large enough number of sample points, then the maximum of the log-likelihood will be achieved on the optimum . The above arguments roughly give us a belief that we can optimize the variational lower bound to find the optimum , and the will catch up the underline data distribution.

We next should select the proper distribution class with highly representative capacity for the distributions in the variational lower bound (31) to approximate the true distribution and make the optimization of the variational lower bound easily and efficiently.

Note that if is a continuous distribution in a dimensional space, e.g, normal distribution,

is a random vector in the

dimensional space, then we can find a function such that also surely(Kingma and Welling (2014)) with the proper complex function . We can conjecture that if the random vector lies on manifold with essentially dimension, we can also find the function , such that . Now if , and is random vector represents the gene expression distribution. Since there are complicated regulatory network between genes, the function should represent these complex regulatory networks. Now, the distribution of

can be the simple normal distribution or log normal distribution, or other continuous distribution. To make the KL divergence

small, we let the approximate posterior in the same distribution class of the distribution of . For the single cell RNA-seq data, the distribution class of

choose the zero-inflated negative binomial distribution.

We call the as the encoder, it encoder the datapoint to its "code" . And we refer as the decoder, it decode the "code" in the data point .

Here, we should point out that the complex regulatory networks between genes is modeled mainly by the mean of the negative binomial distribution.

To get a sense of the final output by a independent Gassional variable with the mean and diagonal variance as a function of random variable

will capture some dependence structure of , we give a simple example. Now let is standard normal variable, and is the conditional density of . We can get the in a close form.


So we get that


It shows that this simple example will capture the dependence of with . So in general form with the mean and diagonal variance output by nonlinear mapping such as neural networks, then the density will capture complex dependence networks. If is the gene expressions, this will capture the complex gene regulatory networks, and the complex gene regulatory networks are captured by . This may be one reason of the success of the scVI model.

The variational autoencoder(VAE) model the probability encoder

by modeling the parameters(i.e, the mean and diagonal covariance matrix) of the distribution with a nonlinear mapping (e.g. neural networks). is the prior distribution usually the basic distribution without parameters , e.g. , standard Gauassion variables. And comes from the same probability distribution class of , this will lead a close form of the KL divergence . The probability distribution fo probability decoder should be accounts for the distribution of the real distribution of , e.g. scVI choose the zero-inflated negative binomial distribution for the gene expression, while the image processing choose the Guassion distribution with diagonal variance. VAE use a nonlinear mapping ( neural networks) to model the parameters of the distribution of .

To train the neural networks on a large dataset, it use the stochastic optimization to train the model, which needs that a low variance estimate of the gradients of the objective function (variational lower bound). In most case, the parametric families of distribution of will leads an analytical of expression which is the differentiable with parameters . While there is some problem with the reconstruction error term of the variational lower bound. If we use


to estimate it, this will cause two problems. The first one is that variance of this estimation is very high, so it will fail the stochastic optimization. And the second one is that we can not differentiate it with parameters , since the backward gradient can not pass through a sample to the parameters of the distribution . To get around this problem, Kingma and Welling (2014) proposed the reparametrization trick. The trick use the fact that we can express the random variable