Implementation of batched Sinkhorn iterations for entropy-regularized Wasserstein loss

07/01/2019 ∙ by Thomas Viehmann, et al. ∙ 0

In this report, we review the calculation of entropy-regularised Wasserstein loss introduced by Cuturi and document a practical implementation in PyTorch.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 A brief review of the Wasserstein distance and its entropy regularisation

We review the regularised Wasserstein distance and focus on probability distributions on finite sets. We largely follow [2].

For positive integers and , consider two probability measures and on the set of and points, i.e. and .

A coupling of and is a probability measure with marginals and . Intuitively, a coupling can be interpreted as a mapping of the probability mass of to that of .

We introduce a cost on the set of couplings by means of a -matrix . Then a coupling between and is optimal for if it is a minimiser of and . By compactness of the admissible set, such a minimiser always exists, but in general it is not unique.

Cuturi [2] proposed to consider the regularised functional

where is the entropy .

As [2] notes, the minimisation problem is closely related to the problem of minimising the original functional on a restricted set with entropy : If a minimiser of has entropy , then it is a minimiser of the original functional on : Another coupling has , where we use the that is an -minimiser and the admissibility condition for .

Note that a probability distribution becomes more “regular” with increasing entropy, motivating the lower bound and negative sign for the entropy term.

As the entropy is strictly concave and the cost term is linear in , the functional is convex and has a unique minimum on the admissible set .

To characterise the minimum we introduce the Lagrange multipliers and to capture the equality constraints in the definition of and have the augmented functional

and minimax problem

We write the Euler-Lagrange equations


Solving for we get


for appropriate and . We absorb , and the constant by introducing and . We also write

With this equation (2) becomes

where is the element wise exponential function and

the diagonal embedding operator mapping vectors to diagonal matrices. Plugging this representation into the marginal constraints

and and we get the coordinate-wise equations

and solving for , we have

This makes is natural to set up the celebrated Sinkhorn-Knopp iteration


This iteration alternatingly enforces each of the two marginal constraints. An important algorithmic question that we skip here is the convergence of this fixed-point iteration, see e.g. [5].

2 Derivative

Recall that the Lagrange multiplier gives the derivative of with respect to the constraint. However, as s and are themselves probability measures, they are constrained themselves: To preserve their summing to , the allowed variations are only those of mean .

Thus, to compute a meaningful gradient with respect to the input manifold, we have to project the full gradient by subtracting the mean and get

Similarly, projected on the mean-zero vectors is the gradient with respect to . This way of obtaining the gradient has been proposed by [6].

3 Batch stabilisation

For stabilisation we rewrite the iteration (3) in log-space as


with the log-sum-exp operator that can stably be implemented by extracting the maximum before exponentiation.

When implementing, we would set and to if or is , respectively.

Schmitzer [11] proposes to avoid exponentiation and logarithms by splitting and and and and only occasionally absorbing parts into the kernel. While this works well for single , , , it means that during the iteration, a varying is used. This does not lend itself to batch computation when one wants to avoid keeping multiple around. In our experience, the speed of GPGPU computations of one step in the iteration tends to be limited by the memory accesses more than the computation.

In [12], we provided a batch stabilised version that took the maximum of and , but does not have the full stabilisation of the log iteration (4).

4 Implementation for GPGPU

We consider to implement a GPGPU kernel for batches of measures , and a single distance matrix.

When using a batch iteration, we need to implement

This has two key properties that shape our implementation as an extension to the PyTorch deep learning framework.

  • The overall reduction structure is akin to a matrix multiplication, i.e. memory accesses to and to compute the result , with the additional input following the same access pattern as the result. We parallelize in the independent dimensions ( and ) and split the reduction over amongst multiple threads then combine their intermediate results. We have not employed tiling, which is commonly used to speed up the memory accesses for matrix multiplication.

  • In our implementation, the stabilisation of the

    -calculation is carried out in an online fashion, i.e. computing the stabilisation and the reduction result in a single pass, similar to the Welford algorithm for the variance.

Incidentally, the logarithm of the energy of a minimiser has a very similar structure, with



being the elementwise product, so we can stably compute this loss function with the same GPU kernel as the iteration step.

Here we adopted the point of view that the and

can be explicitly computed, but that we would prefer not to realise tensors in

. For applications with the cost function based on e.g. Euclidean metrics, one might, instead, trade compute for memory and re-create entries of the distance matrix as they are needed.

5 Practical application in Stochastic Gradient Descent algorithms

Our goal is to enable the use of the (entropy-regularised) Wasserstein loss for deep learning applications.

As is commonly done, we return as calculated in (5) for the (approximative) minimiser of as the value of our loss function. The gradient as computed in Section 2 is that of . Note that the gradient is off for two reasons: First, we use the Lagrange mutliplier for as the gradient for , i.e. we optimize but measure . This seems to work reasonably well for many applications and small , but for cases when it does not [8] offer an improved gradient. The second source of error ist that the iteration might not have fully converged. Empirically, however, it seems that if we iterate often enough ( iterations), the gradient is sufficiently good to pass PyTorch’s gradcheck tests (we used distribution vectors of length and ).

Compared to existing code and libraries, our code combines a stable, memory-efficient logspace implementation that works for batches and uses the Lagrange-Multiplier-based gradient. In [4, 9.1.3] advocate the use of automatic differentiation, in the authors’ words: In challenging situations in which the size and the quantity of histograms to be compared are large, the computational budget to compute a single Wasserstein distance is usually limited, therefore allowing only for a few Sinkhorn iterations. When the histograms do fit the GPU, however, the method of Section 4

seems to achieve a significant speedup over existing implementations, so that in many cases a few tens or even a few hundred iterations seem possible in reasonable time. Also, by not needing to store intermediate results, as relying on autograd implementations of frameworks such as PyTorch, it seems much more memory-efficient to use the Lagrange multiplier. Also, saving the computational cost of backpropagation, which is roughly equivalent to that of the forward pass, allows the number of iterations to be doubled within the same computational budget. As such we disagree with the assessment in

[4], which is also cited in a recent blog entry [3] with implementation.

The latter is particularly important for memory-efficiency because backpropagation though the iteration typically stores intermediate results for each step to facilitate backward computation. This is particularly important because GPU memory is typically an even scarcer resource than computation time in depp learning applications. In our measurement, we achieve a total speedup in forward and backward of 6.5x over [3]’s implementation for distributions with mass at 100 points each, even though our choice to not use early stopping causes us compute 3x as many iterations. A significant part of the advantage is that our backward comes at almost negligible computational cost, the remainder from the efficient computational implementation.