iUNets: Fully invertible U-Nets with Learnable Up- and Downsampling

05/11/2020 ∙ by Christian Etmann, et al. ∙ University of Cambridge 17

U-Nets have been established as a standard neural network design architecture for image-to-image learning problems such as segmentation and inverse problems in imaging. For high-dimensional applications, as they for example appear in 3D medical imaging, U-Nets however have prohibitive memory requirements. Here, we present a new fully-invertible U-Net-based architecture called the iUNet, which allows for the application of highly memory-efficient backpropagation procedures. For this, we introduce learnable and invertible up- and downsampling operations. An open source library in Pytorch for 1D, 2D and 3D data is made available.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 2

page 6

page 9

Code Repositories

memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks


view repo

iunets

A fully invertible U-Net for memory efficiency in Pytorch.


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Invertible neural networks have been an active area of research within the neural networks community of the last few years. They are particularly interesting for memory-constrained applications, as invertible neural networks allow for memory-efficient backpropagation (Gomez et al., 2017). Let (parametrised by ) be the -th layer of a neural network. Let the network’s loss depend on the layer’s activation (for some input ). Then the required weight-gradient for first-order optimisation is

i.e. in general one needs to store the activation in memory. If, however, is invertible and is stored in memory, then one can simply reconstruct instead. By successively reconstructing activations from the output layer back to the input layer, one needs to store only one activation as well as one gradient for the whole backpropagation (apart from possible memory overhead for computing the derivatives and reconstructions). This means that the memory requirement for training invertible networks via this method is independent of the depth of the network.

For many tasks in which inputs are mapped to outputs of the same resolution, the U-Net (Ronneberger et al., 2015)

has become a standard neural network design principle. In the U-Net, features are downsampled and later recombined with their upsampled counterparts through channel concatenation. This allows for different parts of the U-Net handling information at different scales. When dealing with very high-dimensional data (such as 3D medical imaging data, for which the 3D U-Net

(Çiçek et al., 2016) was developed), the memory requirements may soon pose a problem. One way out of this may be the above memory-efficient implementation of backpropagation for invertible networks. Partially reversible U-Nets (Brügger et al., 2019)

already use this principle for each resolution separately. There, since the downsampling is performed with max pooling and the upsampling is performed with trilinear upsampling (both of which are inherently non-invertible operations), the down- and upsampled activations still have to be stored. Moreover, for other applications in which

full invertibility is fundamentally needed (such as in normalizing flows (Rezende and Mohamed, 2015)), those cannot be used. In this work, we introduce novel learnable up- and downsampling operations, with which a fully invertible U-Net can be constructed.

2 Invertible Up- and Downsampling

In this section, we introduce novel learnable invertible up- and downsampling operations.

Purely spatially up- and downsampling operators for image data are inherently non-bijective, as they alter the dimensionality of their input. Classical methods for these from image processing include up- and downsampling with bilinear or bicubic interpolation as well as nearest-neighbour-methods

(Bredies and Lorenz, 2018)

. In neural networks and in particular in U-Net-like architectures, downsampling is usually performed either via max-pooling or with strided convolutions. Upsampling on the other hand is typically done via a strided transposed convolution.


One way of invertibly downsampling for image data in neural networks is known as pixel shuffle or squeezing (Dinh et al., ), which rearranges the pixels in a -image to a -image, where , and denote the number of channels, height and width respectively. Another classical example of such a transformation is the 2D Haar transform, which is a type of Wavelet transform (Mallat, 1999). Here, a filter bank is used to decompose an image into approximation and detail coefficients. These invertible downsampling methods are depicted in Fig 1.

Figure 1: A test image, the ’pixel shuffle’-transformed image and the Haar-transformed image. The resulting 4 channels are depicted as tiled images. Note that pixel shuffle extracts similar-looking images to the input image, whereas the Haar transform also extracts edge information.

This general principle of increasing the number of channels at the same time as decreasing the spatial resolution of each channel guides the creation of our learnable invertible downsampling operators.
In the following, we call the spatial dimensionality. We say is divisible by , if is divisible by for all . We denote by the element-wise (Hadamard) division of by .

Definition 1.

Let and the stride for the spatial dimensionality , such that is divisible by . We call the channel multiplier. For and , we call

an invertible downsampling operator if is bijective. If the function is parametrised by , i.e. , and is invertible for all (for some parameter space ), then is called a learnable invertible downsampling operator.

Remark 2.

For the practically relevant case of stride in all spatial directions, one has for 2D data and for 3D data.

Opposite to the downsampling case, in the upsampling case the number of channels needs to be decreased as the spatial resolution of each channel is increased. Using the invertibility, we simply define invertible upsampling operators as inverse downsampling operators.

Definition 3.

A bijective Operator is called an invertible upsampling operator, if its inverse is an invertible downsampling operator. If the inverse of an operator is a learnable invertible downsampling operator (parametrised by ), then is called a learnable invertible upsampling operator.

Figure 2: Using the inverse of the pixel shuffling operation will result in artifacts.

In Figure 2, the inverse of the pixel shuffling is exemplified on monochromatic input channels, which yields checkerboard artefacts (which would additionally result in Moiré patterns when applying the commonly-used 3-by-3 convolutional kernels on such upsampled features). Thus, when using this type of invertible upsampling, the output will tend to exhibit artefacts, unless the input features are very non-diverse. This highlights, that extracted features and upsampling operators need to be tuned to one another in order to guarantee both feature diversity as well as outputs which are not inhibited by artefacts. In the following, we will present a method to learn the appropriate up- and downsampling operation.

The general idea of our proposed learnable invertible downsampling is to construct a suitable strided convolution operator (resulting in a spatial downsampling), which is orthogonal (and hence due to the finite dimensionality of the involved spaces, bijective). Its inverse is then simply the adjoint operator. Let denote the adjoint operator of , i.e. the unique linear operator such that

for all , from the respective spaces, where the denotes the standard inner products. In this case, is the corresponding transposed convolution operator. Hence, once we know how to construct a learnable orthogonal (i.e. invertible) downsampling operator, we know how to calculate its inverse, which is at the same time a learnable orthogonal upsampling operator.

2.1 Orthogonal Up- and Downsampling Operators as Convolutions

We will first develop learnable orthogonal downsampling operators for the case

, which is then generalised. The overall idea is to create an orthogonal matrix and reorder it into a convolutional kernel, with which a correctly strided convolution is an orthogonal operator.


Let denote the convolution of with kernel and stride , where . Let further and denote the orthogonal and special orthogonal group of real -by- matrices, respectively.

Theorem 4.

Let and and for all . Let further be an operator that reorders the entries of a matrix , such that the entries of consist of the entries of the -th row of . Then for any , the strided convolution

is an invertible downsampling operator. Its inverse is the corresponding transposed convolution.

Proof.

If for all (i.e. the kernel size matches the strides), then the computational windows of the discrete convolution are non-overlapping. This means that each entry of is the result of the multiplication of the -by--matrix with a

-dimensional column vector of the appropriate entries from

. This means that

(1)

where denotes the appropriate reordering of into a column vector. We will now show that if the block diagonal matrix is orthogonal, then the convolution is an orthogonal operator. Since

being orthogonal implies being orthogonal. For any it holds that,

(2)

where we used the fact that the reordering into column vectors as well as are orthogonal operators. Hence, we proved that is an orthogonal operator (and in particular bijective).

Note that the assumption that (the strides match the kernel size) will hold for all invertible up- and downsampling operators in the following.

2.2 Designing Learnable Orthogonal Downsampling Operations

The above invertible downsampling operator is parametrised over the group of orthogonal matrices. Note that since orthogonal matrices have (i.e. there are two connected components of ), there is no way to smoothly parametrise the whole parameter set . However, if , then by switching two rows of , the resulting matrix has . Switching two rows of simply results in a different order of filters in the kernel . The resulting downsampling with kernel is thus the same as with kernel , up to the ordering of feature maps. Hence, the inability to parametrise both connected components of poses no practical limitation, if one can parametrise . Any such parametrisation should be robust and straightforward to compute, as well as differentiable. One such parametrisation222Others parametrisations include sequences of Householder transformatios or Givens rotations as well as Cayley transforms.

is the exponentiation of skew-symmetric matrices (i.e. matrices

, such that ).

From Lie theory (Sepanski, 2007), it is known that the matrix exponential

(3)

from the Lie algebra of real skew-symmetric matrices to the Lie group is a surjective map (which is true for all compact, connected Lie groups). This means that one can create any special orthogonal matrix by exponentiating a skew-symmetric matrix. The -by- skew-symmetric matrices can simply be parametrised by

(4)

where is a matrix. Note that this is an overparametrisation – any two matrices that differ up to an additive symmetric matrix will yield the same skew-symmetric matrix.

Corollary 5.

Let the same setting as in Theorem 4 hold. Then the operator

defined by

is a learnable invertible downsampling operator, parametrised by over the parameter space .

Note that both examples of invertible downsampling from Figure 1 can be reproduced with this parametrisation (up to the ordering of feature maps): Any symmetric matrix yields the pixel shuffle operation, whereas

yields the 2D Haar transform. This demonstrates that the presented technique can learn both very similar-looking as well as very diverse feature maps. The whole concept is summarised in Fig. 3 and exemplified in Fig. 4.

Figure 3: Our concept for learnable, invertible downsampling. By exponentiating a skew-symmetric matrix , a special orthogonal matrix can be created. When its rows are reordered into filters, convolving these with a stride that matches the kernel size results in an orthogonal convolution, which is a special case of a learnable invertible downsampling operator over the parameter space . Because the computational windows of the convolution are non-overlapping, each pixel of the resulting channels is then just the standard inner product of the respective filter with the corresponding image patch in the original image.
Figure 4: A test image, a randomly initialised learnable invertible downsampling and a learnable downsampling trained to create a sparse representation (by minimizing the 1-norm of the representation). The latter looks very similar to a Haar decomposition.

The case of can now be easily generalised to an arbitrary number of channels by applying the learnable invertible downsampling operation to each input channel independently.

Corollary 6.

Let for all . For , the operator

given by

is a learnable invertible downsampling operator, parametrised by over the parameter space , where is defined as in Corollary 5.

2.3 Implementation

When implementing invertible up- and downsampling, one needs both an implementation of the matrix exponential (which in the case of Pytorch 1.4 does not exist) and for calculating gradients with respect to both the weight as well as the input . For the matrix exponentiation, we simply truncate the series representation

(5)

after a fixed number of steps. Since the involved matrices are typically small (see Remark 2

), the computational overhead of calculating the matrix exponential this way is small compared to the convolutions. More computationally efficient implementations include Padé approximations and

scaling and squaring methods (Al-Mohy and Higham, 2009).

Using (which is a self-adjoint, linear operator), , and

, employing the chain rule yields

(6)

The derivatives are linear operators (in the sense of Fréchet derivatives), and as such admit adjoints. Note that the adjoint of is not the transposed convolution (which takes values in ). Instead, this is an adjoint with respect to the kernel variable (which exists, because the convolution is linear in its kernel) and it takes values in . In the following, we will denote this operator by (cf. (Etmann, 2019)). Furthermore, denote by the Fréchet derivative of in . When incorporating the fact that (Al-Mohy and Higham, 2009), this leads to the expression

(7)

Analogously to the matrix exponential itself, we approximate its Fréchet derivative by a truncation of the series

where with (Al-Mohy and Higham, 2009). The gradients for invertible learnable upsampling follow analogously from these derivations. Both series have infinite convergence radius.

3 Invertible U-Nets

As mentioned in section 1, the general principle of the classic U-Net is to calculate features on multiple scales by a sequence of convolutional layers and downsampling operations in conjunction with an increase in the number of feature maps. The downsampled feature maps tend to capture large-scale features, whereas the more highly resolved feature maps capture more fine-grained properties of the data. The low-resolution features are successively recombined with the prior, high-resolution features via feature map concatenation, until the original spatial resolution of the input data is reached again.

In order to construct a fully invertible U-Net (iUNet), we adopt these same principles. A depiction of the iUNet is found in Figure 5. Note that unlike in the case of non-invertible networks, the total data dimensionality may not change – in particular the number of channels may not change if the spatial dimensions remain the same.
Unlike in the case of the classic U-Net, not all feature maps of a certain resolution can be concatenated with the later upsampled branch, as this would violate the condition of constant dimensionality. Instead, we split the feature maps into two portions of and channels (for appropriate split fraction , s.t. ). The portion with channels gets processed further (cf. the gray blocks in Figure 5), whereas the other portion is later concatenated with the upsampling branch (cf. the green blocks in Figure 5). Splitting and concatenating feature maps are invertible operations.
While in the classic U-Net, increasing the number of channels and spatial downsampling via max-pooling are separate operations, these need to be inherently linked for invertibility. This is achieved through learnable invertible downsampling (Cor. 5) to the (non-concatenated) split portion.

Mathematically, we define the left side of the invertible U-Net for each scale via

(8)

where is the network’s input. Here, , and denote appropriate invertible sub-networks, channel splitting and invertible downsampling operations. With , the right side is defined via

(9)

for . Again, , and denote the respective invertible sub-networks, channel concatenation and invertible upsampling. The iUNet is then defined to be the function that maps to .

Remark 7.

The number of channels increases exponentially as the spatial resolution decreases. The base of the exponentiation not only depends on the channel multiplier , but also on the channel split fraction , since only this fraction of channels gets invertibly downsampled. The number of channels thus increases by a factor of between two resolutions. E.g. in 2D, leads to a doubling of channels for , whereas in 3D the split fraction is required to achieve a doubling of channels for .

3.1 Invertible U-Nets in Practice

The above discussion already pointed towards some commonalities and differences between the iUNet and non-invertible U-Nets. The restriction that the dimensionality may not change between each invertible sub-network’s input and output imposes constraints both on the architecture as well as the data. When invertibly downsampling, the spatial dimensions need to be exactly divisble by the strides of the downsampling. This is in contrast to non-invertible downsampling, where e.g. padding or cropping can be introduced. Furthermore, due to the application of channel splitting (or if one employs coupling layers), the number of channels needs to be at least 2. An alternative may be exchanging the order of invertible downsampling and channel splitting.

These restrictions may prove to be too strong in practice for reaching a certain performance for tasks in which full invertibility is not strictly needed. One use case is memory-efficient backpropagation in a segmentation task, which is in contrast to normalizing flows, where full invertibility is required. This means that if one uses e.g. a (linear) convolutional layer before the iUNet to increase the number of channels, the memory-efficient backpropagation procedure can still be applied to the whole fully invertible sub-network, i.e. the whole iUNet. Similarly, a linear convolutional layer can be applied to the output of the iUNet in order to change the number of channels to some desired number, e.g. the number of classes for a semantic segmentation problem. Note that adding a (learnable) linear layer each before as well as after the invertible network does not necessitate storing any additional activations (since derivatives of linear maps are simply the linear maps themselves).
A general issue in memory-efficient backpropagation is stability of the inversion (Behrmann et al., 2020), which in turn influences the stability of the training. We found that using layer normalisation (Ba et al., 2016) was an effective means of stabilising the iUNet in practice.

Figure 5: Example of a 2D iUNet used for memory-efficient backpropagation for segmenting RGB-images into 10 classes. Linear convolutions are used to increase the number of channels to a desired number (64), which then determines the input and output data dimensionality of the invertible U-Net. Invertible layers, invertible up- and downsampling and skip connections in conjunction with channel splitting and concatenation make up the invertible U-Net (contained in the light-blue box).

4 Experiments

In the following, results of an experiment on learned 3D post-processing from imperfect CT reconstruction, as well as a 3D segmentation experiment are presented. In all trained iUNets, the additive coupling layers from (Jacobsen et al., 2018) were used. The models were implemented in Pytorch using the library MemCNN (Leemput et al., 2019) for memory-efficient backpropagation.

4.1 Learned Post-Processing of Imperfect 3D CT Reconstructions

The goal of this experiment is to test the invertible U-Net in a challenging, high-dimensional learned post-processing task on imperfect 3D CT reconstructions, where the induced undersampling artifacts appear on a large, three-dimensional scale. For this experiment, we created an artificial dataset of the 3D ’foam phantoms’ from (Pelt et al., 2018), where the training set consisted of 180 volumes and the test set consisted of 20 volumes. These are comprised of cylinders of varying size, filled with a large number of holes. The volumes were generated at a resolution of before trilinearly downsampling (to prevent aliasing artifacts). At a resolution of , a reconstruction using filtered backpropjection (FBP) of a strongly undersampled parallel-beam CT projection with Poisson noise was created. A diagonal axis of the volume served as the CT axis (perturbed by angular noise). We expect the varying size of the phantoms, the artifacts on the FBP reconstructions as well as the large-scale bubble structures to favor networks with a large, three-dimensional receptive field (i.e. many downsampling operations), which justifies the use of 3D iUNets and 3D U-Nets. The FBP reconstructions as well as the ground truth volumes were downsized to . Both 3D U-Nets as well as 3D iUNets were subsequently trained to retrieve the ground truth from the FBP reconstructions using the squared -loss. The

peak signal-to-noise-ratios

(PSNR) as well as the structural similarity indices (SSIM) of this experiment cohort are compiled in Table 1. Each line represents one classic 3D U-Net and one 3D iUNet of comparable size. While there is no way to construct perfectly comparable instances of both, the 3D U-Net uses 2 convolutional layers before downsampling (respectively after upsampling), whereas the 3D iUNet employs 4 additive coupling layers (each acting on half of the channels). In the case of the classic U-Net, ’channel blowup’ indicates the number of channels before the first downsampling operation (identical to the number of output feature maps before reducing to one feature map again). In both architectures, layer normalisation was applied. The batch size was 1 in all cases, because for the larger models this was the maximum applicable batch size due to the large memory demand. Random flips and rotations were applied for data augmentation.
As indicated the table, even the worst-performing iUNet performed considerably better than the best-performing classic U-Net, both in terms of PSNR but especially in terms of SSIM. While both model classes benefitted from an increased channel blowup, only the invertible U-Net benefits from raising the number of scales from 4 to 8 (at which point the receptive field spans the whole volume). The fact that the classic U-Net drops in performance despite a higher model capacity may indicate that the optimisation is more problematic in this case. The invertible U-Net shows one of its advantages in this application: By initializing the layer normalisation as the zero-mapping in each coupling block, the whole iUNet was initialised as the identity function. At initialisation, each convolutional layer’s input is thus a part of the whole model input (up to an orthogonal transform). Since the optimal function for learned post-processing can be expected to be close to the identity function, we assume that this initialisation is well-suited for this task. We further used the memory-efficiency of the invertible U-Net to double the channel blowup compared to the largest classic U-Net that we were able to fit into memory. This brought further performance improvements, showing that a higher model capacity can aid in such tasks.
In Figure 6, a test sample processed by the best-perfoming classic iUNet as well as classic U-Net are shown, along with the ground truth and the FBP reconstruction. Apart from the overall lower noise level, the iUNet is able to discern neighbouring holes from one another much better than the classic U-Net. Moreover, in this example a hole that is occluded by noise in the FBP reconstruction does not seem to be recognised as such by the classic U-Net, but is well-differentiated by the iUNet.

Figure 6: Slices through 3D volume from test set (post-processing task). Apart from the lower noise level compared to the classic 3D U-Net, the 3D iUNet is also able to differentiate much better between neighbouring holes (red) and discerning holes from noise (green).
scales channel blowup 3D U-Net 3D iUNet
SSIM PSNR SSIM PSNR
0.302 13.29 0.568 14.00
0.416 13.89 0.780 14.99
0.236 12.42 0.768 15.10
0.425 13.92 0.829 15.82
- - 0.854 16.11
Table 1: Results of learned post-processing experiments. Here, ’scales’ indicates the number of different resolutions, whereas ’channel blowup’ denotes the number of feature maps before reverting to one feature map again.

4.2 Brain Tumor Segmentation

In this experiment, we study the performance of the iUNet for volumetric brain tumor segmentation. The experiment is carried out on a benchmark brain tumor dataset from the BraTS 2018 challenge (Menze et al., 2014). The training dataset for BraTS 2018 includes 285 multi-parametric MRI scans with ground truth labels collected by expert neuroradiologists. We further split the 285 scans into 91% and 9% for training and validation respectively.

For this dataset, we consider three different sizes of invertible networks, with a channel blowup of 16, 32, and 64 channels respectively. For comparison, we also train a baseline 3D U-Net (Çiçek et al., 2016). In our implementation, we have different levels of resolutions in both the U-Net and the invertible networks, starting from a cropping size of and input channels (corresponding to modalities). For the baseline U-net, the input is followed by a blowup to , and is then doubled after each down-sampling. The invertible networks have a channel split of after the invertible D down-sampling, meaning that the channel numbers are doubled as well. After each channel split, two additive coupling layers are used. The results on the BraTS validation set (including 66 scans), measured in the dice score and sensitivity (Bakas et al., 2018) are reported in Table 2. According to the table, the increases of the channel numbers in the invertible networks lead to a gain in the performance in terms of both accuracy and the sensitivity. The largest invertible network outperforms the baseline U-Net, thanks to the memory saving and bigger channel numbers under similar hardware configurations.

Dice score Sensitivity
ET WT TC ET WT TC
U-Net 0.770 0.901 0.828 0.776 0.914 0.813
iUNet-16 0.767 0.900 0.809 0.779 0.916 0.798
iUNet-32 0.782 0.899 0.825 0.773 0.908 0.824
iUNet-64 0.801 0.898 0.850 0.796 0.918 0.829
Table 2: Results on BraTS2018 validation set (acting as the test set)

5 Conclusion and Future Work

In this work, we introduced a fully invertible U-Net (iUNet), which employs a novel learnable invertible up- and downsampling. These are orthogonal convolutional operators, whose kernels are created by exponentiating a skew-symmetric matrix and reordering its entries. We show the viability of the iUNet on two tasks, 3D learned post-processing for CT reconstructions as well as volumetric segmentation. On both the segmentation as well as the CT post-processing task, the iUNet benefitted from an increased depth and width and outperformed its non-invertible counterparts; in the case of the post-processing task even substantially. We therefore conclude that the iUNet should be used e.g. for high-dimensional tasks, in which a classic U-Net is not feasible.

In future work, we would like to check the viability of the iUNet on a wider variety of tasks. Among those are normalizing flows, which are invertible generative neural networks that can be trained via maximizing the (tractable) likelihood under this model. Since these require full invertibility, partially invertible instances of U-Nets would not work in this case. We therefore hope that the structure of the iUNet is well-suited for generative tasks.

Acknowledgements

The authors thank Sil van de Leemput for his help in using and extending MemCNN. CE and CBS acknowledge support from the Wellcome Innovator Award RG98755. RK and CBS acknowledge support from the EPSRC grant EP/T003553/1. CBS additionally acknowledges support from the Leverhulme Trust project on ‘Breaking the non-convexity barrier’, the Philip Leverhulme Prize, the EPSRC grant EP/S026045/1, the EPSRC Centre Nr. EP/N014588/1, the RISE projects CHiPS and NoMADS, the Cantab Capital Institute for the Mathematics of Information and the Alan Turing Institute.

References