Clinically, the disease diagnosis is usually conducted based on critical biomarkers derived from an analysis of the images. For example, on fundus images, the vertical Cup-to-Disc Ratio (vCDR) parameter computed from the optic cup/disc (OD/OC) masks is one of the most important clinical parameters for the glaucoma diagnosis . In melanoma diagnosis, an unusual shape of the skin lesions is a major biomarker indicating melanoma . In order to derive these important biomarkers, an essential step is to identify lesions or tissues in an image and segment these areas of interest from the rest of the image [18, 9, 34].
Motivated by this observation, methods have been proposed to utilize segmentation information to facilitate the automated disease diagnosis [10, 36, 31, 19, 4, 30, 33]. The common practices include region of interest (ROI) extraction [10, 4], input concatenation, channel attention [19, 36]
, and transfer learning. These methods have two main limitations. First, the methods proposed for specific medical tasks are not general enough. They are often inapplicable or have unsatisfactory performance on other medical tasks. Second, most methods simply assume that the segmentation and diagnosis features are regional correlated, which is an invalid assumption in most cases. Traditional techniques they applied, like convolution layers and channel attentions are difficult to model this non-regional feature interaction, since these tools are largely local-focused. With the rise of vision transformer , such a research gap can be possibly addressed by its global and dynamic nature .
In this paper, we propose a novel transformer-based model to better capture the interaction of segmentation and diagnosis features. In order to address the scale-level discrepancy between segmentation and diagnosis features, we propose asymmetric multi-scale interaction to correlate multi-scale segmentation features with each single low-level diagnosis feature. A one-to-one coarse interaction and a one-against-rest fine-grain interaction are fused to produce the final feature. An effective approach, called SeA-block, is proposed to model the segmentation-diagnosis interaction, which is constructed by an encoder-decoder pair. The encoder first embeds the diagnosis feature through the calculated segmentation affinity map. Then a decoder maps the embeddings back to the diagnosis feature space through the calculated diagnosis affinity map. Through SeA-block, diagnosis features can be vitalized by the correlated segmentation information.
In brief, we have made three major contributions. First, we propose a general segmentation-assisted diagnosis model, named SeATrans, for integrating segmentation and diagnosis based on medical images. Thanks to the global and dynamic nature of transformer mechanism, SeATrans can achieve superior and robust performance comparing with state-of-the-art counterparts. Second, we propose asymmetric multi-scale interaction to correlate each low-level diagnosis feature with multi-scale segmentation features. In this way, diagnosis feature can be vitalized by both coarse and fine-grain segmentation information. Last but not the least, we propose a new strategy, i.e., SeA-block, for the segmentation-diagnosis interaction. A transformer-based encoder-decoder architecture is constructed to learn across the segmentation and diagnosis feature space. The experimental results show SeATrans outperforms previous best method by at least a 2 % AUC over three different disease diagnosis tasks. Meanwhile, it shows competitive robustness to the domain shift of segmentation model.
In this paper, we propose a general segmentation-assisted diagnosis framework. Given a raw image and its lesions/tissues segmentation features extracted from a segmentation network (joint-trained or pre-trained), our goal is to predict the disease (0 for benign, 1 for malignant) of the image. Our basic idea is to integrate the segmentation information into diagnosis model on the feature level. The interaction module and diagnosis model are jointly optimized to predict the correct diagnosis. An illustration of the overflow is shown in Fig. 1 (a). Raw fundus image
is first sent into a UNet to obtain the deep segmentation embedding. The segmentation features in the UNet decoder are used to interact with the diagnosis features of a disease diagnosis network. In the diagnosis model, convolution layer and SeA-block based Interaction alternatively abstracts and vitalizes the features. The final disease probability is supervised by the binary disease label through binary cross-entropy (BCE) loss function.
2.1 Asymmetric Multi-scale Interaction
Note that the diagnosis network abstracts the low-level structure features to the deep semantic features, while the segmentation model abstracts multi-scale structural features. In order to align the diagnosis and segmentation features, we correlate multi-scale segmentation features to each single low-level diagnosis feature. As shown in Fig. 1 (b), stacked multi-scale segmentation features are collected for a single low-level diagnosis feature. The segmentation feature with the largest scale will first interact with the target diagnosis feature. As large-scale feature contains more specific but artifact structure information , this one-to-one interaction will produce a coarse vitalized diagnosis feature. Other segmentation features with smaller scales are fused together for the interaction with the diagnosis feature. Since these features contain more fine-grained and abstract features, this interaction will produce a fine-grained vitalized diagnosis feature. The coarse and fine-grained features are fused by convolution layer to produce the final result.
In practice, the second and third layers of the diagnosis model will interact with the multi-scale features in UNet decoder. Consider the deep segmentation feature and diagnosis feature are and . To instill segmentation information into ( is the index of layer, ,,, are height, width, down-sample rate and channel number respectively), stacked multi-scale segmentation features ( is the number of layers) are collected for the interaction. First, will interact with by SeA-block for coarse vitalization. Then the subsequent segmentation features will be rearranged by pixel shuffle  to the scale of and concatenated together. Then it will interact with diagnosis feature for the fine-grained interaction. The fine-grained feature and coarse feature are integrated by convolution kernel to obtain the final vitalized diagnosis feature with shape . Then a residual convolution block  with pooling layer is connected to abstract the next feature .
SeA-block is adopted for the segmentation-diagnosis feature interaction. The architecture of SeA-block is shown in Fig. 1 (c). The proposed SeA-block contains an encoder and a decoder. The encoder embeds the diagnosis feature according to its affinity with segmentation feature, which is implemented with the multi-head dot-product attention mechanism (MHA) . Formally, consider encoding a diagnosis feature with segmentation feature , we use as query and as key and value of the attention, which can be formulated as:
where are positional encodings  for segmentation feature and diagnosis feature respectively. The features are all reshaped into a sequence of flattened patches following ViT. In this attention mechanism, the normalized affinity weights is first calculated between query and key to reflect the correlation between diagnosis and segmentation feature globally. Then the affinity weights are used to select and reinforce the diagnosis feature through the dot production of value. After the attention, the Layer Normalization 
with residual connection is applied before and after the MLP layer. The embedded diagnosis feature, which we denoted as, is outputted with the same shape as the inputs.
A decoder is connected after the encoder to map back to diagnosis feature space. There are two inputs for the decoder, diagnosis embedding and original diagnosis feature . Being symmetrical to the encoder, decoder is implemented by the multi-head attention with diagnosis feature as and diagnosis embedding as and , which can be formulated as:
where are positional encodings for diagnosis feature and diagnosis embedding respectively. The decoder transfers to a diagnosis feature by enhancing its affinity with . A self-attention block is connected after the decoder to refine the representations. The obtained sequence will be reshaped back as a vitalized diagnosis feature with the same shape as .
3.1 Diagnosis Tasks
We evaluate SeATrans on three different disease diagnosis tasks: glaucoma diagnosis, thyroid cancer diagnosis and melanoma diagnosis. Glaucoma is predicted from fundus images and is assisted by OD/OC segmentation. Thyroid cancer is predicted from ultrasound images and is assisted by the thyroid nodule segmentation. Melanoma is predicted from dermoscopic images and is assisted by skin lesions segmentation. The experiments of glaucoma, thyroid cancer and melanoma diagnosis are conducted on REFUGE-2 dataset , TNMIX dataset [13, 27] and ISIC dataset , which contain 1200, 8046, 1600 samples, respectively. The datasets are publicly available with both segmentation and diagnosis labels. Train/validation/test sets are split following the default settings of the dataset.
3.2 Experimental Settings
as the diagnosis model. The segmentation network is pre-trained on heterologous data distribution. All the experiments are implemented with the PyTorch platform and trained/tested on 4 Tesla P40 GPU with 24GB of memory. All images are uniformly resized to the dimension of 256
256 pixels. The networks are trained in an end-to-end manner using Adam optimizer with a mini-batch of 16 for 80 epochs. The learning rate is initially set to 1. The detailed configurations can be found in the code.
To verify the effectiveness of SeATrans, we compare it with several baselines. The vanilla baseline is a standard classification model implemented by ResNet50 with no segmentation mask provided. Three other baselines are implemented by commonly used segmentation-assisted diagnosis techniques 
, which are denoted as ’Base-cat’, ’Base-multi’, and ’Base-ROI’, respectively. ’Base-cat’ concatenates the estimated masks with the raw images as the input of the diagnosis model. ’Base-multi’ learns a single network for both segmentation and diagnosis. ’Base-ROI’ crops the region of interest (ROI) based on the estimated segmentation masks.
In order to verify the generalization of the models, we train segmentation network on homologous (-homo) and heterologous (-hetero) data, respectively. ’-homo’ means segmentation and diagnosis network are trained on the same source of data. ’-hetero’ means segmentation model is trained on an external dataset, which is RIGA , DDIT  and PH2  for glaucoma, thyroid cancer and melanoma diagnosis, respectively.
3.3 Main Results
Comparing SeATrans with baselines in Table 1, we can see significant improvement on all three diagnosis tasks. Concretely, comparing with the best baseline by AUC, SeATrans improves 6.56%, 6.78% and 8.14% on glaucoma, thyroid cancer and melanoma diagnosis respectively, indicating SeATrans can gain general and considerable improvement comparing with the present commonly used techniques. SeATrans also achieves the highest sensitivity with competitive accuracy and specificity, indicating it is more applicable to the real clinical scenarios, since sensitivity is commonly of great concern in clinical scenes.
Comparing vanilla baseline with the other methods, we can see except ’Base-multi’, the segmentation more or less improves the diagnosis performance. It demonstrates the segmentation information of lesions/tissues is definitely useful for the automated diagnosis models. However, the improvement it can bring depends largely on the way we use it. Multi-task learning based methods seemed to be invalid according to our experimental results. This may be due to the large discrepancy between segmentation and diagnosis features. The segmentation encoder extract the low-level structure features while the diagnosis needs the high-level semantic features, it is thus hard to learn the universal features in one encoder. SeATrans fuses the multi-scale segmentation features to first few layers of the diagnosis model. In this way, these structure-focused layers are enhanced by the awareness of lesions/tissues structures, and the later layers can still abstract the high-level diagnosis feature. As a result, SeATrans outperforms the other segmentation-assisted diagnosis methods by a large margin.
To verify the generalization of the methods, we also conduct the experiment on heterologous data, where the segmentation model is pre-trained on external dataset. Due to the domain shift, the segmentation masks/features would be inferior to ’-homo’, thus disturb the diagnosis models. Comparing ’-homo’ with ’-hetero’, we can see a drop on the AUC performance over all of the methods. But SeATrans shows very competitive generalization ability, dropping only about 1% AUC on ’-hetero’.
3.4 Comparing with SOTA
To demonstrate the advantage of SeATrans, we compare it with SOTA methods for segmentation-assisted diagnosis. Table 2 quantitatively compare SeATrans with nine SOTA segmentation-assisted diagnosis methods.
SeATrans vs Transformers. Present SOTA transformer-based diagnosis architectures: ConViT  and Swin Transformer  are involved for the comparison. Segmentation masks are concatenated as the inputs of the models. It shows SeATrans clearly outperforms these transformer architectures, increases about 5.60%, 5.82% and 7.10% AUC on glaucoma, thyroid and melanoma, respectively. It demonstrates a large proportion of the improvement comes from the proposed feature fusion strategy, but not the transformer-like architecture.
SeATrans vs ROI. We compare SeATrans with ROI based segmentation-assisted diagnosis methods: DualStage  and DENet . It shows  only gains marginal improvement compared with vanilla baseline. Although  achieves better performance, it is only applicable on glaucoma diagnosis. SeATrans outperforms ROI based methods by an average 4% AUC on a range of tasks.
SeATrans vs Channel Attention. We also compare SeATrans with SOTA channel attention based segmentation-assisted diagnosis methods: AGCNN  and ColNet , who adopted channel-attention to enhance the diagnosis feature by the segmentation masks/features. We observe that SeATrans can surpass AGCNN and ColNet by 6.31% and 3.10% AUC on glaucoma, 3.99% and 2.40% on thyroid cancer,and 4.39% and 3.84% on melanoma diagnosis, indicating the superiority of SeATrans comparing with regional-correlated channel attention.
SeATrans vs Multi-task. Multi-task learning methods MagNet  and CMSNET  are involved for the comparison. SeATrans consistently outperforms both methods, especially on thyroid cancer diagnosis, which outperforms MagNet and CMSVNET by 11.16% and 10.13% AUC respectively.
SeATrans vs Transfer-learning. L2T-KT  uniquely processed the task by teacher-student based transfer learning and achieved competitive performance. Comparing the AUC, SeATrans outperforms L2T-KT by 2.23%, 2.55% and 2.66% on glaucoma, thyroid cancer and melanoma diagnosis, respectively. SeATrans also achieves better sensitivity-speficity trade-off than L2T-KT. For example, SeATrans achieves 79.66% F1 score which surpasses 77.43% F1 score of L2T-KT on glaucoma diagnosis.
Heterologous data Generalization. Comparing with ’-homo’ and ’-hetero’, we can see ROI-based methods (Dual-stage, DENet) show the best generalization, since they used less segmentation information than the others. SeATrans and Transformer-based methods (ConViT, Swin) also show competitive generalization capability, which drop only about 1% AUC on a range of tasks. Channel-attention based methods (AGCNN, ColNet) are more sensitive since their regional correlated assumption is vulnerable to the domain shift. Thanks to the dynamic and global nature of SeATrans, it gains high performance with very competitive generalization ability comparing with the other methods.
3.5 Ablation study
Ablation studies are performed over each component of SeATrans, including multi-scale, asymmetric interaction and SeA-Block, as listed in Table 3. The experiments are conducted on glaucoma diagnosis task. Feature concatenation is adopted to replace SeA-block. In Table 3, as we sequentially adding the proposed modules on vanilla baseline, the model performance is gradually improved. First, by applying multi-scale segmentation-diagnosis integration, the AUC value is increased by 2% on homologous data while only 0.6% on heterologous data. This indicates that multi-scale integration can improve the diagnosis performance with limited generalization. Then, the asymmetric multi-scale interaction is applied to further focus the integration on the low-level features, which boosts the AUC by a 3.53 % and a 3.42% on ’-homo’ and ’-hetero’ respectively. Finally, SeA-Block is utilized for the segmentation-diagnosis interaction. It can be observed the diagnosis performance is remarkably improved, which gains 5.09% and 6.34% AUC improvement on ’-homo’ and ’-hetero’, respectively. It indicates SeA-block gains significant and general improvement by its dynamic and global interaction.
In this work, we proposed SeATrans to overcome the shortcomings of existing segmentation-assisted diagnosis models. In SeATrans, asymmetric multi-scale interaction is proposed to address the segmentation-diagnosis scale level discrepancy. Then SeA-block is constructed for the global and dynamic feature interaction between segmentation and diagnosis space. Extensive empirical experiments demonstrated the general and superior performance of the proposed SeATrans on a range of medical image diagnosis tasks.
-  (2017) Agreement among ophthalmologists in marking the optic disc and optic cup in fundus images. International ophthalmology 37 (3), pp. 701–717. Cited by: §3.2.
Medical image analysis using convolutional neural networks: a review. Journal of medical systems 42 (11), pp. 1–13. Cited by: §3.2.
-  (2016) Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §2.2.
-  (2019) Two-stage framework for optic disc localization and glaucoma classification in retinal fundus images using deep learning. BMC medical informatics and decision making 19 (1), pp. 1–16. Cited by: §1, §3.4, Table 2.
End-to-end object detection with transformers.
European Conference on Computer Vision, pp. 213–229. Cited by: §2.2.
-  (2021) ConViT: improving vision transformers with soft convolutional inductive biases. arXiv preprint arXiv:2103.10697. Cited by: §3.4, Table 2.
-  (2020) An image is worth 16x16 words: transformers for image recognition at scale. arXiv preprint arXiv:2010.11929. Cited by: §1.
-  (2022) REFUGE2 challenge: treasure for multi-domain learning in glaucoma assessment. arXiv preprint arXiv:2202.08994. Cited by: §3.1.
-  (2018) Joint optic disc and cup segmentation based on multi-label deep network and polar transformation. IEEE transactions on medical imaging 37 (7), pp. 1597–1605. Cited by: §1.
-  (2018) Disc-aware ensemble network for glaucoma screening from fundus image. IEEE transactions on medical imaging 37 (11), pp. 2493–2501. Cited by: §1, §3.4, Table 2.
-  (2005) First prospective study of the recognition process of melanoma in dermatological practice. Archives of dermatology 141 (4), pp. 434–438. Cited by: §1.
-  (1998) Vertical cup/disc ratio in relation to optic disc size: its value in the assessment of the glaucoma suspect. British Journal of Ophthalmology 82 (10), pp. 1118–1124. Cited by: §1.
-  (2021) Multi-task learning for thyroid nodule segmentation with thyroid region prior. In 2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI), pp. 257–261. Cited by: §3.1.
-  (2021) MAG-net: multi-task attention guided network for brain tumor segmentation and classification. In International Conference on Big Data Analytics, pp. 3–15. Cited by: §3.4, Table 2.
-  (2016) Skin lesion analysis toward melanoma detection: a challenge at the international symposium on biomedical imaging (isbi) 2016, hosted by the international skin imaging collaboration (isic). arXiv preprint arXiv:1605.01397. Cited by: §3.1.
Deep residual learning for image recognition.
Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §2.1, §3.2.
-  (2019) Adversarial examples are not bugs, they are features. arXiv preprint arXiv:1905.02175. Cited by: §5.
-  (2021) Learning calibrated medical image segmentation via multi-rater agreement modeling. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12341–12351. Cited by: §1.
-  (2019) Attention based glaucoma detection: a large-scale database and cnn model. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10571–10580. Cited by: §1, §3.4, Table 2.
-  (2021) Swin transformer: hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 10012–10022. Cited by: §3.4, Table 2.
-  (2013) PH 2-a dermoscopic image database for research and benchmarking. In 2013 35th annual international conference of the IEEE engineering in medicine and biology society (EMBC), pp. 5437–5440. Cited by: §3.2.
-  (2021) Intriguing properties of vision transformers. Advances in Neural Information Processing Systems 34. Cited by: §1.
-  (2015) An open access thyroid ultrasound image database. In 10th International Symposium on Medical Information Processing and Analysis, Vol. 9287, pp. 92870W. Cited by: §3.2.
-  (2015) U-net: convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pp. 234–241. Cited by: §3.2.
-  (2017) Grad-cam: visual explanations from deep networks via gradient-based localization. In Proceedings of the IEEE international conference on computer vision, pp. 618–626. Cited by: §5.
-  (2016) . In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1874–1883. Cited by: §2.1.
-  (2020) Segmentation, classification, and registration of multi-modality medical imaging data. Cited by: §3.1.
-  (2017) Attention is all you need. In Advances in neural information processing systems, pp. 5998–6008. Cited by: §2.2.
-  (2017) The devil is in the decoder. In British Machine Vision Conference 2017, BMVC 2017, pp. 1–13. Cited by: §2.1.
-  (2022) Gamma challenge: glaucoma grading from multi-modality images. arXiv preprint arXiv:2202.06511. Cited by: §1.
-  (2020) Leveraging undiagnosed data for glaucoma classification with teacher-student learning. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 731–740. Cited by: §1, §3.4, Table 2.
-  (2019) Generating adversarial examples in the harsh conditions. arXiv preprint arXiv:1908.11332. Cited by: §5.
-  (2021) Robust collaborative learning of patch-level and image-level annotations for diabetic retinopathy grading from fundus image.. IEEE Transactions on Cybernetics, pp. 1–11. Cited by: §1.
-  (2017) Automatic skin lesion segmentation using deep fully convolutional networks with jaccard distance. IEEE transactions on medical imaging 36 (9), pp. 1876–1886. Cited by: §1.
Theoretically principled trade-off between robustness and accuracy.
International conference on machine learning, pp. 7472–7482. Cited by: §5.
-  (2019) Collaborative learning of semi-supervised segmentation and classification for medical images. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 2079–2088. Cited by: §1, §3.4, Table 2.
-  (2021) Multi-task learning for segmentation and classification of tumors in 3d automated breast ultrasound images. Medical Image Analysis 70, pp. 101918. Cited by: §3.4, Table 2.
5 Supplementary material
In order to further analysis the interrelation of the segmentation and the diagnosis. We adopt the network explanation techniques on the models to visualize the discriminative features. Grad-CAM is a commonly used explanation tool that produces visual explanations for model decisions. It visualize the gradients of the loss function with pixel-wise weighted feature maps. We compare the Grad-CAM produced visualization results on an glaucoma diagnosis example in Fig. 2.
We can see ROI based methods (DENet and DualStage) and Transformer based methods (Swin and ConViT) show less attention on the clinical focused region, like optic cup. It may because these methods impose the segmentation enhancement on the model inputs rather than the deep features. Although the explanation is not so good, some of these models with sophisticated network structures still achieve fine diagnosis performance, like Swin, ConViT and DENet. Some of the recent literature also show that the sophisticated networks will show stronger capability but inferior explanation[35, 32, 17], since they are prone to learn some features that discriminative to the networks while meaningless to the human. Multi-task based methods (MagNet and CMSNET) and channel attention based methods (AGCNN and ColNet) mainly focus on the optic-cup region, which is important for the clinical glaucoma diagnosis. But most of them are not implemented with sufficient learnable parameters, which cause they show inferior diagnosis performance. Transfer-learning based method (L2T-KT) and proposed SeATrans pay more attention on the optic-cup region. Besides optic-cup region, SeATrans also focuses on the gap between OC and OD boundary, which is another important parameter indicating glaucoma suspect clinically. Such visualization results demonstrate SeATrans can reach superior diagnosis performance with clear and reasonable explanation.