Diagnosis of Pneumonia-like diseases from X-ray scans has shown promising results (Rajpurkar et al., 2017)
, leading to Radiologist level performance. The potential of Deep Learning (DL) with Computer Vision (CV) is not only in performance, but also in speed to diagnosed scans. Multi-task learning (MTL) framework has proven to be effective in different domains(Teichmann et al., 2018). MTL enables different, yet related, tasks to be performed using one neural architecture model. Not only this paradigm enables efficient use of convolution layers, but also, it improves the performance over the individual tasks models, thanks to the captured common features from many data sets. The challenge is usually in having a jointly labeled data set. This is usually overcome with carefully designed training pipelines, making use of alternating training.
Different diseases can be diagnosed at one from the same X-ray image, thanks to multi-label classification (Rajpurkar et al., 2017). This is somehow a form of multi-tasks learning. While diagnosis of diseases from X-rays is more about classification, other tasks, like segmentation (SIIM-ACR, 2018), and localisation (RSNA, 2018), can be important guidance to radiologists both for interpretation and better diagnosis of the disease. Also, Deep learning is often said to be black box, so model interpret-ability is an important factor in the adoption of these models. Providing segmentation and localization of the disease segment, in addition to the disease class, provides better evidence and enhances interpret-ability.
We propose MultiCheXNet, an MTL architecture comprised of a common convolutional encoder, and three tasks heads; 1) classification, 2) localisation and 3) segmentation. The MTL common encoder and the classifier head, are pre-trained on generic data set (Wang et al., 2017), to be able to classify 14 diseases. Also, each of the localisation and segmentation heads are pre-trained before being integrated in the MTL architecture. The final MTL architecture comprising the three heads and the encoder is trained end-to-end.
Since we do not have joint labels for the three tasks, we perform alternating training, using two data sets, with two heads trained at a time. In another setup, we create localization boxes from the segmentation masks, and perform joint training for the three tasks at once. Moreover, the available datasets for segmentation and detection, are actually for different diseases classes. However, we rely on the fact that X-ray images are somewhat similar and have a standard to some extent. So even if the datasets are for different diseases, the features are likely to be similar, since patterns of diseases are similar across different diseases. This justifies the common encoder importance in the MTL architecture.
Medical data sets usually suffer class imbalance, due to the nature of the problem having more negative samples than the positive ones. For tasks like segmentation or object detection, this can be hurtful since those models are not good at providing ”no detections”. We propose a multi-stage pipeline, integrating a first stage classifier that filters negative classes, before performing other tasks. During training, we employ teacher forcing, where we feed ground truth positive classes with certain percent, and classified samples from the first stage classifier for the remaining percent of training samples.
Finally, we employ transfer learning to unseen Pneumonia-like diseases, from the MTL architecture. Using the common convolutional encoder and the classification head, we transfer the model weights, and fine tune the top layers to classify new unseen, Pneumonia-like diseases. Out experimental use-case is COVID-19 classification. Plugging X-rays from unseen, new disease like COVID-19, into the MTL architecture, can provide free segmentation and localization boxes of disease findings, without being trained on that particular disease data. This provides a form of few-shot learning, that can be helpful for new diseases where annotated data for segmentation or localisation is still not yet available. This setup makes sense, since the features and symptoms of Pneumonia-like disease are very similar in the X-ray scans, which gives more advantage to the learnt features in the MTL encoder trained from other diseases data sets.
To evaluate MultiCheXNet, we first develop some baselines, for each of the three individual heads. The classification head is trained on (Wang et al., 2017), the segmentation on (SIIM-ACR, 2018) and the localisation on (RSNA, 2018). In the localisation and segmentation models, teacher forcing is performed, with the ground truth positive cases, and with some percent from the pre-trained classifier output. Moreover, each head classifier is further pre-trained on its own data set. After integrating and training the MTL model, we compare each head performance, with their individual performance, on the same data sets they were pre-trained on, to show the effect of MTL. Finally, we develop a baseline for an individual classifier on (Cohen et al., 2020), and compare that to the performance of the MTL head on the same COVID-19 data set.
2 Related Work
CheXNet (Rajpurkar et al., 2017) was one of the pioneering works that applied Deep Learning and Convolutional networks to medical imaging diagnosis (Lakhani and Sundaram, 2017), (Wang et al., 2017). The authors claim radiologist level accuracy. This inspired many following works to apply deep learning in medical images diagnosis, extending to tasks like semantic segmentation and object detection and localization (Sirazitdinov et al., 2019), (Jaiswal et al., 2019).
Recently, many efforts have been made to apply the same idea to diagnose COVID-19 disease from X-rays and CT scans (Wang and Wong, 2020), (Basu and Mitra, 2020), (Maguolo and Nanni, 2020), (Tartaglione et al., 2020). These efforts were usually constrained by the limited amount of data, due to the recency of the disease outbreak, (Maguolo and Nanni, 2020), (Tartaglione et al., 2020).
In the context of computer vision in automated driving, MultiNet (Teichmann et al., 2018) is also a pioneering work applying Multi-task learning for real-time object detection and road segmentation, making use of the idea of shared convolutional encoder to reduce unnecessarily redundant computation if separate models were developed for each task. Further, it was shown that overall detection and segmentation performance improved over the individual models trainings.
3.1 Multi-Task Learning architecture
MultiChexNet is a Multi-Task Learning (MTL) architecture, of four main neural network blocks; a common encoder, and three decoders, as shown in 2. We collectively define each block as a function parametrized with some weights ; , where is the input data of grey scaled X-ray images, with each image of size .
3.2 Encoder (CheXEnc)
The encoder network CheXEnc is a convolutional neural network, with a set of trainable kernel weights; , that encodes the input X-ray scan
into a common representation vector, where is the number of feature maps, and is the dimension of the encoded feature map:
Three tasks are to be learned using three specialized convolutional decoders: CheXCls, CheXDet and CheXSeg. Each head is associated to as loss function for each task:, and . The whole architecture is optimized end-to-end minimizing the total loss. The common encoder will be updated as part of the three losses, and hence will capture the common features, and will enable to transfer the learned features between the three tasks, which improves the overall performance. Another advantage is that, a single inference pass through the encoder is required, instead of three separate decoders in case of individual networks, which save the inference time and reduces the final model size.
Following are the details of each decoder, the associated loss, and the overall loss.
3.3 Classification decoder head (CheXCls)
The classification head decoder is a convolutional neural network parametrized by kernel weights, that operates on the encoded feature map from CheXEnc:
Where represent output vectors for the input X-ray images; such that each output vector is the result of
output neurons with sigmoid activations, each associated to a class label. We collectively refer to the dataset used for classification as, where are the ground truth labels, and are the individual samples labels. The loss of classification is a binary cross entropy loss over each output class neuron as follows:
3.4 Detection decoder head (CheXDet)
The detection head decoder is a convolutional neural network parametrized by kernel weights, that operates on the encoded feature map from CheXEnc:
Where represent outputs corresponding the input X-ray images. Each output represents output
anchor boxes. Every box parameters are resulting from 4 output neurons with ReLU activations.
We refer to the dataset used for detection as , where , and represent ground truth bounding boxes for an input image . is the maximum number of ground truth bounding boxes in an input image. Each bounding box ( or ) is represented with four numbers: . Where and are the coordinates of the box center, while and are the length and width of the box. The loss of detection is a L2 regression over the box parameters of both the output anchors and the ground truth boxes, provided an object exists in the region of overlap between both boxes:
3.5 Segmentation decoder head (CheXSeg)
The segmentation head decoder is a convolutional neural network parametrized by kernel weights, that operates on the encoded feature map from CheXEnc:
Where is the set of output segmentation masks for the input X-ray images. Each output mask represents a binary mask over the image pixels (result of sigmoid activation), segmenting the area of the disease.
We collectively refer to the dataset used for classification as , where represent output masks, each is a binary mask over the input image . The loss of detection is a binary cross entropy over each value of the mask outputs:
3.6 Overall loss
Given a jointly labelled dataset: , where is the set of joint labels for the class, bounding boxes and segmentation masks annotations for the same input , we can calculate the total loss as:
Where can be optimized end-to-end minimizing the total loss as follows:
However, having a joint dataset is not an always easy. In some cases, if we have segmentation masks, we can have the corresponding boxes. The opposite will result in wide masks, spanning the box area. We will discuss other alternative training method in the training protocol section.
3.7 Positive cases classifier
We have noticed during training of the segmentation and detection heads that, feeding both positive and negative samples hurts the performance. This is mostly because of the high data imbalance, where most of the input X-ray scans correspond to negative cases, where no bounding box or segmentation mask available. This creates some difficulty for these decoders, specially the detection head, where it is by default generating boxes, and then filtered by post processing.
To avoid this issue, we train the segmentation and detection heads on positive samples only (the ones with mask or box in the label). Now the question is: during inference, how we control feeding only positive samples?.
For that we design a first stage classifier , trained as a positive/negative classifier, followed by the MTL encoder and decoders. The classifier will filter the negative cases, and prevent them from going through the segmentation or detection heads. The classifier can be trained independently as a binary classifier with binary cross entropy loss. During MTL training, this classifier is frozen (appears as grey box in the architecture diagrams).
However, we can use the same classification head that we train as part of the MTL architecture, as a filter in that pipeline. In that case the input is first passed on the CheXEnc+CheXCls sub-network, so that the negative cases are filtered out if no disease from the set of diseases is detected by CheXCls decoder. Otherwise, it is a positive case, and can be passed to the segmentation and detection heads.
The drawback of this approach is that we need to run the encoder twice; once for filtering, and then through the MTL. This can be optimized by caching the encoder output, and using it for both phases of the pipeline.
3.8 Teacher forcing(TF)
Having only positive cases during training of the MTL decoders is a form of teacher forcing (TF). During training, the first stage classifier , needs to pre-trained independently on negative cases, before being trained only on positive cases as part of MTL architecture. This will be detailed in the experimental setup section with the training phases.
Another drawback could be that, the detection and segmentation heads are never trained on negative cases. This is an issue if the first stage classifier fails and passes a negative case by mistake. For that, we mix in training to have some negative cases for some percent of the training data, and also some samples as classified by the first stage classifier. This is handled by tossing a coin with every batch or sample, with probabilityforcing the positive sample from the ground truth, and to take whatever passed from the first stage classifier . In the architecture diagrams, we set the module as ”dotted” as a notation of the coin tossing in teacher forcing training.
3.9 Transfer Learning to unseen diseases
The classification task is the mostly needed for diagnosis, while segmentation and detection are more needed for interpret-ability and explanability of the diagnosis. There are many data sets with one or more disease. Such datasets are collected over long time. For new diseases, limited amount of data can be available, specially during the initial outbreak, like in the case of COVID-19. Since the classification head has a multi-label loss, so transfer learning can be employed to fine-tune the separate diseases labels. Under MTL architecture, we can further make use of the common features captured in the CheXEnc network. The classification path (CheXEnc+CheXCls) can be further fine tuned on new data sets with new diseases.
4 Experimental Setup
4.1 Training protocol
As explained, the MultiCheXNet architecture can be trained end-to-end, given a jointly labeled dataset for all the three tasks. However, there are some aspects we consider:
Limited jointly labeled datasets: as explained before, it is possible to obtain bounding boxes from segmentation masks, but the opposite is inefficient. We have many data sets available, some with classification only, some with segmentation masks annotations and others with boxes annotations. We want to make efficient use of all of them.
Big-bang integration: we have four models, with four sets of weights. Instead of initializing them with random weights, we choose to pre-train each head alone, and then integrate and fine tune them in the MTL architecture. This will help all the models to start from the baseline performance of individual networks, and then improve thanks to the common features captures in the common encoder. The common encoder is common in all pre-training phases.
The first stage classifier needs to be trained alone first, before being used for teacher forcing training. Where in MTL training of either segmentation or detection heads, teacher forcing will mostly pass the positive cases, so the will not be enough trained on negative cases, while this is needed for its role as a filter. Hence, we need to pre-train this classifier alone at different stages. We will pre-train it on a generic dataset (Wang et al., 2017), and also we pre-train it within the detection or segmentation pre-training, on all the positive and negative classes, even before the detection and segmentation heads are trained on positive cases.
The phases are as shown in figure 4:
Classifier pre-training: CheXEnc () + CheXCls ().
Detection head pre-training: CheXEnc () + CheXDet ().
Segmentation head pre-training CheXEnc () + CheXSeg ().
MTL training CheXEnc () + CheXCls () + CheXSeg () + CheXDet ().
Transfer learning to new data/diseases: CheXEnc () + CheXCls ().
The first 3 phases are considered as a baseline for comparison of individual performance against the integrated MTL architecture. For all the below figures, Grey boxes means not-trainable, dotted boxes means teacher forcing with percentage.
4.1.1 Classification head pre-training
The classification sub-network: CheXEnc () + CheXCls () is pre-trained on ChestX-ray-14 (ref) for 14 classes. This sub-network will further serve for the pipeline classifier to filter out the negative cases.
4.1.2 Detection head pre-training
In the detection head pre-training,the following networks are updated on RSNA dataset (ref): CheXEnc () + CheXDet (). The encoder is taken from the pre-trained classifier in the first phase. During this phase, we need to run the pipeline classifier to filter out the negative cases. For that, we further pre-train the classifier on the specific positive/negative cases from RSNA dataset.
Also, we employ teacher forcing with , which means that we toss a coin with every batch, and 10% of the time we take the output of the pipeline classifier, and 90% of the time we force the positive class samples.
The detection dataset (RSNA, 2018), includes both the boxes, in addition to the disease classification of the overall image (e.g. Pneumonia). Before pre-training the detection head, we first pre-train the classification path on the classification targets of the dataset. The reason is that; we will only focus on the positive cases during the detection pre-training, and hence the classifier will not be enough trained on negative cases. For that we create a new dataset out of ; , where are the classes of the images of the detection datast :
Following the pre-training, the encoder and the detection head are further trained on the detection labels :
4.1.3 Segmentation head pre-training
In the segmentation head pre-training,the following networks are updated on SIIM-ACR dataset (ref): CheXEnc () + CheXSeg (). The encoder is taken from the pre-trained classifier in the first phase. During this phase, we need to run the pipeline classifier to filter out the negative cases. For that, we further pre-train the classifier on the specific positive/negative cases from SIIM-ACR dataset.
Teacher forcing is also employed, same as in detection head pre-training. Also, similar to the detection case, we first pre-train the classifier on a new dataset out of ; , where are the classes of the images of the detection datast :
Following the pre-training, the encoder and the segmentation head are further trained on the detection labels :
4.2 MTL training
In MTL training, all the four networks are updated: CheXEnc () + CheXCls () + CheXSeg () + CheXDet (). We have two experimental setups:
Joint training: we do not have one training dataset that includes the segmentation masks and bounding boxes, for the same input. However, we could fit a box on the segmentation masks of the SIIM-ACR dataset, and train the model jointly end-to-end. However, this limits the training to only one dataset as in equation 8.
Alternating training: to handle the lack of jointly labeled datasets for segmentation and detection, we train the segmentation pipeline on SIIM-ACR dataset, alternating with the detection pipeline on RSNA dataset, one batch from each data set at a time. While training one head, the other is not trained as in figure 3. The common encoder is trained in both steps, and hence improved from both datasets, and can capture the common features. The full algorithm is detailed in algorithm table 1.
4.3 Transfer learning/ Few-shot learning
We experiment on the COVID-19 use case, were limited data is available for COVID X-rays. However, we can transfer the pre-trained (CheXEnc+CheXCls) network, on the three heads, on abundant data for different diseases. After the MTL architecture is trained, the classification sub-network CheXEnc () + CheXCls () is fine-tuned on COVID-19 dataset (Cohen et al., 2020). The COVID-19 dataset is small, and hence we compare the performance before and after transfer learning. Also, we compare the performance of transfer from the pre-trained classifier in the first stage (only the classification sub-network on ChestX-14 dataset), to COVID-19 data, versus transfer from the MTL architecture.
We fine-tune the classification path . For that, we factorize its parameters as follows:
Where are the early layers leading to the extracted features, of dimensions that will be used for classification, and are the final set of weights, producing the final output for classes, resulting from sigmoid activation neurons. are transfered from the pre-trained in the MTL algorithm, while is fine-tuned from scratch on the new class(es), e.g. COVID-19.
4.4 Data sets
Chest X-ray14 dataset by (Wang et al., 2017) has 112,120 annotated for 14 different diseases labels.
RSNA (RSNA, 2018) is a Kaggle competetion for detecting and localizing Pneumonia in X-ray lung scans. the data consists of 26,700 scans for training and another 3000 scans for testing, labelled with bounding boxes coordinates of the locations of Pneumonia in the images.
SIIM-ACR (SIIM-ACR, 2018)is a Kaggle competetion for detecting and localizing Pneumonia in X-ray lung scans. the data consists of 12,100 scans for training and another 3205 scans for testing, labelled with segmentation masks of the locations of Pneumonia in the images.
COVID-19 (Cohen et al., 2020) is a public open dataset of chest X-ray and CT images of patients which are positive or suspected of COVID-19 and other 16 diseases like MERS, SARS, and ARDS. This dataset is very small ( 100 images).
Effect of Multi-task learning against separately trained tasks - In this experiment, we aim at testing the effect of MTL training over the individually trained models on the the three separate tasks. MTL beats all baselines that are individually trained as shown in Table 1.
|Chest X-ray 14||0.87 (0.86, 0.88)||0.31 (0.27, 0.35)||-||-|
|Baseline Segmentation||SIIM-ACR||-||-||0.68 (0.66, 0.70)||-|
|Baseline Detection||RSNA||-||-||-||0.16 (0.15, 0.17)|
|MTL w/o pre-training||RSNA + SIIM-ACR||0.52 (0.44, 0.59)||0.53 (0.43, 0.61)||0.29 (0.28, 0.30)||0.21 (0.20, 0.22)|
|MTL w/ pre-training||RSNA + SIIM-ACR||0.73 (0.66, 0.79)||0.73 (0.65, 0.80)||0.75 (0.74, 0.76)||0.16 (0.15, 0.17)|
Effect of pre-training - The effect of MTL+pre-training, which we tag as MultiCheXNet in Table 1, is significant over MTL from scratch, using random initialization, which we call BigBang MTL in Table 1. Since the individual heads are pre-trained on relevant datasets, they tend to perform better when integrated in the MTL model, and helps reducing the tuning effort of the overall model, since
Effect of Teacher Forcing and Positive Cases classifier - With the inclusion of the positive cases classifier, which acts as a filter, some errors might occur, and some negative cases can pass to the segmentation and detection heads by mistake as false positives (those should have been filtered by the first stage classifier). Table 2 shows the effect of those false positives on the overall results. It is clear that the results drop on all metrics of the 3 heads, however, given the already good performance of the first stage classifier, the effect is not critical.
|MTL w/||RSNA + SIIM-ACR||0.77 (0.70, 0.82)||0.77 (0.70, 0.83)||0.78 (0.77, 0.79)||0.19 (0.18, 0.20)|
|100% Teacher Forcing|
|MTL w/ pre-training||RSNA + SIIM-ACR||0.73 (0.66, 0.79)||0.73 (0.65, 0.80)||0.75 (0.74, 0.76)||0.16 (0.15, 0.17)|
Multi-task learning – Transfer learning (MTL-TL) scenario - The effect of transferring the learnt classifier (CheXCls), and fine tuning to new unseen classes is demonstrated in Table 3, on COVID-19 dataset, which has only around 100 images, and suffer from high class imbalance towards the negative class as expected. After fine-tuning on the new dataset, the pre-trained classifier performs significantly better than the baseline, trained from scratch on the small COVID-19 dataset alone.
|Source Dataset||Target Dataset||Accuracy||F1|
|Baseline COVID||-||COVID||0.59 (0.52, 0.66)||0.72 (0.63, 0.76)|
MTL - TL: MultiCheXNet (CheXCls)
|RSNA + SIIM-ACR||COVID||0.70 (0.64, 0.76)||0.80 (0.75, 0.85)|
Comparison to CheXNet (Rajpurkar et al., 2017) - We treat CheXNet as a benchmark to compare our results to, since it is the closest work to ours. To do that, we take the trained classification head (CheXCls) in MultiCheXNet, and test its results on the Chest X-ray 14 dataset (Wang et al., 2017) in order to compare our results to CheXNet. Table 4 shows the improvement of F1 score of MultiCheXNet (CheXCls) over the CheXNet. In our setup, we treat CheXNet architecture as a multi-label architecture, with binary cross entropy loss and sigmoid neurons on all the 14 classes of Chest X-ray 14 dataset (Wang et al., 2017). The reason is that, we treat our classification head as a ”universal” classifier, that can be fine tuned jointly on any dataset, and produce joint disease classifications for all classes at once. This is different from the original CheXNet (Rajpurkar et al., 2017) which is based on binary classification setup, handling one disease at a time, focused on Pneumonia, and tested in the same fashion on other diseases. Thus, CheXNet (Rajpurkar et al., 2017) uses one-versus-all setup, with binary cross entropy loss and only one sigmoid, with positive case (1) for the classified disease and negative (0) for ”all” others. So we train our own version of CheXNet, called CheXNet-Multilabel in Table 4, with multi-label setup, and compare that to the MultiCheXNet architecture with the same multi-label problem setup on Chest X-ray 14 dataset (Wang et al., 2017) for fair comparison.
|Source Dataset||Target Dataset||F1|
|CheXNet-Multilabel||-||Chest X-ray 14||0.31 (0.27, 0.35)|
|RSNA + SIIM-ACR||Chest X-ray 14||0.36 (0.31, 0.39)|
As can be seen in 4, fine tuning MultiCheXNet classification pipeline (CheXCls) outperfrom the baseline CheXNet architecture, which again proves the added value of MTL training.
6 Conclusion and Future work
In this work, we presented MultiCheXNet, an MTL framework for Pneumonia-like diseases diagnosis in X-ray scans. Our model can perform three diagnosis tasks at a time, using one model; classification, segmentation and detection of the disease area. Moreover, we carefully designed a training protocol to help the convergence of the overall MTL architecture, making use of all the available datasets, although being dis-jointly labeled for the tasks. This opens the door to incorporating more datasets as needed and available. Finally, we demonstrated the ability to transfer the learned models to new diseases classes, which might suffer the scarcity of data, like COVID-19 case. Future work shall include careful Ablation study of the effect of positive cases classifier, effect of teacher forcing, per disease class evaluation and better statistical evaluation methodology.
This is a work in progress, representing a step towards employing MTL techniques in computer vision aided diagnosis from radiology images. However, it suffers some issues in the evaluation methodology that the authors are aware of, and to be treated in future works. For this reason, this work is not intended to be used as is, or deployed in real treatment protocols, diagnosis or hospitals, due to the lack of enough experimentation and statistically sound evaluation methodologies.
- Deep learning for screening covid-19 using chest x-ray images. arXiv preprint arXiv:2004.10507. Cited by: §2.
- COVID-19 image data collection. arXiv 2003.11597. External Links: Cited by: §1, §4.3, §4.4.
- Identifying pneumonia in chest x-rays: a deep learning approach. Measurement 145, pp. 511–518. Cited by: §2.
- Deep learning at chest radiography: automated classification of pulmonary tuberculosis by using convolutional neural networks. Radiology 284 (2), pp. 574–582. Cited by: §2.
- A critic evaluation of methods for covid-19 automatic detection from x-ray images. arXiv preprint arXiv:2004.12823. Cited by: §2.
- Chexnet: radiologist-level pneumonia detection on chest x-rays with deep learning. arXiv preprint arXiv:1711.05225. Cited by: §1, §1, §2, §5.
- Kaggle rsna pneumonia detection challenge. External Links: Cited by: §1, §1, §4.1.2, §4.4.
- SIIM-acr pneumothorax segmentation. External Links: Cited by: §1, §1, §4.4.
- Deep neural network ensemble for pneumonia localization from a large-scale chest x-ray database. Computers & Electrical Engineering 78, pp. 388–399. Cited by: §2.
- Unveiling covid-19 from chest x-ray with deep learning: a hurdles race with small data. arXiv preprint arXiv:2004.05405. Cited by: §2.
- Multinet: real-time joint semantic reasoning for autonomous driving. In 2018 IEEE Intelligent Vehicles Symposium (IV), pp. 1013–1020. Cited by: §1, §2.
- COVID-net: a tailored deep convolutional neural network design for detection of covid-19 cases from chest x-ray images. arXiv preprint arXiv:2003.09871. Cited by: §2.
- Hospital-scale chest x-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases. In IEEE CVPR, Cited by: §1, §1, §2, 3rd item, §4.4, §5.