Uncertainty-aware Multi-modal Learning via Cross-modal Random Network Prediction

07/22/2022
by   Hu Wang, et al.
0

Multi-modal learning focuses on training models by equally combining multiple input data modalities during the prediction process. However, this equal combination can be detrimental to the prediction accuracy because different modalities are usually accompanied by varying levels of uncertainty. Using such uncertainty to combine modalities has been studied by a couple of approaches, but with limited success because these approaches are either designed to deal with specific classification or segmentation problems and cannot be easily translated into other tasks, or suffer from numerical instabilities. In this paper, we propose a new Uncertainty-aware Multi-modal Learner that estimates uncertainty by measuring feature density via Cross-modal Random Network Prediction (CRNP). CRNP is designed to require little adaptation to translate between different prediction tasks, while having a stable training process. From a technical point of view, CRNP is the first approach to explore random network prediction to estimate uncertainty and to combine multi-modal data. Experiments on two 3D multi-modal medical image segmentation tasks and three 2D multi-modal computer vision classification tasks show the effectiveness, adaptability and robustness of CRNP. Also, we provide an extensive discussion on different fusion functions and visualization to validate the proposed model.

READ FULL TEXT VIEW PDF
03/07/2020

Cross-modal Learning for Multi-modal Video Categorization

Multi-modal machine learning (ML) models can process data in multiple mo...
06/17/2021

Knowledge distillation from multi-modal to mono-modal segmentation networks

The joint use of multiple imaging modalities for medical image segmentat...
04/21/2021

Uncertainty-Aware Boosted Ensembling in Multi-Modal Settings

Reliability of machine learning (ML) systems is crucial in safety-critic...
12/13/2021

AMSER: Adaptive Multi-modal Sensing for Energy Efficient and Resilient eHealth Systems

eHealth systems deliver critical digital healthcare and wellness service...
04/09/2018

HyperDense-Net: A hyper-densely connected CNN for multi-modal image segmentation

Recently, dense connections have attracted substantial attention in comp...
04/30/2019

Cross-Modal Message Passing for Two-stream Fusion

Processing and fusing information among multi-modal is a very useful tec...
06/23/2022

Toward Clinically Assisted Colorectal Polyp Recognition via Structured Cross-modal Representation Consistency

The colorectal polyps classification is a critical clinical examination....

1 Introduction

Multi-modal data analysis, where the input data comes from a wide range of sources, is a relatively common task. For instance, automatic driving vehicles may take actions based on the fusion of the information provided by multiple sensors. In the medical domain, automated diagnosis often relies on data from multiple complementary modalities. Recently, we have seen the development of successful multi-modal techniques, such as vision-and-sound classification [5], sound source localization [4], vision-and-language navigation [35] or organ segmentation from multiple medical imaging modalities [9, 40, 38]. However, current multi-modal models typically rely on complex structures that neglect the uncertainty present in each modality. Although they can obtain promising results under specific scenarios, they are fragile when facing situations where modalities contain high uncertainties due to noise in the data or the presence of abnormal information. Such issue can reduce their prediction accuracy and limit their applicability in safety-critical applications [14].

Uncertainty is a crucial issue in many machine learning tasks because of the inherent randomness of machine learning processes. For instance, the randomness of data collection, data labeling, model initialization and training are sources of uncertainty that can result in large disagreements between models trained under similar conditions. According to 

[13, 1, 18], total uncertainty comprise: 1) aleatoric uncertainty (also known as data uncertainty), representing inherent noise in the data due to issues in data acquisition or labeling; and 2) epistemic uncertainty (i.e., model or knowledge uncertainty), which is related to the model estimation of the input data that may be inaccurate due to insufficient training steps/data, poor convergence, etc. Total uncertainty is defined as:

(1)

where indicates the given dataset, and are the inputs and outputs of the model, and represents the measurement of disagreement (e.g., entropy). The estimation of aleatoric uncertainty is considered as the expectation of the predicted disagreement for each model on data points posterior parameterized by ; while the epistemic uncertainty is shown by the disagreement of different models parameterized by sampled from the posterior. In this paper, we focus on estimating total uncertainty.

In multi-modal methods, existing methods typically assume that each modality contributes equally to the prediction outcome [27, 33, 9]

. This strong assumption may not hold if one of the modalities leads to a highly uncertain prediction, which can damage the model performance. In general, deep learning models that can estimate uncertainty 

[20, 19, 2] were not designed to deal with multi-modal data. These models are usually based on Bayesian learning that have slow inference time and poor training convergence, or on abstention mechanisms [32] that may suffer from the low representational power of characterising all types of uncertainties with a single abnormal class. Recently, there have been a couple of methods designed to model multi-modal uncertainty [14, 26], but they are limited to work with very specific classification and segmentation problems, or they show numerical instabilities.

In this paper, we propose a novel approach to estimate the total uncertainty present in multi-modal data by measuring feature density via Cross-modal Random Network Prediction (CRNP). CRNP measures uncertainty for multi-modal Learning using random network predictions (RNP) [3], where the model is designed to be easily adaptable to disparate tasks (e.g., classification and segmentation) and training is based on a stable optimization that mitigates numerical instabilities. To summarize, the main contributions of this paper are:

  • We propose a new uncertainty-aware multi-modal learning model through a feature distribution learner based on RNP, named as Cross-modal Random Network Prediction (CRNP). CRNP is designed to be easily adapted to disparate tasks (e.g. classification and segmentation) and to be robust to numerical instabilities during optimization.

  • This paper introduces a novel uncertainty estimation based on fitting the output of an RNP, which from a technical viewpoint, represents a departure from more common uncertainty estimation methods based on Bayesian learning or abstention mechanisms.

The adaptability of CRNP is shown by its application on two 3D multi-modal medical image segmentation tasks and three multi-modal 2D computer vision classification tasks, where the proposed model achieves state-of-the-art results on all problems. We perform a thorough analysis of multiple CRNP fusion strategies and present visualization to validate the effectiveness of the proposed model.

2 Related Work

2.1 Multi-modal Learning

Multi-modal learning has attracted increasing attention from computer vision (CV) and medical image analysis (MIA). In MIA, Jia et al. [16] introduced a shared-and-specific feature representation learning for semi-supervised multi-view learning. Dou et al. [9]

proposed a chilopod-shaped multi-modal learning architecture with separate feature normalization for each modality and a knowledge distillation loss function. In CV, Shen et al. 

[4] defined a trusted middle-ground for video-and-sound source localization. In video-and-sound classification, Chen et al. [5] proposed to distill multi-modal image and sound knowledge into a video backbone network through compositional contrastive learning. Also in video-and-source classification, Patrick et al. [29, 30]

brought the idea of self-supervision learning into multi-modal by training the networks on external data, which boosted classification accuracy greatly. By exchanging channels, Wang et al. 

[39] showed that the multi-modal features are able to fuse in a better manner. Analyzing existing multi-modal learning methods, even though successful on several tasks, they do not consider that when reaching a decision, some modalities may be more reliable than others, which can damage the accuracy of the model.

2.2 Uncertainty-based Learning Models

Uncertainty also has been widely studied in deep learning. Corbiere et al. [7] proposed to predict a single uncertainty value by an external confidence network via training on the ground-truth class. Sensoy et al. [32] introduced the Dirichlet distribution for an overall classification uncertainty measurement based on evidence. Kohl et al. [20] proposed a probabilistic UNet segmentation architecture to optimize a variant of the evidence lower bound (ELBO) objective. Based on the probabilistic UNet model, Kohl et al. [19] and Baumgartner et al. [2] further updated the model in a hierarchical manner from either the backbone network or prior/posterior networks. Jungo et al. [17] used two medical datasets to compare several uncertainty measurement models, namely: softmax entropy [12], Monte Carlo dropout [12], aleatoric uncertainty [18], ensemble methods [21] and auxiliary network [8, 31]. In MIA, multiple uncertainty measurements have been proposed as well [36, 37, 24, 22]. However, none of the methods above are designed for multi-modal tasks and some of them contain long and complex pipelines that are not easily adaptable to new tasks. Bayesian or ensemble-based methods demand long training and inference times and have slow convergence. Evidential methods have drawbacks too, where the main issue is the representational power of the abstention class. In contrast, our proposed model, by introducing random network fitting for cross-modal uncertainty measurement, is not only technically novel, but it is also simple and easily adaptable to many tasks without requiring any restrictive assumption about uncertainty representation.

2.3 Combining Uncertainty and Multi-modal Analysis

Some methods have studied the combination of uncertainty modeling and multi-modal learning. For example, a trusted multi-view classification model has been developed by modeling multi-view uncertainties through Dirichlet distribution and merging multi-modal features via Dempster’s Rule [14]. However, it is rigidly designed for classification problems, and cannot be easily translated to other tasks, such as segmentation. Monteiro et al. [26] took pixel-wise coherence into account by optimizing low-rank covariance metrics to apply on lung nodules and brain tumor segmentation. Nevertheless, the method by Monteiro et al. [26] requires a time-consuming step to generate binary brain masks to remove blank areas, and the method is also numerically unstable when training in areas of infinite covariance such as the air outside the segmentation target222As stated by SSN implementation [26] at https://github.com/biomedia-mira/stochastic_segmentation_networks.. From an implementation perspective, this method [26]

is also memory intensive when indexing the identity matrix to create one-hot encodings. Differently, in our model, the uncertainty is measured by modeling the overall distribution directly from features without constructing any second-order relation matrix, leading to a numerically more stable optimization and a smaller memory consumption.

3 Cross-modal Random Network Prediction

Below, we first introduce the Random Network Prediction (RNP), with a theoretical justification for its use to measure uncertainties. Then we present the CRNP model training and inference with the cross-modal uncertainty measuring mechanism to take the RNP uncertainty prediction from one modality to enhance or suppress the outputs for other modalities when producing a classification or segmentation prediction.

3.1 Random Network Prediction

The uncertainty of a particular modality is estimated with the RNP depicted in Fig. 1. Specifically, for each RNP, we train a prediction network to fit the outputs of a weight-fixed and randomly-initialized network for feature density modeling. The intuition is that the prediction network will fit better the random network outputs of samples (i.e., with low uncertainty), populating denser regions of the feature space; but the fitting will be worse (i.e., with high uncertainty) for samples belonging to sparser regions. This phenomenon is depicted in the graph inside Fig. 1.

Figure 1: The input data and are first processed by backbone models and that produce the features and . Then the RNP modules have a fixed-weight random network and a learnable prediction network that tries to fit the output of the random network. The prediction network will fit better (i.e., with low predictive uncertainty) at more densely populated regions of the feature space, as shown in the graph. Hence, the difference between the outputs by and can be used to estimate uncertainty when processing a test input data.

Formally, we consider input images from two modalities , where and represent the modalities. After the input image pass through the encoder (similarly for ), the features of the two modalities are analyzed by each RNP module. The RNP module feeds and

to a randomly initialized neural network

, where , with fixed weights . Meanwhile, and are fed to a learnable prediction network with parameters . The prediction network has the same output space but a different structure from the random network, where the capacity of is smaller than to prevent potential trivial solutions. The cost function used to train the RNP module is based on the mean square error (MSE) between the outputs of the prediction and random networks:

(2)

where denotes the number of training samples, , and . The cost function in (2) provides a simple yet powerful supervisory signal to enable the prediction network to learn the uncertainty measuring function.

3.2 Theoretical Support for Uncertainty Measurement

The RNP has a strong relation with uncertainty measurement. Let us consider a regression process from a set of perturbed data . Considering a Bayesian setting, the objective is to minimize the distance between the ground truth and a sum made up of a generated prior randomly sampled from a Gaussian and an additive posterior term with a regularization . Formally, the optimization is as follows:

(3)

where, according to Lemma 3 in [28], the sum is an approximator of the genuine posterior. If we fix the target with zeros, then the objective to be optimized would be equivalent to minimize the distance between the posterior and the randomly sampled prior . Thus, each output element within the randomized function or the predict function can be viewed as a member of a set of weight-shared ensemble functions [3]

. The predicted error, therefore, can be viewed as an estimate of the variance of the ensemble uncertainty.

3.3 Training and Inference of CRNP

Figure 2: The overall framework of multi-modal fusion with our CRNP.

This section introduces our proposed CRNP, which fuses the multiple modalities with their inferred uncertainties to produce the final predictions (e.g., classification or segmentation), as shown in Fig. 2. During the multi-modal fusion phase, the features of the two modalities and are cross-attended by the uncertainty maps produced by the RNP module from both modalities. The uncertainty map for modality is represented as:

(4)

and similarly for for modality . The feature cross-attended by the uncertainty maps is represented by:

(5)

where represents the operator that fuses the original and cross-attended features, is the channel-wise normalized CRNP uncertainty map, and is the element-wise product operator. is similarly defined as in (5). Different fusion operations are thoroughly discussed in Sec. 4.5.1.

We utilize self-attention to further fuse features and , taking both uni-modal and cross-modal relations between feature elements into consideration. As shown in Fig. 2 , we first concatenate and to form the query, key and value inputs for the self-attention module with . Then the output of the self-attention is denoted by:

(6)

where , , and are linear projection weights for queries, keys and values, respectively. refers to the dimensions of queries, keys and values. The decoder after the multi-modal fusion is denoted by (similarly for ), where is the space of the output from the cross-modal RNP module and input to the decoder, and is the classification simplex (output from softmax). Note that although the annotations of multi-modal data are similar, they can have significant differences, particularly in segmentation tasks. Hence, without losing generality, we may need to have multiple separate decoders, one for each modality. But multi-decoders are not needed in tasks where the multi-modal annotation is exactly the same. For segmentation problems, the output of is the space per pixel. The training of CRNP alternates the training of the RNP modules using (2) and the training of the whole model. During RNP training, only the weights of the prediction network inside the RNP are updated by minimising (2), and all other CRNP weights are kept fixed. During the training of the whole model, all CRNP weights are updated, except for the weights of the prediction network of the RNP. The whole model training minimizes the multi-class cross-entropy loss for a classification problem or the Dice and element-wise cross-entropy losses for a segmentation model.

During inference, CRNP receives multi-modal inputs, where each modality branch estimates an uncertainty output that will weight the other modality, and the results of both modalities will be fused to produce the final prediction. CRNP works by assigning large weights to the other modality when the current modality is uncertain. When both modalities have large uncertainties, the final prediction will rely on a balanced analysis of both modalities. For the analysis of more than two modalities, the uncertainty map for a particular modality, say , in (4) is computed by summing the MSE results produced by all other modalities, with . The decoders and from two modalities can be separated or share-weighted, depending on the corresponding output requirements.

4 Experiments

4.1 Datasets

Medical Image Segmentation Datasets. We conduct experiments on two publicly available multi-modal 3D segmentation datasets: Multi-Modality Whole Heart Segmentation dataset (MMWHS) and Multimodal Brain Tumor Segmentation Challenge 2020 dataset (BraTS2020). The MMWHS dataset contains 20 CTs and 20 MRs for training/validation and other 40 CTs and 40 MRs for testing [41]. Seven classes (background excluded) are considered for each pixel. The two modalities have individual ground-truth (GT) for each CT or MR. The BraTS2020 dataset has 369 cases for training/validation and other 125 cases for evaluation, where each case (with four modalities, namely: Flair, T1, T1CE and T2) share one segmentation GT. The evaluation is performed online333https://ipp.cbica.upenn.edu/categories/brats2020. Four classes (background included) are considered for each pixel.

Computer Vision Classification Datasets. We also validate our method on three computer vision classification datasets, namely: Handwritten444https://archive.ics.uci.edu/ml/datasets/Multiple+Features, CUB [34] and Scene15 [10]. Each sample of the Handwritten dataset contains 2000 samples from six views and it is a ten-class classification problem, CUB contains 11,788 bird images from 200 different categories. Following Han et al. [14]

, we also adopt the first ten classes and two modalities (image and text features) extracted by GoogleNet and doc2vec. Three modalities are included in Scene15, which contains 4,485 images from 15 indoor and outdoor classes.

4.2 Implementation Details

Medical Image Segmentation Tasks. To keep a fair comparison, the implementation of all models evaluated on MMWHS and BraTS2020 is based on the 3D UNet (with 3D convolution and normalization) as our backbone network. On MMWHS, we adopt the official test set proposed by Zhuang et al. [41] (40 CTs and 40 MRs) for testing; on BraTS2020, we evaluate all models on the online validation set. For overall performance evaluation, the models were trained for 100,000 iterations on MMWHS and 180,000 iterations on BraTS2020 without model selection. Following Dou et al. [9]

, our hyper-parameter tuning and ablation are conducted on MMWHS with 16 CTs and 16 MRs for training, 4 CTs and 4 MRs for validation. The batch size is set to 2. Stochastic gradient descent optimizer with a momentum of 0.99 is chosen for the model training. The initial learning rate is set to

on both datasets with cosine annealing [23]

learning rate tuning strategy. For the reproduction of Probability UNet 

[20], we use prior/posterior mean instead of random sampling a latent variable

for prediction. The evaluation of the methods is based on the Dice score and Jaccard index for MMWHS; and the Dice score and Hausdorff95 index for BraTS2020. For cross-modal RNP modules training, the randomized network is made up of 3 depth-wise convolutional hidden layers; the prediction network has 2 depth-wise convolutional hidden layers. Between every two layers, both the randomized network and the prediction network adopt Leaky-ReLU as their activation function, where the negative slope is set to

. We set 256 as RNP output dimension for both tasks. For performance evaluation, the CRNP is placed at the bottleneck of our 3D UNet backbone. For the ensemble version of CRNP on both datasets, following Wang et al. [40]

, we average the logits of 3 CRNP models to reduce the prediction variance.

Computer Vision Classification Tasks. For the model evaluation on computer vision datasets, we follow [14]

to split the data into 80% for training and 20% for testing. To keep a fair comparison, we uniformly trained all models for 500 epochs without model selection and then evaluated them on the test set. The learning rate is set to

; Adam optimizer with weight decay and coefficients (0.9, 0.999) are adopted. Following Han et al. [14]

, we apply accuracy and multi-class AUROC as evaluation metrics. We used similar setups for cross-modal RNP modules as on the medical data, with the following differences: the RNP output dimension is set to 32 for computer vision classification tasks and CRNP is placed at the layer before the fully connected layer. The training of CRNP model is conducted in an end-to-end manner without any pre-training or post-processing. Also, the hyper-parameters do not require much effort to tune.

4.3 Medical Image Segmentation Model Performance

Performance on MMWHS Dataset. We compare our approach with: Individual (CT or MR single modality segmentation with separate 3D UNet), 3D UNet (multi-modal fusion by concatenation), the multi-modal learning model Ummkd [9], and the uncertainty model Probability UNet555 We also tried SSN [26], but it requires the creation of one-hot encodings that are memory intensive for seven classes on MMWHS dataset. [20], which proposes a prior net to approximate the posterior distribution, combining the knowledge of inputs and ground truth, in a latent space. The evaluation is based on the Dice scores of the segmentation of the left ventricle blood cavity (LV), the myocardium of the left ventricle (Myo), the right ventricle blood cavity (RV), the left atrium blood cavity (LA), the right atrium blood cavity (RA), the ascending aorta (AA), the pulmonary artery (PA) and Whole Heart (WH). All results on MMWHS data are obtained by using the official evaluation toolkit666http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/mmwhs/.

Models LV Myo RV LA RA AA PA WH
CT Individual 0.9297 0.8943 0.8597 0.9254 0.8701 0.9335 0.7833 0.8989
3D UNet 0.9138 0.8781 0.8822 0.9274 0.8680 0.9088 0.8239 0.8957
Ummkd 0.9145 0.9066 0.8410 0.9157 0.8853 0.8928 0.7579 0.8734
Prob-UNet 0.9071 0.8775 0.8978 0.9262 0.8657 0.9318 0.8425 0.8997
CRNP (Ours) 0.9369 0.9036 0.9076 0.9375 0.8885 0.9538 0.8628 0.9187
CRNP* (Ours) 0.9373 0.9060 0.9085 0.9366 0.8910 0.9503 0.8629 0.9193
MR Individual 0.8777 0.7923 0.6146 0.5686 0.7528 0.5854 0.3993 0.6729
3D UNet 0.8850 0.7723 0.8559 0.8548 0.8676 0.8551 0.7964 0.8535
Ummkd 0.8721 0.7966 0.8086 0.8577 0.8278 0.7998 0.7224 0.8211
Prob-UNet 0.8742 0.7389 0.8332 0.8495 0.8531 0.8537 0.7895 0.8386
CRNP (Ours) 0.8962 0.7787 0.8605 0.8637 0.8748 0.8736 0.7969 0.8615
CRNP* (Ours) 0.8963 0.7811 0.8742 0.8850 0.8688 0.8692 0.8329 0.8758
Table 1: The performance of different models on CT/MR segmentation of MMWHS dataset. The best results for each column within either CT or MR section are in bold. indicates the result with the ensemble model.
CT MR
Models Dice Jaccard Dice Jaccard
GUT 0.9080 0.8320 0.8630 0.7620
KTH 0.8940 0.8100 0.8550 0.7530
CUHK1 0.8900 0.8050 0.7830 0.6530
CUHK2 0.8860 0.7980 0.8100 0.6870
UCF 0.8790 0.7920 0.8180 0.7010
SIAT 0.8490 0.7420 0.6740 0.5320
UT 0.8380 0.7420 0.8170 0.6950
UB1 0.8870 0.7980 0.8690 0.7730
UB2 - - 0.8740 0.7780
UOE 0.8060 0.6970 0.8320 0.7200
Ours 0.9193 0.8486 0.8758 0.7814
Table 2: The performance comparison of CRNP and different challenge models on both CT and MR segmentation of MMWHS dataset. The best results for each column are in bold. sign indicates the higher value the better.

As shown in Tab. 1, our proposed CRNP and its ensemble version have 7 out of the 8 best Dice results on both CT and MR. On CT (Tab. 1), CRNP raises LV Dice score from 0.9297 to 0.9369 and PA Dice score from 0.8425 to 0.8628, when compared to the second-best models. On whole heart segmentation Dice score, CRNP outperforms the second-best model by 1.9%. The ensemble version of CRNP further improves segmentation accuracy. A similar result is observed on MR. On LV, CRNP raises the Dice score from 0.8850 to 0.8962 and AA Dice score from 0.8551 to 0.8736 when compared to the second-best models. On whole heart segmentation, CRNP increases MR Dice from 0.8535 to 0.8615. Model ensemble further improves the performance.

Interestingly, the Individual model obtains accurate results on CT (0.8989 for WH score). However, performance (0.6729 for WH score) drops drastically on MR evaluation, with particularly poor accuracy on RV, AA and PA. But when considering both modalities (3D UNet model), the model performance increases substantially. This shows the bounds of considering a single modality, especially for MR segmentation. The proposed CRNP outperforms the 3D Unet by a large margin. Ummkd [9] performs consistently well on Myo on both CT and MR. We hypothesize that the domain-specific normalization and knowledge distillation loss contribute more to Myo segmentation than to other organs. Probability UNet tries to model posterior latent space rather than a deterministic prediction, which may explain its performance. In general, we note that the CT segmentation results are better than MR, which resonates with the conclusion from [41].

From the number of parameters perspective, the randomized network is made up of 3 convolutional hidden layers and the prediction network has 2 convolutional hidden layers. So the change in number of parameters is minimal. More specifically, the number of parameters of competing methods are: 1) UNet: 41.05M, 2) Ummkd (with UNet backbone for fair comparison): 41.05M, and 3) ProbUNet: 57.44M. Our CRNP has 42.18M parameters, where the RNP module has 0.29M, and the attention module has 0.84M parameters.

We also compare the proposed CRNP model with the state-of-the-art models reported by the official challenge report [41]. The results are shown in Tab. 2. On whole heart segmentation, CRNP has a particularly accurate Dice score and Jaccard index for CT and MR. Compared to the second-best models, our CRNP model increases the Dice score from 0.9080 to 0.9193 and from 0.8740 to 0.8758 on CT and MR, respectively. Similar results are shown forJaccard index.

Performance on BraTS2020 dataset. Developing automated segmentation models to delineate intrinsically heterogeneous brain tumors is the main goal of BraTS2020 Challenge. Following [38], we compare the proposed CRNP model with many other strong methods, including 3D UNet [6], Basic VNet [25], Deeper VNet [25], Residual 3D UNet, Modal-Pairing [40], TransBTS [38], as well as uncertainty-aware models ProbUNet [20] and SSN [26] that models aleatoric uncertainty by considering spatially coherence. We evaluate the Dice and Hausdorff95 indexes of all models on four organs: enhancing tumor (ET); tumor core (TC) that consists of ET, necrotic and nonenhancing tumor core; and whole tumor (WT) that contains TC and the peritumoral edema.

Dice Hausdorff95
Models ET WT TC ET WT TC
3D UNet [6] 0.6876 0.8411 0.7906 50.9830 13.3660 13.6070
Basic VNet [25] 0.6179 0.8463 0.7526 47.7020 20.4070 12.1750
Deeper VNet [25] 0.6897 0.8611 0.7790 43.5180 14.4990 16.1530
Residual 3D UNet 0.7163 0.8246 0.7647 37.4220 12.3370 13.1050
ProbUNet [20] 0.7392 0.8782 0.7955 36.2458 6.9518 7.7183
SSN [26] 0.6795 0.8420 0.7866 43.6574 14.6945 19.5171
Modal-Pairing* [40] 0.7850 0.9070 0.8370 35.0100 4.7100 5.7000
TransBTS [38] 0.7873 0.9009 0.8173 17.9470 4.9640 9.7690
CRNP (Ours) 0.7887 0.9086 0.8372 26.5972 4.0490 6.0040
CRNP* (Ours) 0.7902 0.9109 0.8550 26.4682 4.1096 5.3337
Table 3: The performance of different models on BraTS2020 Online validation set. The best results for each column are in bold. indicates models with ensemble. sign indicates the higher value the better; while means the lower value the better.

In Tab. 3, our models have 5 out of the 6 best results. The CRNP improves the ET Dice score, compared with the second-best model, from 0.7873 to 0.7887; and from 0.9070 to 0.9086 on WT. Similar results are shown on Hausdorff95 indexes. Note that the Modal-Pairing model adopts an ensemble strategy. When applying the ensemble strategy to CRNP, the results improved even further. The WT Dice of CRNP* can reach 0.9109; the TC Dice can reach 0.8550, which is one more percent increment; and improves the TC Hausdorff95 to 5.3337. The performance improvements show the effectiveness of the proposed CRNP model.

4.4 Computer Vision Classification Model Performance

In this section, we show results that demonstrate the effectiveness of CRNP on multiple CV classification tasks. The evaluation metrics include accuracy and multi-class AUROC on Handwritten, CUB and Scene15 datasets. Following Han et al. [14], the comparison models include multiple uncertainty-aware models: Monte Carlo dropout (MCDO) [11] that adopts dropout at inference as a Bayesian approximator; deep ensemble (DE) [21], which uses an ensemble strategy to reduce uncertainty; uncertainty-aware attention (UA) [15]

that creates uncertainty attention maps from a learned Gaussian distribution; evidential deep learning (EDL) 

[32] that predicts an extra Dirichlet distribution for all logits based on evidence; and trusted multi-view classification (TMC) [14], which is a multi-view version of EDL.

Data Metric MCDO [11] DE [21] UA [15] EDL [32] TMC [14] CRNP
Handwritten Acc 0.9737 0.9830 0.9745 0.9767 0.9851 0.9925
AUROC 0.9970 0.9979 0.9967 0.9983 0.9997 0.9996
CUB Acc 0.8978 0.9019 0.8975 0.8950 0.9100 0.9167
AUROC 0.9929 0.9877 0.9869 0.9871 0.9906 0.9961
Scene15 Acc 0.5296 0.3912 0.4120 0.4641 0.6774 0.7057
AUROC 0.9290 0.7464 0.8526 0.9141 0.9594 0.9734
Table 4: The performance of different models on computer vision classification datasets. The best results for each row are in bold.

As shown in Tab. 4, CRNP model can outperform its counterparts on 5 out of 6 measures across datasets. CRNP performs particularly well on Scene15, increasing the accuracy from 0.6774 to 0.7057 (a 2.83% improvement) and AUROC from 0.9594 to 0.9734 (a 1.4% improvement). CRNP also has promising results on Handwritten and CUB data. On AUROC of Handwritten, CRNP gets slightly worse but comparable results than TMC (0.9996 vs. 0.9997).

4.5 Ablation Study

4.5.1 Effectiveness of Each Component

In the ablation study, we examine each component of the proposed CRNP. The “Base” model is the plain multi-modal 3D UNet with dual branches; “CA” means cross-attention by assigning the query from one modality, while keep the key and value the other modality; “SA” means applying self-attention as we propose. We conducted the ablation on the validation set split of the MMWHS dataset and we measured the average Dice scores of each organ on CT and MR. As shown in Tab. 5, compared with the Base 3D UNet model, the CRNP model is able to improve (around 1% increment of Dice scores) the performance across multiple organs, where the improvements are especially obvious on Myo, LA, RA, AA and WH. From the table, we can perceive that, with the help of either cross-attention or self-attention, the model performance can be further boosted. But applying the self-attention as described in Sec. 3.3, causes the model to produce the best results (6 best results out of 8) across multiple organs. This is mainly because the self-attention on the multi-modal feature fusion not only models the cross-modal relations, but also considers uni-modal attentions.

Models LV Myo RV LA RA AA PA WH
Base 0.9334 0.8596 0.8876 0.8932 0.8794 0.8239 0.8168 0.8706
CRNP 0.9324 0.8685 0.8644 0.9007 0.8957 0.9216 0.8225 0.8865
CRNP+CA 0.9323 0.8683 0.8802 0.9147 0.9116 0.9098 0.8194 0.8909
CRNP+SA 0.9356 0.8891 0.8814 0.9232 0.8987 0.9148 0.8277 0.8958
Table 5: Ablation study on MMWHS dataset. Best results per row are in bold.

4.5.2 Discussion of Different CRNP Fusion functions

In terms of different CRNP fusion functions that can be applied in (Sec. 3.3), we compare and discuss three types, as shown in Tab. 6

: (a) “Replace” represents a naive replacement of the original modality features by the uncertainty map attended features; (b) “Concat” applies the concatenation operation on the original modality features and the uncertainty map attended features; and (c) “Residual”, which is the default fusion strategy of the proposed CRNP, denotes an addition operation performed between two feature tensors. This experiment is conducted on the MMWHS dataset and averages both CT and MR Dice results. From the results, we note that all three types of fusion functions have pros and cons. However, the “Residual” model performs better (4 best results out of 8) than other functions. This advantage is more noticeable on RA, AA and PA, on which more than 1% improvement is gained on Dice score.

Models LV Myo RV LA RA AA PA WH
Replace 0.9342 0.8688 0.8688 0.897 0.8812 0.9074 0.8128 0.8815
Concat 0.9327 0.8676 0.8798 0.9031 0.8781 0.9098 0.8042 0.8822
Residual 0.9324 0.8685 0.8644 0.9007 0.8957 0.9216 0.8225 0.8865
Table 6: Analysis of different fusion functions of CRNP on MMWHS dataset. Best results per row are in bold.

4.6 Visualization

Figure 3: Visualization experiments of CRNP. Sub Fig.(1) shows a comparison between the segmentation of the proposed CRNP ((b) and (d)) and its Base model ((a) and (c)). Sub Fig.(2) shows the T-SNE graph of the in and out of distribution data points produced by the cross-modal RNP module. In the Sub Fig.(3), we show the CRNP uncertainty heat-maps.

We also conduct a visualization experiment in Fig. 3 that shows the MMWHS segmentation visualization (Sub Fig.1), T-SNE visualization of in and out of distribution data points produced by the uncertainty maps from the RNP module on the CT images from MMWHS (Sub Fig.2), and the CRNP uncertainty heat-maps for BraTS2020 images (Sub Fig.3). As the two cases from validation set shown in Sub Fig.(1), (a) (c) are segmented by the Base model and (b) (d) are from CRNP. The color masks denote the segmentation results (e.g., pink) overlaid on the ground truth (e.g., purple). The obvious segmentation differences are highlighted by yellow boxes. When comparing segmentation from two models, we can notice that our CRNP has better segmentation results, especially on the organ edges. This is mainly because organ edges contain more uncertain regions. The proposed CRNP can perceive uncertain segmented regions within one modality and assign more weights to the other one. By leveraging this information, CRNP is able to alleviate segmentation uncertainties in organ edges. Moreover, we visualize the in and out of distribution uncertainty maps processed by T-SNE in Sub Fig.(2). Following Han et al. [14], we consider the original features as the in distribution data and noisy features modified by additive Gaussian noise as the out of distribution data. Then, these samples are fed into the cross-modal RNP modules to get the uncertainty map predictions. The T-SNE is able to clearly split these uncertainty predictions into two clusters. This shows further evidence of the effectiveness of our CRNP model to estimate uncertainties. In Sub Fig.(3), we show the CRNP uncertainty heat-maps for a BraTS image, where the maps are estimated in the feature space and mapped back to the original image space. In this figure, (a)(c)(e)(g) are the flair, t1, t1ce and t2 modalities; (b)(d)(f)(h) are the CRNP uncertainty maps for the modalities above (brighter pixel = higher uncertainty); and (i)(j) are the ground truth (GT) segmentation and CRNP prediction. Note that the high uncertainty regions are concentrated around the areas with brain tumors, which is reasonable since tumors are sparsely represented in the feature space, resulting in a large difference between RNP’s random and prediction networks. Also note that the flair image has a stronger tumor signal than the other modalities, producing a larger uncertainty for the other modalities. In particular, this larger uncertainty will notify the other modalities to pay more attention to these areas.

5 Conclusions

In this paper, we proposed the Uncertainty-aware Multi-modal Learning model, named Cross-modal Random Network Prediction (CRNP). CRNP measures the total uncertainty in the feature space for each modality to better guide multi-modal fusion. Moreover, technically speaking, the proposed CRNP is the first approach to explore random network prediction to estimate uncertainty and fuse multi-modal data. CRNP has a stable training process compared with a recent multi-modal approach that uses potentially unstable covariance measures to estimate uncertainty [26], and CRNP can also be easily translated between different prediction tasks. Through experiments on two medical image segmentation datasets and three computer vision classification datasets, the effectiveness of the proposed CRNP model is verified. Also, ablation and visualization studies further validate CNRP as an effective multi-modal analysis method.

References

  • [1] M. Abdar, F. Pourpanah, S. Hussain, D. Rezazadegan, L. Liu, M. Ghavamzadeh, P. Fieguth, X. Cao, A. Khosravi, U. R. Acharya, et al. (2021) A review of uncertainty quantification in deep learning: techniques, applications and challenges. Information Fusion 76, pp. 243–297. Cited by: §1.
  • [2] C. F. Baumgartner, K. C. Tezcan, K. Chaitanya, A. M. Hötker, U. J. Muehlematter, K. Schawkat, A. S. Becker, O. Donati, and E. Konukoglu (2019) Phiseg: capturing uncertainty in medical image segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 119–127. Cited by: §1, §2.2.
  • [3] Y. Burda, H. Edwards, A. Storkey, and O. Klimov (2018) Exploration by random network distillation. arXiv preprint arXiv:1810.12894. Cited by: §1, §3.2.
  • [4] H. Chen, W. Xie, T. Afouras, A. Nagrani, A. Vedaldi, and A. Zisserman (2021) Localizing visual sounds the hard way. In

    Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition

    ,
    pp. 16867–16876. Cited by: §1, §2.1.
  • [5] Y. Chen, Y. Xian, A. Koepke, Y. Shan, and Z. Akata (2021) Distilling audio-visual knowledge by compositional contrastive learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 7016–7025. Cited by: §1, §2.1.
  • [6] Ö. Çiçek, A. Abdulkadir, S. S. Lienkamp, T. Brox, and O. Ronneberger (2016) 3D u-net: learning dense volumetric segmentation from sparse annotation. In International conference on medical image computing and computer-assisted intervention, pp. 424–432. Cited by: §4.3, Table 3.
  • [7] C. Corbière, N. Thome, A. Bar-Hen, M. Cord, and P. Pérez (2019) Addressing failure prediction by learning model confidence. Advances in Neural Information Processing Systems 32. Cited by: §2.2.
  • [8] T. DeVries and G. W. Taylor (2018) Leveraging uncertainty estimates for predicting segmentation quality. arXiv preprint arXiv:1807.00502. Cited by: §2.2.
  • [9] Q. Dou, Q. Liu, P. A. Heng, and B. Glocker (2020) Unpaired multi-modal segmentation via knowledge distillation. In IEEE Transactions on Medical Imaging, Cited by: §1, §1, §2.1, §4.2, §4.3, §4.3.
  • [10] L. Fei-Fei and P. Perona (2005) A bayesian hierarchical model for learning natural scene categories. In 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’05), Vol. 2, pp. 524–531. Cited by: §4.1.
  • [11] Y. Gal and Z. Ghahramani (2015)

    Bayesian convolutional neural networks with bernoulli approximate variational inference

    .
    arXiv preprint arXiv:1506.02158. Cited by: §4.4, Table 4.
  • [12] Y. Gal and Z. Ghahramani (2016) Dropout as a bayesian approximation: representing model uncertainty in deep learning. In international conference on machine learning, pp. 1050–1059. Cited by: §2.2.
  • [13] J. Gawlikowski, C. R. N. Tassi, M. Ali, J. Lee, M. Humt, J. Feng, A. Kruspe, R. Triebel, P. Jung, R. Roscher, et al. (2021) A survey of uncertainty in deep neural networks. arXiv preprint arXiv:2107.03342. Cited by: §1.
  • [14] Z. Han, C. Zhang, H. Fu, and J. T. Zhou (2021) Trusted multi-view classification. arXiv preprint arXiv:2102.02051. Cited by: §1, §1, §2.3, §4.1, §4.2, §4.4, §4.6, Table 4.
  • [15] J. Heo, H. B. Lee, S. Kim, J. Lee, K. J. Kim, E. Yang, and S. J. Hwang (2018) Uncertainty-aware attention for reliable interpretation and prediction. Advances in neural information processing systems 31. Cited by: §4.4, Table 4.
  • [16] X. Jia, X. Jing, X. Zhu, S. Chen, B. Du, Z. Cai, Z. He, and D. Yue (2020) Semi-supervised multi-view deep discriminant representation learning. IEEE transactions on pattern analysis and machine intelligence 43 (7), pp. 2496–2509. Cited by: §2.1.
  • [17] A. Jungo and M. Reyes (2019) Assessing reliability and challenges of uncertainty estimations for medical image segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 48–56. Cited by: §2.2.
  • [18] A. Kendall and Y. Gal (2017) What uncertainties do we need in bayesian deep learning for computer vision?. Advances in neural information processing systems 30. Cited by: §1, §2.2.
  • [19] S. A. Kohl, B. Romera-Paredes, K. H. Maier-Hein, D. J. Rezende, S. Eslami, P. Kohli, A. Zisserman, and O. Ronneberger (2019) A hierarchical probabilistic u-net for modeling multi-scale ambiguities. arXiv preprint arXiv:1905.13077. Cited by: §1, §2.2.
  • [20] S. Kohl, B. Romera-Paredes, C. Meyer, J. De Fauw, J. R. Ledsam, K. Maier-Hein, S. Eslami, D. Jimenez Rezende, and O. Ronneberger (2018) A probabilistic u-net for segmentation of ambiguous images. Advances in neural information processing systems 31. Cited by: §1, §2.2, §4.2, §4.3, §4.3, Table 3.
  • [21] B. Lakshminarayanan, A. Pritzel, and C. Blundell (2017) Simple and scalable predictive uncertainty estimation using deep ensembles. Advances in neural information processing systems 30. Cited by: §2.2, §4.4, Table 4.
  • [22] Y. Li, L. Luo, H. Lin, H. Chen, and P. Heng (2021)

    Dual-consistency semi-supervised learning with uncertainty quantification for covid-19 lesion segmentation from ct images

    .
    In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 199–209. Cited by: §2.2.
  • [23] I. Loshchilov and F. Hutter (2016) Sgdr: stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983. Cited by: §4.2.
  • [24] X. Luo, W. Liao, J. Chen, T. Song, Y. Chen, S. Zhang, N. Chen, G. Wang, and S. Zhang (2021) Efficient semi-supervised gross target volume of nasopharyngeal carcinoma segmentation via uncertainty rectified pyramid consistency. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 318–329. Cited by: §2.2.
  • [25] F. Milletari, N. Navab, and S. Ahmadi (2016) V-net: fully convolutional neural networks for volumetric medical image segmentation. In 2016 fourth international conference on 3D vision (3DV), pp. 565–571. Cited by: §4.3, Table 3.
  • [26] M. Monteiro, L. Le Folgoc, D. Coelho de Castro, N. Pawlowski, B. Marques, K. Kamnitsas, M. van der Wilk, and B. Glocker (2020) Stochastic segmentation networks: modelling spatially correlated aleatoric uncertainty. Advances in Neural Information Processing Systems 33, pp. 12756–12767. Cited by: §1, §2.3, §4.3, Table 3, §5, footnote 2, footnote 5.
  • [27] D. Nie, L. Wang, Y. Gao, and D. Shen (2016) Fully convolutional networks for multi-modality isointense infant brain image segmentation. In 2016 IEEE 13Th international symposium on biomedical imaging (ISBI), pp. 1342–1345. Cited by: §1.
  • [28] I. Osband, J. Aslanides, and A. Cassirer (2018)

    Randomized prior functions for deep reinforcement learning

    .
    Advances in Neural Information Processing Systems 31. Cited by: §3.2.
  • [29] M. Patrick, Y. M. Asano, P. Kuznetsova, R. Fong, J. F. Henriques, G. Zweig, and A. Vedaldi (2020) Multi-modal self-supervision from generalized data transformations. arXiv preprint arXiv:2003.04298. Cited by: §2.1.
  • [30] M. Patrick, P. Huang, I. Misra, F. Metze, A. Vedaldi, Y. M. Asano, and J. F. Henriques (2021) Space-time crop & attend: improving cross-modal video representation learning. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 10560–10572. Cited by: §2.1.
  • [31] R. Robinson, O. Oktay, W. Bai, V. V. Valindria, M. M. Sanghvi, N. Aung, J. M. Paiva, F. Zemrak, K. Fung, E. Lukaschuk, et al. (2018) Real-time prediction of segmentation quality. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 578–585. Cited by: §2.2.
  • [32] M. Sensoy, L. Kaplan, and M. Kandemir (2018) Evidential deep learning to quantify classification uncertainty. Advances in Neural Information Processing Systems 31. Cited by: §1, §2.2, §4.4, Table 4.
  • [33] V. V. Valindria, N. Pawlowski, M. Rajchl, I. Lavdas, E. O. Aboagye, A. G. Rockall, D. Rueckert, and B. Glocker (2018) Multi-modal learning from unpaired images: application to multi-organ segmentation in ct and mri. In 2018 IEEE winter conference on applications of computer vision (WACV), pp. 547–556. Cited by: §1.
  • [34] C. Wah, S. Branson, P. Welinder, P. Perona, and S. Belongie (2011) The caltech-ucsd birds-200-2011 dataset. Cited by: §4.1.
  • [35] H. Wang, Q. Wu, and C. Shen (2020) Soft expert reward learning for vision-and-language navigation. In European Conference on Computer Vision, pp. 126–141. Cited by: §1.
  • [36] K. Wang, B. Zhan, C. Zu, X. Wu, J. Zhou, L. Zhou, and Y. Wang (2021) Tripled-uncertainty guided mean teacher model for semi-supervised medical image segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 450–460. Cited by: §2.2.
  • [37] L. Wang, L. Ju, D. Zhang, X. Wang, W. He, Y. Huang, Z. Yang, X. Yao, X. Zhao, X. Ye, et al. (2021) Medical matting: a new perspective on medical segmentation with uncertainty. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 573–583. Cited by: §2.2.
  • [38] W. Wang, C. Chen, M. Ding, H. Yu, S. Zha, and J. Li (2021) Transbts: multimodal brain tumor segmentation using transformer. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 109–119. Cited by: §1, §4.3, Table 3.
  • [39] Y. Wang, W. Huang, F. Sun, T. Xu, Y. Rong, and J. Huang (2020) Deep multimodal fusion by channel exchanging. Advances in Neural Information Processing Systems 33, pp. 4835–4845. Cited by: §2.1.
  • [40] Y. Wang, Y. Zhang, F. Hou, Y. Liu, J. Tian, C. Zhong, Y. Zhang, and Z. He (2020) Modality-pairing learning for brain tumor segmentation. In International MICCAI Brainlesion Workshop, pp. 230–240. Cited by: §1, §4.2, §4.3, Table 3.
  • [41] X. Zhuang, L. Li, C. Payer, D. Štern, M. Urschler, M. P. Heinrich, J. Oster, C. Wang, Ö. Smedby, C. Bian, et al. (2019) Evaluation of algorithms for multi-modality whole heart segmentation: an open-access grand challenge. Medical image analysis 58, pp. 101537. Cited by: §4.1, §4.2, §4.3, §4.3.