Lung cancer is one of the most dangerous diseases, the overall five-year survival rate for lung cancer (LC) is even less than 20%. Most lung cancers can be divided into two broad histological subtypes: non-small cell lung cancer (NS-
CLC) and small cell lung cancer (SCLC). Compared to SCLC, NSCLC accounts for the majority of diagnoses and is less aggressive. NSCLC spreads and grows more slowly than SCLC and causes few or no symptoms until it is advanced. As a result, patients are usually not detected until it is at a later stage. And it has caused millions of deaths in both women and men. Lung cancer survival analysis, or prognostication, of lung cancer attempts to model the time range for a given event of interest (biological death), from the beginning of follow-up until the occurrence of the event. The survival model is an estimate of how lung cancer will develop, and it can reveal the relationship between prognostic factors and the disease. Using the accurate prognostic models, doctors can determine the most likely development(s) of the patient’s cancer. To improve predictive accuracy and automate the NSCLC survival analysis process, as well as to assist medical experts develop precise treatment solutions, we aim to explore a novel method for NSCLC survival analysis.
Traditional statistical methods for survival analysis leverage structured data from comprehensive medical examinations of a patient. Traditional methods mainly include time-to-event modelling tools such as the Cox proportional hazards (Cox- PH) model [cox1972regression], the accelerated failure time (AFT) model [2020ACCELERATED], the Kaplan-Meier [kaplan1958nonparametric]
, etc. Besides, survival analysis based on machine learning is also a popular branch, such as survival trees[TM2004Bagging, 2013randomSurvivalForest], Bayesian methods [2002Bayesian, 2011Bayesian]
and Support Vector Regression[2008A, 2008Support] etc. These models assume that the hazard ratio between two subjects is constant over time [kvamme2019continuous] and estimate either a risk score or the time-to-event distribution. However, when implemented in clinical practise, the interaction between covariates is complex and all these methods focus only on structured data, overlooking the enormous information within the unstructured data such as CT scans. Moreover, medical experts usually have to dedicate extesive efforts in introducing the hand-crafted features. Recently, some works have been proposed using pathological images and demonstrating the effectiveness of features from images. However, pathological images require a lung biopsy. Moreover, the process of obtaining pathological images is invasive and associated with potential health risks such as pneumothorax, bleeding, infection, systemic air embolism and other side effects, which is a consequence of abnormal results of non-invasive computed tomography (CT) for lung cancer screening.
Artificial intelligence has witnessed rapid progress in the last decade. With the development of deep learning techniques, it has accomplished remarkable success in various fields of research such as natural language processing, computer vision, etc. As a cutting-edge technology, deep learning has the potential to offer great potential for medical diagnostics. Some of the most innovative and novel deep learning methods have been successfully applied to the diagnosis of lung cancer using CT images and the performance even outperformed human experts [esteva2017dermatologist] [litjens2017survey] [xing2017deep] [ILSVRC15]. However, existing deep learning-based methods only consider the structured or visual cues. In contrast, medical experts often consider clinical data and visual information such as CT images together to make a comprehensive decision. Consequently, the prediction of current methods is not reliable and credible enough. To address this weakness, we proposed a novel multimodal paradigm for lung cancer survival analysis inspired by the success of deep learning.
To successfully build such a multimodal network, the first challenge is to encode the task-friendly features from different modalities. The form of the clinical data looks like some discrete symbols in the diagram, but they are essentially highly correlated with each other in essence. Therefore, to find correlations between different words, a Light transformer network is proposed for processing the textual clinical data. The core building block of our model is the multi-head self-attention[DBLP:conf/nips/VaswaniSPUJGKP17]. Moreover, self-attention mechanism is capable of correlate different disease factors to capture more information.
CT slices contain rich spatial and temporal information. In our previous work [DBLP:conf/smc/WuMHLS21], we adopt 3D Resnet as a backbone for feature extraction. However, we found that there is too much redundant information in both the spatial and temporal dimensions, which severely prevents the model from perceiving the most important components in the visual information. To alleviate this problem, we for the first time propose a 3D channel SE block and a 3D temporal SE block. Both blocks are integrated into the original residual module to form an architecture specifically for NSCLC prognosis, namely ProgSE-Net. Moreover, we also observe that the pixels in adjacent CT slices are similar or the same in most cases. To support the ProgSE-Net, we further propose a mechanism of frame difference. The proposed frame difference creates two additional CT slices by subtracting the adjacent pixels in two directions, which is an effective strategy in our practise.
Finally, considering the above, we develop the first multimodal network for NSCLC survival analysis, which takes Deep Learning-based NSCLC survival analysis one step forward by simultaneously considering the textual clinical data and the visual CT clues. As shown in Fig.1, our network is a two-tower paradigm, i.e., clinical tower and a visual tower. The clinical tower is responsible for encoding the clinical data, while the visual tower aims to extract the visual representation from the CT images. Finally, the prediction head fuses the cross-modality features and provides the time prediction.
In summary, the key contributions associated with the Lite-ProSENet can be highlighted as follows:
The first application of a two-tower DNN to survival analysis of NSCLC using structured data and CT images simultaneously.
The first application of transformer and 3DSE-Net block to multimodal clinical data for disease prognosis.
Results on benchmark and real-world clinical datasets demonstrate that Lite-ProSENet outperforms SOTA methods by a substantial margin.
The rest of this paper is organized as follows: In section II, we present related work on survival analysis of NSCLC, including traditional methods and deep learning-based practice. In section III, we elaborate the details of our proposed Lite-ProSENet. In section IV, we present the experiments and ablation studies. We discuss some choices when building our network and the included hyper-parameters in section V. Finally, the conclusion and future work are given in section VI. In the following, the details of each section will be given.
2 Related Works
In this section, we give an overview of the traditional statistical methods and deep convolutional neural networks, then highlight the correlation to our contributions.
2.1 Statistical Methods
Conventional statistical methods for NSCLC survival analysis only use the textual modality and involve modelling time to an event. They can be divided into three types: non-parametric, semi-parametric and parametric methods. Kaplan-Meier analysis (KM) [2016An] is a typical non-parametric approach to survival outcomes. KM Analysis is suitable for small data sets with a more accurate analysis cannot include multiple variables. Life table [1972Regression] is a simple statistical method that appropriate for large data sets and has been successfully applied to European lung cancer patients [janssen1998variation]. The Nelson-Aalen estimator (NA) [2013The]
is a non-parametric estimator of the cumulative hazard function (CHF) for censored data. NA estimator directly estimates the hazard probability. As for semi-parametric method, the distribution of survival is not required. For example, the Cox regression model is used in[port2003tumor], which discovered the critical factor that has a greater impact on survival analysis in lung cancer. The Cox proportional hazards model [1972Regression]
is the most commonly used model in survival analysis and the baseline hazard function is not specified. Coxboost can be applied to high-dimensional data to fit the sparse survival models. Better than the regular gradient boosting approach (RGBA), coxboost can update each step with a flexible set of candidate variables[2008Allowing]. The parametric method is easy to interpret and can provide a more efficient and accurate result when the distribution of survival time follows a certain distribution. But it leads to inconsistencies and can provide sub-optimal results if the distribution is violated. The Tobit model [1956Estimation]1979Linear, 2008Doubly]
uses least squares as an empirical loss function and can be applied to high-dimensional survival data. BJ regression is an accelerated failure time model. Bayesian survival analysis[2011Bayesian, 2015Bayesian, 2015Survival] encodes the assumption via prior distribution.
2.2 DNN based Methods
Image-based techniques for survival analysis of lung cancer normally adopt histopathological images. The work of [DBLP:conf/miccai/YaoWZH16]
was the first to use a deep learning approach to classify cell subtypes. The work found that survival models built from clinical imaging biomarkers had better predictive power than methods using molecular profiling data and traditional imaging biomarkers. Using machine learning methods, H. Wang et al, proposed a framework[2014Novel] and found a set of diagnostic image markers highly correlated with NSCL-C subtype classification. The work of Kun-Hsing Yu et al [DBLP:conf/amia/YuZBARRS17], extracts 9,879 quantitative image features and uses regularised machine learning methods to distinguish short-term survivors from long-term survivors. In the work of Xinliang Zhu et al [DBLP:conf/bibm/ZhuYH16]
, a deep convolutional neural network for survival analysis (DeepConvSurv) with pathological images was proposed for the first time. The deep layers in the proposed model could represent more abstract information, and hand-crafted features is not required from the images. The mentioned methods cannot learn discriminative patterns directly from Whole Slide Histopathological Images and some of them predict the survival status of patients using hand-crafted features extracted from manually labelled small discriminative patches. In the work of[2017WSISA], an annotation-free method for survival prediction based on whole slide histopathology images was proposed for the first time.
In summary, traditional statistical methods tend to use textual data with limited information. In recent years, with the development of deep learning, more work has begun to explore methods that use histopathology images. However, it is invasive to obtain the images. There is a work that uses CT images but with hand-crafted features that require instructions from medical experts. Moreover, all these methods only use single modality and ignore the complementary information that comes from multimodality. Therefore, to capture the underlying complex relationships between multimodality medical testing results and NSCLC survival time, we proposed a non-invasive, fully automated DNN method to improve the prediction accuracy of NSCLC prognosis.
In this work, we propose a multimodal deep learning framework with 3D-ResNet for better individualised survival prediction of NSCLC, which can fully utilise and analyse information gathered from all types of data sources, such as CT images and clinical information.
The proposed method is a two-tower architectural model. In this section, we describe details within the model for NSCL-C prognosis.
3.1 The Architecture of Lite-ProSENet
Clinical data and visual CT images both contain rich information but lie in different spaces, as a result, the information from different modalities cannot be integrated directly to give a comprehensive representation. To perform an effective feature fusion and alignment, we devise our model as a two-tower architecture, whose effectiveness has been well validated in the cross-modality learning field [2019Sampling, 2016Learning, 2016Smart, 2018Learning] . Figure 2 gives the overall illustration of our framework, the proposed Lite-ProSENet contains two towers, i.e., Lite-Transformer and ProSENet. Given a piece of data , which is composed by the clinical data , the CT images and survival time , i.e., . The clinical data is first fed into an embedding layer to obtain the dense representation, and then pass through the light transformer to get the effective features.
CT images are first fed into the ProSENet for the feature extraction. The following prediction module fuses the features from different modalities and give the survival prediction based on the multi-modality feature. Finally, the parameters of two towers are jointly optimized by minimizing the distance between the survival time prediction and ground-truth one . In the following, we will illustrate the details of each component of our network.
3.2 Light Transformer
As shown in Fig. 4, the raw items in clinical data is first fed into an embedding layer to get a dense representation, then, the dense representations are fed into the multi-head self-attention layers to get the clinical features.
Clinical Embedding. A piece of clinical data usually contains several items, , where
is a kind of clinical item. To better represent the raw clinical embedding, we assign each clinical item a dense feature using the popular embedding technique. We first give the initial item representation by the one-hot encoding, and then a matrix is introduced to project the initial representation to a dense feature:
where one_hot() is the function that project the item to a one hot vector, is the learnable map weight, is the item vocabulary size, and the is the dimension of dense representation. For the sake of symbol simplicity, we still use to denote the dense vector of item .
Multi-Head Self Attention. We adopt multi-head self attention in our model, which allows the model to jointly attend to information from different representation subspaces at different positions. Multi-head attention is an extension of self-attention, but repeat the attention mechanism several times.
Each time, the transformer uses three different representations: the Queries, Keys and Values generate from the fully-connected layers. Fig.3 illustrates the whole process of self-attention mechanism. Let be the matrix formed by the item embeddings of clinical data , mathematically, the outputs by the computation of self-attention can be expressed as:
where are the learnable parameters, is the embedding dimension. Taking the self-attention (SA) as the basic block, the multi-head self-attention (MSA) is given by repeating the SA several times, and the outputs from different heads are concatenated together. Finally, the architecture of our lite-transformer is given as follows:
where , is the layer normalization, indicates the final clinical features, is the total MHA layers.
ProSE-Net is the model that learns unstructured data representation through a 3DResnet based network with several repeatable 3DSE-ResBlocks that is composed of a residual block, followed by a Channel SE-block and a Temporal SE-block. Such design can effectively improve the representational power of ProSENet. The key contributions of our ProSENet lie in the 3D channel SE block and temporal SE block, hereinafter, we will elaborate the details of these two modules.
Channel SE-block. Channel SE-block targets to produce a compact feature via a squeeze-and-excitation operation along the channel dimension. Let be an arbitrary feature map, channel SE-block first performs “Squeeze” operation:
Excitation operation first introduces two full-connected layers to perform a interaction between different channels, and a sigmoid is introduced to produce a information filter:
where . would server as the gate to perform information selection and the channel feature in would be updated as follow:
The common SE block only maintain the channel information and squeeze other dimensions, consequently, the important temporal information of 3D slices is also missed. To address this weakness, we augment the naïve SE block with a temporal excitation. First, the temporal dimension is remained when pooling the feature:
, we next produce a channel gate for each frame , where the channel gate for frame is computed by Eq. 7:
we share the weights and when producing gates in each channel SE block. The goals for the weight sharing stems from two aspects, the first is to propagate the information inside different views, building a lighter network with fewer paramters is the second reason. In our practice, sharing parameters can also promote the performance.
Finally, we fuse two types of gates and develop our full channel SE block as follows:
where is the final gate filter, which is a combination of local view and global to perform a more reliable information filtering.
Temporal SE-block Similar with Channel SE-block, temporal SE-block also targets to filter out the redundant information but focuses on the temporal dimension. As shown in Figure 5 (b), the computation procedure is analogous with the channel SE-block. The pooling is down along the channel-spatial dimension and spatial dimension, which provide the global and local frame information, respectively. Then, two types of gates are similarly produced following Eq. 10 and Eq. 7, which are then fused via Eq. 12 and give the joint gate . Finally, the temporal SE-block is formulated as:
In our 3D SE block, the channel SE and temporal SE are stacked to achieve the information filtering along the channel and temporal dimensions, which forms our entire 3D SE-Resblock with the well-known 3D Resblock. The CT slices are first fed into the ProSENet to extract multi-dimension features and then pass through a 3D global average pooling to get the final feature .
3.4 Multimodal feature fusion and prediction
Given the clinical features from lite transformer and CT image feature from ProSENet, the next task is to fuse the multi-modality features and give the prediction of survival time . Thanks to the powerful features from our Lite transformer and ProSENet, we simply concatenate the cross-modal features and predict the survival time using a MLP, and a encouraging performance can be harvested in our practice:
where the is a two-layer full-connected layers, performs concatenation for the input two vectors.
Enhance Prediction via Frame Difference. We observe that although the CT images contain rich information, there are so many duplicated pixels between the CT slices, hindering the ProSENet to perceive the key information among the CT slices. To remedy this issue, we propose a simple yet effective mechanism, i.e., frame difference. The proposed frame difference performs a subtraction between two consecutive slices, such that the duplicated pixel could be ignored in the resulted slice. Following this idea, we perform the frame difference along two directions: forward and backward, the produced CT images are marked as , respectively. Given this, our visual information are contain three types, i.e., the raw data , frame difference along forward and backward direction and . We then feed each of the visual information and the clinical data into our Lite-ProSENet, consequently, three time prediction could be given. Finally, we integrate the three predictions to produce the final result:
where and are the survival prediction from the and , respectively, is the trade-off weight, and is the final prediction of survival time.
3.5 Network optimization
With the final prediction and the ground-truth survival time , the network parameters are learned by minimizing the distance between the prediction and the ground-truth:
where is the batch size during training, is the normalization, the second penalty is the parameter normalization, which is introduced to avoid overfitting, is all of the network parameters, and is the trade-off hyper-parameter.
We conduct extensive experiments based on NSCLC patients from TCIA to validate the performance of our proposed method with several state-of-the-art methods in terms of the prediction accuracy for the survival time for each patient. Besides, we also evaluate the prediction result by concordance. Afterward, we perform several ablation experiments regarding different network structures to determine the best structure.
In this work, we considered 422 NSCLC patients from TCIA to assess the proposed framework. For these patients pretreatment CT scans, manual delineation by a radiation oncologist of the 3D volume of the gross tumor volume and clinical outcome data are available [clark2013cancer]. The corresponding clinical data are also available in the same collection. The patients who had neither survival time nor event status were excluded from this work.
4.2 Data Preprocessing
For CT images, we resize the raw data which is the 3D volume of the primary gross tumor volume into . After that, we transform the range linearity into [0,1]. Then, to prevent overfitting, we perform data argumentation which includes three methods: rotate, swap, and flip. Then we get samples, among which there are uncensored samples and censored samples.
Clinical data contains categorical data and non-categorical data. Firstly, missing value is a common problem in medical data and may pose difficulties for data analyzing and modeling. Specifically, in our dataset, the ’age’ category contains a few missing values. After observing the data, we find that the age of patients in the dataset is close to each other. Thus, we impute the mean value and fill it into the missing value. Afterward, in order to fit into our model, we use the one-hot encoder to encode categorical data into numbers, which allows the representation of categorical data to be more expressive.
Then, we use the min-max feature scaling method and standard score method to perform data normalization, such as age and survival time. For input , the min-max feature scaling method’s output is:
and the standard score method’s output is:
is the standard deviation.
For a single patient with multiple tumors, we select the primary gross tumor volume (’GTV-1’) to be processed in our work, while other tumors such as secondary tumor volumes denoted as ’GTV2’, ’GTV3’ to name just a few, which were occasionally present, were not considered in our work.
4.3 Experiment Setup
We train and evaluate the framework on the NSCLC-Radiomic dataset following 5-fold cross-validation with the patient-level split. We divided the dataset into training, validation, and testing data into 6:2:2 respectively. In Lite-transformer, the number of head in MHA is set as 3, and the total layers is configured in 5, more layers and heads bring limited performance gain but large parameters in our practice. In ProSENet, the ratio of channel and temporal SE are both set as 2, i.e.,
. For hyperparameters tuning such as the penalty coefficient, we use the validation dataset to fine-tune and get the optimized hyperparameters. We config thein Eq. 15 and the in Eq. 16
as 0.4 and 0.001 to trade-off different terms. In the training process, we use 800 epochs in total with Adam as the optimizer. The batch size parameter is set as 64. The initial learning rate is set as 0.001, then the learning rate is decayed by 0.5 in every 40 epochs.
Since we use survival time as the label, not cumulative hazard. In the training and validation process, we only use the uncensored data for precise survival time and objective function calculation, and in the testing process, we use all data for concordance evaluation and uncensored data for MAE evaluation.
Since this is the first work to use a multimodal framework for NSCLC survival analysis, we implement several state of the art survival analysis methods as baselines to compare with our work. The baseline methods include Cox-time [kvamme2019time]), DeepHit [lee2018deephit], CoxCC [kvamme2019time], PC-Hazard [kvamme2019continuous] and the regular cox regression.
|Cox + SuperPC [bair2006prediction]||✓||-||-||✓||0.556||NA|
4.4 Quantitative Results
In this subsection, we make a thorough comparison with both traditional and recent deep learning-based methods. The quantitative results of C-index and MAE are compared in Table 1.
As shown in Table 1, all the comparison methods except our previous work DeepMMSA only use the clinical data or the CT slices for prediction, for example, by building a survival function, Cox- regression can provide the probability that a certain event (e.g. death) occurs at a certain time , the C-index of Cox- regression is only 0.601. In contrast, many experiments based on Deep Learning only use the visual information from CT scans. Although deep convolutional neural networks (DCNNs) are very powerful in feature extraction, the visual information alone is not reliable enough to accurately predict survival time. For example, the best C-index of deep learning-based methods only use visual CT is 0.703 [2017WSISA]. Our previous work, DeepMMSA [DBLP:conf/smc/WuMHLS21] makes the first attempts to fuse the multimodal data using a two-tower framework. Although we found that multimodal inputs could boost the performance, the final results do not surpass the deep learning based methods using only visual information such as WSISA [2017WSISA]. This observation indicates that the straightforward network cannot work well for multimodal fusion. Consequently, we developed our Lite-ProSENet to build an effective multimodal network for survival analysis. Our Lite-ProSENet was able to achieve a C-index of 0.893, outperforming all comparative methods , which well validate the superiority of our method.
4.5 Ablation Study
To build an effective cross-modal survival model, we design our Lite transformer for clinical data and propose the 3D- SE Resblock to effectively model the visual CT slices. Furthermore, we propose a frame difference mechanism to promote our performance to the new state-of-the-art. In this subsection, we will verify the effectiveness of the above modules to support our claims through extensive experiments.
The results are reported in Table 2, where we systematically examine the contribution of each component, including the Lite-Transformer, the 3D- SE Resblock in ProSENet, and the mechanism of frame difference. In the baseline method ( no modules are equipped ), the Lite-Transformer is replaced by several MLP layers to form a similar parameters. As is shown in Table 2, the C-index of the baseline method is only 0.796, and the C-index improves when each module is equipped. For example, the baseline with Lite-Transformer could achieve a C-index of 0.824, and the 3D- SE Resblock helps the baseline to improve the C-index from 0.796 to 0.841. Applying any two modules simultaneously could improve the performance even further. If we apply 3D- SE Resblock and frame difference, we could attain the C-index of 0.873, which is a significant improvement. When all of three modules are configured, we harvest the best performance, whose C-index could reach the new state- of-the-art 0.893. The observation on MAE shows the consistent tendency.
As one of our main motivations for the Lite-ProSENet design, verifying the effectiveness of multi-modality modeling is also a critical aspect. We also investigate the benefits of multi-modality learning from this aspect. The results are also reported in Table 2, where and Lite-Pr-
refer to the Lite-ProSENet with visual tower and textural tower, respectively. We can observe that the network with any tower alone could not achieve satisfactory performance, the visual tower only achieves a C-index of 0.712. Although the 3D- SE block boosts the performance to 0.739, it is still not satisfactory. The observations of are also conclusive. The model with multi-modality learning could achieve a C-index of 0.796, which well demonstrates the importance of fusing the clinical data and the visual CT images for the survival time analysis.
|Lite-Transformer||3D-SE Resblock||Frame Difference||C-index||MAE|
In this section, we would give several discussion about the many choices when building our network, including the effect of the joint gate in our 3D SEResblock, the order of two SE blocks, the impact of the bi-directional frame difference. Besides the choices of several mechanisms, the hyper-parameters, in Eq. 15 and in Eq. 16, are also presented in this section.
|SE blocks||global SE||local SE||C-index||MAE|
5.1 Validate the joint gate in 3D SEResblock.
In the 3D SEReslock, we augment the channel SE and the temporal SE with the joint gate to perform the information filtering, more details can be found in subsection 3.3. In this subsection, we would validate the effectiveness of our proposed joint gate.
The results are reported in Table 3, we set the baseline as the network where visual tower is the naïve 3D Resnet, ‘global SE‘ refers to the gate is only built by the naïve SE block, for channel SE, the output of global SE is produced by the Eq. 8111When studying the channel SE (temporal SE), we equip the full temporal SE block (channel SE).. ‘local SE’ indicates the gate is only build by the channel-wise or frame-wise information, for channel SE, the output of global SE is produced by replacing the in Eq. 8 with defined in Eq. 10. Joint gate uses both the global and local SE block, i.e., our 3D SEResblock. As we can observed from Table 3, SE block is an effective block, the system is benefit from both types of SE block. For channel SE block, when equipping the global SE block, the C-index is improved from 0.826 to 0.842, and the local SE block also boosts the perform from 0.826 to 0.867. When the joint gate is applied, the performance gets a significantly improvement, from 0.0.826 to 0.893. The observation of temporal SE block is also conclusive, which well validates the effectiveness of 3D-SE Resblock.
|Channel SE||Temporal SE||C-index||MAE|
5.2 Study the stacking order of two SE blocks.
In our 3D- SE Resblock, the channel SE is applied first, and the temporal acts on the output of the channel SE block, as shown in Eq. 12 and Eq. 13. In this subsection, we study the difference in performance between two SE blocks in different stacking order.
The performance comparison is given in Table 4, we study two types of stacking order, i.e., channel SE first and then temporal SE second, and vice versa. As shown in Table 4, the strategy of the channel SE first and the temporal SE second performs better. The C-index of channel SE first could reach 0.893, while the temporal SE first is worse, whose c-index is 0.881. Consequently, we first apply the channel SE in our network to achieve a better C-index. In addition to stacking order, in this subsection, we also investigate the importance of two SE blocks. As shown in Table 4, using the channel SE or the temporal SE performs alone performs worse than using two SE blocks simultaneously with arbitrary stacking order, which validating the effectiveness of our channel and the temporal SE block.
5.3 The effectiveness of the bi-directional frame difference.
When predicting the final survival time, we introduce frame difference to filter out the redundant information between different CT slices. To further boost the performance, we perform bidirectional frame difference among CT images. In this subsection, we discuss the effectiveness of our bidirectional frame difference.
To thoroughly validate the effectiveness of the proposed frame difference, we study three cases, i.e.,, only the frame difference along forward direction and backward direction, and the bi-directional frame difference. The results can be found in Table 5, where the ’forward’ and ’backward’ mean the normal direction and the reverse direction, respectively. From Table 5 we can observe that both the’forward’ and ’backward’ frame difference can promote the performance. When introducing the forward frame difference, the C-index gets improved from 0.854 to 0.881, and the backward frame difference can boost the C-index from 0.854 to 0.879. When we integrate the frame difference simultaneously in the forward and backward directions, we get the best C-index of 0.893. These observations well reveal that our proposed frame difference is an effective mechanism.
5.4 Discussion about hyper-parameter
We introduce a trade-off parameter when integrating the prediction of normal CT slices and the bi-directional frame difference, as shown in Eq. 15. The is set as 0.4 by default. In this subsection, we would study the performance change when the hyper-parameter varied.
Figure 6(a) shows the performance of our network when varied from 0 to 1 with step 0.2. We can observe from Figure 6(a) that the c-index and MSE loss show consistent trendency. When , this means we only predict by the slices of frame difference and ignore the normal CT slice, this case does not achieve a satisfactory performance whose c-index is only 0.841, the reason for this observation may stem from too much information missed when completely abandoning the original CT slice. This guess is validated by the cases of . The case of means we does not apply the frame difference, such case also fail to achieve the best performance, revealing that the slices of frame difference are necessary for our network. The case of achieve the best performance, this means that slices from frame difference play an important role in the survival time prediction. Further enlarging the weight of frame difference does not promote the performance. Therefore, we fix as 0.4 in our network.
5.5 Discussion about hyper-parameter .
When training the network, we employ the popular parameter normalization strategy to avoid overfitting, i.e., the second term in Eq. 16, and introduce a hyperparameter to balance the main loss and the parameter normalization. By default, we set to 0.001. In this section, we will study the impact of the hyper-parameter on the performance.
The changes of c-index and MSE loss are shown in Figure 6(b), where the y-axis represents performance and the x-axis represents . We study the performance under , from Figure 6(b), we can observe a clear trend. When , which means that we dispense with parameter normalization, the network does not achieve very good performance. Then, when we increase , the performance starts to increase. When , we were able to achieve the best c-index 0.893. If we increase further, we cannot obtain more performance gains. In the case of , a good trade-off between the main loss and the parameter normalization is achieved.
6 Conclusion and Future Work
This work contributes a powerful multimodal network for more accurate prediction of NSCLC survival, with the purpose of helping clinicians to develop timely treatment plans and improve patients’ quality of life. Our method provides a new state-of-the-art result of 89.3% on the C-index. To well model the cross-modal data, we develop a two-tower network, with the textual tower responsible for the clinical data and the visual tower for the CT slices. Inspired by the success of the transformer in the NLP field, we propose a very light transformer using the core of self-attention. For the visual tower, we design a ProSENet based on the 3D- SE Resblock, where channel Squeeze-and-excitation and temporal Squeeze-and-excitation are proposed to suppress the redundant information among the CT slices. Besides, we further introduce a frame difference mechanism to help promote our network up to the new state-of-the-art in terms of C-index and MAE. In experiments, we conduct comparisons, ablation studies and discussions that well verify the superiority of our Lite-ProSENet. The practice of this work gives us much confidence about the deep learning-based survival analysis. We believe that the deep learning-based method has great potential to be realized for survival time analysis. In the future, we will further investigate this problem from the following two aspects:
Effective fusion of cross-modal features. In this work, the fusion of multimodal features is simple, we simply concatenate the features from Lite-transformer and ProSENet. In the future we will explore more effective fusion manners.
Borrow information from large-scale pretrained models. Large-scale pretrained cross-modal models have shown great potential in many tasks, such as Visual question answering, images captioning, cross-media retrieval, et al. After training with millions of data, the large-scale models contain powerful knowledge, how to adapt these knowledge to survival time analysis is a promising direction. We will explore this direction in the future.