Adversarial Policy Gradient for Deep Learning Image Augmentation

by   Kaiyang Cheng, et al.
UC San Francisco
berkeley college

The use of semantic segmentation for masking and cropping input images has proven to be a significant aid in medical imaging classification tasks by decreasing the noise and variance of the training dataset. However, implementing this approach with classical methods is challenging: the cost of obtaining a dense segmentation is high, and the precise input area that is most crucial to the classification task is difficult to determine a-priori. We propose a novel joint-training deep reinforcement learning framework for image augmentation. A segmentation network, weakly supervised with policy gradient optimization, acts as an agent, and outputs masks as actions given samples as states, with the goal of maximizing reward signals from the classification network. In this way, the segmentation network learns to mask unimportant imaging features. Our method, Adversarial Policy Gradient Augmentation (APGA), shows promising results on Stanford's MURA dataset and on a hip fracture classification task with an increase in global accuracy of up to 7.33 improved performance over baseline methods in 9/10 tasks evaluated. We discuss the broad applicability of our joint training strategy to a variety of medical imaging tasks.


Adversarial Augmentation for Enhancing Classification of Mammography Images

Supervised deep learning relies on the assumption that enough training d...

Suggestive Annotation of Brain MR Images with Gradient-guided Sampling

Machine learning has been widely adopted for medical image analysis in r...

Domain Generalization on Medical Imaging Classification using Episodic Training with Task Augmentation

Medical imaging datasets usually exhibit domain shift due to the variati...

Blind Inpainting of Large-scale Masks of Thin Structures with Adversarial and Reinforcement Learning

Several imaging applications (vessels, retina, plant roots, road network...

Tubular Shape Aware Data Generation for Semantic Segmentation in Medical Imaging

Chest X-ray is one of the most widespread examinations of the human body...

Orthogonal Policy Gradient and Autonomous Driving Application

One less addressed issue of deep reinforcement learning is the lack of g...

A Positive/Unlabeled Approach for the Segmentation of Medical Sequences using Point-Wise Supervision

The ability to quickly annotate medical imaging data plays a critical ro...

1 Introduction

Convolutional neural networks (CNNs) have become an essential part of medical image acquisition, reconstruction, and post-processing pipelines, as this technology significantly improves our ability to detect, study, and predict diseases at scale. In computer-vision, CNNs have achieved above-human performance in natural-object classification tasks [8]. However, in medical imaging, where datasets are limited in size and labels are often uncertain, there is still a significant need for methods that maximize information gain while preventing overfitting. Our work presents a novel reinforcement-learning (RL) based image augmentation framework for medical image classification.

Training data image augmentation imposes image-level regularization on CNNs in order to combat overfitting. Augmentation can include the addition of noise, image transformations such as zooming or cropping, image occlusion, and attention masking [10]. The application of the first three methods is limited as they often rely on domain-knowledge to define the appropriate characteristics and severity of the augmentation. The last method requires dense segmentation masks for the region of interest (ROI). However, ROIs relevant to a classification task may not be known a-priori. For instance, when inspecting hip radiographs for bone fracture (Fx), the determination of Fx or no-Fx heavily depends on the location of abnormal signal within the bone and abnormalities from nearby tissues. Our RL image augmentation framework leverages an adversarial reward to weakly supervise a segmentation network and create these ROIs.

Through trial-and-error, reinforcement learning algorithms discover how to maximize their objective or reward R. The careful design of reward functions has enabled the application of RL to multiple medical tasks including landmark detection, automatic view planning, treatment planning, and MR/CT reconstruction. In this work, we present APGA, a joint, reinforcement learning strategy for training of a segmentation network and a classification network, that improves accuracy in the classification task.

2 Methodology

2.1 Improving Classification with Segmentation for Image Masking

The framework has two parts: a classification model with parameters and a segmentation model with parameters (policy) .

To mask out the image-level features that are less useful for the classification task, we use the segmentation model to produce the pixel-wise probability

of the pixel being useful. We zero out the pixels of the original image with , and use it for training the classification model, updating . With the end-goal of improving the classification performance, our method optimizes the segmentation model to evaluate the importance of each image pixel.

2.2 Policy Gradient Training

Following the policy gradient context [11], our segmentation model is seen as a policy in which the image batch is treated as a state , whereas the pixel-wise classification of useful image features is framed as an action . In practice, at each -th training step, the classification model receives the masked image as the input and outputs a reward signal . To accomplish this, our objective is to maximize the expected reward and find the optimal segmentation policy.


In eq. 1, is the expected reward with respect to the probability of taking an action , when the model has been parameterized with . The policy is learned through back-propagation, which requires the definition of the gradient of the expected reward with respect to the model parameters. Following the REINFORCE rule presented in [11], the gradient can be defined as


The expected reward cannot be estimated and requires approximation. As is common practice in the tractation of policy gradient, we can achieve such approximation using the negative log-likelihood loss, which is differentiable with respect to the model parameters, and can be properly weighted by the reward signal to obtain the segmentation policy loss presented in eq. 



where is the binary cross-entropy loss


which becomes eq. 5.


By using pixel-wise binary cross entropy, we can achieve preservation of spatial information of the deviation of from . We then update by computing . Consequently, the classification model parameters are updated with gradient descent by using the cross entropy loss between the classification of the masked image samples and the original target labels. In our experiments, we perform stochastic gradient update for both and at each batch step.

Figure 1: (top) Adversarial training of segmentation agent and (bottom) training of classification model with original and augmented data

2.3 Adversarial Reward

The design of the reward is crucial to the convergence of the segmentation model. Using the change in training loss as a reward, as is done in Neural Architecture Search [12], results in a weak reward signal hardly discernible from the expected changes in loss during training. Similarly, approximating rewards with a critic network introduces unnecessary overhead and slows down convergence. We propose a stable adversarial reward . Given pixel-wise feature importance probability , we zero out the pixels to mask-out the features predicted to be of high importance. The original and masked image batches are then fed as inputs to the classification model producing the losses and . The reward function is defined as:


To reduce the variance of training, a baseline , the exponential moving average of the reward, is included, similarly done in [12]. Intuitively, by erasing the important features we revert the problem and tend to maximize the gain in loss. However, we do not want the segmentation policy to erase all pixels in favor of a gain in , so we penalize the masking of all pixels. Given pixel-wise all zero feature importance and a weight for regularization, the final loss is defined as:


The resulting reward signal is strongly related to mask quality, rather than reflecting the stochasticity in training of the classification network.

Input: Training steps (); training samples and labels (,

); classification and segmentation loss functions (

, ); learning rates (, ); classification and segmentation models (, ) parameterized with (, ).
Output: and

1:  Initialize models parameters and
2:  Train classification model on to convergence
3:  for t = 1,…, T do
4:     Sample a training batch () from the () pool
5:     Get classification loss
6:     Perform gradient descent update
7:     Get adversarial action probabilities from the seg network
8:     Calculate masks of important features
9:     Produce masks from segmentation networks and erase the predicted features
10:     Get adversarial loss
11:     Calculate reward
12:     Calculate distances of action probabilities from actions taken
13:     Calculate distances of action probabilities from the opposite of undesirable extreme actions (masking all pixels)
14:     Calculate adversarial policy gradient loss
15:     Update the segmentation network
16:     Produce aiding actions from segmentation network
17:     Update with the aiding masks
18:  end for
Algorithm 1 Adversarial Policy Gradient Augmentation (APGA)

3 Experiments and Results

We evaluate our methodology on MURA [7]

and an internal hip fracture dataset, using the same experimental setup, including network architectures, RL framework, and training hyperparameters. A DenseNet-169 


pretrained on ImageNet 

[2] serves as the base classification model. A TernausNet [5], pretrained on Carvana Image Masking Challenge [1], serves as the segmentation model. Masked images are used as augmentation in a ratio with original images to train the classification network. Images resized to , batch size 25. Adam optimizers [6] with a initial learning rate of 0.0001 are used for the classification and the segmentation model. The exponential average baseline has a decay rate of 0.5. is set to 0.1. Training of APGA converges within 30 minutes on a single Nvidia TitanX GPU. Source code available at

3.0.1 Baselines:

We benchmark APGA using a DenseNet-169 [4]classifier trained (1) without data augmentation, (2) with cutout [3] augmentation in a 1:1 ratio, and with (3) GradCam [9] derived masks augmentation also in a 1:1 ratio. Cutout augmentation masks out randomly sized patches of the input image while GradCam masks are produced by discretizing the probability saliency map from the DenseNet trained without data augmentation. For further comparison, a segmentation and classification network are trained end-to-end, by propagating the gradient from the classification loss function through the segmentation network and applying the discretized masks from the segmentation network in the same update step. Additionally, regularization terms and , and , are added to the loss function to prevent all or none masking behavior. However, end-to-end training was unstable, and the segmentation network produced all-one or all-zero masks, despite tuning of and . Therefore, these results were omitted. At its best, the end-to-end network produced all-one masks and performed the same as the DenseNet trained without augmentation.

3.1 Binary Classification: MURA

The MURA [7] dataset contains 14,863 musculoskeletal studies of elbows, finger, forearm, hand, humerus, shoulder, and wrist, which contains 9,045 normal and 5,818 abnormal labeled cases. We train the methods on the public training set and evaluate on the validation set, with global accuracy as the metric. We train and evaluate separate models on each body part, and train a single model on a random sample of 100 training images per class to test the performance of our method under extreme data constraints. The performance on the validation set is presented in Table 2, Table 2

as average and standard deviation of 5 random seeds.



+ cutout


+ GradCam




Table 2: 100-shot results (validation accuracy) on Elbow in MURA.


+ cutout


+ GradCam



Table 1: Classification results (validation accuracy) on MURA.
Figure 2: Example X-Rays (top) and masks created by APGA (middle) and DenseNet+GradCam (bottom) for hip, hand, and elbow.

3.2 Multi-class Classification: Hip Fracture

The Hip Fracture dataset contains 1118 studies with an average patient age of 74.6 years (standard deviation 17.3), and a female:male ratio. Each study includes a pelvic radiograph, labeled as 1 of 6 classes: No fracture, Intertrochanteric fracture, Displaced femoral neck fracture, Non-displaced femoral neck fracture, Arthroplasty, or ORIF (previous internal fixation). Bounding boxes are manually drawn on each study, resulting in 3034 bounded hips. The images are split by accession number into train:valid:test using a split, ensuring no overlap in patients between any of the sets. We train and evaluate separate models on the whole pelvic radiographs and the bounded hip radiographs. Per-image accuracy is used as the metric. The performance on the validation and test set is shown in Table 3.

 DenseNet [4]


+ cutout


+ GradCam



Whole Pelvis (val)
Whole Pelvis (test)
Bounded Hip (val)
Bounded Hip (test)
Table 3: Classification results (validation and test accuracy) on Hip Fx Dataset.

3.2.1 Results:

Compared to the baseline, our method achieved higher global accuracy in 9 out of 10 tasks including binary (MURA Table 2) and multi-class (hip Fx Table 3) classification tasks. On average, our method improved MURA validation accuracy by 1.56% and hip validation and testing accuracy by 0.78% and 1.72% respectively. The most significant improvement in accuracy over the baseline was 7.33% and it was achieved in a data-constrained condition, reported in Table 2. In this particular experiment, the elbow training data was limited to 100 samples per class. Overall, APGA outperformed baseline methods in 9 out of 10 tasks, and consistently provided higher testing results. Example segmentation masks from the weakly supervised network are shown in Fig 2. APGA learns to ignore unimportant features in the radiographs, such as anatomy irrelevant to the classification task. APGA masking appears more exploratory in nature compared to saliency based attention masking (DenseNet + GradCam), which contains biases from the converged model.

4 Discussions and Conclusions

We propose a framework, APGA, for producing segmentations to aid medical image classification in a reinforcement learning setting. This framework requires no manual segmentation, which has the benefit of scalability and generalizability. The system is trained online with the goal of improving the performance of the main task, classification. If no improvement is seen, this can be a check for the assumption that masking based augmentation would aid classification, before pursuing more manual work. Marginal improvements should be evidence that APGA has the potential to add valuable information to the training process. The computational overhead in training is justified by those added benefits, and could be eliminated during inference, as the segmentation network can also be used as an inference augmentation technique. This general reinforcement learning with adversarial reward framework could easily be adopted for other medical imaging tasks, involving regression, and segmentation, with different aiding methods, such as bounding box detection, image distortion, and image generation. The reinforcement guided data augmentation has more generalizability compared to traditional data augmentation based on domain knowledge.


  • [1] Carvana Image Masking Challenge. External Links: Link Cited by: §3.
  • [2] J. Deng, W. Dong, R. Socher, L. Li, K. Li, and L. Fei-Fei ImageNet: A Large-Scale Hierarchical Image Database. pp. 8 (en). Cited by: §3.
  • [3] T. DeVries and G. W. Taylor (2017-08) Improved Regularization of Convolutional Neural Networks with Cutout. arXiv:1708.04552 [cs]. Note: arXiv: 1708.04552 External Links: Link Cited by: §3.0.1.
  • [4] G. Huang, Z. Liu, L. van der Maaten, and K. Q. Weinberger (2016-08) Densely Connected Convolutional Networks. arXiv:1608.06993 [cs]. Note: arXiv: 1608.06993Comment: CVPR 2017 External Links: Link Cited by: §3.0.1, Table 3, §3.
  • [5] V. Iglovikov and A. Shvets (2018-01) TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation. arXiv:1801.05746 [cs]. Note: arXiv: 1801.05746Comment: 5 pages, 4 figures External Links: Link Cited by: §3.
  • [6] D. P. Kingma and J. Ba (2014-12) Adam: A Method for Stochastic Optimization. arXiv:1412.6980 [cs]. Note: arXiv: 1412.6980Comment: Published as a conference paper at the 3rd International Conference for Learning Representations, San Diego, 2015 External Links: Link Cited by: §3.
  • [7] P. Rajpurkar, J. Irvin, A. Bagul, D. Ding, T. Duan, H. Mehta, B. Yang, K. Zhu, D. Laird, R. L. Ball, C. Langlotz, K. Shpanskaya, M. P. Lungren, and A. Y. Ng (2017-12) MURA: Large Dataset for Abnormality Detection in Musculoskeletal Radiographs. arXiv:1712.06957 [physics]. Note: arXiv: 1712.06957Comment: 1st Conference on Medical Imaging with Deep Learning (MIDL 2018) External Links: Link Cited by: §3.1, §3.
  • [8] O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A. Khosla, M. Bernstein, A. C. Berg, and L. Fei-Fei (2014-09) ImageNet Large Scale Visual Recognition Challenge. arXiv:1409.0575 [cs]. Note: arXiv: 1409.0575Comment: 43 pages, 16 figures. v3 includes additional comparisons with PASCAL VOC (per-category comparisons in Table 3, distribution of localization difficulty in Fig 16), a list of queries used for obtaining object detection images (Appendix C), and some additional references External Links: Link Cited by: §1.
  • [9] R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra (2016-10) Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. arXiv:1610.02391 [cs]. Note: arXiv: 1610.02391Comment: 24 pages, 22 figures. Adds bias experiments, and robustness to adversarial noise External Links: Link Cited by: §3.0.1.
  • [10] M. Wallenberg and P. Forssén (2017-09) Attentional masking for pre-trained deep networks. In 2017 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS), pp. 6149–6154. External Links: Document Cited by: §1.
  • [11] R. J. Williams (1992-05) Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine Learning 8 (3), pp. 229–256 (en). External Links: ISSN 1573-0565, Link, Document Cited by: §2.2.
  • [12] B. Zoph and Q. V. Le (2016-11) Neural Architecture Search with Reinforcement Learning. arXiv:1611.01578 [cs]. Note: arXiv: 1611.01578 External Links: Link Cited by: §2.3.