    # Implementation of batched Sinkhorn iterations for entropy-regularized Wasserstein loss

In this report, we review the calculation of entropy-regularised Wasserstein loss introduced by Cuturi and document a practical implementation in PyTorch.

## Authors

##### This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

## 1 A brief review of the Wasserstein distance and its entropy regularisation

We review the regularised Wasserstein distance and focus on probability distributions on finite sets. We largely follow .

For positive integers and , consider two probability measures and on the set of and points, i.e. and .

A coupling of and is a probability measure with marginals and . Intuitively, a coupling can be interpreted as a mapping of the probability mass of to that of .

We introduce a cost on the set of couplings by means of a -matrix . Then a coupling between and is optimal for if it is a minimiser of and . By compactness of the admissible set, such a minimiser always exists, but in general it is not unique.

Cuturi  proposed to consider the regularised functional

 Eλ(P)=∑ijpijcij−λh(P),

where is the entropy .

As  notes, the minimisation problem is closely related to the problem of minimising the original functional on a restricted set with entropy : If a minimiser of has entropy , then it is a minimiser of the original functional on : Another coupling has , where we use the that is an -minimiser and the admissibility condition for .

Note that a probability distribution becomes more “regular” with increasing entropy, motivating the lower bound and negative sign for the entropy term.

As the entropy is strictly concave and the cost term is linear in , the functional is convex and has a unique minimum on the admissible set .

To characterise the minimum we introduce the Lagrange multipliers and to capture the equality constraints in the definition of and have the augmented functional

 Eλμ,ν(P,α,β)=Eλ(P)+∑iαi(∑jpij−μi)+∑jβj(∑ipij−νj)

and minimax problem

 P∗=argminP∈Usupα∈Rd1,β∈Rd2Eλμ,ν(P,α,β)

We write the Euler-Lagrange equations

 0=∂∂pijEλμ,ν(P,α,β)=cij−λ−λlogpij+αi+βj. (1)

Solving for we get

 pij=exp(−1−1λαi−1λβj−1λcij) (2)

for appropriate and . We absorb , and the constant by introducing and . We also write

With this equation (2) becomes

 P=diag(u)Kdiag(v).

where is the element wise exponential function and

the diagonal embedding operator mapping vectors to diagonal matrices. Plugging this representation into the marginal constraints

and and we get the coordinate-wise equations

 μi=ui(Kv)i,νj=vj(KTu)j,

and solving for , we have

 μi/(Kv)i=ui,νj/(KTu)j=vj.

This makes is natural to set up the celebrated Sinkhorn-Knopp iteration

 v(k+1)j:=νj/(KTu(k))j,u(k+1)i:=μi/(Kv(k+1))i (3)

This iteration alternatingly enforces each of the two marginal constraints. An important algorithmic question that we skip here is the convergence of this fixed-point iteration, see e.g. .

## 2 Derivative

Recall that the Lagrange multiplier gives the derivative of with respect to the constraint. However, as s and are themselves probability measures, they are constrained themselves: To preserve their summing to , the allowed variations are only those of mean .

Thus, to compute a meaningful gradient with respect to the input manifold, we have to project the full gradient by subtracting the mean and get

 ∇μ(Eλ(P∗(μ,ν))=α−1d1(∑αi)1d1.

Similarly, projected on the mean-zero vectors is the gradient with respect to . This way of obtaining the gradient has been proposed by .

## 3 Batch stabilisation

For stabilisation we rewrite the iteration (3) in log-space as

 logvk+1j=logνj−logsumexpi(1λcij−loguki),loguk+1i=logμi−logsumexpj(1λcij−logvk+1j) (4)

with the log-sum-exp operator that can stably be implemented by extracting the maximum before exponentiation.

When implementing, we would set and to if or is , respectively.

Schmitzer  proposes to avoid exponentiation and logarithms by splitting and and and and only occasionally absorbing parts into the kernel. While this works well for single , , , it means that during the iteration, a varying is used. This does not lend itself to batch computation when one wants to avoid keeping multiple around. In our experience, the speed of GPGPU computations of one step in the iteration tends to be limited by the memory accesses more than the computation.

In , we provided a batch stabilised version that took the maximum of and , but does not have the full stabilisation of the log iteration (4).

## 4 Implementation for GPGPU

We consider to implement a GPGPU kernel for batches of measures , and a single distance matrix.

When using a batch iteration, we need to implement

 logvbj:=logνbj−logsumexpi(−1λcij+logubi).

This has two key properties that shape our implementation as an extension to the PyTorch deep learning framework.

• The overall reduction structure is akin to a matrix multiplication, i.e. memory accesses to and to compute the result , with the additional input following the same access pattern as the result. We parallelize in the independent dimensions ( and ) and split the reduction over amongst multiple threads then combine their intermediate results. We have not employed tiling, which is commonly used to speed up the memory accesses for matrix multiplication.

• In our implementation, the stabilisation of the

-calculation is carried out in an online fashion, i.e. computing the stabilisation and the reduction result in a single pass, similar to the Welford algorithm for the variance.

Incidentally, the logarithm of the energy of a minimiser has a very similar structure, with

 E0(P)=∑ijpijcij=vT(K⊙P)u=exp(logsumexpjlogvjlogsumexpi(−1λcij+logcij+logui)), (5)

with

being the elementwise product, so we can stably compute this loss function with the same GPU kernel as the iteration step.

Here we adopted the point of view that the and

can be explicitly computed, but that we would prefer not to realise tensors in

. For applications with the cost function based on e.g. Euclidean metrics, one might, instead, trade compute for memory and re-create entries of the distance matrix as they are needed.

## 5 Practical application in Stochastic Gradient Descent algorithms

Our goal is to enable the use of the (entropy-regularised) Wasserstein loss for deep learning applications.

As is commonly done, we return as calculated in (5) for the (approximative) minimiser of as the value of our loss function. The gradient as computed in Section 2 is that of . Note that the gradient is off for two reasons: First, we use the Lagrange mutliplier for as the gradient for , i.e. we optimize but measure . This seems to work reasonably well for many applications and small , but for cases when it does not  offer an improved gradient. The second source of error ist that the iteration might not have fully converged. Empirically, however, it seems that if we iterate often enough ( iterations), the gradient is sufficiently good to pass PyTorch’s gradcheck tests (we used distribution vectors of length and ).

Compared to existing code and libraries, our code combines a stable, memory-efficient logspace implementation that works for batches and uses the Lagrange-Multiplier-based gradient. In [4, 9.1.3] advocate the use of automatic differentiation, in the authors’ words: In challenging situations in which the size and the quantity of histograms to be compared are large, the computational budget to compute a single Wasserstein distance is usually limited, therefore allowing only for a few Sinkhorn iterations. When the histograms do fit the GPU, however, the method of Section 4

seems to achieve a significant speedup over existing implementations, so that in many cases a few tens or even a few hundred iterations seem possible in reasonable time. Also, by not needing to store intermediate results, as relying on autograd implementations of frameworks such as PyTorch, it seems much more memory-efficient to use the Lagrange multiplier. Also, saving the computational cost of backpropagation, which is roughly equivalent to that of the forward pass, allows the number of iterations to be doubled within the same computational budget. As such we disagree with the assessment in

, which is also cited in a recent blog entry  with implementation.

The latter is particularly important for memory-efficiency because backpropagation though the iteration typically stores intermediate results for each step to facilitate backward computation. This is particularly important because GPU memory is typically an even scarcer resource than computation time in depp learning applications. In our measurement, we achieve a total speedup in forward and backward of 6.5x over ’s implementation for distributions with mass at 100 points each, even though our choice to not use early stopping causes us compute 3x as many iterations. A significant part of the advantage is that our backward comes at almost negligible computational cost, the remainder from the efficient computational implementation.