TransMorph_Transformer_for_Medical_Image_Registration
TransMorph: Transformer for Unsupervised Medical Image Registration
view repo
In the last decade, convolutional neural networks (ConvNets) have dominated the field of medical image analysis. However, it is found that the performances of ConvNets may still be limited by their inability to model long-range spatial relations between voxels in an image. Numerous vision Transformers have been proposed recently to address the shortcomings of ConvNets, demonstrating state-of-the-art performances in many medical imaging applications. Transformers may be a strong candidate for image registration because their self-attention mechanism enables a more precise comprehension of the spatial correspondence between moving and fixed images. In this paper, we present TransMorph, a hybrid Transformer-ConvNet model for volumetric medical image registration. We also introduce three variants of TransMorph, with two diffeomorphic variants ensuring the topology-preserving deformations and a Bayesian variant producing a well-calibrated registration uncertainty estimate. The proposed models are extensively validated against a variety of existing registration methods and Transformer architectures using volumetric medical images from two applications: inter-patient brain MRI registration and phantom-to-CT registration. Qualitative and quantitative results demonstrate that TransMorph and its variants lead to a substantial performance improvement over the baseline methods, demonstrating the effectiveness of Transformers for medical image registration.
READ FULL TEXT VIEW PDFTransMorph: Transformer for Unsupervised Medical Image Registration
Deformable image registration (DIR) is fundamental for many medical imaging analysis tasks. It functions by establishing spatial correspondence in order to minimize the differences between a pair of fixed and moving images. Traditional methods formulate image registration as a variational problem for estimating a smooth mapping between the points in one image and those in another (avants2008symmetric; beg2005computing; vercauteren2009diffeomorphic; heinrich2013mrf; modat2010fast). However, such methods are computationally expensive and usually slow in practice because the optimization problem needs to be solved from scratch for each pair of unseen images.
Since recently, deep neural networks (DNNs), especially convolutional neural networks (ConvNets), have demonstrated state-of-the-art performances in many computer vision tasks, including object detection (
redmon2016you), image classification (he2016deep), and segmentation (long2015fully). Ever since the success of U-Net in the ISBI cell tracking challenge of 2015 (ronneberger2015u), ConvNet-based methods have become a major focus of attention in medical image analysis fields, such as tumor segmentation (isensee2021nnu; zhou2019unet++), image reconstruction (zhu2018image), and disease diagnostics (lian2018hierarchical). In medical image registration, ConvNet-based methods can produce significantly improved registration performance while operating orders of magnitudes faster (after training) compared to traditional methods. ConvNet-based methods replace the costly per-image optimization seen in traditional methods with a single global function optimization during a training phase. The ConvNets learn the common representation of image registration from training images, enabling rapid alignment of an unseen image pair after training. Initially, the supervision of ground-truth deformation fields (which are usually generated using traditional registration methods) is needed for training the neural networks (onofrey2013semi; yang2017quicksilver; rohe2017svf). Recently, the focus has been shifted towards developing unsupervised methods that do not depend on ground-truth deformation fields (balakrishnan2019voxelmorph; dalca2019unsupervised; kim2021cyclemorph; de2019deep; de2017end; lei20204d; chen2020generating; zhang2018inverse). Nearly all of the existing deep-learning-based methods mentioned above used U-Net (ronneberger2015u) or the simply modified versions of U-Net (e.g., tweaking the number of layers or changing down- and up-sampling schemes) as their ConvNet designs.ConvNet architectures generally have limitations in modeling explicit long-range spatial relations (i.e., relations between two voxels that are far away from each other) present in an image due to the intrinsic locality (i.e., the limited effective receptive field) of convolution operations (luo2016understanding). Even though U-Net (or V-Net (milletari2016v)) was proposed to overcome this limitation by introducing down- and up-sampling operations into a ConvNet, which theoretically enlarges the receptive field of the ConvNet and, thus, encourages the network to consider long-range relationships between points in images. However, several problems remain: first, the receptive fields of the first several layers are still restricted by the convolution-kernel size, and the global information of an image can only be viewed at the deeper layers of the network; second, it has been shown that as the convolutional layers deepen, the impact from far-away voxels decays quickly (li2021medical). Therefore, the effective receptive field of a U-Net is, in practice, much smaller than its theoretical receptive field, and it is only a portion of the typical size of a medical image. Yet, it is desirable to have a registration network capable of seeing long-range relationships between points in the moving and fixed images. Otherwise, the network is limited to produce relatively small local deformations, yielding suboptimal registration performance. Many works in other fields (e.g., image segmentation) has addressed this limitation of U-Net (zhou2019unet++; jha2019resunet++; devalla2018drunet; alom2018recurrent). To allow for a better flow of multi-scale contextual information throughout the network, Zhou et al. (zhou2019unet++) proposed a nested U-Net (i.e., U-Net++), in which the complex up- and down-samplings along with multiple skip connections were used. Devalla et al. (devalla2018drunet) introduced dilated convolution to the U-Net architecture that enlarges the network’s effective receptive field. A similar idea was proposed by Alom et al. (alom2018recurrent), where the network’s effective receptive field was increased by deploying recurrent convolutional operations. Jha et al. proposed ResUNet++ (jha2019resunet++) that incorporates the attention mechanisms into U-Net for modeling long-range spatial information. Despite these methods’ promising performance in other medical imaging fields, limited work has been proposed on using advanced network architectures for medical image registration.
Transformer, originated from natural language processing tasks (
vaswani2017attention), has shown its potential in computer vision tasks. A Transformer deploys self-attention mechanisms to determine which parts of the input sequence (e.g., an image) are essential based on contextual information. Unlike convolution operations, the self-attention mechanisms in a Transformer have unlimited size of the effective receptive field, making a Transformer capable of capturing long-range spatial information (li2021medical). Dosovitskiy et al. (dosovitskiy2020image) proposed Vision Transformer (ViT) that applies the Transformer encoder from NLP directly to images. It was the first purely self-attention-based network for computer vision and achieved state-of-the-art performance in image recognition. Subsequent to their success, Swin Transformer (liu2021swin) and its variants (dai2021dynamic; dong2021cswin) have demonstrated their superior performances in object detection, and semantic segmentation. Recently, Transformer-related methods have gained increased attention in the medical imaging community as well (chen2021transunet; xie2021cotr; wang2021transbts; li2021medical; wang2021ted; zhang2021transct), in which the majority of the methods were applied on the task of image segmentation.Transformer can be a strong candidate for image registration because it can better comprehend the spatial correspondence between the moving and fixed images. Registration is the process of establishing such correspondence, and intuitively, by comparing different parts of the moving to the fixed image. ConvNet has a narrow field of view and performs convolution locally, limiting its ability to associate the distant parts between two images. For example, if the left part of the moving image matches the right part of the fixed image, ConvNet will be unable to establish the proper spatial correspondence between the two parts if it cannot see both parts concurrently (i.e., when one of the parts falls outside of the ConvNet’s field of view). However, Transformer is capable of handling such circumstances and rapidly focusing on the parts that need displacement, owing to its unlimited receptive field and self-attention mechanism.
Our group has previously shown preliminary results that demonstrated the simple bridging of ViT and V-Net could achieve promising performances in image registration (chen2021vit). In this work, we extend the preliminary work and investigate various Transformer models from other tasks (i.e., computer vision and medical imaging tasks) and present a hybrid Transformer-ConvNet framework, TransMorph, for volumetric medical image registration. In this method, the Swin Transformer (liu2021swin) was employed as the encoder to capture the spatial correspondence between the input moving and fixed images. Then, a ConvNet decoder processed the information provided by the Transformer encoder into a dense displacement field. Long skip connections were deployed to maintain the flow of local information between the encoder and decoder stages. We also introduced diffeomorphic variations of TransMorph to ensure a smooth and topology-preserving deformation. Additionally, we applied variational inference on the parameters of TransMorph, resulting in a Bayesian model that predicts registration uncertainty based on the given image pair. Qualitative and quantitative evaluation of the experimental results demonstrate the robustness of the proposed method and confirm the efficacy of Transformers for image registration.
The main contributions of this work are summarized as follows:
Transformer-based model: This paper presents the pioneering work on using Transformers for image registration. A novel Transformer-based neural network, TransMorph, was proposed for affine and deformable image registration.
Architecture analysis: Experiments in this paper demonstrate that positional embedding, which is a commonly used element in Transformer by convention, is not required for the proposed hybrid Transformer-ConvNet model. Additionally, this paper provides proof that Transformer-based models have larger effective receptive fields than ConvNets.
Diffeomorphic registration: Two diffeomorphic variants of TransMorph are presented that ensure a diffeomorphism in the deformation.
Uncertainty quantification: This paper also provides a Bayesian uncertainty variant of TransMorph that yields perfectly calibrated uncertainty estimates.
State-of-the-art results: We extensively validate the proposed registration models on the brain MRI registration application and on a novel application of XCAT-to-CT registration with an aim to create a population of anatomically variable XCAT phantom. The datasets used in this study include over 1000 image pairs for training and testing. The proposed models were compared with various registration methods and demonstrated state-of-the-art performance. Eight registration approaches were employed as baselines, including learning-based methods and commonly used traditional methods. The performances of four recently proposed Transformer architectures from other tasks (e.g., semantic segmentation, classification, etc.) were also evaluated on the task of image registration.
Open source: We provide the community with a fast and accurate tool for deformable registration. The code and the trained model is publicly available at https://bit.ly/37eJS6N.
The paper is organized as follows. Section 2 discusses the related works. Section 3 explains the proposed methodology. Section 4 discusses experimental setup, implementation details, and datasets used in this study. Section 5 presents experimental results. Section 6 discusses the findings based on the results, and Section 7 concludes the paper.
DIR establishes spatial correspondence between two images by optimizing an energy function:
(1) |
where and denote, respectively, the moving and fixed image, denotes the deformation field that warps the moving image (i.e., ), imposes smoothness of the deformation field, and is the regularization trade-off parameter. The optimal warping, is given by minimizing this energy function:
(2) |
In the energy function, measures the level of alignment between the deformed moving image, , and the fixed image, . Some common choices for are mean squared error (MSE) (beg2005computing), sum of squared differences (SSD) (wolberg2000robust), normalized cross-correlation (NCC) (avants2008symmetric), structural similarity index (SSIM) (chen2020generating), and mututal information (MI) (viola1997alignment). The regularization term, , imposes spatial smoothness on the deformation field. A common assumption in most applications is that similar structures exist in both moving and fixed images. As a result, a continuous and invertible deformation field (i.e., a diffeomorphism) is needed to preserve topology, and the regularization, is meant to enforce or encourage this. Isotropic diffusion (equivalent to Gaussian smoothing) (balakrishnan2019voxelmorph), anisotropic diffusion (pace2013locally), total variation (vishnevskiy2016isotropic), and bending energy (johnson2002consistent) are the popular options for .
While traditional image registration methods iteratively minimize the energy function in (1) for each pair of moving and fixed images, DNN-based methods optimize the energy function for a training dataset, thereby learning a global representation of image registration that enables alignment of an unseen pair of volumes. DNN methods are often categorized as supervised or unsupervised, with the former requiring a ground truth deformation field for training and the latter relying only on the image datasets.
In supervised DNN methods, the ground-truth deformation fields are either produced synthetically or generated by the traditional registration methods (yang2017quicksilver; sokooti2017nonrigid; cao2018deep). yang2017quicksilver proposed a supervised ConvNet that predicts the LDDMM (beg2005computing) momentum from image patches. sokooti2017nonrigid trained a registration ConvNet with synthetic displacement fields. The ground-truth deformation fields are often computationally expensive to generate, and the registration accuracy of these methods is highly dependent on the quality of the ground truth.
Due to the limitations of supervised methods, the focus of research has switched to unsupervised DNN methods that do not need ground-truth deformation fields. Unsupervised DNNs optimize an energy function on the input images, similarly to traditional methods. However, DNN-based methods learn a common registration representation from a training set and then apply it to unseen images. Note that the term ”unsupervised” refers to the absence of ground-truth deformation fields, but the network still needs training (this is also known as ”self-supervised”). de2019deep; balakrishnan2018unsupervised; balakrishnan2019voxelmorph are the representative ones of unsupervised DNN-based methods.
Later on, diffeomorphic deformation representation was developed to address the issue of non-smooth deformations in DNN-based methods. We briefly introduce its concepts in the next subsection.
Diffeomorphic deformable image registration is important in many medical image applications, owing to its special properties including topology preservation and transformation invertibility. A diffeomorphism is a smooth and continuous one-to-one mapping with invertible derivatives (i.e., non-zero Jacobian determinant). In some traditional methods (e.g., LDDMM (beg2005computing) and SyN (avants2008symmetric)), a diffeomorphic deformation field can be obtained via the time-integration of a time-varying velocity field , i.e., , where is the identity transform. On the other hand, learning-based diffeomorphic models (dalca2019unsupervised; mok2020fast; shen2019networks; niethammer2019metric; krebs2019learning) primarily use stationary velocity field (SVF) with an efficient scaling-and-squaring approach (arsigny2006log), because it can be implemented with relative ease. In the scaling-and-squaring approach, the deformation field is represented as a Lie algebra member that is exponentiated to generate a time 1 deformation , which is a member of the Lie group: . This means that the exponentiated flow field compels the mapping to be diffeomorphic and invertible using the same flow field. Starting from an initial deformation field:
(3) |
where denotes the spatial locations. The can be obtained using the recurrence:
(4) |
Thus, .
In practice, a neural network first generates a displacement field, which is then scaled by to produce an initial deformation field . Subsequently, the squaring technique (i.e., Eqn. 4) is applied recursively to times via a spatial transformation function, resulting in a final diffeomorphic deformation field
. Despite the fact that diffeomorphisms are theoretically guaranteed to be invertible, interpolation errors can lead to invertibility errors that rise linearly with the number of interpolation steps (
avants2008symmetric; mok2020fast).Transformer makes use of the self-attention mechanism that estimates the relevance of one input sequence to another via the Query-Key-Value (QKV) model (vaswani2017attention; dosovitskiy2020image). The input sequences are often originated from the flattened patches of an image. Let be an image volume defined over a 3D spatial domain (i.e., ). The image is first divided into flattened 3D patches , where is the size of the original image, is the size of each image patch, and . Then, a learnable linear embedding is applied to , which projects each patch into a vector representation:
(5) |
Then, a learnable positional embedding is added to so that the patches can retain their positional information, i.e., , where . These vector representations, often known as tokens, are subsequently used as inputs for self-attention computations.
To compute self-attention (SA), is encoded by to three matrix representations: queries , keys , and values . The scaled dot-product attuention is given by:
(6) |
where is the attention weight matrix, each element of represents the pairwise similarity between two elements of the input sequence and their respective query and key representations.
A Transformer employs multi-head self-attention (MSA) rather than a single attention function. MSA is an extension of self-attention in which self-attention operations (i.e., ”heads”) are processed simultaneously, then their outputs are concatenated then projected onto a -dimensional representation:
(7) |
where , and is typically set to in order to keep the number of parameters constant before and after the MSA operation.
Uncertainty estimates help comprehend what a machine learning model does not know. They indicate the likelihood that a neural network may make an incorrect prediction. Because most deep neural networks are incapable of representing uncertainty, their predictions are frequently taken at face value and thought to be correct.
Bayesian deep learning estimates predictive uncertainty, providing a realistic paradigm for understanding uncertainty within deep neural networks (gal2016dropout). The uncertainty caused by the parameters in a neural network is known as epistemic uncertainty, which is modeled by placing a prior distribution (e.g., a Gaussian prior distribution: ) on the parameters of a network and then attempting to capture how much these weights vary given specific data. Recent efforts in this area include the Bayes by Backprop (blundell2015weight), its closely related mean-field variational inference by assuming a Gaussian prior distribution (Tolle2021), stochastic batch normalization (
atanov2018uncertainty), and Monte-Carlo (MC) dropout (gal2016dropout; Kendall2017). The applications of Bayesian deep learning in medical imaging expands on image denoising (Tolle2021; Laves2020) and image segmentation (devries2018leveraging; baumgartner2019phiseg; mehrtash2020confidence). In deep-learning-based image registration, the majority of methods provide a single, deterministic solution of the unknown geometric transformation. Knowing about epistemic uncertainty helps determine if and to what degree the registration results can be trusted and whether the input data is appropriate for the neural network. The registration uncertainty estimations may be used for uncertainty-weighted registration (simpson2011longitudinal; kybic2009bootstrap), surgical treatment planning, or directly visualized for qualitative evaluations (yang2017quicksilver). Cui et al. (cui2021bayesian) and Yang et al. (yang2017quicksilver) incorporated MC dropout layers in the registration network design, which allows for uncertainty estimates by sampling multiple deformation field predictions from the network.The proposed image registration framework expands on these ideas. In particular, a new registration framework is presented, which uses Transformer for network design, diffeomorphism for image registration, and Bayesian deep learning to estimate registration uncertainty.
The conventional paradigm of image registration is shown in Fig. 3. The moving and fixed images, denoted respectively as and , are first affinely transformed into a single coordinate system. The resulting affine-aligned moving image is denoted as . Subsequently, is warped to using a deformation field, , generated by a DIR algorithm (i.e., ). Fig. 1 presents an overview of the proposed method. Here, both the affine transformation and the deformable registration are performed using Transformer-based neural networks. The affine Transformer takes and as inputs and computes a set of affine transformation parameters (e.g., rotation angle, translation, etc.). These parameters are used to affinely align with via an affine transformation function, yielding an aligned image . Then, a DIR network computes a deformation field given and , which warps using a spatial transformation function (i.e., ). During training, the DIR network may optionally leverage supplementary information (e.g., anatomical segmentations). The network architectures, the loss and regularization functions, and the variants of the method are described in detail in the following sections.
Affine transformation is often used as the initial stage in image registration because it facilitates the optimization of the following more complicated DIR processes (de2019deep). An affine network examines a pair of moving and fixed images globally and produces a set of transformation parameters that aligns the moving image with the fixed image. Here, the architecture of the proposed Transformer-based affine network is a modified Swin Transformer (liu2021swin) that takes two 3D volumes as the inputs (i.e., and ) and generates 12 affine parameters: three rotation angles, three translation parameters, three scaling parameters, and three shearing parameters. The details and a visualization of the architecture are shown in Fig. 18 in the Appendix. We reduced the number of parameters in the original Swin Transformer due to the relative simplicity of affine registration. The specifics of the Transformer’s architecture and parameter settings are covered in a subsequent section.
Fig. 2 shows the network architecture of the proposed TransMorph. The encoder of the network first splits the input moving and fixed volumes into non-overlapping 3D patches, each of size , where is typically set to 4 (dosovitskiy2020image; liu2021swin; dong2021cswin). We denote the patch as , where and is the total number of patches. Each patch is flattened and regarded as a ”token”, and then a linear projection layer is used to project each token to a feature representation of an arbitrary dimension (denoted as ):
(8) |
where denotes the linear projection, and the output has a dimension of .
Because the linear projection operates on image patches and does not keep the token’s location relative to the image as a whole, existing Transformers often adds a positional embedding to the linear projections in order to integrate the positional information into tokens, i.e. (vaswani2017attention; dosovitskiy2020image; liu2021swin; dong2021cswin
). Such Transformers were primarily designed for image classification, where the output is often a vector describing the likelihood of an input image being classified as a certain class. Thus, if the positional embedding is not employed, the Transformer may lose the positional information. However, for pixel-level tasks such as image registration, the network often includes a decoder that generates a dense prediction with the same resolution as the input or target image. The exact voxel locations in the output image is enforced by comparing the output with the target image using a loss function. Any spatial mismatches between output and target would contribute to the loss and be backpropagated into the Transformer encoder. The Transformer should thereby inherently capture the tokens’ positional information. In this work, we observed that positional embedding is not necessary for image registration, and it only adds extra parameters to the network without improving performance. The effects of positional embedding will be discussed in more detail later in this paper.
Following the linear projection layer, several consecutive stages of patch merging and Swin Transformer blocks (liu2021swin) are applied on the tokens . The Swin Transformer blocks outputs the same number of tokens as the input, while the patch merging layers concatenate the features of each group of neighboring tokens, thus they reduce the number of tokens by a factor of (e.g., ). Then, a linear layer is applied on the -dimensional concatenated features to produce features each of -dimension. After four stages of Swin Transformer blocks and three stages of patch merging in between the Transformer stages (i.e., orange boxes in Fig. 2), the output dimension at the last stage of the encoder is . The decoder consists of successive upsampling and convolutional layers with the kernel size of . Each of the upsampled feature maps in the decoding stage was concatenated with the corresponding feature map from the encoding path via skip connections, then followed by two consecutive convolutional layers. As shown in Fig. 2, the Transformer encoder can only provide feature maps up to a resolution of owing to the nature of patch operation (denoted by the orange arrows). Hence, Transformer may fall short of delivering high-resolution feature maps and aggregating local information at lower layers (raghu2021vision). To address this shortcoming, we employed two convolutional layers using the original and downsampled image pair as inputs to capture local information and generate high-resolution feature maps. The outputs of these layers were concatenated with the feature maps in the decoder to produce a deformation field. The final deformation field, , was generated by applying sixteen
convolutions. Except for the last convolutional layer, each convolutional layer is followed by a Leaky Rectified Linear Unit (
maas2013rectifier) activation.In the next subsections, we discuss the Swin Transformer block, the spatial transformation function, and the loss functions in detail.
Swin Transformer (liu2021swin) can generate hierarchical feature maps at various resolutions by using patch merging layers, making it ideal for usage as a general-purpose backbone for pixel-level tasks like image registration and segmentation. Swin Transformer’s most significant component, apart from patch merging layers, is the shifted window-based self-attention mechanism. Unlike ViT (dosovitskiy2020image), which computes the relationships between a token and all other tokens at each step of the self-attention modules. Swin Transformer computes self-attention within the evenly partitioned non-overlapping local windows of the original and the lower resolution feature maps (as shown in Fig. 5 (a)). In contrast to the original Swin Transformer, this work uses rectangular-shaped windows to accommodate non-square images, and each has a shape of . At each resolution, the first Swin Transformer block employs a regular window partitioning method, beginning with the top-left voxel, and the feature maps are evenly partitioned into non-overlapping windows of size . The self-attention is then calculated locally within each window. To introduce connections between neighboring windows, Swin Transformer uses a shifted window design, i.e., in the successive Swin Transformer blocks, the windowing configuration shifts from that of the preceding block, by displacing the windows in the preceding block by () voxels. As illustrated by an example in Fig. 5 (b), the input feature map has voxels. With a window size of , the feature map is evenly partitioned into number of windows in the first Swin Transformer block (”Swin Block 1” in Fig. 5 (b)). Then, in the next block, the windows are shifted by , and the number of windows becomes . We extended the original 2D efficient batch computation (i.e., cyclic shift) (liu2021swin; liu2021video) to 3D and applied it to the 27 shifted windows, keeping the final number of windows for attention computation at 8. With the windowing-based attention, two consecutive Swin Transformer blocks can be computed as:
(9) |
where W-MSA and SW-MSA stand for, respectively, window-based multi-head self-attention and shifted-window-based multi-head self-attention module; MLP denotes the multi-layer perceptron module (
vaswani2017attention); and denote the output features of the (S)W-MSA and the MLP module for block , respectively. The self-attention is computed as:(10) |
where are query, key, value matrices, denotes the dimension of query and key features, is the number of tokens in a 3D window, and represents the relative position of tokens in each window. Since the relative position between tokens along each axis (i.e., ) can only take values from , the values in are taken from a smaller bias matrix .
The spatial transformation function proposed in (jaderberg2015spatial) is used to apply a nonlinear warp to the moving image with the deformation field (or the displacement field ) provided by the network. The intensity of each voxel location in the output image is determined by:
(11) |
Note that is not necessarily an integer, and voxel intensities are only defined at integer locations in an image. Thus, the intensity at is obtained by applying interpolation methods to the intensities of the neighboring voxels around :
(12) |
where , denotes the neighboring voxel locations of , denotes the three directions in 3D image space, and is a generic sampling kernel that defines the interpolation method (e.g., tri-linear, nearest-neighbor, etc.). Tri-linear interpolation is often used on images, while nearest-neighbor is often used on label maps. Due to the non-differentiable nature of the nearest-neighbor sampling kernel, it can only be used for inference and not for network training.
The overall loss function for network training derives from the energy function of the traditional image registration algorithms (i.e., Eqn. (1)). The loss function consists of two parts: one computes the similarity between the deformed moving and the fixed images, and another one regularizes the deformation field so that it is smooth:
(13) |
where denotes the image fidelity measure, and denotes the deformation field regularization.
In this work, we experimented with two widely-used similarity metric for . The first is the mean squared error, which is the mean of the squared difference in voxel values between and :
(14) |
where denotes the voxel location, and represents the image domain.
Another similarity metric is the local normalized cross-correlation between and :
(15) |
where and denotes the mean voxel value within the local window of size centered at voxel , and in our experiments.
Optimizing the similarity metric solely would encourage to be visually as close as possible to . The resulting deformation field , however, might not be smooth or realistic. To impose smoothness in the deformation field, a regularizer is added to the loss function. encourages the displacement value in a location to be similar to the values in its neighboring locations. Here, we experimented with two regularizers for . The first is the diffusion regularizer balakrishnan2019voxelmorph:
(16) |
where is the spatial gradients of the displacement field . The spatial gradients are approximated using forward differences, that is, .
The second regularizer is bending energy (rueckert1999nonrigid), which penalizes sharply curved deformations, thus, it may be helpful for abdominal organ registration. Bending energy operates on the second derivative of the displacement field , and it is defined as:
(17) |
where the derivatives are estimated using the same forward differences that were used previously.
When the organ segmentation of and is available, TransMorph may leverage this auxiliary information during training for improving the anatomical mapping between and . A loss function that quantifies the segmentation overlap is added to the overall loss function (Eqn. 13):
(18) |
where and represent, respectively, the organ segmentation of and , and is a weighting parameter that controls the strength of . In the field of image registration, it is common to use Dice score (dice1945measures) as a figure of merit to quantify registration performance. Therefore, we directly minimized the Dice loss (milletari2016v) between and , where represents the structure/organ:
(19) |
should ideally be implemented via nearest-neighbor interpolation, since is a binary mask. As mentioned in Section 3.2.2, nearest-neighbor interpolation is not differentiable, and therefore cannot be used for network training. Here, we used a method similar to that described in (balakrishnan2019voxelmorph), in which we designed and as image volumes with channels, each channel containing a binary mask defining the segmentation of a specific structure/organ. Then, is computed by warping the -channel with using linear interpolation so that the gradients of can be backpropagated into the network.
In this section, we present two variants of TransMorph that guarantee a diffeomorphic deformation such that the resulting deformable mapping is continuous, differentiable, and preserves topology. We achieved diffeomorphic registration using the scaling-and-squaring approach (described in section 2.1.2) with a stationary velocity field representation (arsigny2006log). The two existing diffeomorphic models, VoxelMorph-diff (dalca2019unsupervised) and MIDIR (qiu2021learning), have been adopted as bases for the proposed TransMorph diffeomorphic variants, designated by TransMorph-diff (section 3.3.1) and TransMorph-bspl (section 3.3.2), respectively. The next subsections discuss these models in depth.
As shown in Fig. 6, we introduced a variational inference framework to the proposed TransMorph (which we denote as TransMorph-diff.). A prior distribution
(20) |
was placed over the dense displacement field , where and
are the mean and covariance of the multivariate normal distribution. We followed (
dalca2019unsupervised) and defined , where denotes the precision matrix, controls the scale of , is the Laplacian matrix of a neighborhood graph formed on the voxel grid, is the graph degree matrix, andis a voxel neighborhood adjacency matrix. The probability
can be computed using the law of total probability:
(21) |
The likelihood was also assumed to be Gaussian
(22) |
where
captures the variance of the image noise, and
is the group exponential of the time-stationary velocity field , i.e. , which is computed using the scaling-and-squaring approach (section 2.1.2).Our goal is to estimate the posterior probability
. Due to the intractable nature of the integral over in Eqn. 21, is usually calculated using just the ’s that are most likely to have generated (krebs2019learning). Since computing the posterior analytically is also intractable, we instead assumed a variational posterior learned by the network with parameters. The Kullback-Leibler divergence (KL) is used to relate the variational posterior to the actual posterior, which results in the evidence lower limit (ELBO) (
kingma2013auto):(23) |
where the KL-divergence on the left hand side vanishes if the variational posterior is identical to the actual posterior. Therefore, maximizing is equivalent to minimizing the negative of ELBO on the right hand side of Eqn. 23. Since the prior distribution was assumed to be a multivariate Gaussian, the variational posterior is likewise a multivariate Gaussian, defined as:
(24) |
where and are the voxel-wise mean and variance generated by the network with parameters . In each forward pass, the dense displacement field is sampled using reparameterization with . The variational parameters and are learned by minimizing the loss (dalca2019unsupervised):
(25) |
where can be think of as a diffusion regluarization (Eqn. 16) placed over the mean displacement field , that is , where represents the neighboring voxels of the voxel.
As discussed in section 3.2.3, when the auxiliary segmentation information is available (i.e., the label maps of and , denoted as and ), Dice loss can be used for training the network to further enhance registration performance. Dice loss, however, does not preserve a Gaussian approximation. Instead, we follow (dalca2019unsupervised) and replace the KL divergence in Eqn. 23 with:
(26) |
which yields a loss function of the form:
(27) |
In (dalca2019unsupervised), and represent anatomical surfaces originated from label maps. In contrast, we directly used the label maps as and in this work. They are image volumes with multiple channels, each channel containing a binary mask defining the segmentation of a certain structure/organ.
We incorporated a cubic B-spline model (qiu2021learning) to TransMoprh (which we denote as TransMorph-bspl.), in which the network produces a lattice of low-dimensional control points instead of producing a dense displacement field at the original resolution, which might be computationally costly. As shown in Fig. 7, we denote the displacements of the B-spline control points generated by the network as and the spacing between the control points as . Then, a weighted combination of cubic B-spline basis functions (i.e., ) (rueckert1999nonrigid
) is used to generate the dense displacement field (i.e., the B-spline tensor product in Fig.
7):(28) |
where is the index of the control points on the lattice , and denotes the coordinates of the control points in image space. Then the final time-stationary displacement is obtained using the same scaling-and-squaring approach described in section 2.1.2.
In this section, We extend the proposed TransMorph to a Bayesian neural network (BNN) using the variational inference framework with Monte Carlo dropout (gal2016dropout). We denoted the resulting model as TransMorph-Bayes. In this model, Dropout layers were inserted into the Transformer encoder of the TransMorph architecture but not into the ConvNet decoder, in order to avoid imposing excessive regularity for the network parameters and thus decreasing performance. We briefly review the fundamental ideas in the following paragraph, but we refer readers to (gal2016dropout) for more details.
Model | Conv. skip. | Trans. skip. | Parameters (M) |
---|---|---|---|
TransMorph | ✓ | ✓ | 46.77 |
TransMorph w/o conv. skip. | ✓ | - | 46.70 |
TransMorph w/o Trans. skip. | - | ✓ | 41.55 |
TransMorph w/ lrn. positional embedding | ✓ | ✓ | 63.63 |
TransMorph w/ sin. positional embedding | ✓ | ✓ | 46.77 |
Model | Embed. Dimension | Swin-T. block numbers | Head numbers | Parameters (M) |
---|---|---|---|---|
TransMorph | 96 | {2, 2, 4, 2} | {4, 4, 8, 8} | 46.77 |
TransMorph-tiny | 6 | {2, 2, 4, 2} | {4, 4, 8, 8} | 0.24 |
TransMorph-small | 48 | {2, 2, 4, 2} | {4, 4, 4, 4} | 11.76 |
TransMorph-large | 128 | {2, 2, 12, 2} | {4, 4, 8, 16} | 108.34 |
VoxelMorph-huge | - | - | - | 63.25 |
Given a training dataset of moving and fixed images , where represents the moving images, the corresponding target images is denoted by , and is the number of image pairs. We aim to find the predictive distribution of the form , which can be computed as:
(29) |
where denotes the weight matrices in the Transformer encoder. However, this integral over is computationally intractable since one would need to sample an large number of networks with various configurations of the parameters to get an accurate estimate of the predictive distribution. Since the distribution cannot be evaluated analytically, we define an variational distribution
that follows a Bernoulli distribution:
(30) |
where is the weight, is chosen so that the dimensions of and are match, and is the probability that . It has been shown in (gal2016dropout) that ELBO maximization may be accomplished via training and inferencing with dropout (srivastava2014dropout). In the Transformer encoder of TransMorph, we added a dropout layer after each fully connected layer in the MLPs (Eqn. 9) and after each self-attention computation (Eqn. 10). Note that these are the locations where dropout are commonly used for Transformer training. We set the dropout probability at 0.15 to further avoid the network imposing an excessive degree of regularity on the network weights. The predictive mean can be estimated by using Monte Carlo integration (gal2016dropout):
(31) |
This is equivalent to averaging the output of forward passes through the network during inference, where represents the deformation field produced by forward pass. The registration uncertainty can be estimated using the predictive variance as in (yang2017quicksilver; yang2016fast; yang2017fast; sentker2018gdl; sedghi2019probabilistic):
(32) |
Estimating the uncertainty in image registration allows for evaluating the trustworthiness of the registered image, which enables the assessment of operative risks and leads to better-informed clinical decisions (luo2019applicability). An ideal uncertainty estimate should be properly correlated to the inaccuracy of the registration results; that is, a high uncertainty value should indicate a large registration error, and vice versa. Otherwise, doctors/surgeons may be misled by the erroneous estimate of registration uncertainty and place unwarranted confidence in the registration results, resulting in severe consequences (luo2019applicability; risholm2013bayesian; risholm2011estimation). The uncertainty given by Eqn. 32 is expressed as the variability from the mean model prediction. Such an uncertainty estimation does not account for the systematic errors (i.e., bias) between the mean registration prediction and the target image; therefore, a low uncertainty value given by Eqn. 32 does not always guarantee an accurate registration result.
When the predicted uncertainty values closely corresponded with the expected model error, they are considered to be well-calibrated (laves2019well; levi2019evaluating). In an ideal scenario, the estimated registration uncertainty should completely reflect the actual registration error. For instance, if the predictive variance of a batch of registered images generated by the network is found to be 0.5, the expectation of the squared error should likewise be 0.5. Accordingly, if the expected model error is quantified by MSE, then the perfect calibration of predictive registration uncertainty may be established by extending the definitions presented in (guo2017calibration; levi2019evaluating; laves2020uncertainty):
(33) |
In the conventional paradigm of Bayesian neural network, the uncertainty estimate is derived from the predictive variance relative to the predictive mean as in Eqn. 32. However, it can be shown that this predictive variance may be miscalibrated as a result of overfitting the training dataset (as shown in B). Therefore, the uncertainty values estimated based on in Eqn. 32 may be biased. This bias often needs to be corrected in many applications, such as image denoising or classification (laves2019well; guo2017calibration; kuleshov2018accurate; phan2018calibrating; laves2020uncertainty; pmlr-v121-laves20a), such that the uncertainty values closely reflect the expected error. In image registration, however, the expected error may be computed even during the test time since the target image is always known. Therefore, a perfectly calibrated uncertainty quantification may be achieved without additional effort. Here, we propose to replace the predicted mean with the target image in Eqn. 32. Then, the estimated registration uncertainty is the equivalent to the expected error:
(34) |
The comparison between the two uncertainty estimate methods (i.e., and ) is shown later in this paper.
TWo datasets including over 1000 image pairs were used to thoroughly validate the proposed method. The details of each dataset are described in the following sections.
For the inter-patient brain MR image registration task, we used a dataset of 260 T1–weighted brain MRI images acquired from Johns Hopkins University. The dataset was split into 182, 26, and 52 (7:1:2) volumes for training, validation, and test sets. Each image volume was randomly matched to two other volumes in the set to form four registration pairs of and , resulting in 768, 104, and 208 image pairs. FreeSurfer (fischl2012freesurfer) was used to perform standard pre-processing procedures for structural brain MRI, including skull stripping, resampling, and affine transformation. The pre-processed image volumes were then cropped to equal size of 160×192×224. Label maps including 29 anatomical structures were obtained using FreeSurfer for evaluating registration performances.
Computerized phantoms have been widely used in the medical imaging field for algorithm optimization and imaging system validation (Christoffersen2013; chen2019incorporating; zhang2017new). The four-dimensional extended cardiac-torso (XCAT) phantom (segars20104d) was developed based on anatomical images from the Visible Human Project data. While the current XCAT phantom^{1}^{1}1as of October, 2021 can model anatomical variations through organ and phantom scaling, they can not completely replicate the anatomical variations seen in humans. As a result, XCAT-to-CT registration (which can be thought of as atlas-to-image registration) has become a key method for creating anatomically variable phantoms (chen2020generating; fu2021iphantom; segars2013population). This research used a CT dataset from (segars2013population), which includes 50 non-contrast chest-abdomen-pelvis (CAP) CT scans collected from the Duke University imaging database. Selected organs and structures were manually segmented in each patient’s CT scan. The structures segmented included the following: the body outline, the bone structures, lungs, heart, liver, spleen, kidneys, stomach, pancreas, large intestine, prostate, bladder, gall bladder, and thyroid. The manual segmentation was done by several medical students, and the results were subsequently corrected by an experienced radiologist at Duke University. The CT volumes have voxel sizes ranging from mm to mm. We used trilinear interpolation to resample all volumes to an identical voxel spacing of
mm. The volumes were then cropped and zero-padded to have an equivalent size of
voxels. The intensity values were first clipped in the range of Hounsfield Units and then normalized to the range of . The XCAT attenuation map was generated with a resolution of mm using the material compositions and attenuation coefficients of the constituents at 120 keV. It is then resampled, cropped, and padded so that the resulting volume matches the size of the CT volumes. The XCAT attenuation map’s intensity values were also normalized to be within a range of . The XCAT and CT images are affinely aligned using the proposed affine network. The dataset was split into 35, 5, and 10 (7:1:2) volumes for training, validation, and testing. We conducted five-fold cross-validation on the fifty image volumes, resulting in 50 testing volumes in total.To illustrate TransMorph’s effectiveness, we compared it to various registration methods that have previously demonstrated state-of-the-art registration performances. We begin by comparing TransMorph
with four non-deep-learning-based methods, and their hyperparameter settings are described below:
SyN^{2}^{2}2https://github.com/ANTsX/ANTsPy(avants2008symmetric): For brain MR registration, we used the mean squared difference (MSQ) as the objective function, along with a Gaussian smoothing of 3 and three scales with 180, 80, 40 iterations, respectively. For XCAT-to-CT registration, we used the cross-correlation (CC) as the objective function, along with a Gaussian smoothing of 5 and three scales with 160, 100, 40 iterations, respectively.
NiftyReg^{3}^{3}3https://www.ucl.ac.uk/medical-image-computing(modat2010fast): We employed the sum of squared differences (SSD) as the objective function and bending energy as a regularizer for both tasks. For brain MR registration, we used a regularization weighting of 0.0002 and three scales with 300 iterations each. For XCAT-to-CT registration, we used a regularization weight of 0.0005 and five scales with 500 iterations each.
deedsBCV^{4}^{4}4https://github.com/mattiaspaul/deedsBCV (heinrich2015multi): The objective function was self-similarity context (SSC) (heinrich2013towards) by default. For brain MR registration, we used the hyperparameter values suggested in (hoffmann2020learning) for neuroimaging, in which the grid spacing, search radius, and quantization step are set to , , and , respectively. For XCAT-to-CT registration, we used the default parameters suggested for abdominal CT registration, where the grid spacing, search radius, and quantization step are , , and , respectively.
LDDMM^{5}^{5}5https://github.com/brianlee324/torch-lddmm (beg2005computing): MSE was used as the objective function by default. For brain MR registration, we used the smoothing kernel size of 5, the smoothing kernel power of 2, the matching term coefficient of 4, the regularization term coefficient of 10, and the iteration number of 500. For XCAT-to-CT registration, we used the same kernel size, kernel power, the matching term coefficient, and the number of iteration. However, the regularization term coefficient was empirically set to 3.
Next, we compared the proposed method with several existing deep-learning-based methods. For a fair comparison, unless otherwise indicated, the loss function (Eqn. 13) that consists of MSE (Eqn. 14) and diffusion regularization (Eqn. 16) was used to register brain MR, while the loss function (Eqn. 18) that consists of LNCC (Eqn. 15), bending energy (Eqn. 17), and Dice (Eqn. 19) was used for XCAT-to-CT registration. Auxiliary data (organ segmentation) was used for XCAT-to-CT registration. Recall that the hyperparameters and define, respectively, the weight for deformation field regularization and Dice loss. The detailed parameter settings of each method are as follows:
VoxelMorph^{6}^{6}6http://voxelmorph.csail.mit.edu (balakrishnan2018unsupervised; balakrishnan2019voxelmorph): We employed two variants of VoxelMorph, the second of which twice the number of convolution filters in the first variant; they are designated as VoxelMorph-1 and -2, respectively. For brain MR registration, the regularization hyperparameter was set to 0.02, which was reported as an optimal value in the VoxelMorph paper. For XCAT-to-CT registration, we set .
VoxelMorph-diff^{7}^{7}7http://voxelmorph.csail.mit.edu (dalca2019unsupervised): For brain MR registration task, the loss function (Eqn. 25) was used with set to 0.01 and set to 20. For XCAT-to-CT registration, we used the loss function (Eqn. 27) with and .
CycleMorph^{8}^{8}8https://github.com/boahK/MEDIA_CycleMorph (kim2021cyclemorph): In CycleMorph, the hyerparameters , , and , correspond to the weights for cycle loss, identity loss, and deformation field regularization. We chose the optimal and values reported in kim2021cyclemorph, and was set empirically because the default values were found to be sub-optimal for applying CycleMorph to our dataset. For brain MR registration, we set , , and . Whereas for XCAT-to-CT registration, we modified the CycleMorph by adding a Dice loss with a weighting of 1 to incorporate organ segmentation during training, and we set and . We observed that the value of 1 suggested in kim2021cyclemorph yielded over-smoothed deformation field in our application. Therefore, the value of was decreased to 0.1.
MIDIR^{9}^{9}9https://github.com/qiuhuaqi/midir (qiu2021learning): The same loss function and value as VoxelMorph were used. In addition, the control point spacing for B-spline transformation was set to 2 for both tasks, which was shown to be an optimal value in qiu2021learning.
To show that the proposed Swin-Transformer-based network architecture outperforms other Transformer models, we compared its performance to other existing Transformer-based networks that achieved state-of-the-art performance in other applications (e.g., image segmentation, object detection, etc.). We customized these models to make them suitable for image registration. They were modified to produce 3-dimensional deformation fields that warp the given moving image. Note that the only change between the methods below and VoxelMorph is the network architecture, with the spatial transformation function, loss function, and network training procedures remaining the same. The first three models use the hybrid Transformer-ConvNet architecture (i.e., ViT-V-Net, PVT, and CoTr), while the last model uses a pure Transformer-based architecture (i.e., nnFormer). Their network hyperparameter settings are as follows:
ViT-V-Net^{10}^{10}10https://bit.ly/3bWDynR (chen2021vit): This registration network was developed based on ViT (dosovitskiy2020image). For brain MR registration, we applied the default hyperparameter settings suggested in (chen2021vit).
PVT^{11}^{11}11https://github.com/whai362/PVT (wang2021pyramid): The default settings were applied, except that the embedding dimensions were to be , the number of heads was set to , and the depth was increased to .
CoTr^{12}^{12}12https://github.com/YtongXie/CoTr (xie2021cotr): We used the default network settings for both brain MR and XCAT-to-CT registration tasks.
nnFormer^{13}^{13}13https://github.com/282857341/nnFormer (zhou2021nnformer): Because nnFormer was also developed on the basis of Swin Transformer, we applied the same Transformer parameter values as in TransMorph to make a fair comparison.
The proposed TransMorph
was implemented using PyTorch (
paszke2019pytorch) on a PC with an NVIDIA TITAN RTX GPU and an NVIDIA RTX3090 GPU. All models were trained for 500 epochs using the Adam optimization algorithm, with a learning rate of
and a batch size of 1. The brain MR dataset was augmented with flipping in random directions during training, while no data augmentation was applied to the CT dataset. Restricted by the sizes of the image volumes, the window sizes (i.e., ) used in Swin Transformer were set to for MR brain registration and for XCAT-to-CT registration, respectively. The Transformer hyperparameter settings for TransMorph is listed in the first row of Table. 2. Note that the variants of TransMorph (i.e., TransMorph-Bayes, TransMorph-bspl, and TransMorph-diff) share the same Transformer settings as TransMorph. The hyperparameter settings for each proposed variant are described as follows:TransMorph: The identical loss function parameters as VoxelMorph were used for both MR registration and XCAT-to-CT registration tasks.
TransMorph-Bayes: The identical loss function parameters as VoxelMorph were applied here for both tasks. The dropout probability (i.e., in Eqn. 30) was set to 0.15.
TransMorph-bspl: Again, the loss function settings for both tasks were the same as those used in VoxelMorph. The control point spacing, , for B-spline transformation was also set to 2, the same value used in MIDIR.
TransMorph-diff: We applied the same loss function parameters as those used in VoxelMorph-diff.
The affine model presented in this work comprises of a compact Swin Transformer. The Transformer parameter settings were identical to TransMorph except that the embedding dimension was set to be 12, the numbers of Swin Transfomer block were set to be , and the head numbers were set to be . The resulting affine model has a total number of parameters. Because the MRI datasets were affinely aligned as part of the preprocessing, the affine model was only used in the XCAT-to-CT registration.
In this section, we design experiments to verify the effect of Transformer modules in TransMorph architecture. Specifically, we performed two additional studies on model complexity and network components. They are described in the following subsections.
We begin by examining the effects of several network components on registration performance. Table 1 listed three variants of TransMorph by either keeping or removing the network’s long skip connections or the positional embeddings in the Transformer encoder. In ”TransMorph w/o conv. skip.”, the long skip connections from the two convolutional layers were removed (including two convolutional layers), which are the green arrows in Fig. 2. In ”TransMorph w/o trans. skip.”, the long skip connections coming from the Swin Transformer blocks were removed, which are the orange arrows in Fig. 2. We hypothesized in section 3.2 that the positional embedding (i.e., in Eqn. 8) is not a necessary element of TransMorph, because the positional information of tokens can be learned implicitly in the network via the consecutive up-sampling in the decoder and backpropagating the loss between output and target. Here, we conducted experiments to study the effectiveness of positional embeddings. In the second to last variant, ”TransMorph w/ lrn. positional embedding”, we used a learnable positional embedding in the Transformer encoder, the same one used in the Swin Transformer (liu2021swin). In the last variant, ”TransMorph w/ sin. positional embedding”, we substituted the learnable positional embedding with a sinusoidal positional embedding, the same embedding used in the original Transformer (vaswani2017attention), which hardcodes the positional information in the tokens.
Model | DSC | % of |
---|---|---|
Affine | 0.5720.166 | - |
SyN | 0.7220.127 | 0.0001 |
NiftyReg | 0.7160.131 | 0.0610.093 |
LDDMM | 0.7100.131 | 0.0001 |
deedsBCV | 0.7190.130 | 0.2530.110 |
VoxelMoprh-1 | 0.7110.134 | 0.4260.231 |
VoxelMoprh-2 | 0.7160.133 | 0.3890.222 |
VoxelMorph-diff | 0.7060.136 | 0.0001 |
CycleMorph | 0.7170.134 | 0.2310.168 |
MIDIR | 0.7040.131 | 0.0001 |
ViT-V-Net | 0.7290.128 | 0.4020.249 |
PVT | 0.7220.130 | 0.4270.254 |
CoTr | 0.7180.131 | 0.4150.258 |
nnFormer | 0.7220.128 | 0.3990.234 |
TransMorph-Bayes | 0.7360.125 | 0.3890.241 |
TransMorph-diff | 0.7230.128 | 0.0001 |
TransMorph-bspl | 0.7330.124 | 0.0001 |
TransMorph | 0.7370.125 | 0.3960.240 |
The impact of model complexity on registration performance was also investigated in this paper. Table 2 listed the parameter settings and the number of trainable parameters of four variants of the proposed TransMorph model. In the base model, TransMorph, the embedding dimension was set to 96, and the number of Swin Transformer blocks in the four stages of the encoder was set to 2, 2, 4, and 2, respectively. Additionally, we introduced TransMorph-tiny, TransMorph-small, and TransMorph-large, which are about , , and the model size of TransMorph. Finally, we compared our model to a customized VoxelMorph (denoted as VoxelMorph-huge), which has a comparable parameter size to that of TransMorph w/ lrn. positional embedding. Specifically, we maintained the same number of layers in VoxelMorph-huge as in VoxelMorph, but increased the number of convolution kernels in each layer. As a result, VoxelMorph-huge has 63.25 million trainable parameters.
Model | DSC | % of | SSIM |
---|---|---|---|
w/o registration | 0.2200.242 | - | 0.5760.071 |
Affine Transformer | 0.3300.291 | - | 0.7510.018 |
SyN | 0.4940.341 | 0.0001 | 0.8940.021 |
NiftyReg | 0.4820.341 | 0.0001 | 0.8860.027 |
LDDMM | 0.4110.341 | 0.0060.007 | 0.8740.031 |
deedsBCV | 0.5680.306 | 0.0010.001 | 0.8630.029 |
VoxelMoprh-1 | 0.5330.312 | 0.0230.013 | 0.8990.027 |
VoxelMoprh-2 | 0.5490.310 | 0.0170.009 | 0.9100.027 |
VoxelMorph-diff | 0.5280.328 | 0.0001 | 0.9110.020 |
CycleMorph | 0.5310.321 | 0.0330.012 | 0.9090.024 |
MIDIR | 0.5470.306 | 0.0001 | 0.8960.022 |
ViT-V-Net | 0.5790.315 | 0.0140.006 | 0.9150.020 |
PVT | 0.5270.317 | 0.0290.012 | 0.9000.027 |
CoTr | 0.5590.312 | 0.0150.011 | 0.9050.029 |
nnFormer | 0.5470.311 | 0.0140.006 | 0.9020.024 |
TransMorph-Bayes | 0.5950.313 | 0.0150.009 | 0.9190.024 |
TransMorph-diff | 0.5410.325 | 0.0001 | 0.9100.025 |
TransMorph-bspl | 0.5750.308 | 0.0001 | 0.9080.025 |
TransMorph | 0.5980.313 | 0.0170.008 | 0.9180.023 |
The left panel of Fig. 8 shows the qualitative registration results on an example brain MRI slice. The scores in blue, orange, green, and pink correspond to ventricles, third ventricle, thalami, and hippocampi, respectively. Additional qualitative comparisons across all methods are shown in Fig. 19 in Appendix. Among the proposed models, diffeomorphic variants (i.e., TransMorph-diff and TransMorph-bspl) generated smoother displacement fields, with TransMorph-bspl producing the smoothest deformations inside the brain area. On the other hand, TransMorph and TransMorph-Bayes showed better visual results with higher Dice scores for the highlighted structures.
The quantitative evaluations are shown in Table 3. Overall, the results presented in the table show that the proposed method, TransMorph, achieved the highest Dice score of 0.737. Although the diffeomorphic variants produced slightly lower Dice scores than TransMorph, they still outperformed the existing registration methods and generated almost no foldings (i.e., of ) in deformation fields. By comparison, TransMorph improved performance by 0.1 when compared to VoxelMorph and by 0.3 when compared to CycleMorph. We found that the Transformer-based models (i.e., TransMorph, ViT-V-Net, PVT, CoTr, and nnFormer) generally produced better Dice scores than the ConvNet-based models. Note that even though ViT-V-Net has almost twice the number of the trainable parameters (as shown in Fig. 9), TransMorph still outperformed all the Transformer-based models (including ViT-V-Net) by at least 0.08 in Dice, demonstrating Swin-Transformer’s superiority over other Transformer architectures. When we conducted hypothesis testing on the results using the paired -test, the -values for TransMorph over all other methods (i.e., non-TransMorph methods) were , with an exception of TransMorph over ViT-V-Net, which was .
The right panel of Fig. 8 shows the qualitative results on a representative CT slice. The blue, orange, green, and pink lines denote the liver, heart, left lung, and right lung, respectively, while the bottom values show the corresponding Dice scores. Similar to the findings in the previous section, TransMorph and TransMoprh-Bayes gave more precise registration results, although diffeomorphic variations produced smoother deformations. Additional qualitative comparisons are shown in the Fig. 22 in the Appendix. It is possible to see certain artifacts in the displacement field created by nnFormer (as shown in Fig. 22); these are most likely caused by the patch operations of the Transformers used in its architecture. nnFormer is a near-convolution-free model (convolutional layers were employed only to form displacement fields). In contrast to brain MRI registration, displacement in XCAT-to-CT registration may exceed the patch size. As a consequence, the lack of convolutional layers to refine the displacement field may have resulted in those artifacts. Four example coronal slices of the deformed XCAT phantoms generated by various registration methods are shown in Fig. 23 in Appendix.
The quantitative evaluation results are presented in Table 4. They include Dice scores for all organs and scans, the percentage of non-positive Jacobian determinants, and the structural similarity index (SSIM) (wang2004image) between the deformed XCAT phantom and the target CT scan. The window size used in SSIM was set to 7. Without registration or affine transformation, a poor Dice score of 0.22 and an SSIM of 0.576 demonstrate the vast dissimilarity between the original XCAT phantom and patient CT scans. The Dice score and SSIM increased to 0.33 and 0.751, respectively, after aligning the XCAT and patient CT using the proposed affine Transformer. Among the traditional registration methods, deedsBCV, which was initially designed for abdominal CT registration-based segmentation (heinrich2015multi), achieved the highest Dice score of 0.568, which is even higher than most of the learning-based approaches. Among the learning-based approaches, Transformer-based models outperformed ConvNet-based models on average, which is consistent with the previous section’s finding in brain MR registration. The
-value from the paired t-test between
TransMorph and other learning-based methods were with an exception that -value=0.158 for ViT-V-Net. The proposed TransMorph models yielded the highest Dice and SSIM scores of all methods on average, with the best Dice of 0.598 given by TransMorph and the best SSIM of 0.919 given by TransMorph-Bayes. The diffeomorphic variants produced lower Dice and SSIM scores as a consequence of not having any folded voxels in the deformation.The left figure in Fig. 10 shows the Dice scores from the ablation study on inter-patient brain MR registration. When evaluating the effectiveness of skip connections, we discovered that the skip connections from both the convolution and Transformer layers might assist with registration performance. TransMorph scored a mean Dice of 0.735 after the skip connections from the convolutional layers were removed. However, the score decreased to 0.719 when the skip connections from the Transformer blocks were removed. In comparison, the efficacy of the skip connections from convolutional layers is less substantial, with a mean DSC improvement of 0.002. Note that TransMorph without using positional embedding achieved the same mean DSC and a very comparable violin plot as TransMorph, suggesting that positional embedding may be an unnecessary component.
The figure on the right in Fig. 10 shows scores from the study on XCAT-to-CT registration. Without the skip connections from the convolution and Transformer layers, the Dice scores drop by 0.005 and 0.014, respectively, when compared to TransMorph, further supporting that skip connections can improve performance. For XCAT-to-CT registration, TransMorph performed slightly worse with positional embeddings than it did without them, with Dice scores decreasing by 0.005 and 0.004 for learnable and sinusoidal positional embeddings, respectively. More data suggest that positional embedding may be unnecessary for TransMorph. The effect of each component is addressed in depth in the Discussion section (section 6.1).
The circular barplot in the left panel of Fig. 11 shows the computational complexity comparisons between the deep-learning-based registration models. The plot was created using an input image with a resolution of , the same size as the brain MRI images. The numbers are expressed in Giga multiply-accumulate operations (GMACs), with a higher value indicating a more computationally expensive model that may also be more memory intensive. The proposed model, TransMorph, and its Bayesian variant, TransMorph-Bayes, have a moderate computational complexity with 687 GMACs which is much less than CoTr and CycleMorph. In practice, the GPU memory occupied during training was about 15 GiB with a batch size of 1 and an input image size of . The diffeomorphic variants, TransMorph-diff and TransMorph-bspl, have 281 and 454 GMACs, which are comparable to that of the conventional ConvNet-based registration models, VoxelMorph-1 and -2. In practice, they occupied approximately 11 GiB of GPU memory during training, which is a size that can be readily accommodated by the majority of modern GPUs.
Fig. 10 shows the quantitative results of various TransMorph models and the customize ConvNet-based model VoxelMorph-huge on both inter-patient brain MRI and XCAT-to-CT registration. When parameter size is the only variable in TransMorph models, there is a strong correlation between model complexity (as shown in the right panel of Fig. 11) and registration performance. TransMorph-tiny produced the lowest mean DSC of 0.689 and 0.497 for brain MRI and XCAT-to-CT registration, respectively. The Dice scores steadily improve as the complexity of the model increases. Note that for brain MRI registration (left figure in Fig. 10), the improvement in mean DSC from TransMorph to TransMorph-large is just 0.002 but the latter is almost twice as computationally costly (as shown in the right panel of Fig. 11). Furthermore, TransMorph-large performed very comparable to TransMorph with a difference of 0.001 in Dice for XCAT-to-CT registration. The customized ConvNet-based model, VoxelMorph-huge, achieved a mean DSC close to that of TransMorph for brain MR registration but a suboptimal Dice score of 0.547 for XCAT-to-CT registration. A significant disadvantage of VoxelMorph-huge is its computational complexity, with 3656 GMACs (as seen in the right panel of Fig. 11), it is nearly five times as computationally expensive as TransMorph, making it memory-intensive ( 22 GiB for a patch size of 1 during training) and slow to train in practice.
As previously shown in section 5.3, skip connections may aid in enhancing registration accuracy. In this section, we give further insight into the skip connections’ functionality.
Fig. 12 shows some example feature maps in each skip connection (a full feature map visualization is shown in Fig. 28 in Appendix). Specifically, the left panel shows sample slices of the input volumes; the center panel illustrates selected feature maps in the skip connections of the convolutional layers, and the right panel illustrates selected feature maps in the skip connections of the Swin Transformer blocks. As seen from these feature maps that the Swin Transformer blocks provided more abstract information (right panel in Fig. 12), in comparison to the convolutional layers (middle panel in Fig. 12). Since a Transformer divides an input image volume into patches to create tokens for self-attention operations (as described in section 3.2), it can only deliver information up to a certain resolution, which is often a factor of the patch size lower than the original resolution (i.e., , and in our case). On the other hand, the convolutional layers resulted in higher resolution feature maps with more detailed and human-readable information (e.g., edge and boundary information). Certain feature maps even revealed distinctions between the moving and fixed images (highlighted by the red boxes). Fig. 13 shows the qualitative comparisons between the proposed model with and without a specific type of skip connection. As seen by the magnified areas, TransMorph with both skip connection types provided a more detailed and accurate displacement field. Therefore, adding skip connections with convolutional layers is still recommended, although the actual DSC improvement were subtle ( for inter-patient brain MRI and for XCAT-to-CT registration).
Transformers in computer vision were initially designed for image classification tasks (dosovitskiy2020image; liu2021swin; dong2021cswin; wang2021pyramid). Such a Transformer produces a condensed probability vector that is not in the image domain but instead a description of the likelihood of being a certain class. The loss calculated based on this vector does not backpropagate any spatial information into the network. Thus, it is critical to encode positional information on the patched tokens; otherwise, as the network gets deeper, Transformer would lose track of the tokens’ locations relative to the input image, resulting in unstable training and inferior predictions. However, for pixel-level tasks like image registration, the condensed features generated by Transformers are often subsequently expanded using a decoder whose output is an image with the same resolution as the input and target images. Any spatial mismatching between the output and target contributes to the loss, which is then backpropagated throughout the network. As a result, the Transformer implicitly learns the positional information of tokens, thus obviating the need for positional embedding. In this work, we compared the registration performance of TransMorph and TransMoprh with positional embedding on brain MRI and XCAT-to-CT registration. The results shown in section 5.3 indicated that positional embedding does not improve registration performance; rather, it introduces more parameters into the network. In this section, we discuss the positional embeddings in further detail.
Two positional embeddings were studied in this paper: sinusoidal (vaswani2017attention) and learnable (liu2021swin) embeddings, which are also the two major types of positional embedding. In sinusoidal positional embedding, the position of each patched token is represented by a value drawn from a predetermined sinusoidal signal according to the token’s position relative to the input image. Whereas with learnable positional embedding, the network learns the representation of the token’s location from the training dataset rather than giving a hardcoded value. To validate that the network learned the positional information, dosovitskiy2020image computed the cosine similarities between a learned embedding of a token and that of all other tokens. The obtained similarity values were then used to form an image. If positional information is learned, the image should reflect increased similarities at the token’s and nearby tokens’ positions.
Here, we computed the images of cosine similarities for both sinusoidal and learnable positional embeddings used in this work. The left and right panels in Fig. 14 show the images of cosine similarities. These images are generated based on an input image size of and a patch size of (resulting in patches). Each image has a size of representing an image of cosine similarities in the plane of (i.e., the middle slice). There should have been a total of images in each panel. However, for better visualization, just a few images are shown here. The images were chosen with step sizes of 5 and 8 in and direction, respectively, resulting in images in each panel. As seen from the left panel, the images of sinusoidal embeddings exhibit a structured pattern, showing a high degree of correlation between tokens’ relative locations and image intensity values. Note that the brightest pixel in each image represents the cosine similarity between a token’s positional embedding and itself, which reflects the token’s actual location relative to all other tokens. The similarity then gradually decreases as it gets farther away from the token. On the other hand, images generated with learnable embeddings (right panel of Fig. 14) lack such structured patterns, implying that the network did not learn the positional information associated with the tokens in the learnable embeddings. However, as seen from the Dice scores in Fig. 10, regardless of which positional embedding is employed, the mean Dice scores and violin plots are quite comparable to those produced without positional embedding. There is thus evidence that the network learns the positional information of the tokens implicitly, and hence positional embedding is redundant and does not improve registration performance.
As previously mentioned in section 3.4, registration uncertainty estimates produced by the existing method (e.g., (yang2017quicksilver; yang2017fast; sentker2018gdl; sedghi2019probabilistic)) were actually miscalibrated, meaning that the uncertainty values did not properly correlate to predicted model errors since variance was computed using the predictive mean instead of target image . We proposed to directly use the expected model error to express model uncertainty since the target image is available at all times in image registration. Thus, the resulting uncertainty estimate is perfectly calibrated. In this section, we examine how the proposed and existing methods differ in their estimates of uncertainty.
To quantify the calibration error, we used the Uncertainty Calibration Error (UCE) introduced in (pmlr-v121-laves20a), which is calculated on the basis of the binned difference between the expected model error (i.e., ) and the uncertainty estimation (e.g., in Eqn. 32 or in Eqn. 34). We refer the interested reader to the corresponding references for further details about UCE. The plots in the left panel of Fig. 15 exhibit the calibration plots and UCE obtained on four representative test sets. All results are based on a sample size of 25 (i.e., in Eqn. 31, 32, and 34) from 10 repeated runs. The blue lines show the results produced with the and the shaded regions represent the standard deviation from the 10 runs, while the dashed black lines indicate the perfect calibration achieved with the proposed method. Notice that the uncertainty values obtained using do not match well to the expected model error; in fact, they are consistently being underestimated (for reasons described in section 3.4.1). In comparison, the proposed method enables perfect calibration with UCE since its uncertainty estimate equals the expected model error. In the right panel of Fig. 15, we show the visual comparisons of the uncertainty derived from and . When we compare either (e) to (f) or (k) to (l), we see that the former (i.e., (e) and (k)) captured more registration failures than the latter (as highlighted by the yellow arrows), indicating a stronger correlation between deformation uncertainty and registration failures. This is thus further evidence that the proposed method provides the perfect uncertainty calibration. More results on uncertainty estimations are shown in Fig. 26 in Appendix.
Despite the promising results, there are some limitations of using to estimate uncertainty. In this work, we model as , which is the MSE of the Monte Carlo sampled registration outputs relative to the fixed image. MSE, on the other hand, is not necessarily the optimal metric for expressing the expected error. In multi-modal registration instances like PET to CT or MRI to CT registration, MSE is anticipated to be high, given the vast difference in image appearance and voxel values across modalities. Thus, if MSE is employed to quantify uncertainty in these instances, the uncertainty values will be dominated by the squared bias (i.e., in Eqn. 38), resulting in an ineffective uncertainty estimate. In these instances, the predicted variance may be a more appropriate choice for uncertainty quantification.
We demonstrate in this section that the effective receptive fields (ERFs) of Transformer-based models are larger than that of ConvNet-based models and span the whole spatial domain of an image. We used the definition of ERF introduced in (luo2016understanding), which quantifies the amount of influence that each input voxel has on the output of a neural network. In the next paragraph, we briefly discuss the computation of ERF and recommend interested readers to the reference for further information.
Assume the voxels in the input image and the output displacement field are indexed by . We employed an image size of (i.e., the size of CT scans used in this work), and the center voxel is located at . ERF quantifies how much each contributes to the center voxel of the displacement field, i.e. . This is accomplished using the partial derivative , which indicates the relative relevance of to . To obtain this partial derivative, we set the error gradient to:
(35) |
where denotes an arbitrary loss function. Then this gradient is propagated downward from to the input , where the resulting gradient of represents the desired partial derivative . This partial derivative is independent of the input and loss function and is only a function of the network architecture and the index , which adequately describes the distribution of the effective receptive field.
The comparison of the ERFs of VoxelMorph, CycleMorph, MIDIR, and TransMorph is shown in Fig. 16. Recall that the first three models are built on the ConvNet architecture. Due to the locality of convolution operations, their ERFs (top row in Fig. 16) are small, limited, and square-shaped. On the other hand, the ERF of the proposed TransMorph spans over the entire image. For better visualization, we then extracted the non-zero regions from the ERFs and interpolated their voxel sizes to match that of the original CT scan. The bottom row of Fig. 16 shows the delineation of the interpolated ERFs on top of a CT scan. These figures demonstrate that ConvNet-based architectures can only perceive a portion of the input image, implying that they cannot explicitly comprehend the spatial relationships between distant voxels. Thus, ConvNets may fall short of establishing accurate voxel correspondences between the moving and fixed images, which is essential for image registration. Additionally, since the receptive field of ConvNets only grows with the depth of the layers, the actual receptive fields of the first few layers of ConvNets are even smaller. The proposed TransMorph, on the other hand, sees the entire image at each level of its encoder thanks to the self-attention mechanism of the Transformer.
A full visual comparison of the ERFs of the ConvNet- and Transformer-based models is shown in Fig. 27 in Appendix.
Fig. 17 shows the validation dice scores of the learning-based methods during training. In comparison to other methods, the proposed TransMorph achieves in Dice within the first 20 epochs, showing that it learns the spatial correspondence between image pairs quicker than the competing models. Notably, TransMorph consistently outperformed the other Transformer-based models while having a comparable number of parameters and computational complexity. This implies Swin Transformer architecture is more effective than other Transformers, resulting in a performance improvement for TransMorph. On average, Transformer-based models provide better validation scores than ConvNet-based models, with the exception of CoTr, whose validation results are volatile during training (as seen from the orange curve in Fig. 17). The performance of CoTr may be limited by its architecture design, which substitutes a Transformer for the skip connections and bottleneck of a U-shaped CovnNet. As a result, it lacks the direct flow of features learned during the encoding stage to the layers creating the registration, making it difficult to converge.
Model | Training (min/epoch) | Inference (sec/image) |
---|---|---|
SyN | - | 192.140 |
NiftyReg | - | 30.723 |
LDDMM | - | 66.829 |
deedsBCV | - | 31.857 |
VoxelMoprh-1 | 8.75 | 0.380 |
VoxelMoprh-2 | 9.40 | 0.430 |
VoxelMoprh-diff | 4.20 | 0.049 |
VoxelMoprh-huge | 28.50 | 1.107 |
CycleMorph | 41.90 | 0.281 |
MIDIR | 4.05 | 1.627 |
ViT-V-Net | 9.20 | 0.197 |
PVT | 13.80 | 0.209 |
CoTr | 17.10 | 0.372 |
nnFormer | 6.35 | 0.105 |
TransMorph-Bayes | 22.60 | 7.739 |
TransMorph-diff | 7.35 | 0.099 |
TransMorph-bspl | 10.50 | 1.739 |
TransMorph | 14.40 | 0.329 |
Comments
There are no comments yet.