Log In Sign Up

Learning to solve inverse problems using Wasserstein loss

We propose using the Wasserstein loss for training in inverse problems. In particular, we consider a learned primal-dual reconstruction scheme for ill-posed inverse problems using the Wasserstein distance as loss function in the learning. This is motivated by miss-alignments in training data, which when using standard mean squared error loss could severely degrade reconstruction quality. We prove that training with the Wasserstein loss gives a reconstruction operator that correctly compensates for miss-alignments in certain cases, whereas training with the mean squared error gives a smeared reconstruction. Moreover, we demonstrate these effects by training a reconstruction algorithm using both mean squared error and optimal transport loss for a problem in computerized tomography.


page 6

page 7


How many samples are needed to reliably approximate the best linear estimator for a linear inverse problem?

The linear minimum mean squared error (LMMSE) estimator is the best line...

Learning the optimal regularizer for inverse problems

In this work, we consider the linear inverse problem y=Ax+ϵ, where A X→ ...

Spatial Frequency Loss for Learning Convolutional Autoencoders

This paper presents a learning method for convolutional autoencoders (CA...

Regularized Variational Data Assimilation for Bias Treatment using the Wasserstein Metric

This paper presents a new variational data assimilation (VDA) approach f...

Bilevel learning of l1-regularizers with closed-form gradients(BLORC)

We present a method for supervised learning of sparsity-promoting regula...

Minimax Estimation of Conditional Moment Models

We develop an approach for estimating models described via conditional m...

Residual Entropy

We describe an approach to improving model fitting and model generalizat...

1 Introduction

In inverse problems the goal is to determine model parameters from indirect noisy observations. Example of such problems arise in many different fields in science and engineering, e.g., in X-ray CT natterer2001mathematical , electron tomography oktem2015mathematics , and magnetic resonance imaging brown2014magnetic

. Machine learning has recently also been applied in this area, especially in imaging applications. Using supervised machine-learning to solve inverse problems in imaging requires training data where ground truth images are paired with corresponding noisy indirect observations. The learning provides a mapping that associates observations to corresponding images. However, in several applications there are difficulties in obtaining the ground truth, e.g., in many cases it may have undergone a distortion. For example, a recent study showed that MRI images may be distorted by up to 4 mm due to, e.g., inhomogeneities in the main magnetic field

Walker2014 . If these images are used for training, the learned MRI reconstruction will suffer in quality. Similar geometric inaccuracies arise in several other imaging modalities, such as Cone Beam CT and full waveform inversion in seismic imaging.

This work seeks to provide a scheme for learning a reconstruction scheme for an ill-posed inverse problem with a Wasserstein loss by leveraging upon recent advances in efficient solutions of optimal transport cuturi2013sinkhorn ; karlsson2016generalized and learned iterative schemes for inverse problems adler2017learned . The proposed method is demonstrated on a computed tomography example, where we show a significant improvement compared to training the same network using mean squared error loss. In particular, using the Wasserstein loss instead of standard mean squared error gives a result that is more robust against potential miss-alignment in training data.

2 Background

2.1 Inverse problems

In inverse problems the goal is to reconstruct an estimate of the signal

from noisy indirect measurements (data) assuming


In the above and are referred to as the reconstruction and data space, respectively. Both are typically Hilbert or Banach spaces. Moreover denotes the forward operator, which models how a given signal gives rise to data in absence of noise. Finally, is the noise component of data. Many inverse problems of interest are ill-posed, meaning that there is no uniques solution to (1) and hence there is no inverse to . Typically reconstructions of are sensitive to the data and small errors gets amplified. One way to mitigate these effects is to use regularization engl2000regularization .

Variational regularization

In variational regularization one formulates the reconstruction problem as an optimization problem. To this end, one introduces a data discrepancy functional , where , that quantifies the miss-fit in data space, and a regularization functional that encodes a priori information about by penalizing undesirable solutions. For a given , this gives an optimization problem of the form


Here, acts as a trade-off parameter between the data discrepancy and regularization functional. In many cases is taken to be the negative data log-likelihood, e.g., in the case of additive white Gaussian noise. Moreover, a typical choice for regularization functional is total variation (TV) regularization, rudin1992nonlinear . These regularizers typically give rise to large scale and non-differentiable optimization problems, which requires advanced optimization algorithms.

Learning for inverse problems

In many applications, and so also in inverse problems, data driven approaches have shown dramatic improvements over the state-of-the-art LeBeHi15

. Using supervised learning to solve an inverse problem amounts to finding a parametrized operator

where the parameters are selected so that

For inverse problems in image processing, such as denoising and deblurring, we have

and it is possible to apply a wide range of widely studied machine learning techniques, such as fully convolutional deep neural networks with various architectures, including fully convolutional networks

NIPS2008_3506 and denoising auto-encoders NIPS2012_4686 .

However, in more complicated inverse problems as in tomography, the data and reconstruction spaces are very different, e.g., their dimension after discretization may differ. For this reason, learning a mapping from to becomes nontrivial, and classical architectures that map, e.g., images to images using convolutional networks cannot be applied as-is. One solution is to use fully-connected layers as in PaGiKaLoMa04 for very small scale tomographic reconstruction problems. A major disadvantage with such a fully learned approach is that the parameters space has to be very high dimensional in order to be able to learn both the prior and the data model, which often renders it infeasible due to training time and lack of training data.

A more successful approach is to first apply some crude reconstruction operator and then use machine learning to post process the result. This separates the learning from the complications of mapping between spaces since the operator can be applied off-line, prior to training. Such an approach has been demonstrated for tomographic reconstruction in PeBa13 ; WuGhChMa16 . Its drawback for ill posed inverse problems is that information is typically lost by using , and this information cannot be recovered by post processing.

Finally, another approach is to incorporate the forward operator and its adjoint into the neural network. In these learned iterative schemes, classical neural networks are interlaced with applications of the forward and backward operator, thus allowing for the learned reconstruction operator to work directly from data without having to learn the data model. For example, in YaSuLiXu16 an ADMM-like scheme for Fourier inversion is learned and PuWe17 consider solving inverse problems typically arising in image restoration by a learned gradient-descent scheme. In adler2017solving this later approach is shown to be applicable to large scale tomographic inversion. Finally, in adler2017learned they apply learning in both spaces and , yielding a Learned Primal-Dual scheme, and show that it outperforms learned post-processing for reconstruction of medical CT images.

Loss functions for learning

Once the parametrization of is set, the parameters are typically chosen by minimization of some loss functional . Without doubt, the most common loss function is the mean squared error, also called loss, given by


It has however been noted that it is sub-optimal for imaging, and a range of other loss functions have been investigated. These include the classical norms and the structural similarity index (SSIM) Zhao2017 , as well as more complex losses such as perceptual losses Johnson2016 and adversarial networks DBLP:journals/corr/MardaniGCVZATHD17 .

Recently, optimal mass transport has also been considered as loss function for classification frogner2015learning and generative models arjovsky2017wasserstein . In this work we consider using optimal transport for training a reconstruction scheme for ill-posed inverse problems.

2.2 Optimal mass transport and Sinkhorn iterations

In optimal mass transport the aim is to transform one distribution into another by moving the mass in a way that minimizes the cost of the movement. For an introduction and overview of the topic, see, e.g., the monograph villani2008optimal . Lately, the area has attracted a lot of research cuturi2013sinkhorn ; cuturi2015asmoothed ; chizat2015unbalanced with applications to, e.g., signal processing haker2004optimal ; georgiou2009metrics ; jiang2012geometric ; engquist2014application and inverse problems benamou2015iterative ; karlsson2016generalized .

The optimal mass transport problem can be formulated as follows: let be a compact set, and let and be two measures, defined on , with the same total mass. Given a cost that describes the cost for transporting a unit mass from one point to another, find a (mass preserving) transference plan that is as cheap as possible. Here, the transference plan characterizes how to move the mass of in order to deform it into . Letting the transference plan be a nonnegative measure on the space

yields a linear programming problem in the space of measures:

subject to

Although this formulation is only defined for measures and with the same total mass, it can also be extended to handle measures with unbalanced masses georgiou2009metrics ; chizat2015unbalanced . Moreover, under suitable conditions one can define the Wasserstein metrics using , by taking for and where is a metric on , and (villani2008optimal, , Definition 6.1). As the name indicates, is a metric on the set of nonnegative measures on with fixed mass (villani2008optimal, , Theorem 6.9), and is weak continuous on this set. One important property is that (and thus also ) does not only compare objects point by point, as standard metrics, but instead quantifies how the mass is moved. This makes optimal transport natural for quantifying uncertainty and modelling deformations jiang2012geometric ; karlsson2013uncertainty .

One way to solve the optimal transport problem in applications is to discretize

and solve the corresponding finite-dimensional linear programming problem. In this setting the two measures are represented by point masses on the discretization grid, i.e., by two vectors

where the element corresponds to the mass in the point for and . Moreover, a transference plan is represented by a matrix where the value denotes the amount of mass transported from point to . The associated cost of a transference plan is , where is the transportation cost from to , and by discretizing the constraints we get that is a feasible transference plan from to if the row sums of is and the column sums of is . The discrete version of (4) thus takes the form

subject to

where denotes element-wise non-negativity of the matrix. However, even though (5) is a linear programming problem it is in many cases computationally infeasible due to the vast number of variables. Since the number of variables is , and thus if one seek to solve the optimal transport problem between two images this results in more than variables.

One approach for addressing this problem was proposed by Cuturi cuturi2013sinkhorn that introduces an entropic regularizing term for approximating the transference plan, so the resulting perturbed optimal transport problem reads as

subject to

One can show that an optimal solution to (6) is of the form


where (point-wise exponential) is known, and are unknown. This shows that the solution is parameterized by only variables. Moreover, the two vectors can be computed iteratively by so called Sinkhorn iterations, i.e., alternatingly compute and that matches and respectively. This is summarizied in Algorithm 1 where denotes elementwise multiplication and elementwise division. The procedure has been shown to have a linear convergence rate, see cuturi2013sinkhorn and references therein.

Moreover, when the underlying cost is translation invariant the discretized cost matrix , and thus also the transformation , gets a Toeplitz-block-Toeplitz structure. This structure can be used in order to compute and

efficiently using the fast Fourier transform in

, instead of naive matrix-vector multiplication in karlsson2016generalized . This is crucial for applications in imaging since for images of size pixels one would have to explicitly store and multiply with matrices of size .

2:initialize and
3:for  do
Algorithm 1 Sinkhorn iterations for computing entropy-regularized optimal transport cuturi2013sinkhorn

3 Learning a reconstruction operator using Wasserstein loss

In this work we propose to use entropy regularized optimal transport (6) to train a reconstruction operator, i.e., to select the parameters as


This should give better results when data is not aligned with the ground truth . To see this, consider the case when is a point mass. In that case training the network with the loss (3) will (in the ideal case) result in a perfect reconstruction composed with a convolution that “smears” the reconstruction over the area of possible miss-alignment. On the other hand since optimal mass transport does not only compare objects point-wise, the network will (in the ideal case) learn a perfect reconstruction combined with a movement of the object to the corresponding barycenter (centroid) of the miss-alignment. These statements are made more precise in the following propositions. Formal definitions and proofs are deferred to the appendix.

Proposition 1.

Let , let be a

-valued random variable with probability measure

, and let . Then there exists a function that minimizes , and this has the form

Proposition 2.

Let be the Dirac delta function on , let be a -valued random variable with probability measure , and let . Then

and otherwise. Furthermore, finding a that minimizes is equivalent to finding the global minimizers to . In particular, if (i) the probability measure is symmetric around its mean, (ii) the underlying cost is of the form , where is convex and symmetric, and (iii) and are such that

both exist and are finite for all , then is an optimal solution. Furthermore, if is also strictly convex, then this is the unique minimizer.

To illustrate Propositions 1 and  2 we consider the following example.

Example 3.


be uniformly distributed on

, and let . This gives

which has minimum , and hence the (unique) minimizer to is . For the case with the uniform distribution, the minimizer of is the smoothed function .

The most common choice of distance is to use the squared norm , as in the previous example. In this case the result of Proposition 2 can be strengthened, as shown in the following example.

Example 4.

Let be a -valued random variable with probability measure

with finite first and second moments, and let

. This gives

which has a unique global minimum in and hence .

4 Implementation and evaluation

We use the recently proposed learned primal-dual structure in adler2017learned for learning a reconstruction operator for solving the inverse problem in (1). In this algorithm, a sequence of small blocks work alternatingly in the data (dual) space and the reconstruction (primal) space and are connected using the forward operator and its adjoint . The algorithm works with any differentiable operator , but we state the version for linear operators for simplicity in algorithm 2.

2:for  do
Algorithm 2 Learned Primal-Dual reconstruction algorithm
Figure 1: Network architecture used to solve the inverse problem. Dual and primal iterates are in blue and red boxes, respectively. Several arrows pointing to the same box indicates concatenation. The initial values enter from the left, while the data is supplied to the dual iterates.

The method was implemented using ODL adler2017ODL , ASTRA vanaerle2016fast

, and TensorFlow

abadi2016tensorflow . We used the reference implementation111 with default parameters, i.e., the number of blocks in the primal and dual space was , and the number of primal and dual variables was set to . Moreover, the blocks used a residual structure and had three layers of convolutions with filters. PReLU nonlinearities were used. Thus, this corresponds to a residual CNN with convolutional depth of , as shown in graphical format in fig. 1. We used zero initial values, .

We compare a learned reconstruction operator of this form when trained using loss (3) and using optimal transport loss (8). Moreover, the evaluation is done on a problem similar to the evaluation problem in adler2017learned ; adler2017solving , i.e., on a problem in computed tomography. More specifically, training is done on data that consists of randomly generated circles on a domain of 512 ×512, and the forward operator is the ray transform natterer2001mathematical . What makes this an ill-posed problem is that the data acquisition is done from only 30 views with 727 parallel lines. Moreover, the source of noise is two-fold in this set-up: (i) the pairs of data sets and phantoms are not aligned, meaning that the data is computed from a phantom with a random change in position. This random change is independent for the different circles, and for each circle it is a shift which is uniformly distributed over pixels, both in up-down and left-right direction. (ii) on the data computed from the shifted phantom, 5% additive Gaussian noise was added. For an example, see fig. 2.

(a) Phantom

(b) Translated phantom

(c) Data
Figure 2: Example of data generation process used for training and validation, where 1(a) shows an example phantom, 1(b) is the phantom with a random translation and 1(c) is the data (sinogram) corresponding to 1(b)

with additive white noise on top. The pair

= (1(c), 1(a)) is what is used in the training.

The optimal mass transport distance computed with Sinkhorn iterations was used as loss function, where we used the transport cost

This was chosen since it heavily penalizes large movements, while not diverging to infinity which causes numerical instabilities. Moreover, is in fact a metric on (see lemma 6 in the appendix) and thus gives rise to a Wasserstein metric on the space of images, where is the optimal mass transport distance with the transport cost . Since this cost is translation invariant, the matrix-vector multiplications and can be done with fast Fourier transform, as mentioned in section 2.2, and this was implemented in Tensorflow. We used 10 Sinkhorn iterations with entropy regularization , to approximate the optimal mass transport. Automatic differentiation was used to back-propagate the result during training.

Since the optimal mass transport function (6) is only finite for marginals and with the same total mass, in the training we normalize the output of the reconstruction with . This makes invariant with respect to the total mass, which is undesirable. To compensate for this, a small penalization on the error in total mass was added to the loss function.

The training also followed adler2017learned closely. In particular, we used batches of size , using the ADAM optimizer kingma2014adam with default values except for . The learning rate (step length) used was cosine annealing loshchilov2016sgdr with initial step length . Moreover, in order to improve training stability we performed gradient norm clipping pascanu2012understanding with norms limited to 1. The convolution parameters were initialized using Xavier initialization glorot2010understanding , and all biases were initialized to zero. The training took approximately 3 hours using a single Titan X GPU. The source code used to replicate these experiments are available online 222

Results are presented in fig. 3. As can be seen, the reconstruction using loss “smears” the reconstruction to an extent where the shape is impossible to recover. On the other hand, the reconstruction using the Wasserstein loss retains the over-all global shape of the object, although relative and exact positions of the circles are not recovered.

(a) Phantom
(b) Translated phantom
(c) Mean squared error loss
(d) Optimal transport loss
Figure 3: In 2(a) we show the validation phantom, which was generated from the same training set but not used in training, in 2(b) the translated phantom from which the validation data was computed, in 2(c) a reconstruction with neural network trained using mean squared error loss (3), and in 2(d) a reconstruction with neural network trained using optimal mass transport loss (8).

5 Conclusions and future work

In this work we have considered using Wasserstein loss to train a neural network for solving ill-posed inverse problems in imaging where data is not aligned with the ground truth. We give a theoretical motivation for why this should give better results compared to standard mean squared error loss, and demonstrate it on a problem in computed tomography. In the future, we hope that this method can be applied to other inverse problems and to other problems in imaging such as segmentation.

Appendix: Deferred definition and proofs

Proof of Proposition 1.

To show that minimizes we expand the expression and use Fubini’s theorem to get

Rearranging terms and using that , this can be written as

where is a constant. Using this it follows that the minimizing is of the form

To see that we note that, by using Fubini’s theorem, we have

where the first inequality is the arithmetic-geometric mean inequality. This completes the proof. ∎

Definition 5.

Let . A subgradient to in a point is a vector so that

The set of all subgradients in a point is called the subdifferential of at , and is denoted by . This is a set-valued operator, and for any measure on we define to be the set-valued operator

Proof of Proposition 2.

We consider finding the marginal that minimize . Without loss of generality we assume that is zero-mean, since otherwise we simply consider which is a zero-mean random variable. First we note that is only finite for nonegative measures with total mass , and hence is only finite for such measures. Second, for such a we have

since one needs to transport all mass in into the point where has its mass. Using this and expanding the expression for the expectation gives that

where we have used Fubini’s theorem in the last step. This completes the first half of the statement.

To prove the second half of the statement, note that the optimal have support only in the global minimas of the function

which by assumption exists and is finite. Now, since is convex we have that


and convolving this inequality with gives the inequality


where all terms exist and are bounded by assumption. This shows that

Now, since is symmetric we have that is anti-symmetric, i.e., that , since


where the last inclusion follows since is symmetric and is anti-symmetric. Now, since we have that is a global minimizer to (bauschke2011convex, , Theorem 16.2), and thus one optimal solution to the problem is . Now, if is strictly convex, the inequality (9) is strict for , and thus (10) is also strict, which shows that the optimal solution is unique. ∎

Lemma 6.

Let be a norm on . Then

is a metric on for .


It is easily seen that is symmetric, nonnegative, and equal to zero if only if . Thus we only need to verify that the triangle inequality holds. To this end we note that if


for all , then by taking , , and using the triangle inequality for the norm we have that

Therefore we will show that (11) holds for all , and to do so we will

  1. show that if a function fulfills , , for all , then ,

  2. show that for the map fulfills the assumptions in (i) for any .

To show (i) we note that

where the inequality uses that for any since for all .

To show (ii), let and observe that . Differentiating twice gives

For we see that for all . Moreover, for we see that for all and for all if and only if . With the change of variable , we thus want to show that for all and all . To see this we note that and that

This shows (ii), and hence completes the proof. ∎