Deep neural networks have played a significant role in medical image analysis. Since the advent of UNet[ronneberger2015u] to UNetr[hatamizadeh2022unetr], the performance of neural networks on various tasks like classification, segmentation, and restoration has improved considerably. Deeper and broader convolutional neural networks generally show an improvement in performance at the cost of an increase in the number of learnable parameters, model size and total floating-point operations performed during a single forward pass of the data through the network. Moreover, these models require specialised high-performance hardware even during inference. This reliance on larger models and high-performance hardware hinders the last-mile delivery of AI solutions to improve the existing healthcare system, especially in resource constrained developing and under-developed countries.
Challenges: The performance and trustability of deep neural network-based methods are of utmost importance, especially in the medical domain. The performance of these methods decreases as we try to reduce the number of learnable parameters in the model. As an example, in the case of image classification, deeper networks have been shown to be superior to shallow networks with fewer parameters [he2016deep, huang2017densely]
. Despite the good performance measured in terms of quantitative evaluation metrics, deep neural network (DNN) are known to make the right decision for the wrong reasons[chakravartyradiologist]. This limits the trustability of DNN-based frameworks in practical application. Additionally, the black box nature of the convolutional neural networks makes them unreliable for clinical applications. Developing a method that relies on fewer parameters and is clinically verifiable is a challenging task. Also, an efficient model is expected to replicate the performance during inference at a resonable execution speed even in the absence of GPUs.
Attention-based networks were proposed to augment DNNs with explainability in the case of natural images. However, due to the inherent differences in the nature of images, we cannot assume an equivalent performance in the medical images. As an example, in detecting objects in natural images, the objects of interest often have a well-defined shape and structure, which are absent in the case of medical images. In the case of medical image classification, the biomarkers are usually unstructured pathologies with variable appearance. In this work, we try to verify the effectiveness of replacing convolutions with attention in neural networks for medical images.
Related works: Transformers [vaswani2017attention], based solely on attention mechanisms has revolutionised the way models are designed for natural language tasks. Motivated by their success, [zhu2020deformable], [ye2019cross], [ramachandran2019stand] and [yang2020learning] explored the possibility of using self-attention to solve various vision tasks. Among these, the stand-alone self-attention proposed by [ramachandran2019stand] established that self-attention could potentially replace convolutional layers altogether. Even though it is efficient compared to other DNNs, such models can be further improved by quantising the weights and activations of the networks [paupamah2020quantisation]. The quantisation of deep neural networks has shown significant progress in recent years [xu2021mixed][aji2020compressing]. The ability to quantise the neural network trained in high precision without substantial loss in performance during inference simplifies the process.
Our Approach: Inspired by the success of [ramachandran2019stand] in natural image classification tasks, we propose the design of a new class of networks for medical image classification and segmentation, in which we replace the convolution layers with self-attention layers. Furthermore, we optimise the networks for inference by quantising the parameters thereby decreasing energy consumption. To the best of our knowledge, a quantised fully self-attentive network for classification and segmentation of medical images and comparison with its convolutional counterparts has not been attempted so far. Schematic overview of the proposed method is illustrated in Fig. 1.
2.1 Stand-alone self-attention
Attention was introduced by [bahdanau2014neural]
for a neural machine translation model. Attention modules can learn to focus on essential regions within a context, making it an important component of neural networks. Self-attention[vaswani2017attention] is defined as attention applied to a single context instead of across multiple contexts; that is, Key, Query and Values are derived from the same context. [ramachandran2019stand]
introduced the stand-alone self-attention layer, which can replace convolutions to construct a fully attentional model. Motivated by the initial success of[ramachandran2019stand] in natural images, we explore the feasibility of using such modules in the proposed class of networks for medical image analysis.
To compute attention for each pixel in an image or an activation map, local regions with spatial extent around are used to derive the keys and values
. Learned linear transformations are performed onand its local regions to obtain query (), keys () and values () as
where , and are learnable transformation matrices and is the local region centered at .
Self-attention on its own does not encode any positional information, which makes it permutation equivariant. Relative positional embedding [shaw2018self] as used in [ramachandran2019stand] are incorporated into the attention module. The keys are split into each and column offset and row offset of the positional embedding are added to these separately. After this, we concatenate to obtain a new key () which contains the relative spatial information of pixels in the local region of size . Thus, the relative spatial attention for a pixel is mathematically defined as in Eq. 4 and is graphically illustrated in Figure 2.
where is the neighbourhood of size centered at .
We use these attention blocks instead of 2D convolutional blocks in our networks. During training, all the weights and activations are represented and stored with a precision of FP32. The parameters are quantised to INT8 precision for inference.
2.2 Quantisation of network parameters
We perform quantisation using the FBGEMM (FaceBook GEneral Matrix Multiplication) [fbgemm]
backend of PyTorch for x86 CPUs, which is based on the quantisation scheme proposed by[jacob2018quantization]. In order to be able to perform all the arithmetic operations using integer arithmetic operations on quantised values, we require the quantisation scheme to be an affine mapping of integers to real numbers as
where and are quantisation parameters. We have employed a post-training -bit quantisation of all the weights and operations for our proposed model.
2.3 Network architecture
. The network consists of a series of alternating attention blocks and attention down blocks followed by fully-connected linear layers. The feature maps are downsampled using the max-pooling operation. The size of the output linear layer is equal to the number of target classes. The network is trained to perform multi-label classification using a binary cross-entropy loss.
Segmentation: The proposed segmentation network has a fully attention-based encoder-decoder architecture as shown in Fig. 3(b)
. The encoder unit consists of stand-alone self-attention blocks with ReLU activation and max-pooling operations with the number of feature maps increasing progressively with each attention block. The decoder consists of attention blocks and max-unpooling operations. The size of activation maps of the decoder matches with the corresponding layer in the encoder. The unpooling operations are performed using the indices transferred from the pooling layers in the encoder. To prevent the loss of subtle information, we employ activation concatenation in the decoder, similar to UNet[ronneberger2015u]. The network is trained using soft dice loss .
Classification: To evaluate the performance of the fully self-attentive network (SaDNN-cls) on classification tasks, we have used the NIH Chest X-ray dataset of Common Thorax Disease [wang2017chestx]. The dataset comprises frontal-view X-ray images of patients with fourteen disease labels. These disease classes can co-occur in an image; therefore, the classification problem is formulated as multi-label classification. The train, validation and test split provided in the dataset was used for the experiments.
Segmentation: A subset of the medical segmentation decathlon dataset [antonelli2021medical] is used to evaluate the performance of the proposed fully-attentive network (SaDNN-seg) for liver segmentation. Out of the ground truth paired D CT volumes-Ground truth pairs available in the dataset, per cent were randomly chosen for training, and the remaining per cent were used for testing.
3.2 Implementation Details
Training: The proposed models were trained using an Adam Optimiser [kingma2014adam] with a learning rate of . The models for classification task were trained for epochs and the models for segmentation were trained for epochs.
Baselines: Performance of the proposed quantised self-attention network for the classification task is compared with ResNet-18, ResNet-50 and their bit quantised versions q-ResNet-18, and q-ResNet-50. To assess the performance of the segmentation network, we chose a modified UNet[ronneberger2015u] (UNet-small) and SUMNet[nandamuri2019sumnet] architecture trained on the same dataset split and their quantised versions q-UNet-small and q-SUMNet as baselines.
System specifications: All networks were trained on a high-performance server with a NVIDIA GPU, Intel(R) Xeon(R) Silver CPU @ , GB RAM and TB HDD running on Ubuntu LTS OS. The inference of quantised models was also performed on the same class of CPUs.
4 Results and Discussions
4.1 Qualitative Analysis
visualisation of predictions of the proposed q-SaDNN-seg network and its unquantised version SaDNN-seg are presented in Fig. 4. Over-segmented regions in the predicted segmentation maps are marked in green, under-segmented regions are marked in red and correctly segmented region is shown in white. We observe that the tendency of the original unquantised network SaDNN-seg to over-segment is significantly reduced post quantisation. However, the quantisation of network parameters causes the q-SaDNN-seg to under-segment the target organ. This is reflected in the slightly lower Dice coefficient (DSC) of the proposed model as seen in Table 6.
4.2 Quantitative Analysis
The performance of the proposed quantised fully self-attentive network and baselines for multi-label classification task is reported in terms of accuracy in Table 6. It can be observed that the proposed network can achieve performance slightly better than the existing deep residual convolutional neural networks. Table 6 shows the comparison of the proposed segmentation network with the baselines in terms of DSC. The proposed quantised network performs almost as good as the quantised versions of the baseline convolutional neural networks.
4.3 Computational Analysis
The DNNs used for the experiments exhibited superior classification and segmentation performance in terms of quantitative metrics, but they require a considerable amount of computations and memory access operations to be performed. Deploying a framework which needs excessive computations to be performed results in large energy consumption, which is not feasible in diverse resource-constrained scenarios. Therefore, it is key to have an energy-efficient model without degradation in performance. A rough estimate of energy cost per operation inIC design can be calculated using Table 1 presented in [horowitz20141, wu2018training, park2020cenna].
|8-bit INT||0.2 pJ||0.03 pJ|
|16-bit FP||1.1 pJ||0.40 pJ|
|32-bit FP||3.7 pJ||0.90 pJ|
The number of multiplication and addition operations in a standalone self-attention layer [vaswani2021scaling] can be calculated as
where is the block (local region) size and is the number of channels.
The total number of parameters, MACs, energy consumed during forward pass and model size of the proposed q-SaDNN-cls and q-SaDNN-seg networks are reported in Table 2 and Table 3 with graphical comparisons in Fig. 7. Models with the least area in the radar charts are more efficient. The proposed q-SaDNN-cls network is smaller than quantised ResNet-18 and smaller than quantised ResNet-50 in terms of model size. In terms of total MAC units, the the propsed networks have fewer MACs than ResNet-18, fewer than ResNet-50. Similarly, in terms of the total trainable parameters, the proposed networks have lesser parameters than ResNet-18 and lesser than ResNet-50.
|ResNet-18||11.17 M||9.10 G||44.79 MB||20.93 J|
|q-ResNet-18||11.17 M||9.10 G||11.40 MB||1.04 J|
|ResNet-50||23.53 M||21.11 G||94.45 MB||48.53 J|
|q-ResNet-50||23.53 M||21.11 G||24.52 MB||2.41 J|
|SaDNN-cls||4.56 M||3.10 G||18.30 MB||7.13 J|
|q-SaDNN-cls||4.56 M||3.10 G||4.72 MB||0.35 J|
Similar improvement in efficiency of computing can be observed in the case of segmentation as well. The segmentation network q-SaDNN-seg is smaller than q-UNet-small and smaller than q-SUMNet in terms of model size. In terms of total MAC units, the q-SaDNN-seg has fewer than SUMNet. In terms of the trainable parameters, q-SaDNN-seg has lesser parameters than UNet-small and lesser than SUMNet. It is to be noted that the proposed models are superior in terms energy consumption as well.
|UNet-small||31.03 M||218.60 G||118.48 MB||502.78 J|
|q-UNet-small||31.03 M||218.60 G||29.77 MB||25.13 J|
|SUMNet||23.53 M||425.98 G||91.07 MB||979.75 J|
|q-SUMNet||23.53 M||425.98 G||22.88 MB||48.97 J|
|SaDNN-seg||7.95 M||277.15 G||30.47 MB||637 J|
|q-SaDNN-seg||7.95 M||277.15 G||8.02 MB||31.87 J|
4.4 Analysis of clinical relevance
Validating the results of the model with respect to clinically relevant information to provide some explanations for the decision made by the model is an important factor that determines trustability. The clinically relevant region provided in the NIH Chest X-ray dataset as marked by a radiologist and the saliency map based explanation generated using RISE[petsiuk2018rise] for the proposed quantised self-attention deep neural network for classification are shown in Fig. 8. It can be observed that the proposed model focuses on the clinically relevant region while making the decision.
We proposed a class of quantised self-attentive neural networks which can be used for medical image classification and segmentation. In these networks, convolutional layers are replaced with attention layers which have fewer learnable parameters. Computation of attention while considering a small local region surrounding a pixel prevents degradation of performance despite the absence of local feature extraction which is typically performed in a CNN. We show that our energy efficient method achieves performance at par with the commonly used CNNs with fewer number of parameters and model size. These attributes make our proposed models affordable and easy to adopt in resource constrained settings.