Diagnose like a Radiologist: Attention Guided Convolutional Neural Network for Thorax Disease Classification

01/30/2018 ∙ by Qingji Guan, et al. ∙ BEIJING JIAOTONG UNIVERSITY 0

This paper considers the task of thorax disease classification on chest X-ray images. Existing methods generally use the global image as input for network learning. Such a strategy is limited in two aspects. 1) A thorax disease usually happens in (small) localized areas which are disease specific. Training CNNs using global image may be affected by the (excessive) irrelevant noisy areas. 2) Due to the poor alignment of some CXR images, the existence of irregular borders hinders the network performance. In this paper, we address the above problems by proposing a three-branch attention guided convolution neural network (AG-CNN). AG-CNN 1) learns from disease-specific regions to avoid noise and improve alignment, 2) also integrates a global branch to compensate the lost discriminative cues by local branch. Specifically, we first learn a global CNN branch using global images. Then, guided by the attention heat map generated from the global branch, we inference a mask to crop a discriminative region from the global image. The local region is used for training a local CNN branch. Lastly, we concatenate the last pooling layers of both the global and local branches for fine-tuning the fusion branch. The Comprehensive experiment is conducted on the ChestX-ray14 dataset. We first report a strong global baseline producing an average AUC of 0.841 with ResNet-50 as backbone. After combining the local cues with the global information, AG-CNN improves the average AUC to 0.868. While DenseNet-121 is used, the average AUC achieves 0.871, which is a new state of the art in the community.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 2

page 3

page 4

page 5

page 7

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

The chest X-ray (CXR) has been one of the most common radiological examinations in lung and heart disease diagnosis. Currently, reading CXRs mainly relies on professional knowledge and careful manual observation. Due to the complex pathologies and subtle texture changes of different lung lesion in images, radiologists may make mistakes even when they have experienced long-term clinical training and professional guidance. Therefore, it is of importance to develop the CXR image classification methods to support clinical practitioners. The noticeable progress in deep learning has benefited many trials in medical image analysis, such as lesion segmentation or detection

[1, 2, 3, 4, 5], diseases classification [6, 7, 8, 9], noise induction [10], image annotation [11, 12], registration [13], regression [14] and so on. In this paper, we investigate the CXR classification task using deep learning.

Several existing works on CXR classification typically employ the global image for training. For example, Wang et al. [9] evaluate four classic CNN architectures, i.e., AlexNet [15], VGGNet [16], GoogLeNet [17], ResNet [18], to tell the presence of multiple pathologies using a global CXR image. In addition, using the same network, the disease lesion areas are located in a weakly supervised manner. Viewing CXR classification as a multi-label recognition problem, Yao et al. [19] explore the correlation among the 14 pathologic labels with global images in ChestX-ray14 [9]. Using a variant of DenseNet [20]

as an image encoder, they adopt the Long-short Term Memory Networks (LSTM)

[21] to capture the dependencies. Kumar et al. [7]

investigate that which loss function is more suitable for training CNNs from scratch and present a boosted cascaded CNN for global image classification. The recent effective method consists in CheXNet

[8]. It fine-tunes a 121-layer DenseNet on the global chest X-ray images, which has a modified last fully-connected layer.

However, the global learning strategy can be compromised by two problems. On the one hand, as shown in Fig. 1 (the first row), the lesion area can be very small (red bounding box) and position unpredictable (e.g. , “Atelectasis”) compared with the global image, so using the global image for classification may include a considerable level of noise outside the lesion area. This problem is rather different from generic image classification [22, 23] where the object of interest is usually positioned in the image center. Considering this fact, it is beneficial to induce the network to focus on the lesion regions when making predictions. On the other hand, due to the variations of capturing condition, e.g. , the posture of the patient and the small size of children body, the CXR images may undergo distortion or misalignment. Fig. 1 (the second row) presents a misalignment example. The irregular image borders may exist an non-negligible effect on the classification accuracy. Therefore, it is desirable to discover the salient lesion regions and thus alleviate the impact of such misalignment.

To address the problems caused by merely relying on the global CXR image, this paper introduces a three-branch attention guided convolutional neural network (AG-CNN) to classify the lung or heart diseases. AG-CNN is featured in two aspects. First, it has a focus on the local lesion regions which are disease specific. Generally, such a strategy is particularly effective for diseases such as ”Nodule”, which has a small lesion region. In this manner, the impact of the noise in non-disease regions and misalignment can be alleviated. Second, AG-CNN has three branches,

i.e., a global branch, a local branch and a fusion branch. While the local branch exhibits the attention mechanism, it may lead to information loss in cases where the lesion areas are distributed in the whole images, such as Pneumonia. Therefore, a global branch is needed to compensate for this error. We show that the global and local branches are complementary to each other and, once fused, yield favorable accuracy to the state of the art.

Fig. 1: Two training images from the ChestX-ray14 dataset. (a) The global images. (b) Heat maps extracted from a specific convolutional layer. (c) The cropped images from (a) guided by (b). In this paper, we consider both the original global image and the cropped local image for classification, so that 1) the noise contained in non-lesion area is less influencing, and 2) the misalignment can be reduced. Note that there are some differences between the global images and their heat maps. The reason is that the global images are randomly cropped from 256255 to 224224 during training.

The working mechanism of AG-CNN is similar to that of a radiologist. We first learn a global branch that takes the global image as input: a radiologist may first browse the whole CXR image. Then, we discover and crop a local lesion region and train a local branch: a radiologist will concentrate on the local lesion area after the overall browse. Finally, the global and local branches are fused to fine-tune the whole network: a radiologist will comprehensively consider the global and local information before making decisions.

Our contributions are summarized as follows.

  • We propose an attention guided convolutional neural network (AG-CNN) which diagnoses thorax diseases by combining the global and local information. AG-CNN improves the recognition performance by correcting image alignment and reducing the impact of noise.

  • We introduce a CNN training baseline, which produces competitive results to the state-of-the-art methods by itself.

  • We present comprehensive experiment on the ChestX-ray14 dataset. The experiment results demonstrate that our method achieves superior performance over the state-of-the-art approaches.

Ii Related Works

Chest X-ray datasets. The problem of Chest X-ray image classification has been extensively explored in the field of medical image analysis. Several datasets have been released in this context. For example, the JSRT dataset [24, 25] contains 247 chest X-ray images including 154 lung nodules. It also provides masks of the lung area for segmentation performance evaluation. The Shenzhen chest X-ray set [26] has a total of 662 images belonging to two categories (normal and tuberculosis (TB)). Among them, 326 are normal cases and 336 are cases with TB. The Montgomery County chest X-ray set (MC) [26] collects 138 frontal chest X-ray images from Montgomery Country’s Tuberculosis screen program, of which 80 are normal and 58 are cases with manifestations of TB. These three datasets are generally small for deep model training. In comparison, the Indiana University Chest X-ray Collection dataset [27] has of 3,955 radiology reports and the corresponding 7,470 chest X-ray images. It is publicly available through Open-I [28]. However, this dataset does not provide explicit disease class labels, so we do not use it in this paper. Recently, Wang et al. [9] released the ChestX-ray14 dataset, which is the largest chest X-ray dataset by far. ChestX-ray14 collects 112,120 frontal-view chest X-ray images of 30,805 unique patients. Each radiography is labeled with one or more types of 14 common thorax diseases. This dataset poses a multi-label classification problem and is large enough for deep learning, so we adopt this dataset for performance evaluation in this paper.

Fig. 2:

Overall framework of the attention guided convolutional neural network (AG-CNN). We show an example with ResNet-50 as backbone. AG-CNN consists of three branches. Global and local branches consist of five convolutional blocks with batch normalization and ReLU. Each of them is then connected to a max pooling layer (Pool5), a fully connected (FC) layer, and a sigmoid layer. Different from the global branch, the input of the local branch is a local lesion patch which is cropped by the mask generated from global branch. Then, Pool5 layers of the these two branches are concatenated into the fusion branch. ”BCE” represents binary cross entropy loss. The input image is added to the heat map for visualization.

Deep learning for chest X-ray image analysis. Recent surveys [29, 30, 31, 32] have demonstrated that deep learning technologies have been extensively applied to the field of chest X-ray image annotation [33], classification [6, 34, 8, 9], and detection (localization) [35, 36]. Islam et al. [34] explore different CNN architectures and find that a single CNN does not perform well across all abnormalities. Therefore, they leverage model ensemble to improve the classification accuracy, at the cost of increased training and testing time. Yao et al. [19] and Kumar et al. [7] classify the chest X-ray images by investigating the potential dependencies among the labels from the aspect of multi-label problems. Rajpurkar et al. [8] train a convolutional neural network to address the multi-label classification problem. This paper departs from the previous methods in that we make use of the attention mechanism and fuse the local and global information to improve the classification performance.

Attention models in medical image analysis. The CXR classification problem needs to tell the relatively subtle differences between different diseases. Usually, a disease is often characterized by a lesion region, which contains critical dues for classification. Ypsilantis et al. [37]

explore where to look in chest X-rays with recurrent attention model (RAM)

[38]. The RAM learns to sample the entire X-ray image sequentially and focus on informative areas. Only one disease enlarged heart is considered in their work. Recently, Pesce et al. [39] explore a soft attention mechanism from the saliency map of CNN features to locate lung nodule position in radiographies. And a localization loss is calculated by comparing the predicted position with the annotated position.

In this paper, AG-CNN locates the salient regions with an attention guided mask inference process, and learns the discriminative feature for classification. Compared with the method which relies on bounding box annotations, Our method only need image-level labels without any extra information.

Iii The Proposed Approach

In this section, we describe the proposed attention guided convolutional neural network (AG-CNN) for thorax disease classification. We will first illustrate the architecture of AG-CNN in Section III-A. Second, we describe the mask inference process for lesion region discovery in Section III-B. We then present the training process of AG-CNN in Section III-C. Finally, a brief discussion of the AG-CNN is provided.

Iii-a Structure of AG-CNN

The architecture of AG-CNN is presented in Fig. 2. Basically, it has two major branches, i.e., the global and local branches, and a fusion branch. Both the global and local branches are classification networks that predict whether the pathologies are present or not in the image. Given an image, the global branch is first fine-tuned from a classification CNN using the global image. Then, we crop an attended region from the global image and train it for classification on the local branch. Finally, the last pooling layers of both the global and local branches are concatenated for fine-tuning the fusion branch.

Multi-label setup.

We label each image with a 15-dim vector

in which . represents whether the there is any pathology, i.e., 1 for presence and 0 for absence. The last element of L represents the label with ”No Finding”.

Global and local branches. The global branch informs the underlying CXR information derived from the global image as input. In the global branch, we train a variant of ResNet-50 [18] as the backbone model. It consists of five down-sampling blocks, followed by a global max pooling layer and a 15-dimensional fully connected (FC) layer for classification. At last, a sigmoid layer is added to normalize the output vector of FC layer by

(1)

where is the global image.

represents the probability score of

belonging to the class, . We optimize the parameter of global branch by minimizing the binary cross-entropy (BCE) loss:

(2)

where is the groundtruth label of the class, is the number of pathologies.

Fig. 3: The process of lesion area generation. (Top:) global CXR images of various thorax diseases for the global branch. The manually annotated legion areas provided by [9] are annotated with red bounding boxes. Note that we do not use the bounding boxes for training or testing. (Middle:) corresponding visual examples of the output of the mask inference process. The lesion areas are denoted by green bounding boxes. Higher response is denoted with red, and lower blue. Note that the heat maps are resized to the same size as the input images. (Bottom:) cropped and resized images from the green bounding boxes which are fed to the local branch.

On the other hand, the local branch focuses on the lesion area and is expected to alleviate the drawbacks of only using the global image. In more details, the local branch possesses the same convolutional network structure with the global branch. Note that, these two branches do not share weights since they have distinct purposes. We denote the probability score of local branch as , as the parameters of local branch. Here, is the input image of local branch. We perform the same normalization and optimization as the global branch.

Fusion branch. The fusion branch first concatenates the Pool5 outputs of the global and local branches. The concatenated layer is connected to a 15-dimensional FC layer for final classification. The probability score is . We denote as the parameters of fusion branch and optimize by Eq. 2.

Input: Input image ; Label vector ; Threshold .
Output : Probability score .
Initialization: the global and local branch weights.
Learning with , computing , optimizing by Eq. 2 (Stage I); Computing mask and the bounding box coordinates , cropping out from ; Learning with , computing , optimizing by Eq. 2 (Stage II); Concentrating and , learning , computing , optimizing by Eq. 2.
Algorithm 1 Attention Guided CNN Procedure

Iii-B Attention Guided Mask Inference

In this paper, we construct a binary mask to locate the discriminative regions for classification in the global image. It is produced by performing thresholding operations on the feature maps, which can be regarded as an attention process. This process is described below.

Given a global image, let represent the activation of spatial location in the th channel of the output of the last convolutional layer, where , in ResNet-50. denotes the global branch. We first take the absolute value of the activation values at position . Then the attention heat map is generated by counting the maximum values along channels,

(3)

The values in directly indicate the importance of the activations for classification. In Fig. 1(b) and Fig. 3 (the second row), some examples of the heat maps are shown. We observe that the discriminative regions (lesion areas) of the images are activated. Heat map can be constructed by computing different statistical values across the channel dimensions, such as L1 distance or L2 distance . Different statistics result in subtle numerical differences in heat map, but may not effect the classification significantly. Therefore, we computing heat map with Eq. 3 in our experiment. The comparison of these statistics is presented in Section III.

We design a binary mask to locate the regions with large activation values. If the value of a certain spatial position in the heat map is larger than a threshold , the value at corresponding position in the mask is assigned with 1. Specifically,

(4)

where is the threshold that controls the size of attended region. A larger leads to a smaller region, and vice versa. With the mask , we draw a maximum connected region that covers the discriminative points in . The maximum connected region is denoted as the minimum and maximum coordinates in horizontal and vertical axis . At last, the local discriminative region is cropped from the input image and is resized to the same size as . We visualize the bounding boxes and cropped patches with in Fig. 3. The attention informed mask inference method is able to locate the regions (green bounding boxes) which are reasonably close to the groundtruth (red bounding boxes).

Fig. 4: Examples of 8 pathologies in ChestX-ray14. The lesion regions are annotated with the red bounding boxes provided by [9]. Note that these groundtruth bounding boxes are only used for demonstration: they are neither used in training nor testing.
Method CNN Atel Card Effu Infi Mass Nodu Pne1 Pne2 Cons Edem Emph Fibr PT Hern Mean
Wang et al. [9] R-50 0.716 0.807 0.784 0.609 0.706 0.671 0.633 0.806 0.708 0.835 0.815 0.769 0.708 0.767 0.738
Yao et al. [19] D-/ 0.772 0.904 0.859 0.695 0.792 0.717 0.713 0.841 0.788 0.882 0.829 0.767 0.765 0.914 0.803
Rajpurkar et al. [8] D-121 0.821 0.905 0.883 0.720 0.862 0.777 0.763 0.893 0.794 0.893 0.926 0.804 0.814 0.939 0.842
Kumar et al. [7] D-161 0.762 0.913 0.864 0.692 0.750 0.666 0.715 0.859 0.784 0.888 0.898 0.756 0.774 0.802 0.795
Global branch (baseline) R-50 0.818 0.904 0.881 0.728 0.863 0.780 0.783 0.897 0.807 0.892 0.918 0.815 0.800 0.889 0.841
Local branch R-50 0.798 0.881 0.862 0.707 0.826 0.736 0.716 0.872 0.805 0.874 0.898 0.808 0.770 0.887 0.817

AG-CNN
R-50 0.844 0.937 0.904 0.753 0.893 0.827 0.776 0.919 0.842 0.919 0.941 0.857 0.836 0.903 0.868
Global branch (baseline) D-121 0.832 0.906 0.887 0.717 0.870 0.791 0.732 0.891 0.808 0.905 0.912 0.823 0.802 0.883 0.840
Local branch D-121 0.797 0.865 0.851 0.704 0.829 0.733 0.710 0.850 0.802 0.882 0.874 0.801 0.769 0.872 0.810
AG-CNN D-121 0.853 0.939 0.903 0.754 0.902 0.828 0.774 0.921 0.842 0.924 0.932 0.864 0.837 0.921 0.871
  • We compute the AUC of each class and the average AUC across the 14 diseases. denotes that a different train/test split is used: 80% for training and the rest 20% for testing. All the Other methods split the dataset with 70% for training, 10% for validation and 20% for testing. Each pathology is denoted with its first four characteristics, e.g., Atelectasis with Atel. Pneumonia and Pneumothorax are denoted as Pneu1 and Pneu2, respectively. PT represents Pleural Thickening. We report the performance with parameter . ResNet-50 (R-50) and Desnet-121 (D-121) are used as backbones in our approach. For each column, the best and second best results are highlighted in red and blue, respectively.

TABLE I: Comparison results of various methods on ChestX-ray14.
Fig. 5: ROC curves of the global, local and fusion branches (DenseNet-121 as backbone) over the 14 pathologies. The corresponding AUC values are given in Table. I. We observe that fusing global and local information yields clear improvement.

Iii-C Training Strategy of AG-CNN

This paper adopts a three-stage training scheme for AG-CNN.

Stage I.

Using the global images, we fine-tune the global branch network pretrained by ImageNet.

is normalized by Eq. 1.

Stage II. Once the local image is obtained by mask inference with threshold , we feed it into the local branch for fine-tuning. is also normalized by Eq. 1. When we fine-tune the local branch, the weights in the global branch are fixed.

Stage III. Let and represent the Pool5 layer outputs of the global and local branches, respectively. We concatenate them for a final stage of fine-tuning and normalize the probability score by Eq. 1. Similarly, the weights of previous two branches are fixed when we fine-tune the weights of fusion branch.

In each stage, we use the model with the highest AUC on the validation set for testing. The overall AG-CNN training procedure is presented in Algorithm 1. Variants of training strategy may influence the performance of AG-CNN. We discussed it in Section IV-C.

Iv Experiment

This section evaluates the performance of the proposed AG-CNN. The experimental dataset, evaluation protocol and the experimental settings are introduced first. Section IV-C demonstrates the performance of global and local branches and the effectiveness of fusing them. Furthermore, comparison of AG-CNN and the state of the art is presented in Table. I. In Section. IV-D, we analyze the parameter impact in mask inference.

Iv-a Dataset and Evaluation Protocol

Dataset. We evaluate the AG-CNN framework using the ChestX-ray14111https://nihcc.app.box.com/v/ChestXray-NIHCC dataset [9]. ChestX-ray14 collects 112,120 frontal-view images of 30,805 unique patients. 51,708 images of them are labeled with up to 14 pathologies, while the others are labeled as “No Finding”. Fig. 4 presents some examples of 8 out of 14 thorax diseases and the ground-truth bounding boxes of the lesion regions provided by [9]. We observe that the size of the lesion area varies a lot for different pathologies.

Evaluation protocol. In our experiment, we randomly shuffle the dataset into three subsets: 70% for training, 10% for validation and 20% for testing. Each image is labeled with a 15-dim vector in which . represents the label with ”No Finding”.

Fig. 6: ROC curves of AG-CNN on the 14 diseases (ResNet-50 and DenseNet-121 as backbones, respectively).
Fig. 7: Examples of classification results. We present the top-10 predicted categories and the corresponding probability scores. The ground-truth labels are highlighted in blue.
Fig. 8: Average AUC scores of AG-CNN with different settings of on the validation set (ResNet-50 as backbone).
Fig. 9: Average AUCs for different settings of on the test set (ResNet-50 as backbone). Note that the results from global branch are our baseline.

Iv-B Experimental Settings

For training (any of the three stages), we perform data augmentation by resizing the original images to , randomly resized cropping to

, and random horizontal flipping. The ImageNet mean value is subtracted from the image. When using ResNet-50 as backbone, we optimize the network using SGD with a mini-batch size of 126, 64, 64 for global, local and fusion branch, respectively. But for DenseNet-121, the network is optimized with a mini-batch of 64, 32, and 32, respectively. We train each branch for 50 epochs. The learning rate starts from 0.01 and is divided by 10 after 20 epochs. We use a weight decay of 0.0001 and a momentum of 0.9. During validation and testing, we also resize the image to

, and then perform center cropping to obtain an image of size . Except in Section IV-D, we set

to 0.7 which yields the best performance on the validation set. We implement AG-CNN with the Pytorch framework

[40].

Iv-C Evaluation

We evaluate our method on the ChestX-ray14 dataset. Mostly, ResNet-50 [18] is used as backbone, but the AUC and ROC curve obtained by DenseNet-121 [20] are also presented.

Global branch (baseline) performance. We first report the performance of the baseline, i.e., the global branch. Results are summarized in Table. I, Fig. 5 and Fig. 9.

The average AUC across the 14 thorax diseases arrives at 0.841 and 0.840, using ResNet-50 and DenseNet-121, respectively. For both backbone networks, this is a competitive accuracy compared with the previous state of the art. Except Herina, the AUC scores of the other 13 pathologies are very close to or even higher than [8]. Moreover, we observe that Infiltration has the lower recognition accuracy (0.728 and 0.717 for ResNet-50 and DenseNet-121). This is because the diagnosis of Infiltration mainly relies on the texture change among the lung area, which is challenging to recognize. The disease Cardiomegaly achieves higher recognition accuracy (0.904 and 0.912 for ResNet-50 and DenseNet-121, respectively), which is characterized by the relative solid region (heart).

Performance of the local branch. The local branch is trained on the cropped and resized lesion patches, which is supposed to provide attention mechanisms complementary to the global branch. The performance of the local branch is demonstrated in Table. I, Fig. 5 and Fig. 9 as well.

Using ResNet-50 and DenseNet-121, the average AUC score is 0.817 and 0.810, respectively, which is higher than [9, 7]

. Despite of being competitive, the local branch yields lower accuracy than the global branch. For example, when using ResNet-50, the performance gap is 2.4% (0.841 to 0.817). The probable reason for this observation is that the lesion region estimation and cropping process may lead to information loss which is critical for recognition. So the local branch may suffer from inaccurate estimation of the attention area.

Among the 14 classes, the largest performance drop is observed at “Pneumonia” (0.067). The reason for the inferior performance at “Pneumonia” is probably that lots of information are lost. Generally, the area where the lung is inflamed is relative large and its corresponding attention heat map shows a scattered distribution. With a higher value of , only a very small patch is cropped in original image. For the classes “Hernia” and “Consolidation”, the local branch and global branch yield very similar accuracy. We speculate that the cropped local patch is consist with the lesion area in the global image.

Effectiveness of fusing global and local branches. In Table. I, Fig. 5, and Fig. 6, we illustrate the effectiveness of the fusion branch, which yields the final classification results of our model. Table. I shows AUC of AG-CNN over 14 classes. The observations are consistent across different categories and the two backbones. Fig. 5 presents the ROC curve of three branches for each pathologies which illustrates that fusing global and local branches can improve both of them obviously. We presents the ROC curves of 14 pathologies with these two backbones in Fig. 6. It shows the highly consistency which demonstrate that AG-CNN is not sensitive to network architecture of backbone.

For both ResNet-50 and DenseNet-121, the fusion branch, i.e., AG-CNN, outperforms both the global branch and local branch. For example, when using ResNet-50, the performance gap from AG-CNN to the global and local branches is 0.027 and 0.051, respectively. Specifically AG-CNN (with DenseNet-121 as backbone) surpasses the global and local branches for all 14 pathologies.

The advantage of AG-CNN is consistent across the categories. Using ResNet-50 for example, the largest improvement (0.047) is observed at the class “Nodule”, the disease of which is featured by small lesion areas (see Fig. 4). In fact, under such circumstances, the global branch can be largely affected by the noise within the non-disease areas. By paying attention on the small yet focused lesion areas, our method effectively improves the classification performance of Nodule. On the other hand, we also notice that under the class Pneumonia, AG-CNN is inferior to the global branch, a consistent observation made with the local branch: the local branch is the least effective at this class. Some classification results are presented in Fig. 7.

Another experiment, inputing a global image into both global and local branch, is conducted to verify the effectiveness of fusing global and local cues. The same experimental settings with Section IV-B are performed expect that the mini-batchsize is 64 in training. Three branches are trained together with ResNet-50 as backbone. The average AUC of global, local and fusion branches achieve to 0.845, 0.846 and 0.851, respectively. The performance is lower 0.017 compared with inputing a local patch into local branch. The results show that AG-CNN is superior than both global and local branch. In particular, the improvement is benefit from the local discriminative region instead of increasing the number of parameters.

Comparison with the state of the art. We compare our results with the state-of-the-art methods [9, 19, 7, 8] on the ChestX-ray14 dataset. Wang et al. [9] classify and localize the thorax disease in a unified weakly supervised framework. This localization method actually compromises the classification accuracy. The reported results from Yao et al. [19] are based on the model in which labels are considered independent.

Kumar et al. [7] try different boosting methods and cascade the previous classification results for multi-label classification. The accuracy of the previous step directly influences the result of the following pathologies.

Comparing with these methods, this paper contributes new state of the art to the community: average AUC = 0.871. AG-CNN exceeds the previous state of the art [8] by 2.9%. AUC scores of pathologies such as Cardiomegaly and Infltration are higher than [8] by about 0.03. AUC scores of Mass, Fibrosis and Consolidation surpass [8] by about 0.05. Furthermore, we train AG-CNN with 70% of the dataset, but 80% are used in [7, 8]. In nearly all the 14 classes, our method yields best performance. Only Rajpurkar et al. [8] report higher accuracy on Hernia. In all, the classification accuracy reported in this paper compares favorably against previous art.

Variant of training strategy analysis. Training three branches with different orders influences the performance of AG-CNN. We perform 4 orders to train AG-CNN: 1) train global branch first, and then local and fusion branch together (G_LF); 2) train global and local branch together, and then fusion branch (GL_F); 3) train three branches together (GLF); 4) train global, local and fusion branch sequentially (G_L_F). Note that G_L_F is our three-stage training strategy. We limit the batchsize to 64 for training two or three branches together, such as GL_F and GLF. And if the global branch is trained first, the batchsize of each branch is set to 128, 64 and 64, respectively. The other experimental settings are same as Section  IV-B. We present the classification performance of these training strategies in Table. II.

AG-CNN yields better performance (0.868 and 0.854) with strategy of training three branches sequentially (G_L_F and G_L_F). When global branch is trained first, we perform the same model as the baseline in Table. I. Training with G_L_F, AG-CNN obviously improves the baseline from 0.841 to 0.868. AG-CNN (G_L_F) performs a overall fine-tuning when we train the fusion branch. It improves the global branch performance to 0.852, but not the local and fusion branches. Compared with G_L_F and G_L_F, performance of AG-CNN (G_LF) is much lower because its the inaccuracy of local branch. When AG-CNN is trained with GL_F and GLF, it is inferior to G_L_F or G_L_F. We infer that local branch is essential to enhance AG-CNN performance.

Strategy Batchsize Global Local Fusion
GL_F 64/64/64 0.831 0.800 0.833
GLF 64/64/64 0.847 0.815 0.849
G_LF 128/64/64 0.841 0.809 0.843
G_L_F 128/64/64 0.852 0.819 0.854
G_L_F 128/64/64 0.841 0.817 0.868
  • represents that the parameters in global and local branch are fine-tuned when we train the fusion branch. ResNet-50 is used as backbone.)

TABLE II: Results of different training strategies.
Statistic Global Local Fusion
Max 0.8412 0.8171 0.8680
L1 0.8412 0.8210 0.8681
L2 0.8412 0.8213 0.8672
  • ResNet-50 is used as backbone.

TABLE III: Results corresponding different statistics.

Variant of heat map analysis. In Table. III, we report the performance of using different heat map computing methods. Based on the same baseline, the local branch produce a gap of 0.0042 between Max and L2, but only 0.008 in fusion branch. Max and L1 achieve very close performance on both the local and fusion branch. It illustrates that different statistics result in subtle differences in local branch, but not effect the classification performance significantly.

Iv-D Parameter Analysis

We analyze the sensitivity of AG-CNN to parameter variations. The key parameter of AG-CNN consists in in Eq. 4, which defines the local regions and affects the classification accuracy. Fig. 8 shows the average AUC of AG-CNN over different on validation set. AG-CNN achieves the best performance when is setting as 0.7. Therefore, we report the results on test set with .

Fig. 9 compares the average AUC of the global, local branch and fusion branch on the test dataset when ResNet-50 is used as basic network. changes from 0.1 to 0.9. When is small (e.g. , close to 0), the local region is close to the global image. For example, when , the average AUC of the local branch (0.828) is close to the result of the global branch (0.841). In such cases, most of the entries in the attention heat map are preserved, indicating that the cropped image patches are close to the original input. On the other hand, while reaches to 1, e.g., 0.9, the local branch is inferior to the global branch by a large margin (0.9%). Under this circumstance, most of the information in the global image is discarded but only the top 10% largest values in the attention heat map are retained. The cropped image patches reflect very small regions.

Unlike the local branch, AG-CNN is relative stable to changes of the threshold . When concentrating the global and local branches, AG-CNN outperforms both branches by at least 1.7% at and . AG-CNN exhibits the highest AUC (0.866) when ranges between [0.6, 0.8].

V Conclusion

In this paper, we propose an attention guided two-branch convolutional neural network for thorax disease classification. The proposed network is trained by considering both the global and local cues informed in the global and local branches, respectively. Departing from previous works which merely rely on the global information, it uses attention heat maps to mask the important regions which are used to train the local branch. Extensive experiments demonstrate that combining both global and local cues yields state-of-the-art accuracy on the ChestX-ray14 dataset. We also demonstrate that our method is relatively insensitive to parameter changes.

In the future research, we will continue the study from two directions. First, we will investigate more accurate localization of the lesion areas. Second, to tackle with the difficulties in sample collection and annotation, semi-supervised learning methods will be explored.

References