Adversarial Counterfactual Augmentation: Application in Alzheimer's Disease Classification

03/15/2022
by   Tian Xia, et al.
3

Data augmentation has been widely used in deep learning to reduce over-fitting and improve the robustness of models. However, traditional data augmentation techniques, e.g., rotation, cropping, flipping, etc., do not consider semantic transformations, e.g., changing the age of a brain image. Previous works tried to achieve semantic augmentation by generating counterfactuals, but they focused on how to train deep generative models and randomly created counterfactuals with the generative models without considering which counterfactuals are most effective for improving downstream training. Different from these approaches, in this work, we propose a novel adversarial counterfactual augmentation scheme that aims to find the most effective counterfactuals to improve downstream tasks with a pre-trained generative model. Specifically, we construct an adversarial game where we update the input conditional factor of the generator and the downstream classifier with gradient backpropagation alternatively and iteratively. The key idea is to find conditional factors that can result in hard counterfactuals for the classifier. This can be viewed as finding the `weakness' of the classifier and purposely forcing it to overcome its weakness via the generative model. To demonstrate the effectiveness of the proposed approach, we validate the method with the classification of Alzheimer's Disease (AD) as the downstream task based on a pre-trained brain ageing synthesis model. We show the proposed approach improves test accuracy and can alleviate spurious correlations. Code will be released upon acceptance.

READ FULL TEXT

Please sign up or login with your details

Forgot password? Click here to reset