DeepAI
Log In Sign Up

Probability Functional Descent: A Unifying Perspective on GANs, Variational Inference, and Reinforcement Learning

01/30/2019
by   Casey Chu, et al.
0

The goal of this paper is to provide a unifying view of a wide range of problems of interest in machine learning by framing them as the minimization of functionals defined on the space of probability measures. In particular, we show that generative adversarial networks, variational inference, and actor-critic methods in reinforcement learning can all be seen through the lens of our framework. We then discuss a generic optimization algorithm for our formulation, called probability functional descent (PFD), and show how this algorithm recovers existing methods developed independently in the settings mentioned earlier.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

11/03/2018

VIREL: A Variational Inference Framework for Reinforcement Learning

Applying probabilistic models to reinforcement learning (RL) has become ...
04/04/2020

The equivalence between Stein variational gradient descent and black-box variational inference

We formalize an equivalence between two popular methods for Bayesian inf...
10/06/2016

Connecting Generative Adversarial Networks and Actor-Critic Methods

Both generative adversarial networks (GAN) in unsupervised learning and ...
08/16/2022

Langevin Diffusion Variational Inference

Many methods that build powerful variational distributions based on unad...
02/01/2022

Tutorial on amortized optimization for learning to optimize over continuous domains

Optimization is a ubiquitous modeling tool that is often deployed in set...
04/20/2021

Outcome-Driven Reinforcement Learning via Variational Inference

While reinforcement learning algorithms provide automated acquisition of...
02/19/2018

Distribution Matching in Variational Inference

The difficulties in matching the latent posterior to the prior, balancin...

1 Introduction

Domain Distribution of interest Functional Functional derivative
Generative adversarial networks Generator Discriminator
Variational inference Approximate posterior Negative ELBO
Reinforcement learning Policy Expected reward Advantage
Table 1: Framing a problem as the optimization of a probability functional unifies several domains.
Algorithm

Type of derivative estimator

Generative adversarial networks
Minimax GAN (Goodfellow et al., 2014) Convex duality
Non-saturating GAN (Goodfellow et al., 2014) Binary classification
Wasserstein GAN (Arjovsky et al., 2017) Convex duality
Variational inference
Black-box variational inference (Ranganath et al., 2014) Exact
Adversarial variational Bayes (Mescheder et al., 2017) Binary classification
Adversarial posterior distillation (Wang et al., 2018) Convex duality
Reinforcement learning
Policy iteration (Howard, 1960) Exact
Policy gradient (Williams, 1992) Monte Carlo
Actor-critic (Konda & Tsitsiklis, 2000; Sutton et al., 2000) Least squares
Dual actor-critic (Chen & Wang, 2016; Dai et al., 2017b) Convex duality
Table 2: Different existing algorithms correspond to different ways of estimating the functional derivative.

Deep learning now plays an important role in many domains, for example, in generative modeling, deep reinforcement learning, and variational inference. In the process, dozens of new algorithms have been proposed for solving these problems with deep neural networks, specific of course to domain at hand.

In this paper, we introduce a conceptual framework which can be used to understand in a unified way a broad class of machine learning problems. Central to this framework is an abstract optimization problem in the space of probability measures, a formulation that stems from the observation that in many fields, the object of interest is a probability distribution; moreover, the learning process is guided by a

probability functional

to be minimized, a loss function that conceptually maps a probability distribution to a real number.

Table 1 lists these correspondences in the case of generative adversarial networks, variational inference, and reinforcement learning.

Because the optimization now takes place in the infinite-dimensional space of probability measures, standard finite-dimensional algorithms like gradient descent are initially unavailable; even the proper notion for the derivative of these functionals is unclear. We call upon on a body of literature known as von Mises calculus (von Mises, 1947; Fernholz, 2012), originally developed in the field of asymptotic statistics, to make these functional derivatives precise. Remarkably, we find that once the connection is made, the resulting generalized descent algorithm, which we call probability functional descent

, is intimately compatible with standard deep learning techniques such as stochastic gradient descent

(Bottou, 2010), the reparameterization trick (Kingma & Welling, 2013), and adversarial training (Goodfellow et al., 2014).

When we apply probability functional descent to the aforementioned domains, we find that we recover a wide range of existing algorithms, and the essential distinction between them is simply the way that the functional derivative, the von Mises influence function in this context, is approximated. Table 2 lists these algorithms and their corresponding approximation methods. Probability functional descent therefore acts as a unifying framework for the analysis of existing algorithms as well as the systematic development of new ones.

1.1 Related work

The problem of optimizing functionals of probability measures is not new. For example, Gaivoronski (1986) and Molchanov & Zuyev (2001) study these types of problems and even propose Frank-Wolfe and steepest descent algorithms to solve these problems. However, their results do not attempt to unify the areas that we discuss nor are their algorithms immediately practical for the high-dimensional machine learning settings described here. Owing to our unifying perspective, though, one may be able to adapt these works to the machine learning applications described here.

One part of our work recasts convex optimization problems as saddle-point problems by means of convex duality as a technique for estimating functional derivatives. This correspondence between convex optimization problems and saddle point problems is an old and general concept (Rockafellar, 1968), and it underlies classical dual optimization techniques (Lucchetti, 2006; Luenberger & Ye, 2015). Nevertheless, the use of these min-max representations remains an active topic of research in machine learning. Most notably, the literature concerning generative adversarial networks has recognized that certain min-max problems are equivalent to certain convex problems (Goodfellow et al., 2014; Nowozin et al., 2016; Farnia & Tse, 2018). Outside of GANs, Dai et al. (2017a, 2018) have begun using these min-max representations to inspire learning algorithms. These min-max representations are an important tool for us that allows for practical implementation of our machinery.

Finally, we mention the work of Dai (2018)

, who also casts many machine learning problems, including Bayesian inference and policy evaluation, as infinite-dimensional optimization problems. These optimization problems are over functions in Hilbert spaces with an objective in the form of an integral operator. Because probability distributions may often be represented as functions in an appropriate Hilbert space (via their density, for example, if they exist), this approach bears similarity to ours. Their use of a Hilbert space as the optimization domain represents a tradeoff that allows access to powerful Hilbert space structure but also requires the use of specialized kernel-based

(Dai et al., 2014) or particle-based methods (Dai et al., 2016; Liu & Wang, 2016) to practically apply functional gradients and obtain samples. In our work, we work directly in the space of probability measures and associate a dual space without assuming any Hilbert (or even Banach) structure. This approach, as we shall illustrate, is compatible with standard deep learning techniques and indeed corresponds exactly to many existing deep learning-based algorithms.

2 Descent on a Probability Functional

We let be the space of Borel probability measures on a topological space . Our abstract formulation takes the form of a minimization problem over probability distributions:

where is called a probability functional. In order to avoid technical digressions, we assume that is a metric space that is compact, complete, and separable (i.e. a compact Polish space). We endow with the topology of weak convergence, also known as the weak* topology.

We now draw upon elements of von Mises calculus (von Mises, 1947) to make precise the notion of derivatives of functionals such as . See Fernholz (2012) for an in-depth discussion, or Santambrogio (2015) for another perspective.

Definition 1 ((Gâteaux differential)).

Let be a function. The Gâteaux differential at in the direction is defined by

where for some .

Intuitively, the Gâteaux differential is a generalization of the directional derivative, so that describes the change in the value of when the probability measure is infinitesimally perturbed in the direction of , towards another measure . Though powerful, the Gâteaux differential is a function of differences of probability measures, which can make it unwieldy to work with. In many cases, however, the Gâteaux differential has a concise representation as an influence function .

Definition 2 ((Influence function)).

We say that is an influence function for at if the Gâteaux differential has the integral representation

for all , where .

The influence function provides a convenient representation for the Gâteaux differential. We note that if is an influence function, then so is for a constant , since . We also note that under some regularity conditions, the influence function can be computed as , where is the distribution that gives the point with probability .

The Gâteaux derivative and the influence function provide the proper notion of a functional derivative, which allows us to generalize first-order descent algorithms to apply to probability functionals such as . In particular, they permit a linear approximation to around , which we denote :

This representation, also known as a von Mises representation, yields additional intuition about the influence function. Concretely, note that decreases locally around if decreases. Therefore, acts as a potential function defined on that dictates where samples should descend if the goal is to decrease . Of course, only carries this interpretation around the current value of ; once is changed, the influence function must be recomputed for the new distribution.

Based on this intuition, we now present probability functional descent, a straightforward analogue of finite-dimensional first-order descent algorithms:

  Initialize to a distribution in
  while  has not converged do
     Set (differentiation step)
     Update to decrease (descent step)
  end while
Algorithm 1 Probability functional descent on

As we shall see, probability functional descent serves as a blueprint for many existing algorithms: in generative adversarial networks, the differentiation and descent steps correspond to the discriminator and generator updates respectively; in reinforcement learning, they correspond to policy evaluation and policy improvement.

In its abstract form, probability functional descent requires two design choices in order to convert it into a practical algorithm. In section 3, we discuss different ways to choose the update in the descent step; footnote 1 provides one generic way. In section 4, we discuss different ways to approximate the influence function in the differentiation step; Definition 3 provides one generic way and a surprising connection to adversarial training.

3 Applying the Descent Step

One straightforward way to apply the descent step of PFD is to adopt a parametrization and descend the stochastic gradient of .111Note that this gradient step is simply one possible choice of update rule for the descent step of PFD; see subsection 7.1 (policy iteration) for an instance of PFD where this gradient-based update rule is not adopted.

This gradient step is justified by the following analogue of the chain rule: [(Chain rule)] Let

be continuously differentiable, in the sense that the influence function exists and is continuous. Let the parameterization be differentiable, in the sense that converges to a weak limit as . Then

where is treated as a function that is not dependent on .

Proof.

Without loss of generality, assume

, as the gradient is simply a vector of one-dimensional derivatives. Let

, and let . Then

Assuming for now that

we have that

where the interchange of limits is by the definition of weak convergence (recall we assumed that is compact, so is continuous and bounded by virtue of being continuous).

The equality we assumed is the definition of a stronger notion of differentiability called Hadamard differentiability of . Our conditions imply Hadamard differentiability via Proposition 2.33 of Penot (2012), noting that the map is continuous by assumption. ∎

footnote 1 converts the computation of , where may be a complicated nonlinear functional, into the computation of a gradient of an expectation, which is easily handled using standard methods (see e.g. Schulman et al. (2015)). For example, the reparameterization trick, also known as the pathwise derivative estimator (Kingma & Welling, 2013; Rezende et al., 2014), uses the identity

where samples using . Alternatively, the log derivative trick, also known as the score function gradient estimator, likelihood ratio gradient estimator, or REINFORCE (Glynn, 1990; Williams, 1992; Kleijnen & Rubinstein, 1996), uses the identity

where

is the probability density function of

. This gradient-based update rule for the descent step is therefore a natural, practical choice in the context of deep learning.

4 Approximating the Influence Function

The approximation of the influence function in the differentiation step can in principle be accomplished in many different ways. Indeed, we shall see that the distinguishing factor between many existing algorithms is exactly which influence function estimator used, as shown in Table 2. In some cases, it is possible that the influence function can be evaluated exactly, bypassing the need for approximation. Otherwise, the influence function, being a function , may be modeled as a neural network; the precise way in which this neural network needs to be trained will depend on the exact analytical form of the influence function.

Remarkably, a generic approximation technique is available if the functional is convex. In this case, the influence function possesses a variational characterization in terms of the convex conjugate of . To apply this formalism, we now view as a convex subset of the vector space of finite signed Borel measures , equipped with the topology of weak convergence. Crucial to the analysis will be its dual space, , the space of continuous functions . Finally, denotes the extended real line . The convex conjugate is then defined as follows:

Definition 3.

Let be a function. Its convex conjugate is a function defined by

Note that must now be defined on all of ; it is always possible to simply define if , although sometimes a different extension may be more convenient. The convex conjugate forms the core of the following representation for the influence function :

[(Fenchel–Moreau representation)] Let be proper, convex, and lower semicontinuous. Then the maximizer of , if it exists, is an influence function for at . With some abuse of notation, we have that

Proof.

We will exploit the Fenchel–Moreau theorem, which applies in the setting of locally convex, Hausdorff topological vector spaces (see e.g. Zalinescu (2002)). The space we consider is , the space of signed, finite measures equipped with the topology of weak convergence, of which is a convex subset. is indeed locally convex and Hausdorff, and its dual space is (see e.g. Aliprantis & Border (2006), section 5.14).

We now show that a maximizer is an influence function. By the Fenchel–Moreau theorem,

and

Because is differentiable, is differentiable, so by the envelope theorem (Milgrom & Segal, 2002),

so that is an influence function.

The abuse of notation stems from the fact that not all influence functions are maximizers. This is true, though, if if :

since the convex function lies above its tangent line:

Since , we have that

Definition 3 motivates the following influence function approximation strategy: model with a neural network and train it using stochastic gradient ascent on the objective . The trained neural network is then an approximation to suitable for use in the descent step of PFD. Under this approximation scheme, PFD can be concisely expressed as the saddle-point problem

where the inner supremum solves for the influence function (the differentiation step of PFD), and the outer infimum descends the linear approximation (the descent step of PFD), noting that is a constant w.r.t. . This procedure is highly reminiscent of adversarial training (Goodfellow et al., 2014); for this reason, we call PFD with this approximation scheme based on convex duality adversarial PFD. PFD therefore explains the prevalence of adversarial training as a deep learning technique and extends its applicability to any convex probability functional.

In the following sections, we demonstrate that PFD provides a broad conceptual framework that describes a wide range of existing machine learning algorithms.

The following lemma will come in handy in our computations. Suppose has a representation

where is proper, convex, and lower semicontinuous. Then .

Proof.

By definition of the convex conjugate, . Then , by the Fenchel–Moreau theorem. ∎

We note that when applying this lemma, we will often implicitly define the appropriate extension of to to be . The exact choice of extension can certainly affect the exact form of the convex conjugate; see Ruderman et al. (2012) for one example of this phenomenon.

5 Generative Adversarial Networks

Generative adversarial networks (GANs) are a technique to train a parameterized probability measure to mimic a data distribution . There are many variants of the GAN algorithm. They typically take the form of a saddle-point problem, and it is known that many of them correspond to the minimization of different divergences . We complete the picture by showing that many GAN variants could have been derived as instances of PFD applied to different divergences.

5.1 Minimax GAN

Goodfellow et al. (2014) originally proposed the following saddle-point problem

The interpretation of this minimax GAN problem is that the discriminator

learns to classify between fake samples from

and real samples from via a binary classification loss, while the generator is trained to produce counterfeit samples that fool the classifier. It was shown that the value of the inner optimization problem equals , where

is the Jensen–Shannon divergence, and therefore the problem corresponds to training to minimize the divergence between and . As a practical algorithm, simultaneous stochastic gradient descent steps are performed on the discriminator’s parameters and the generator’s parameters using the two loss functions

(1)

where and are parameterized with neural networks.

Our unifying result is the following: Adversarial PFD on the Jensen–Shannon divergence objective

yields the minimax GAN algorithm (1). That is, the minimax GAN could have been derived mechanically and from first principles as an instance of adversarial PFD. To build intuition, we note that the discriminator plays the role of the approximate influence function: Suppose has density and has density . Then the influence function for is

Proof.

Recall that in the minimax GAN, the optimal discriminator satisfies , so the influence function is approximated using the learned discriminator.

Now, we rederive the minimax GAN problem (1) as a form of adversarial PFD. We compute: The convex conjugate of is

Proof.

Setting the integrand’s derivative w.r.t.  to , we find that pointwise, the optimal satisfies

We eliminate in the integrand. Notice that the first two terms in the integrand cancel after plugging in . Since

we obtain that

Definition 3 yields the representation

an ascent step on which is the -step in (1) with the substitution . The descent step corresponds to updating to decrease the linear approximation , which corresponds to the -step in (1). In fact, a similar argument can be applied to the -GANs of Nowozin et al. (2016), which generalize the minimax GAN. The observation that -GANs (and hence the minimax GAN) can be derived through convex duality was also noted by Farnia & Tse (2018).

5.2 Non-saturating GAN

Goodfellow et al. (2014) also proposed an alternative to (1) called the non-saturating GAN, which prescribes descent steps on

In the step on the generator’s parameters , the in the minimax GAN has been replaced with

. This heuristic change prevents gradients to

from converging to when the discriminator is too confident, and it is for this reason that the loss for is called the non-saturating loss.

We consider a slightly modified problem, in which the original minimax loss and the non-saturating loss are summed (and scaled by a factor of ):

(2)

This also prevents gradients to from saturating, achieving the same goal as the non-saturating GAN. Huszar (2016) and Arjovsky & Bottou (2017) recognize that this process minimizes .222The derivation of Huszar (2016) omits showing that the dependence of on can be ignored, but the result is proved by Theorem 2.5 of Arjovsky & Bottou (2017). We remark that this result can be seen as a corollary of footnote 1 and subsection 5.2.

We claim the following: PFD on the reverse Kullback–Liebler divergence objective

using the binary classification likelihood ratio estimator to approximate the influence function, yields the modified non-saturating GAN optimization problem (2).

Suppose has density and has density . The influence function for is

Proof.

Now, because the binary classification loss

(3)

is minimized by , one estimator for is simply

where is updated as in the -step in (2). With this approximation scheme, the differentiation step and the descent step in PFD correspond exactly to the -step and -step respectively in (2). Once again, the discriminator serves to approximate the influence function.

5.3 Wasserstein GAN

Arjovsky et al. (2017) propose solving the following saddle-point problem

where denotes the Lipschitz constant of . The corresponding practical algorithm amounts to simultaneous descent steps on

(4)

where is reprojected back to the space of -Lipschitz functions after each -step. Here, is again the generator, and is the discriminator, sometimes called the critic. This algorithm is called the Wasserstein GAN algorithm, so named because this algorithm approximately minimizes the -Wasserstein distance ; the motivation for the -step in (4) is so that the discriminator learns the Kantorovich potential that describes the optimal transport from to . See e.g. Villani (2008) for the full optimal transport details.

We claim that the Wasserstein GAN too is an instance of PFD, and once again, the discriminator plays the role of approximate influence function: Adversarial PFD on the Wasserstein distance objective

yields the Wasserstein GAN algorithm (4). The influence function for is the Kantorovich potential corresponding to the optimal transport from to .

Proof.

See Santambrogio (2015), Proposition 7.17. ∎

We remark that the gradient computation in Theorem 3 of Arjovsky et al. (2017) is a corollary of footnote 1 and subsection 5.3. Now, we show that the Wasserstein GAN algorithm can be derived mechanically via convex duality. The connection between the Wasserstein GAN and convex duality was also observed by Farnia & Tse (2018). The convex conjugate of is

We use the notation to denote the convex indicator function, which is if is true and if is false.

Proof.

Using Kantorovich–Rubinstein duality, we have that

where we use the notation

By Definition 3,

Definition 3 yields the representation

The adversarial PFD differentiation step therefore corresponds exactly to the -step in (4), and the PFD descent step is exactly the -step in (4).

6 Variational Inference

In Bayesian inference, the central object is the posterior distribution

where is an observed datapoint, is the likelihood, is the prior. Unfortunately, the posterior is difficult to compute due to the presence of the integral. Variational inference therefore reframes this computation as an optimization problem in which a variational posterior approximates the true posterior by solving

6.1 Black-box variational inference

This objective is not directly optimizable, due to the presence of the intractable term. The tool of choice for variational inference is the evidence lower bound (ELBO), which rewrites

Because is fixed, we may maximize the ELBO to minimize the KL divergence. The advantage of doing so is that all the terms inside the expectation are now tractable to evaluate, and thus the expectation may be approximated through Monte Carlo sampling. This leads to the following practical algorithm, namely stochastic gradient descent on the objective

(5)

This is called black-box variational inference (Ranganath et al., 2014). Roeder et al. (2017) later recognized that ignoring the -dependence of the term in the expectation yields the same gradients in expectation; it is this variant that we consider. Our unification result is the following: PFD on the variational inference objective

using exact influence functions, yields the black-box variational inference algorithm (5).

In fact, the influence function turns out to be precisely the inside of the negative ELBO bound: The influence function for is

Proof.

In this context, the influence function can be evaluated exactly, so the differentiation step of PFD may be performed without approximation. The descent step of PFD becomes exactly the descent step on of (5), where the -dependence of the term in the expectation is ignored. We remark that the argument of Roeder et al. (2017) that this dependence is removable can be seen as a corollary of footnote 1 and subsection 6.1.

6.2 Adversarial variational Bayes

When the density function of the prior or the variational posterior is not available, adversarial variational Bayes (Mescheder et al., 2017) may be employed. Here, the quantity is approximated by a neural network through a binary classification problem, much like (3). The resulting algorithm applies simultaneous descent steps on

(6)

This algorithm is another instance of PFD: PFD on the variational inference objective , using the binary classification likelihood ratio estimator to approximate the influence function, yields adversarial variational Bayes (6).

It is easily seen that

Therefore, the -step of (6) is the differentiation step of PFD, and the -step of (6) is the descent step. We remark that the gradient computation in Proposition 2 of Mescheder et al. (2017) is a corollary of footnote 1 and subsection 6.1.

7 Reinforcement Learning

In a Markov decision process, the distribution of states

, actions , and rewards is governed by the distribution

where is an initial distribution over states, gives the transition probability of arriving at state with reward from a state taking an action , and is a policy that gives the distribution of actions taken when in state . In reinforcement learning, we are interested in learning the policy that maximizes the expected discounted reward , where is a discount factor, while assuming we only have access to samples from and .

7.1 Policy iteration

Policy iteration (Howard, 1960; Sutton & Barto, 1998) is one scheme that solves the reinforcement learning problem. It initializes arbitrarily and then cycles between two steps, policy evaluation and policy improvement. In the policy evaluation step, the state-action value function is computed. In the policy improvement step, the policy is updated to the greedy policy, the policy that at state takes the action with probability .

Before we present our unification result, we introduce an arbitrary distribution over states

and consider the joint distribution

, so that is one probability distribution rather than one for every state . Now: PFD on the reinforcement learning objective

using exact influence functions and global minimization of the linear approximation, yields the policy iteration algorithm.

Proofs continue on the following page.

The influence function for is

where is the state-action value function, is the state value function, and is the marginal distribution of states after steps, all under the policy .

Proof.

First, we note that

where we abuse notation to denote .

We have

or, plugging in the measure,

The integral is over all free variables; we omit them here and in the following derivation for conciseness.

In computing , the product rule dictates that a term appear for every , in which is replaced with . Hence:

reordering the summations. Note that for , the summand vanishes:

since all the variables integrate away to . This yields: