SA vs SAA for population Wasserstein barycenter calculation

01/21/2020 ∙ by Darina Dvinskikh, et al. ∙ Weierstrass Institute 0

In Machine Learning and Optimization community there are two main approaches for convex risk minimization problem. The first approach is Stochastic Averaging (SA) (online) and the second one is Stochastic Average Approximation (SAA) (Monte Carlo, Empirical Risk Minimization, offline) with proper regularization in non-strongly convex case. At the moment, it is known that both approaches are on average equivalent (up to a logarithmic factor) in terms of oracle complexity (required number of stochastic gradient evaluations). What is the situation with total complexity? The answer depends on specific problem. However, starting from work [Nemirovski et al. (2009)] it was generally accepted that SA is better than SAA. Nevertheless, in case of large-scale problems SA may ran out of memory problems since storing all data on one machine and organizing online access to it can be impossible without communications with other machines. SAA in contradistinction to SA allows parallel/distributed calculations. In this paper we show that SAA may outperform SA in the problem of calculating an estimation for population (μ-entropy regularized) Wasserstein barycenter even for non-parallel (non-decenralized) set up.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

We consider the problem of finding a mean

of probability measures. Observing a sample of probability measures

from a some probabilistic family, we are interested in their population mean. Since the geometric structure of the problem is not Euclidean we refer to optimal transport (OT) problem, formulated long time ago and solved by G. Monge [Monge (1781)] and then improved by L. Kantorovich [Kantorovich (1942) ]. Nowadays, OT is a popular framework using in clustering Ho et al. (2017), text classification Kusner et al. (2015)

, image retrieval [

Rubner et al. (2000)

], computer vision [

Ni et al. (2009)], economics and finance [Beiglböck et al. (2013); Rachev et al. (2011)]. Inspired by the problem of OT, distance function was introduced which is usually named by Wassertein distance, or Monge–Kantorovich distance. This distance allows to measure how one object is deffer from the other one even for non-linear objects such as probability measures or histograms. To define the notion of barycenter we consider 2-Wasserstein distance between two probability measures and supported on complete metric space with metric

(1)

where is the set of of probability measures on the product space with respective marginals and . In the Wasserstein space of all probability measures with metric we define the Fréchet mean which is a generalization of usual mean for non-linear space [Fréchet (1948)]

(2)

where is the 2-Wasserstein distance. We refer to minimizer (2), which is also a probability measure, as the population (Fréchet) mean. In this work, we consider the specific case where measures are discrete measures with finite support of size to reduce the problem (1

) to finite linear program requiring

[Tarjan (1997)] arithmetic iterations for solving it. Approximation of a probability measures by a measure with finite support were studied in [Genevay et al. (2018); Mena and Weed (2019); Panaretos and Zemel (2019); Weed et al. (2019)]. To tackle the problem of computational complexity for solving linear program there were proposed an entropic regularization [Cuturi (2013)]. It helps to reduce the computational complexity to 111The estimate is the best theoretical known estimate for solving OT problem [Blanchet et al. (2018); Jambulapati et al. (2019); Lee and Sidford (2014); Quanrud (2018)]. Moreover, entropic regularization improves statistical properties of Wasserstein distance itself [Klatt et al. (2018); Bigot et al. (2019)]. This regularization shows good results in generative models [Genevay et al. (2017)], multi-label learning Frogner et al. (2015), dictionary learning [Rolet et al. (2016)], image processing [Cuturi and Peyré (2016); Rabin and Papadakis (2015)], neural imaging [Gramfort et al. (2015)]. A nice survey of OT and Wasserstein barycenter presents in [Peyré and Cuturi (2018)]. In this paper we are aim at constructing an -approximation for population Wasserstein barycenter. To do so, we estimate the number of sampled measures to get -precision in function value. We consider online and offline algorithms with providing comparison study for their convergence rates. Both types of approaches (online and offline) have the pros and the cons depending on specific of the problem which covering by this paper. Generally, starting from the work [Nemirovski et al. (2009)] SA (Stochastic Averaging) was considered to be better than SAA (Stochastic Average Approximation). On the example of population Wasserstein barycenter problem we demonstrate superiority of SA under SAA with definite values of regularization parameter. The main reason for that is the observation that gradient of dual function for has the complexity times smaller than primal one. We emphasize that the transition to a dual function is possible only in SAA approach. Furthermore, in our paper we study an

-confidence interval for population Wasserstein barycenter defined w.r.t. entropic-regularized OT. Our choice of entropic regularization is due to it ensures strong convexity of OT that allows to write

-convergence in argument.

1.1 Related work

Consistency of the empirical barycenter to its population counterpart as the number of measures tends to infinity was considered in many papers, e.g, [Le Gouic and Loubes (2015); Panaretos and Zemel (2019); Le Gouic and Loubes (2017); Bigot and Klein (2012); Rios et al. (2018)]. In [Bigot et al. (2017)] convergence of the empirical barycenter to its population counterpart in Wasserstein metric was also studied. However, they do not provide any rates of convergence. The rate of convergence can be found in paper [Boissard et al. (2015)] for the problem of template estimation. Authors provide a confidence interval for population barycenter, approximating it by iterated barycenter in Wasserstein space. However, they only consider probability measures obtained by template deformations with definite properties, e.g., the expectation of a function of deformation from admissible deformations family is identity. Our aim is refusing this particular conditions on is identity. Without any assumptions of generating process for observing measures one can find the rate of convergence for empirical Wasserstein barycenter towards its population counterpart in [Bigot et al. (2018)]. However, it is only valid when measures has one-dimensional support. Our approaches for constructing a confidence interval significantly use the results of the paper [Shalev-Shwartz et al. (2009)] for stochastic convex optimiazation.

1.2 Contribution

We summarize our contribution as follows. To the best of our knowledge, this is the first paper which provides a confidence interval for population Wasserstein barycenter with specifying the rates of convergence to calculate its approximation. Our first result is that SAA gets better rates of convergence in comparison with SA for the problem of approximating of the population Wasserstein barycenter defined w.r.t. entropy-regularized OT with the proper value of regularization parameter. We comment on its value in Section 6. Our second result is providing new regularization in SAA approach for the problem of calculating of the population Wasserstein barycenter (Section 5). To the best of our knowledge, this regularization contributes to improving convergence bounds compared to the state-of-the-art regularization from [Shalev-Shwartz et al. (2009)] for general convex function. Finally, we show that for the problem of calculating an approximation for population Wasserstein barycenter, stochastic mirror descent and risk minimization approach with our new regularization show better complexity bounds in comparison with SA and SAA.

1.3 Paper organization

The structure of the paper is the following. Sections 3 and 4 presents SA and SAA approaches respectively to find the confidence interval for barycenter w.r.t. regularized OT. In Section 5 we estimate barycenter w.r.t. OT involving some other algorithm with particular structure on regularization constant, including zero constant. In Section 6 we compare the rates of convergences. Finally, we present concluding remarks.

2 Wasserstein distance and Wasserstein barycenter: Notations and Properties

In this paper we consider discrete probability measures given in the probability simplex

Measures and are discrete if they can be presented in the form and , where is the Dirac measure at point ; and are histograms. Let , then we define transportation polytope

We define optimal transport (OT) problem between discrete probability measures as follows

Here is the cost matrix: is the cost to move the unit mass from point to point . When , where is the distance on support points of probability measures , then is known as 2-Wasserstein distance on .222We skeep the sub-index 2 for simplicity. We consider entropic OT, introduced in [Peyré and Cuturi (2018)]

(3)

For statistical explanation of such regularization see [Rigollet and Weed (2018)]. Further we define the notion of population barycenter of probability measures w.r.t. regularized OT by using the notion of Fréchet mean

We refer to empirical barycenter as the empirical counterpart of

If , then is population Wasserstein barycenter (i.e. w.r.t. OT) and its empirical counterpart.
For our convenience we also define the following notation

We refer to when we want to indicate the complexity hiding constants and the logarithms.

2.1 Dual formulation and properties

The structure of allows us to write Lagrangian dual function with dual variables and to the constraints and respectively

(4)

Here we denoted by the Lagrangian dual function for . Let be a solution of problem (2.1), then since dual function is strictly convex we have

where

is the smallest eigenvalue of positive semi-definite matrix

. Note, that from the theoretical point of view we do not know any accurate bounds from below for better than exponentially small of . Moreover,

(5)

We use denotation for gradient w.r.t. . [ properties] Entropy-regularized Wasserstein distance is

  1. -strongly convex in w.r.t 2-norm: for any

    (6)
  2. -Lipschitz in w.r.t 2-norm: for any

    (7)

    where ( is Lipschitz constant w.r.t. -norm, .

The first proposition follows from [Kakade et al. (2009); Nesterov (2005)]. Indeed, according to [Peyré and Cuturi (2018); Nesterov (2005)] the gradient of function in (2.1) is -Lipschitz continuous in 2-norm. From [Kakade et al. (2009)] we may conclude that in this case is -strongly convex w.r.t. in 2-norm [Nesterov (2005)]. The second proposition follows from [Dvurechensky et al. (2018); Kroshnin et al. (2019b); Lin et al. (2019)]. Note, that this result assumes some additional assumptions about the separability of considered measures from zero. But without loss of generality we can reduce general case to the desired one [Dvurechensky et al. (2018); Kroshnin et al. (2019b)]. The next two sections present constructing an -confidence interval for population barycenter defined w.r.t regularized Wasserstein distance.

3 Population barycenter with respect to regularized OT. Online (SA) approach

In this section we assume that probability measures come in online regime. One of the benefits of online approach is no need to fix the number of measures that allows to regulate the precision for calculated barycenter. Moreover, a problem of storing a large number of measures in a computing node is not present if we have an access to online oracle, e.g., some measuring device. Using online to batch conversation [Shalev-Shwartz et al. (2009)] we build the confidence interval for population barycenter defined w.r.t. regularized OT and provide complexity bounds to do it.
The following algorithm calculate online sequence of measures by online stochastic gradient descent, that at each iteration call Sinkhorn algorithm to compute an approximation for the gradient of entropic-regularized Wasserstein distance .

1:Starting point
2:for  do
3:     ,
4:     
5:where is -subgradient [Polyak (1987)] calculated by Sinkhorn algorithm, is the projection onto ,
Algorithm 1 Online Stochastic Gradient Descent (SGD)

In Algorithm 1 we define -subgradient using (5) as where

(8)

where is an output of Sinkhorn algorithm [Dvurechensky et al. (2018); Peyré and Cuturi (2018)] after iterations. To approximate population barycenter by the outputs of Algorithm 1 we use online to batch conversation [Shalev-Shwartz et al. (2009)] and define as the average of online outputs from Algorithm 1

(9)

Next we formulate two theorems which indicates the precision of as an approximation for population barycenter in function value and in argument. With probability for the from (9) the following holds

where The proof follows from proof in [Kakade and Tewari (2009)] accounting the error accumulation for inexactness of the gradient. [Barycenter confidence interval] With probability for from (9) the following holds

The proof follows directly from Theorem 3, strong convexity of and . Using (not online) algorithm from [Juditsky and Nesterov (2014)] instead of online algorithm allows to avoid accumulation of inexactness : the term in Theorems  33 can be replaced by . From Theorem 3 we can immediately conclude that the number of probability measures taking as inputs of Algorithm 1 and which is the precision for Sinkhorn algorithm performing at each iteration of Algorithm 1. [Number of probability measures and auxiliary precision] To get the -confidence region:

(10)

it suffices to take the following number of probability measures (iterations)

(11)

and the following -precision

If we use restarted SGD from [Juditsky and Nesterov (2014)] instead of Algorithm 1, then

Now we are ready to provide the complexity bounds for calculating that guarantee (10). The total complexity of restarted SGD from [Juditsky and Nesterov (2014)] to guarantee (10) is

(12)

where we use . The proof follows from the complexity of Sinkhorn algorithm. To state the complexity of Sinkhorn we firstly define as (see also (8))

From (2.1) we can conclude that is proportionally to . Using this we formulate the complexity of Sinkhorn [Kroshnin et al. (2019b); Stonyakin et al. (2019)]

The complexity for Accelerated Sinkhorn can be improved [Guminov et al. (2019)]

(13)

Multiplying both of this estimates by the number of measures (iterations), taking the minimum and using Corollary 3 we get the statement of the theorem. Suppose that after iterations of Sinkhorn algorithm we get approximate dual solutions and . According to [Franklin and Lorenz (1989)] there is a convergence of calculated variables and (see (2.1)) to the true variables and in Hilbert–Birkhoff metric :

(14)

Since all norms are equivalent in finite spaces we can obtain by proper choosing (8). The number of iterations will be proportional to [Franklin and Lorenz (1989)], however, in general theoretical constant before logarithm can be too big to get good theoretical results. But accurate calculations allows to obtain here the result like (14), where can be replaced by for some . This means that in (12) we may consider to be . That is better than direct bound from the definition.

4 Barycenter with respect to regularized OT: Offline (SAA) approach

In this section we suppose that we sample measures in advance. We construct the confidence interval for population barycenter calculating the approximation for the empirical barycneter. Moreover, we also provide the total complexity bounds to do it. This offline setting can be relevant when we are interested in parallelization or decentralization.
We refer to as the approximated empirical barycenter of if it satisfies the following inequality for some precision

(15)

Now suppose that we somehow find . The following theorem estimates the precision for approximation of in function value. For from (15) with probability the following holds

Consider for any the following difference

(16)

From Theorem 6 from [Shalev-Shwartz et al. (2009)] with probability at least for the empirical minimizer the following holds

(17)

Then from (16) and (17) we have

(18)

From Lipschitz continuity of we have

(19)

From strong convexity of we get

(20)

By using (19) and (20) for (17) and taking we get the first inequality of the theorem.
Then using strong convexity of function we formulate the results of convergence in argument. [Barycenter confidence interval] For from (15) with probability at least

The proof consists in application of strong convexity of for Theorem 4. From Theorem 20 we estimate the number of measures and auxialiary precision of fidelity term (15) to get -confidence interval with probability [Number of probability measures and auxiliary precision] To get the -confidence region:

(21)

it suffices to take the following number of probability measures

(22)

and auxiliary -precision in (15):

(23)

Next theorem estimates the complexity for calculating which is an approximation for population barycenter . The total complexity per each node of offline algorithm from [Kroshnin et al. (2019b)] is

(24)

where is the parameter of the architecture:

Moreover, for one node architecture (without parallelization) the complexity can be simplified

(25)

To calculate the total complexity we refer to the Algorithm 6 in the paper [Kroshnin et al. (2019b)] providing . For the readers convenience we repeat the the scheme of the proof. This algorithm relates to the class of Fast Gradient Methods for Lipschitz smooth functions and, consequently, has the complexity [Nesterov (2018)]. Here is the constant for dual function from (2.1) ( from the proof of Proposition 5) and is the radius for dual solution (Lemma 8 from [Kroshnin et al. (2019b)]). Combining all of this we get the following number of iterations

(26)

where we denoted by the parameter of the architecture. Multiplying by the complexity of calculating the gradient for the dual function (which is ) and using Corollary 4 for definition of we get the following complexity per each node

Using Corollary 4 for the number of measures we get the first statement of the theorem. By using for one-machine architecture we get the second statement and finish the proof. From the recent results [Feldman and Vondrák (2019)] we may expect that the dependence on in (22) and (25) is indeed much better (logarithmic). But, unfortunately, as far as we know in general (not small ) it’s still a hypothesis. In the next section we construct confidence interval for population Wasserstein barycenter .

5 Population Wasserstein barycenter problem

In previous sections we were aim at constructing the confidence interval for population barycenter defined w.r.t regularized OT . Now we refuse the regularization of OT and seek to find population barycenter w.r.t OT . To do so, we firstly use the results from Sect. 3 and 4 (SA and SAA algorithms), then we consider another two methods, one of them is based on our new regularization. Since is not strongly convex, in this section we construct -confidence interval in function value.
Thought this section we use the following notation

5.1 SA and SAA

We start with application of the results obtained by SA in Sect. 3. We regularize by entropy with the definite regularization parameter . Here is a desired accuracy in function value [Dvurechensky et al. (2018); Peyré and Cuturi (2018); Weed (2018)]

Then we use the following statement from [Gasnikov et al. (2015); Kroshnin et al. (2019b); Peyré and Cuturi (2018)]

Taking regularization parameter we ensure the following

The last inequality allows us to modify Theorem 3 and get the complexity of calculating as an -approximation for in function value. To ensure for from (9) with probability , it suffices to take the following number of probability measures (iterations) in restarted SGD from [Juditsky and Nesterov (2014)]

and the following auxiliary -precision . The total complexity will be

Similarly we use the results of SAA (Theorem 23) to state the complexity bounds for calculating an approximation for . For simplicity we provide only results for only one-machine architecture. To ensure with probability it suffices to take the following number of probability measures

and the following auxiliary -precision

The total complexity of algorithm form [Kroshnin et al. (2019b)] on one machine (without parallelization/decentralization) is

5.2 Stochastic Mirror Descent

Another approach to construct -confidence interval in function value is refusing any regularization and using stochastic mirror descent333By using Dual Averaging scheme [Nesterov (2009)] we can rewrite Algorithm 2 in online regime without including in the step-size policy. Note, that Mirror Descent and Dual Averaging schemes are very close to each other [Juditsky et al. (2019)]. with 1-norm and KL-prox structure, see, e.g., [Hazan et al. (2016); Nemirovski et al. (2009); Orabona (2019)], for inexact case see, e.g., [Gasnikov et al. (2016); Juditsky and Nemirovski (2012)].

1:Starting point , – number of measures
2:,
3:for  do
4:     
where indices denote the -th (or

-th) component of a vector,

is calculated with -precision (e.g., by Simplex Method or Interior Point Method)
5:
Algorithm 2 Stochastic Mirror Descent

From [Gasnikov et al. (2016); Juditsky and Nemirovski (2012)] we have with probability

(27)

where , see [Dvurechensky et al. (2018); Kroshnin et al. (2019b); Lin et al. (2019)]. Bound (27) is -times better the bound for Stochastic Gradient Descent with Euclidean set up [Nemirovski et al. (2009); Shalev-Shwartz et al. (2009)]. We also notice that the smoothed complexity of finding exact is , see [Dadush and Huiberts (2018)] and references there in. To ensure with probability it suffices to make the following number of iterations of Algorithm 2

The total complexity is

5.3 Empirical Risk Regularization

In offline approach with Euclidean set up one may use regularization trick from [Shalev-Shwartz et al. (2009)].444In the same paper [Shalev-Shwartz et al. (2009)] one can find an explanation why do we need regularization in offline approach for non strongly convex case. The problems of SAA approach for non strongly convex case are also discussed in [Guigues et al. (2017); Shapiro and Nemirovski (2005)]. For more complete understanding see [Shapiro et al. (2009); Sridharan (2012)]. We introduce composite term in r.h.s of equation (2) ( is some initial vector from )

Assume that such that

(28)

Then the main result of Theorem 4 can be rewritten as follows [Shalev-Shwartz et al. (2009)]555Note, that in [Shalev-Shwartz et al. (2009)] instead of it was used simple . For the moment we do not know how to justify this replacement. That is why we write . Fortunately, when is big enough it doesn’t matter. : for from (28) with probability

Consider to be big enough, we choose (like in [Shalev-Shwartz et al. (2009)]) approximately and obtain with probability

(29)

Recently, it was shown [Feldman and Vondrák (2019)] that dependence on in (29) can be improved to logarithmic. Another type of regularization allows to improve (29). We refer to Bregman divergence ([Ben-Tal and Nemirovski (2015)])

, , .

We notice that is 1-strongly convex in 1-norm and -Lipschitz continuous in 1-norm on . In [Bigot et al. (2017)] there proposed to use entropy as regularizer. For entropy we have the same strong convexity properties in 1-norm on , but we loose limitation from above on a Lipschitz constant. Using this we redefine as follows666Note, that to solve (30) we may use the same dual distributed tricks like in [Kroshnin et al. (2019a)] if we put composite term in a separate node. But before, we should regularized with . The complexity in terms of will be the same as in Theorem 23.

(30)

Assume that such that

(31)

Consider to be big enough, we choose approximately , where , and obtain for from (31) with probability

(32)

To prove this estimate we use the same arguments as in [Shalev-Shwartz et al. (2009)], but replace strong convexity and smoothness from 2-norm to 1-norm. Since we may conclude that (32) is -times better than (29). Note, that if we choose then (32) can be written as follows:

(33)

This fact can be easily extract from [Shalev-Shwartz et al. (2009)], see formula (21) of this work in the proof of Claim 6.2. The same thing one can say about (29). We summarize the result in the next theorem. To ensure with probability we need to take the following number of probability measures

and find satisfy (31) with

The total complexity of properly corrected algorithm form [Kroshnin et al. (2019b)] on one machine (without parallelization/decentralization) is

From the recent results [Feldman and Vondrák (2019)] we may expect that the dependence on in Theorems 5.15.3 is indeed much better (logarithmic). For the moment we don’t possess an accurate prove of it, but we suspect that original ideas in [Feldman and Vondrák (2019)] allow to prove it.

6 Comparison

In this section we compare approaches from Sections 3, 4, 5. For the reader convenience we skip the details about high probability bounds. The first reason is we can fixed , say as , and consider it to be a fixed parameter in all the bounds. The second reason is an intuition (goes back to [Shalev-Shwartz et al. (2009)]) that all the bounds in this paper have logarithmic dependence on in fact and up to a denotation we can ignore the dependence on . The main result is proving the possible superiority of SAA under SA for population (-entropy regularized) Wasserstein barycenter estimation even in non-parallel case (and non-decentralized). For this purpose we provide Table 1, where we used (7), to compare the total complexity of the algorithms. Here -precision is the precision in argument.

Complexity
SA
SAA
Table 1: Complexity bounds for constructed confidence interal for population barycenter w.r.t. regularized OT. is the precision in argument.

When is not too large, SA has the complexity according to the second term under the minimum. In this case we have obvious advantage of SAA since its complexity about in times less that SA complexity. Next, we compare the results complexity with proper regularization of OT (with definite or ). Table 2 presents the results. Here is the precision in function value.

Complexity
SA with
SAA with
Stochastic MD
Regularized ERM
Table 2: Complexity bounds for constructed confidence interval for population barycenter w.r.t. OT. is the precision in function value.

We do not make any conclusions about comparison of Stochastic MD and Regularized ERM since it depends on comparison and . However, both of this methods are definitely outperform (according to complexity results) SA and SAA approaches based on entropy regularization of OT. The conclusions of advantages SAA approach vs SA approach can be reinforced by using parallelization or distributed calculations. For that we can use estimate (24) instead of (25). The same we can say about the following formulas that we used in Section 5. We are grateful to Alexander Gasnikov and Vladimir Spokoiny who initiated this research. We thank Pavel Dvurechensky, Eduard Gorbunov for fruitful discussions as well. We thank Ohad Shamir for useful reference. The work in the first part was funded by RFBR, project number 19-31-51001. In the second part (from section 5) the work was funded by Russian Science Foundation, project no. 18-71-10108.

References