Adaptive Gradient Balancing for UndersampledMRI Reconstruction and Image-to-Image Translation

04/05/2021 ∙ by Itzik Malkiel, et al. ∙ 23

Recent accelerated MRI reconstruction models have used Deep Neural Networks (DNNs) to reconstruct relatively high-quality images from highly undersampled k-space data, enabling much faster MRI scanning. However, these techniques sometimes struggle to reconstruct sharp images that preserve fine detail while maintaining a natural appearance. In this work, we enhance the image quality by using a Conditional Wasserstein Generative Adversarial Network combined with a novel Adaptive Gradient Balancing (AGB) technique that automates the process of combining the adversarial and pixel-wise terms and streamlines hyperparameter tuning. In addition, we introduce a Densely Connected Iterative Network, which is an undersampled MRI reconstruction network that utilizes dense connections. In MRI, our method minimizes artifacts, while maintaining a high-quality reconstruction that produces sharper images than other techniques. To demonstrate the general nature of our method, it is further evaluated on a battery of image-to-image translation experiments, demonstrating an ability to recover from sub-optimal weighting in multi-term adversarial training.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 2

page 3

page 5

page 6

page 7

page 9

page 10

page 12

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Magnetic resonance imaging (MRI) data acquisition is inherently slow, often exceeding 30 min. per exam. One way to accelerate MR scanning is by undersampling -space, i.e., reducing the number of -space traversals by a factor , accelerating the scan proportionately. However, this violates the Nyquist criterion, resulting in aliasing artifacts in the zero-filled reconstructed image (Fig. 1). Over the last several decades, a number of methods have been used to overcome this problem, including Parallel Imaging (PI) [1sodickson1997simultaneous, 2pruessmann1999sense, 3griswold2002generalized] and Compressed Sensing (CS) [5lustig2007sparse].

Fig. 1: Fully-sampled -space multiplied by an acquisition sampling pattern, with acceleration factor of 4, results in highly undersampled -space. Reconstruction of the undersampled -space using zero-filling generates a low-quality image with heavy artifacts that is completely non-diagnostic (right). A non-accelerated acquisition that uses fully-sampled -space results in high-quality image (left). In this study, we focus on 2D data acquisition, which utilizes a 1D sampling pattern in the phase-encoding direction.

More recently, Deep Neural Networks (DNNs) have been used to push values even higher [8hammernik2018learning, 9schlemper2018deep, 6zhu2018image]

. Among the most promising Deep Learning (DL) techniques, unrolled iterative networks (also called cascading networks) have emerged as a leading powerful method 

[8hammernik2018learning, 9schlemper2018deep]. Inspired by CS, this technique uses a DNN composed of a sequence of iterations that include data-consistency and regularization units. The data-consistency units utilize the acquired -space lines as a prior that keeps the network from drifting away from the acquired data, and the regularization units are trained to regularize the reconstruction.

As with other image generation problems, using a naive pixel-wise distance for training DL-based undersampled MRI reconstruction models can result in image blurring and unrealistic appearance. In a clinical setting, avoidance of blurring can be crucial for proper diagnosis. Recently, Generative Adversarial Networks (GANs) have been used to promote the naturalness of MRI reconstructions[14hammernik2018variational, 12mardani2019deep, 13yang2018dagan]. In our work, we harness the power of conditional Wasserstein GANs (cWGANs) to further improve image quality, and alleviate long training and experimental process using a novel training technique for multi-term adversarial objectives.

The main contributions of this paper are as follows: (1) We adopt and evaluate a cWGAN method for undersampled MRI reconstruction, in which both the generator and discriminator are conditioned using the acquired undersampled data. (2) We introduce a novel training algorithm called Adaptive Gradient Balancing (AGB) which balances the losses in multi-term adversarial objectives. (3) We provide an extensive comparison between different models and training techniques. In particular, we report results of six methods—an unrolled iterative network, a Variational Network[8hammernik2018learning, 14hammernik2018variational], a CNN-Cascade [9schlemper2018deep], a WGAN based network, a cWGAN, and a cWGAN trained with our AGB. (4) We propose and evaluate a novel Densely Connected Iterative Network (DCI-Net) for undersampled MRI reconstruction, which is inspired by Dense-Nets [10huang2017densely].

2 Related work

Recent methods in image-to-image translation adopt the idea of conditional GANs (cGANs) [17mirza2014conditional], in which the generated data are conditioned by data that are being fed to both the generator and discriminator networks. Isola et al. [20isola2017image] uses a cGAN [17mirza2014conditional] to learn a mapping from one image domain to another, such as converting a satellite photo into a map, a sketch into a photorealistic image, etc. Wang [19wang2018high] extended this work to generating high-resolution images.

There have recently appeared an increasing number of DL-based accelerated MRI reconstruction models [9schlemper2018deep, 15diamond2017unrolled, malkiel2018densely, malkiel2019leveraging, 8hammernik2018learning, 14hammernik2018variational, 13yang2018dagan, hardy2018residual, brada2019towards, rotman2021correcting, chen2018variable, chen2018improving] . Schlemper et al. [9schlemper2018deep]

used a cascade of convolutional neural networks (CNNs) employing data-consistency layers and optimized to minimize a pixel-wise distance. Diamond et al.

[15diamond2017unrolled] proposed a framework for integrating prior knowledge into DL architecture called unrolled optimization with deep priors (ODP). They presented a general method for solving inverse imaging problems and demonstrated their approach on undersampled MRI reconstruction. Zhu et al. [6zhu2018image] introduced automated transform by manifold approximation (AUTOMAP), a DNN that learns a mapping between sensor and image domains. The AUTOMAP architecture is composed of fully connected layers followed by convolutional layers trained to optimize a pixel-wise objective.

Hammernik et al. [8hammernik2018learning, 14hammernik2018variational] proposed a variational network (VN) for solving undersampled MRI reconstruction. In [8hammernik2018learning] they presented a VN trained to minimize a pixel-wise loss, and then in [14hammernik2018variational] proposed a GAN-based VN to reduce blurring and to improve the perceptual appearance of the reconstructed images. Mardani et al. [12mardani2019deep] proposed a GAN-based model that uses a deep residual network as a generator, and a discriminator trained to optimize a mixture of pixel-wise and least squares GAN (LSGAN) losses [16mao2017least]. In [yu2017deep, 13yang2018dagan], deep de-aliasing GAN (DAGAN) was introduced, a GAN-based model trained to optimize a mixture of pixel-wise, perceptual and GAN losses. In contrast to the original cGAN technique [17mirza2014conditional], DAGAN uses an architecture that conditions only the generator input but not the discriminator, which in this context is similar to [12mardani2019deep, 14hammernik2018variational]. The authors also reported [13yang2018dagan] that a model that uses only pixel-wise and GAN (but not perceptual) losses, generates unrealistic jagged artifacts. As a motivation to our study, we experienced similar behavior when training our generator to minimize a weighted-sum objective with a weighting that appeared to favor the GAN term.

Our method differs from other DL based undersampled reconstruction studies in employing a conditional architecture that conditions both the generator and the discriminator, as we found that applying a conditional discriminator has a profound impact on model convergence and performance. In addition, our method is unique in its AGB training and DCI-Net generator architecture. The former improves performance, automates the process of multi-term adversarial loss combination, and streamlines hyperparameter tuning. The latter provides a simple yet effective technique for promoting feature propagation and reuse in iterative networks.

Very recently, the authors of [pnas2020instabilities] have reported that DL-based reconstruction methods can suffer from instabilities. Our study was motivated by similar concerns.

In [lucic2017gans], the authors show that most GAN techniques, including [29arjovsky2017wasserstein, gulrajani2017improved, mao2016least, berthelot2017began, kodali2017convergence], can reach similar scores with sufficient exploration of hyper-parameters and random initializations. In addition, the authors show that most improvements do not arise from fundamental algorithmic changes of the underlying GAN technique. In our work, we propose a technique that can accelerate the process of hyperparameter tuning.

Fig. 2: The generator receives undersampled

-space data as input and generates a matched estimated fully-sampled image. The discriminator learns to estimate the Wasserstein Distance between “fake” pairs and “real” pairs.

3 Problem Formulation

Let be the -space signal acquired by an MRI scanner. W and H are the width and height of the acquired signal. For a single-coil receiver, an image

can be estimated by performing an inverse Fourier transform:

In multi-coil MRI, an array of coils acquire different 2D -space measurements of the same object

(1)

Each coil , positioned at a different location, is typically highly sensitive in one region of space. This position-dependent sensitivity can be represented by a complex-valued coil sensitivity map in real space,

(2)

During reconstruction, the images from each coil are combined into a fully-sampled image

(3)

where is a reconstruction function

(4)

is the complex conjugate of the sensitivity map of coil , and denotes the Hadamard product. To accelerate imaging, a binary sampling pattern is used to undersample each coil’s -space signal for each slice. The undersampled -space signal, denoted by , can be calculated by

(5)

The undersampled zero-filled image can be written as

(6)

The learning task is to find a reconstruction function

that minimizes an expected loss function

(Sec. 4.1) over a population of scans:

(7)

For a given and , we will denote the generated image by

(8)

4 Methods

Our method learns a DL-based undersampled MRI reconstruction model from training samples, each of which is a pair of a fully sampled and matched undersampled -space data. We propose a cGAN architecture, which conditions the reconstruction using the zero-filled image. Specifically, our model is composed of a generator and a discriminator network. The generator reconstructs an image from undersampled -space data. The discriminator receives a pair of input images: (i) either a ground truth image or a generated (“fake”) image from undersampled -space data and (ii) a zero-filled image (see Fig. 2).

While it is possible to use a non-conditional GAN architecture that looks only at or , in this case the discriminator can only enforce general style properties learned from the distribution of fully sampled images, which might not necessarily match a specific image. For example, a discriminator that learns that fully sampled images possess a high sharpness level may encourage all generated images to maintain sharp features, although some features might be less sharp than others. Under the context of training, such a non-conditional discriminator can be fooled by fake images that are inconsistent with their prior information from , and thus can generate gradients that do not necessarily correlate with a pixelwise term. This property can hinder training convergence (see Sec. 6.3), and degrade data fidelity, and may expose the generator to hallucinations. However, a conditional discriminator can enforce both realistic appearance and spatial consistency by matching each generated image with its corresponding zero-filled image.

4.1 Objective

Following the success of the Wasserstein GAN (WGAN) [29arjovsky2017wasserstein] and the framework proposed by Isola et al. [20isola2017image], we adopt a conditional WGAN objective:

(9)

where and are the generator and discriminator networks, repectively. is an undersampled -space dataset, is a fully sampled image and is a zero-filled image. In addition to the adversarial loss, we also add a pixel-wise Mean Square Error (MSE) loss

(10)

can be expressed as:

(11)

where is an adaptive weight that changes during training (see next section).

4.2 Adaptive Gradient Balancing

In WGAN training, the discriminator network is used as a learned loss function, which dynamically changes during training, and thus may generate gradients with variable norm. To stabilize the WGAN training and to avoid drifting away from the ground-truth spatial information, we introduce the Adaptive Gradient Balancing (AGB) algorithm for continually balancing the gradients of the pixel-wise and WGAN loss functions.

In order to keep the gradients of both terms at the same level, and since the WGAN gradients tend to vary, we choose to adaptively upper-bound the WGAN gradients. Specifically, we define to be an adaptive weight that will be used to bound the WGAN loss gradients. We calculate two moving-average variables and

corresponding to the WGAN loss and the pixel-wise loss, respectively. These moving averages capture the standard deviation (SD) of the gradients calculated at every backward step on the generated image, with respect to each one of the losses separately. At every training step, if

for a predefined value, we update and as follows: , , where is a predefined decay rate. During training, we divide the WGAN loss by to carefully decay the WGAN loss gradients to roughly the same order of magnitude as those of the pixel-wise loss. Moreover, in order to keep a reasonable ratio between the generator’s WGAN loss gradients and the discriminator loss gradients, we also decay the discriminator loss by the same factor (see Alg. 1).

Fig. 3: DCI-Net (A) consists of N unrolled iterative blocks, each with dense skip-layer connections (curved arrows) to subsequent blocks. Each iterative block (B) consists of data-consistency (D) and regularization (C) units. The regularization unit operates on all G+1 connections, while the data-consistency unit operates only on direct connection. Images at all stages are complex—in practice, treated by creating real and imaginary channels (not shown).
; ; for number of training iterations do for = 0, …, do Sample a minibatch {()} + Adam(, ) clip(, -, ) end for Sample a minibatch {()} Adam(, ) if then end if end for
Algorithm 1 AGB training of WGANs for multi-term loss. Here , , , , , , . and are discriminator and generator networks with weights and , respectively.

By extending WGAN training to adaptively balance a multi-term loss objective, our AGB algorithm ensures one invariant during the entire training—the SD of the WGAN loss gradients is upper-bounded by a factor of the SD of the pixel-wise gradients. This invariant maintains the effectiveness of both loss terms, over the course of training.

4.2.1 Hyperparameteres

AGB utilizes the following hyperparameters: (a value for initializing the decay rate for the GAN gradients), (decay weight used to calculate the moving average of the gradients of each loss), (defines the maximal ratio between the GAN loss gradients and the pixel-loss gradients), (the rate used to increase the value), learning rate, and clipping value (for the WGAN training). The same default parameters are used in all experiments. , , , .

The is the only parameter for which we applied hyperparameter search during the development of the method (besides the early research phase of developing the method, we did not employ any hyperparameter search in any of the experiments). We have explored the values 1, 10, 100, 1000, and found that configuring with a value of 10, which upper bounds the GAN gradients by a factor of 10 compared to the pixel-loss gradients, yields the best performance. Then, through all of our experiments, we set .

In our experiments, the learning rate was chosen to be the same as the one used in the original pix2pix work [20isola2017image] or the value used during the training of the MRI reconstruction network when optimizing only the pixel-wise loss. The clipping value is the recommended one from WGAN.

(a) (b) (c)
Fig. 4:

Moving averages calculated on the norm of the GAN loss gradients, versus training step, for Pix2pix and Pix2pix-AGB trained to generate (a) facade images from annotations, (b) urban street scenes from annotations, and (c) aerial photographs from map images. The gradients are derived with respect to the pixels of the generated image. The increasing norm may indicate the growing dominance of the GAN loss during training. The norm of the pixel-wise loss gradients (not shown in the figure) remains fixed since Pix2pix utilizes an L1 loss. Both models were trained for 1000 epochs (instead of 200, as suggested in the Pix2pix work).

4.3 Network Architectures

4.3.1 Generator

We propose a new generator architecture (Fig. 3), called Densely Connected Iterative Network (DCI-Net), based on the iterative CNN [8hammernik2018learning, 9schlemper2018deep]. The key new developments are the use of (1) dense connections [10huang2017densely] across all iterations, which strengthens feature propagation, making the network more robust, and (2) a relatively deep architecture of over 60 convolutional layers, bringing increased capacity. Our generator receives M coils of undersampled -space data, and uses N = 20 iterations, each of which includes a data-consistency unit and a regularization unit (Fig. 3B). Dense skip-layer connections between the output of each iteration and the following G iterations—where typically G = 5—are represented as curved lines in Fig. 3A. This results in an input to each block composed of skip and direct connections concatenated to form a G+1 channel complex image.

Data-consistency unit Each data-consistency (DC) unit (Fig. 3D) shades the input image with each coil sensitivity map, transforms the resulting images to -space, imposes the sampling mask, calculates the difference relative to acquired -space and returns them to the image domain, multiplied by a learned weight (see Fig. 3D). These operations can also be expressed as:

(12)

where and are the input image and the learnt weight of the th iteration, respectively. By utilizing the acquired -space data as a prior, the data-consistency units, embedded as operations inside the network, keep the network from drifting away from the acquired data. For this use, the undersampled -space data were also input directly into each iterative block of the network (Figs. 3A,B).

Fig. 5:

Representative samples from the Facade, cityscapes and maps test sets. The same source image (center) was used by both models as a source to generate the images on either side. Pix2pix with a longer training of 1000 epochs (left) introduces visual artifacts. Pix2pix-AGB trained for the same number of epochs (right) yields higher image quality.

Fig. 6: A representative sample from the Facade test set. The same facade annotation image (center) was used by both models to generate the facade images on either side. Pix2pix-Unbalanced (left) introduces visual artifacts since, during training, the GAN loss term dominates the pixel-wise loss. Pix2pix-Unbalanced-AGB (right) can mitigate these artifacts by adaptively updating the term, which upper bounds the GAN loss by a factor of the pixel-wise loss gradients.

Regularization unit Each unit (Fig. 3C) has three sequences consisting of 5x5 convolution, bias, and leakyReLU [36_lrelu_xu2015empirical] layers. The output of the final iteration (Fig. 3A) is (1) compared to the fully sampled reference image to generate a pixel-wise loss function, using MSE, and (2) paired with its corresponding zero-filled image and fed into the discriminator network to evaluate WGAN loss [29arjovsky2017wasserstein].

4.3.2 Discriminator

For our discriminator we use a convolutional “PatchGAN” [18li2016precomputed]. The discriminator receives a pair of (1) and (2) or

, concatenated as two channels, and is able to penalize structure at the scale of image patches, from both channels. The architecture incorporates four convolutional layers with a stride of 2, each followed by batch normalization and LeakyReLU. The last convolutional layer is flattened and then fed into a linear layer, for which each input value corresponds to a different patch in the input channels. The linear layer outputs a single value, used to calculate the discriminator’s WGAN loss.

Importantly, our discriminator receives pairs of zero-filled image and generate/real image. Paring those images allows the discriminator to match between features from the zero-filled image and generated/real image in the spatial space. In other words, such a discriminator can not be fooled by fake images that are inconsistent with their zero-filled image.

Fig. 7: Representative samples from the cityscape [cityscapes] test set. The same annotated images (center) were used by both models to generate the urban street images on either side. Pix2pix-Unbalanced (left) introduces visual artifacts and fails to generate realistic car instances. Pix2pix-Unbalanced-AGB (right) yields higher-quality images.
Fig. 8: Representative samples from the maps test set. The same map images (center) were used by both models to generate the aerial photographs on either side. Pix2pix-Unbalanced (left) introduces visual artifacts. Pix2pix-Unbalanced-AGB (right) yields higher-quality images.
(a) (b) (c)
Fig. 9: value calculated versus epoch, for the Pix2pix-Unbalanced-AGB model trained to generate (a) facade images from annotations, (b) urban street scenes from annotations, and (c) aerial photographs from map images.
(a) (b) (c)
Fig. 10: L1 and FID per epoch, reported on the validation set, for both unbalanced models trained on (a) facade, (b) cityscapes, and (c) maps datasets.

5 Pix2Pix Experiments

To show the generality of ABG beyond the MRI settings for which it was developed, we consider first the optimization of a multi-term adversarial loss function in the Pix2pix model [20isola2017image]. This model transforms images from a source domain to a target domain. The Pix2pix objective incorporates a bi-term adversarial loss function, comprising a pixel-wise loss (L1) and a GAN loss. Pix2pix balances the two loss terms by means of weights crafted using a hyperparameter search. AGB does not require hyperparameter tuning. Therefore, it streamlines the development process by saving the exhaustive search required for carefully weighting multi-loss optimization terms.

The Pix2Pix experiments investigating the balancing of the GAN loss and the pixel-wise losses are conducted on three datasets: facade [Tylecek13], aerial-maps [20isola2017image] and cityscapes [cityscapes]. To this end, we trained two models for each of the three datasets. The first model of each dataset, utilizes the original Pix2pix scheme. The second employs the Pix2pix architecture with AGB training (denoted by Pix2pix-AGB). The two Facade models were trained to generate facade [Tylecek13] images from facade annotations. The aerial-maps models were trained to generate aerial photographs from their matched map images. The cityscapes models were trained to generate urban street scenes from their matched semantic annotations.

In Fig. 4, we show moving averages, calculated on the norms of the GAN gradients, for the Pix2pix and Pix2pix-AGB models of all three experiments. In these experiments we do not employ early stopping, as suggested in [20isola2017image]. As can be seen, the gradients of Pix2pix are constantly increasing, while those of Pix2pix-AGB reamin bounded and range over a much smaller interval.

These experiments indicate that without applying an early stopping, Pix2pix could become unstable. Indeed, we verified that longer training (1000 epochs instead of the originally suggested 200 epochs) introduces visual artifacts. On the other hand, the Pix2pix-AGB model, trained for 1000 epochs, yields images with higher quality (see Fig. 5).

We further demonstrate the ability of AGB to improve convergence and mitigate artifacts in a Pix2pix model trained with a non-optimal weighting between its loss terms. To this end, we trained an additional two Pix2pix models, for each of the above three datasets, while using non-optimal loss weighting. Hence, we multiply the Pix2pix GAN loss term by 100 (the original factor was 1), and keep the original value (100) for the L1 loss. The first model trained for each dataset, is a Pix2pix with conventional training, which minimizes the unbalanced objective above (denoted by Pix2pix-Unbalanced). The second is a Pix2pix model trained with AGB (denoted by Pix2pix-Unbalanced-AGB). Following the Pix2pix work [20isola2017image], all models were trained for 200 epochs.

Figure 6 presents two representative images from the test set of the unbalanced facade models. It can be seen that Pix2pix-Unbalanced introduces visual artifacts, while Pix2pix-Unbalanced-AGB yields higher-quality images.

Figure 7 presents representative images from the test set of the unbalanced cityscapes models. The Pix2pix-Unbalanced yields urban street images with visual artifacts, while the Pix2pix-Unbalanced-AGB, yields images with higher quality. Specifically, as can be seen in the figure, the Pix2pix-Unbalanced model seems to completely fail to generate car instances. For example, in the first row of the figure, the left side of the source image indicates a sequence of car instances. The corresponding generated image of the Pix2pix-Unbalanced model introduces visual artifacts in the cars’ location, while the Pix2pix-Unbalanced-AGB model yields relatively-realistic car instances.

Figure 8 presents representative samples from the test set of the unbalanced maps models. As can be seen, the Pix2pix-Unbalanced model introduces visual artifacts while the Pix2pix-Unbalanced-AGB yields substantially higher-quality images.

Figure 9 presents the value versus epoch, for the above three Pix2pix-Unbalanced-AGB models. As can be seen, in the Facade experiment, first updates to a value of 55, and then stabilizes at a value of 75. Perhaps unsurprisingly, the latter value is close to the non-optimal multiplication weight of the GAN loss (100). In the cityscape experiment, the value increases during training, up to a value of 190. In the maps experiment, first updates to a value of 20, then starts to constantly increase up to a value of 165. Notably, the trends of all values are well correlated with the GAN gradients of the matched Pix2pix models (see Figure 4). In the original Pix2pix work, the weighting was manually crafted after conducting a hyperparameter search.

Figure 10 exhibits the FID and L1 reported on a validation set, versus epoch, for the two unbalanced models of each dataset. As can be seen in the figure, Pix2pix-Unbalanced-AGB converges faster, compared to Pix2pix-Unbalanced, yielding improved L1 score, along with better or similar FID performance.

Images NMSE FID
ZF 115 173.0
Wavelets 18.7 138.4
TV 14.1 117.0
PI 18.9 109.0
Variational Network 6.70 23.3
CNN-Cascade 7.22 22.7
cWGAN-AGB 3.39 18.7
TABLE II: Ablation analysis. First section: Ablation for DCI-Net generator, where all models were trained to minimize an MSE loss alone, with no GAN loss. Second section: Ablation on the GAN technique. All WGAN variants utilize best performing architecture from the first section, i.e. a DCI-Net with 20 iterations (”20I”), growth rate of G = 5 (”5G”) and 40 kernels for each convolution (”40K”).
Experiment NMSE FID
DCI-Net (5I-5G-160K) 3.67 20.2
DCI-Net (20I-1G-40K, no dense) 3.46 19.3
DCI-Net (20I-5G-40K) 3.24 19.4
WGAN 3.71 19.7
cWGAN 3.61 19.9
cWGAN-AGB (proposed) 3.39 18.7
TABLE III: Mean of sharpness, SNR, contrast, artifacts and overall IQ scored for our proposed cWGAN-AGB, a baseline DCI-Net (which optimizes an MSE loss alone without any GAN) and the fully-sampled images. Scores 1 to 5 indicate poor to excellent.
Images Sharpness SNR Contrast Artifacts Overall IQ
Fully sampled 5.0 3.3 4.0 4.0 4.5
DCI-Net (20I-5G-40K) 2.3 4.5 4.0 3.8 2.3
cWGAN-AGB (proposed) 3.8 3.8 4.0 3.8 3.5
TABLE I: Comparison of our method with zero-filled images (ZF), and reconstruction using wavelets or TV [5lustig2007sparse], PI [31ARC_beatty2007method], VN [8hammernik2018learning] and CNN-Cascade [9schlemper2018deep]. NMSE is w.r.t fully sampled image.

6 MRI Experiments

6.1 Dataset

Fully sampled brain MRI datasets (T1, T2, T1-FLAIR and T2-FLAIR in axial, coronal and sagittal orientations) were acquired with various k-space data sizes and various numbers of coils along with sensitivity maps estimated from separate calibration scans. In total, 2267 slices were acquired, of which 1901 were used to train the networks, 151 for validation and 215 for testing. In addition, during training, we also applied random horizontal flips and rotations (bounded to 20 degrees) to augment the training set. The data were retrospectively down-sampled using 12 central lines of k-space and a 1D variable-density sampling pattern outside the central region, resulting in a net under-sampling factor

. As evaluation metrics, we compute both normalized mean square error (NMSE), and the Fréchet Inception Distance (FID) 

[30heusel2017gans], which is a similarity measure between two datasets that correlates well with human judgment of visual quality and is most often used to evaluate the quality of images generated by GANs (see Sec. 6.4).

6.2 Comparison with Baseline Methods

We compare on the test set our cWGAN-AGB to CS methods that use wavelets or Total Variation (TV) [5lustig2007sparse] and to an autocalibrated PI method [31ARC_beatty2007method]. We also compare to the Cascade CNN [9schlemper2018deep] and Variational Network (VN) [8hammernik2018learning], both trained using our same dataset and sampling pattern. As can be seen in Table III, our proposed model produces significantly more accurate reconstructions than the other methods, as measured by both the NMSE and FID metrics.

For the sake of completeness, we provide a qualitative comparison of our proposed model to compressed sensing methods using wavelets or TV  [5lustig2007sparse] and to PI  [31ARC_beatty2007method], as shown in Fig. 11. It can be seen that our proposed method produces higher-quality images than traditional CS and PI methods, both in terms of perceptual quality and reconstruction error.

Fig. 11: Comparison with CS and PI methods. Left to right: fully sampled, cWGAN-AGB, wavelets, Total Variation, PI, zero-filled.

6.3 Comparing GANs Convergance

To show the effectiveness of our method, we compared the convergence of our cWGAN-AGB model to that of cWGAN and WGAN, trained without AGB. During the training phase, FID and NMSE were evaluated on a hold-out validation set, for each epoch. Although WGAN suffers from a slow start (Fig. 12), eventually it performs better on FID compared to cWGAN, but worse on NMSE (which can indicate a more realistic appearance at the cost of decreased fidelity). Our model converges better, with both scores decreasing substantially faster than the other techniques. For more experiments on AGB training, see the supplementary materials showing AGB applied to the Pix2pix model and evaluated on additional three different datasets.

6.4 Fréchet Inception Distance (FID)

Fréchet Inception Distance (FID) [30heusel2017gans]

is a similarity measure between two datasets that correlates well with human judgment of visual quality and is most often used to evaluate the quality of images generated by GANs. We utilize FID as a quality metric to evaluate the similarity between the set of our generated images and the corresponding fully-sampled images. FID relies on the Fréchet distance calculated from two Gaussians each fitted on feature vectors taken from a pre-trained Inception network, one for the generated images and one for the fully-sampled images:

(13)

where and are sets of feature vectors extracted from an Inception network, for the fully sampled images and generated images, respectively. , and ,

are the mean and variance of the Gaussians fitted on

and , respectively.

Fig. 12: FID and NMSE during training, as evaluated on the validation set. Results are shown for WGAN, a vanilla cWGAN, and our cWGAN-AGB.

6.5 Ablation Analysis

We compare, in Table III, our cWGAN-AGB with three other models: 1) cWGAN, 2) WGAN, and 3) a baseline DCI-Net for undersampled MRI reconstruction, which optimizes an MSE loss alone without any GAN. All models were evaluated with NMSE and FID on the test set. We found that (a) cWGAN and cWGAN-AGB have better SNR than WGAN, (b) cWGAN-AGB converges much faster than cWGAN or WGAN (Fig. 12), has fewer artifacts, and performs better in both FID and NMSE measures (Table III) and (c) although cWGAN-AGB has higher NMSE than the baseline DCI-Net, it performs better in FID and yields sharper images with more fine details while maintaining a natural image texture. A representative reconstruction can be seen in Fig. 13, where both WGAN and cWGAN models suffer from local inconsistencies with the ground truth image (red arrows). In the same area, our proposed method exhibits a more accurate reconstruction. In addition, Fig. 13 shows a representative reconstruction from the baseline DCI-Net (trained without GAN loss), which exhibits some image blurring.

Fig. 13: A representative example with regions of interest showing the reconstruction of all models side-by-side (top row), along with the ground-truth fully-sampled image and the absolute difference images between ground-truth and each GAN based reconstruction (bottom row). cWGAN and cWGAN-AGB have better SNR than WGAN. cWGAN-AGB yields sharper images with fewer artifacts and more fine detail, while maintaining a more natural appearance. The baseline DCI-Net sometimes exhibits some blurring.

In Table III, we also compare to baseline architectures, demonstrating the effectiveness of our key new architecture developments: (1) dense connections across all iterations, which strengthen feature propagation, making the network more robust, and (2) a relatively deep architecture of 20 iterations, comprising more than 60 convolutional layers, which brings increased capacity. We compared our generator to (1) an unrolled iterative network, similar to DCI-Net but without dense connections and (2) a 5-iteration DCI-Net with a similar number of learned parameters. Employing dense connections significantly improved accuracy, and the deeper network produced 12% lower mean NMSE than a shallower network with a similar number of learned parameters.

6.6 Visual Scoring

To assess the perceptual quality of the resulting images we report a visual scoring conducted by four experienced MRI scientists. The same test set was ranked for cWGAN-AGB, the baseline DCI-Net and for the fully sampled images. The scoring was performed blindly and the images were randomly shuffled. The studies were taken from a cohort of seven healthy volunteers. Each study contained a full brain scan comprising 25-43 slices. For each study, image sharpness, signal-to-noise ratio (SNR), contrast, artifacts and overall image quality (IQ) were reported.

Here images were rated on a scale of 1 to 5, where the numbers denote 1: not diagnostic, 2: limited, 3: diagnostic, 4: good, 5: excellent. Table III shows that cWGAN-AGB produced significantly sharper images than the baseline DCI-Net, at the cost of somewhat weaker denoising of the images.

6.7 Implementation Details

Adam optimizer is used with a learning rate of for both generator and discriminator networks, with the momentum parameter

= 0.9. Training is performed with TensorFlow interface on a GeForce GTX TITAN X GPU, 12GB RAM. For the proposed model with AGB training,

is initialized to 10, without any hyperparameter exploration, and was found to increase in multiple steps during training to a value of 370 (see Fig. 14). For the traditional GAN training, is initialized to 100, after a hyperparameter search conducted on the values 10, 100, 1000. All models performed 600 epochs in 2 weeks of training, and the inference run time was 100ms per slice on a single GPU. Our code can be found at https://github.com/ItzikMalkiel/AGB.

Fig. 14: Beta value calculated per epoch, for cWGAN-AGB model.

7 Discussion of MRI Experiments

The DCI-Net used as our generator is similar to [6] and [7] in that it is an unrolled optimization. While the CNN Cascade [7] alternates data-consistency and CNN blocks, the DCI-Net applies them in parallel within each iteration, in similar fashion to the VN [6]. However, while the VN learns, in addition to the convolutional filters, a set of nonlinear activation functions for each iteration, the DCI-Net employs leaky ReLU activations. But perhaps the biggest differences with these networks are our use of dense skip connections, combined with a relatively deep architecture.

One practical consideration is sensitivity to k-space undersampling pattern, and specifically, whether a network trained with one undersampling pattern and/or R value will face problems when inferencing with another pattern. While this could be addressed by training a separate network for each anticipated combination of undersampling pattern and R value, we have found that training a single network over a variety of patterns (including variable density and uniform) and R values (e.g. ranging from 2 to 6) can result in average NMSE values for test data that are within 6% of those obtained from more focused networks.

While the upper part of Tab. III shows that both the NSME and FID scores were improved for the deeper network, FID (unlike NMSE) did not likewise improve when dense connections were added. This can be attributed to (1) the ability of FID to indicate the perceptual quality of images, including blur level, and (2) the ability of dense connections to improve accuracy, rather than mitigating blurring.

While the methods described here were developed in the context of Cartesian k-space sampling, they could be readily extended to non-Cartesian trajectories such as interleaved spiral or radial. This would involve changing the generator’s data-consistency block (Fig. 3D) to include an inverse gridding step (interpolating from the Cartesian grid onto the non-Cartesian trajectory) immediately following the 2DFFT into the k-space domain, and a second gridding step (interpolation onto the Cartesian grid) just prior to 2DIFFT into the image domain.

A number of investigators have searched for an optimal perceptual image-quality metric, but this has proven to be an elusive goal. While our metrics ranged from the relatively crude (NMSE), to a more sophisticated perceptually based score (FID), and supplemented by visual scoring using a 5-point-scale Mean Opinion Score (MOS), other metrics have been shown to have utility, including Structural Similarity Index (SSIM) [wang2004image], peak SNR, and Semantic Interpretability Score (SIS) [seitzer2018adversarial]. SSIM has been widely adopted in medical imaging as a means of assessing quality based on degradation of structural information in the image. SIS has been used in cases where expert-provided segmentation labels are available, with Dice overlap calculated between ground-truth and network segmentations as a measure of visibility of segmented objects in the reconstructed images. And in [schlemper2018stochastic], a stochastic approach has been proposed to measure localized reconstruction uncertainty, by inferencing using a stochastic subset of sub-networks and measuring variance between different reconstructions. See supplementary for discussion about mitigating reconstruction instabilities.

7.1 DL-based MRI Reconstruction Instabilities

A few components in our model are aimed to mitigate MRI reconstruction instabilities, similar to those discussed in [pnas2020instabilities], as well as potential hallucinations that may arise by the use of GANs. (1) we are the first to utilize a full conditional GAN architecture, in MRI reconstruction, which allows the discriminator to also verify fidelity, instead of solely reinforcing the realistic appearance of the images (as done in other papers that use GANs for MRI reconstruction). This property stems from the architectural choice of the discriminator, which receives both the generated image and the zero-filled image and therefore can match between the spatial features of both. (2) the unrolled architecture, alternating between convolutions and data consistency terms, reinforces the generator to produce images that are very close to the measure k-space. In our generator architecture, we apply a data consistency operation after every convolutional block. (3) the essence of the Adaptive Gradient Balancing is to ensure that the GAN component does not dominate the MSE loss. This entails a superiority to the MSE loss over the GAN loss since AGB decays the GAN gradients to be below a specific threshold defined by the MSE loss. Hence, AGB encourages the network to produce images that are more faithful to the images of full scans yet maintaining a natural image texture.

8 Conclusions

We present a novel undersampled MRI reconstruction model that employs a cWGAN with a novel multi-loss GAN training procedure. Our AGB training adaptively balances the adversarial and pixel-wise terms, streamlines the process of hyperparameter tuning (by saving the exhaustive search required for carefully weighting multi-loss adversarial terms), accelerates convergence and results in superior performance. By leveraging GANs to their fullest, the method generates sharper images with more fine detail and natural appearance than would otherwise be possible. In addition, dense connections are used to improve the performance of our unrolled iterative generator network. In the context of MRI reconstruction, a GAN based model can raise concerns about hallucination, where image details that do not appear in the ground truth are generated. We found that our method produces significantly less hallucination than other GANs. AGB is demonstrated both as an MRI method and as a general computer vision method. AGB training could be beneficial for any model employing a multi-term adversarial objective, especially in the medical domain where there is considerable variability in the quality of the input and less experience in balancing GAN loss terms.

References