Dynamic memory to alleviate catastrophic forgetting in continuous learning settings

07/06/2020 ∙ by Johannes Hofmanninger, et al. ∙ MedUni Wien 0

In medical imaging, technical progress or changes in diagnostic procedures lead to a continuous change in image appearance. Scanner manufacturer, reconstruction kernel, dose, other protocol specific settings or administering of contrast agents are examples that influence image content independent of the scanned biology. Such domain and task shifts limit the applicability of machine learning algorithms in the clinical routine by rendering models obsolete over time. Here, we address the problem of data shifts in a continuous learning scenario by adapting a model to unseen variations in the source domain while counteracting catastrophic forgetting effects. Our method uses a dynamic memory to facilitate rehearsal of a diverse training data subset to mitigate forgetting. We evaluated our approach on routine clinical CT data obtained with two different scanner protocols and synthetic classification tasks. Experiments show that dynamic memory counters catastrophic forgetting in a setting with multiple data shifts without the necessity for explicit knowledge about when these shifts occur.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 3

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In clinical practice, medical images are produced with continuously changing policies, protocols, scanner hardware or settings resulting in different visual appearance of scans despite the same underlying biology. In most cases, such a shift in visual appearance is intuitive for clinicians and does not lessen their ability to assess a scan. However, the performance of machine-learning based methods can deteriorate significantly when the data distribution changes. Continuous learning adapts a model to changing data and new tasks by sequentially updating the model on novel cases. The ground truth labels of these cases can be, for instance, acquired by corrective labelling. However, catastrophic forgetting is a major undesired phenomenon affecting continuous learning approaches [9]. That is, when a model is continuously updated on a new task or a different data distribution, the model performance will deteriorate on the preceding tasks. Alleviating catastrophic forgetting is one of the major challenges in continuous learning.

Here, we propose an approach for a scenario where new data are sequentially available for model training. We aim to utilize such a data stream to frequently update an existing model without the requirement of keeping the entire data available. We assume a real-world clinical setting where changing visual domains and appearance of classification targets can occur gradually over time and where the information about such an eventual shift in data is not available to the continuous learning system. Figure 1 illustrates the scenario and the experimental setup. A trained model (base-model) has been trained to perform well on a certain classification task. Subsequently, a continuous data stream is used to update the model with the aim to learn variations of the initial task given new data. This model should become accurate on new data, while at the same time staying accurate on data generated by previous technology. Accordingly, the final model is then evaluated on all tasks to assess the effect of catastrophic forgetting and the classification performance. Note that here we use the term task to denote the detection of the same target but on shifted visual domains and not for additional target classes.

Related Work

There are various groups of methods dealing with the problem of continuously updating a machine learning model over time such as continuous learning, continuous domain adaptation or active learning. They operate on similar but different assumptions about the problem settings, the data available and the level of supervision required. For example, continuous domain adaption assumes novel data to be shifted but closely related to the previous data. Continuous learning makes no assumptions about domain shifts and is not limited to a specific task to perform in new data (e.g. incremental learning). Active learning is characterized by the task of automatically selecting examples for which supervision is beneficial. In this work, we propose a continuous learning technique which can also be categorized as a supervised continuous domain adaptation method.

Various methods for continuous learning have been proposed to alleviate catastrophic forgetting in scenarios where multiple tasks are learned sequentially.

A popular method to retain previous knowledge in a network is elastic weight consolidation (EWC) [7]. EWC is a regularization technique aiming to constrain parameters of the model which are critical for performing previous tasks during the training of new tasks. Alternative methods attempt to overcome catastrophic forgetting by rehearsing past examples [8] or proxy information (pseudorehearsal) [10] when new information is added [11]. In the field of medical imaging, continuous learning has been demonstrated to reduce catastrophic forgetting on segmentation and classification tasks. Karani et al. proposed domain-specific batch norm layers to adapt to new domains (different MR protocols) while learning segmentations of various brain regions [5]. Baweja et al. applied EWC to sequentially learn normal brain structure and white matter lesion segmentation [1]. Ravishankar et al. propose a pseudorehearsal technique and training of task-specific dense layers for pneumothorax classification [10]. These current approaches expect that the information about the domain to which a training example belongs, is available to the learning system. In real world medical imaging data, such information may not be available at the image level (e.g. a change in treatment policies, or different hardware updates across departments). At the same time, changes in protocol or scanner manufacturer may not automatically lead to a loss of performance of the model and considering each protocol as a novel domain may lead do adverse effects such as overfitting.

Contribution

We propose an approach for continuous learning of continuously or repeatedly shifting domains. This is in contrast to most previous methods treating updates as sequentially adding well specified tasks. In contrast to existing rehearsal methods, we propose a technique that automatically infers data shifts without explicit knowledge about them. To this end, the method maintains a diverse memory of previously seen examples that is dynamically updated based on high level representations in the network.

Figure 1: Experimental setup for a continuous learning scenario:

We assume a learning scenario for which a conventionally trained model (e.g. multi-epoch training) performing well on Task A is available. This model is continuously updated on a data stream with a shift in overall image appearance caused by scanner parameters (modality-shift) and a shift in the appearance of the classification target (target-shift). The timing of shifts is not known a priori. For evaluation, the final model is evaluated on a test set of all three tasks.

2 Method

We continuously update the parameters of an already trained model with new training data. Our approach composes this training data to capture novel data characteristics while sustaining the diversity of the overall training corpus. It chooses examples from previously seen data (dynamic memory ) and new examples (input-mini-batch ) to form the training data (training-mini-batch ) for the model update.

Our approach is a rehearsal method to counter catastrophic forgetting in continuous learning. We adopt a dynamic memory (DM)

(1)

of a fixed-size holding image-label pairs that are stored and eventually replaced during continuous training. To alleviate catastrophic forgetting, a subset of cases of is used for rehearsal during every update step. It is critical that the diversity of is representative of the visual variation across all tasks, even without explicit knowledge about the task membership of training examples. As the size of is fixed, the most critical step of such an approach is to decide which samples to keep in and which to replace with a new sample. To this end, we define a memory update strategy based on following rules: (1) every novel case will be stored in the memory, (2) a novel case can only replace a case in memory of the same class and (3) the case in memory that will be replaced is close according to a high level metric. Rule 1 allows the memory to dynamically adapt to changing variation. Rule 2 prevents class imbalance in the memory and rule 3 prevents the replacement of previous cases if they are visually distant. The metric used in rule 3 is critical as it ensures that cases of previous tasks are kept in memory and not fully replaced over time. We define a high-level metric based on the gram matrix where is the number of feature maps in layer .

is defined as the inner product between the vectorized activations

and of two feature maps and in a layer given a sample image :

(2)

where denotes the number of elements in the vectorized feature map (width height). For a set of convolution layers we define a gram distance between two images and as:

(3)

The rationale behind using the gram matrix is the fact, that the gram matrix encodes high level style information. Here, we are interested in this style information to maintain a diverse memory not only with respect to the content but also with respect to different visual appearances. Similar gram distances have been used in computer vision methods in the area of neural style transfer as a way to compare the style of two natural images

[2].

During continuous training, a memory update is performed after an input-mini-batch of sequential cases (image and label ) is taken from the data stream. Sequentially, each element of replaces an element of . More formally, given an input sample , the sample will replace the element in with index

(4)

During the initial phase of continuous training, the memory is filled with elements of the data stream. Only after the desired proportion of a class in the memory is reached, the replacement strategy is applied. After the memory is updated, a model update is done by assembling a training-mini-batch of size . Each misclassified element of for which the model predicted the wrong label is added to and additional cases are randomly drawn from until . Finally, using the training-mini-batch , a forward and backward pass is performed to update the parameters of the model.

3 Experiments and Results

We evaluated and studied the DM method in a realistic setting using medical images from clinical routine. To this end, we collected a representative dataset described in Section 3.1

. Based on these data, we designed a learning scenario with three tasks (A, B and C) in which a classifier pre-trained on task A is continuously updated with changing input data over time. Figure

1 illustrates the learning scenario, the data, and the experimental setup. Within the continuous dataset, we created two shifts, (1) a modality-shift between scanner protocols and (2) a target-shift by changing the target structure from high to low intensity.

Task A Task B Task C Total
Protocol B3/3 B6/1 B6/1
Target low how high
Base 1513 0 0 1513
Continuous 1513 1000 2398 4911
Validation 377 424 424 1225
Test 381 427 426 1234
Table 1: Data: Splitting of the data into a base, continuous, validation, and test set. The number of cases in each split are shown.

3.1 Dataset

In total, we collected 8883 chest CT scans from individual studies and for each extracted an axial slice at the center of the segmented [4] lung. Each scan was performed on a Siemens scanner with either B3 reconstruction kernel and 3mm slice-thickness (B3/3) or B6 reconstruction kernel and 1mm slice-thickness (B6/1). We collected 3784 cases with B3/3 protocol and 5099 cases with the B6/1 protocol. We imprinted a synthetic target structure in the form of a cat on random locations, rotations and varying scale at 50% of the cases (see also Figure 1). The high-intensity target structures were engraved by randomly adding an offset between 200 and 400 hounsfield units (HU) and the low-intensity target structures by subtracting between 200 and 400 HU. A synthetic target was chosen to facilitate data set collection, as well as to create a dataset without label noise. Table 1 lists the data collected and shows the partitioning into base, continuous, validation and test split and the stratification into the three tasks.

3.2 Experiments

We created the base model by fine-tuning a pre-trained Res-Net50 [3]

model, as provided by the pytorch-torchvision library

111https://pytorch.org on the base training set (Task A). Given this base model, we continuously updated the model parameters on the continuous training set using different strategies:

  • Naive: The naive approach serves as a baseline method by training sequentially on the data stream, without any specific strategy to counter catastrophic forgetting.

  • Elastic Weight Consolidation (EWC): As a reference method, we used EWC as presented in [7]. EWC regularizes weights that are critical for previous tasks based on the fisher information. We calculated the fisher information after training on the base set (Task A) so that in further updates, weights that are important for this task are regularized.

  • EWC-fBN

    : In preliminary experiments, we found, that EWC is not suitable for networks using batch norm layers in scenarios where a modality-shift occurs. The reason for that is, that the regularization of EWC does not effect batch norm parameters (mean and variance) and that these parameters are constantly updated by novel data. Thus, to quantify this effect, we show results where we fixed the batch norm layers once the base training was completed.

  • Dynamic Memory (DM): The method as described in this paper. The input-mini-batch size and training-mini-batch size have been set to 8 for all experiments. If not stated differently, results have been computed with a memory size of 32. The gram matrices have been calculated on feature maps throughout the network covering multiple scales. Specifically, we used the feature maps of the last convolution layers before the number of channels is increased. That is, for Res-Net50 we calculate the matrices on four maps with 256, 512, 1024 and 2048 features.

In addition to the continuous scenario we trained an upper bound network in a conventional, epoch-based way, using all training data at once (full training). All training processes were run five times to assess the variability and robustness and the results were averaged. All models were trained using an Adam optimizer [6]

and binary cross entropy as a loss function for the classification task.


With the described methods we,

  • studied the dynamics during training by calculating the classification accuracies on the validation set every 30 iterations.

  • calculated quantitative results after training on the test set for each task separately. Classification accuracy, backward transfer (BWT) and forward transfer (FWT), as defined in [8] were used. BWT and FWT measure the influence that learning a new task has on the performance of previously learned, respectively future tasks. Thus, larger BWT and FWT values are preferable. Specifically, negative BWT quantifies catastrophic forgetting.

  • studied the influence of memory size on the performance of our method during the training process and after training. We included experiments with set to 16, 32, 64, 80, 128 and 160.

3.3 Results

Figure 2: Validation accuracy during training: The curves show the change of validation accuracy for each of the tested approaches. Accuracies are computed on the validation sets for the three tasks during training. The last row show how the composition of the training data stream is changing over time.

Dynamics during training are shown in Figure 2. As expected, the naive approach shows catastrophic forgetting for task A (drop of accuracy from 0.91 to 0.51) and B (drop from 0.93 to 0.70) after new tasks are introduced. Without any method to counteract forgetting, knowledge of previous tasks is lost after sequential updates to the model. EWC exhibits catastrophic forgetting for task A (0.92 to 0.54) after the introduction of task B and C data into the data stream. The shift from task A to B is a modality shift, the batch norm layers of the network adapt to the second modality and knowledge about the first modality is lost. Although EWC protected the weights that are important for task A, those weights were not useful after the batch norm layers were adapted. EWC-fBN avoids this problem by fixing the batch norm layers together with the weights that are relevant for task A. In this setting a forgetting effect for task B (accuracy drop from 0.90 to 0.70) can be observed after the target-shift to task C. This effect is due to the requirement of EWC to know when shifts occur. Thus, in our scenario, EWC only regularizes weights that are important for task A. As described previously, DM does not have this requirement by dynamically adapting to changing data. Using the DM approach only mild forgetting occurs and all three tasks reach a comparable performance of about 0.85 accuracy after training.

ACC Task A ACC Task B ACC Task C BWT FWT
Naive
EWC
EWC-fBN
DM (Ours)
Full training - -
Table 2: Accuracy and BWT/FWT values for our dynamic memory (DM) method compared to baseline methods. Results were calculated on the test set after continuous training. Lower values marked with indicate forgetting.

Quantitative results are shown in Table 2. The large negative BWT values for the naive approach (-0.32) and EWC (-0.20) indicate that these methods suffer from catastrophic forgetting. Using EWC-fBN mitigates the forgetting for task A, but the model is forgetting part of the knowledge for task B when task C is introduced (observable in Figure 2). Both, DM and EWC-fBN show comparable backward and forward transfer capabilities. The accuracy values in Table 2 show that DM performs equally well on all tasks, while the other approaches show signs of forgetting of task A (naive and EWC) and task B (naive and EWC-fBN).

(a)
Task A Task B Task C Avg
16 0.73 0.83 0.96 0.84
32 0.82 0.86 0.93 0.87
64 0.85 0.81 0.87 0.84
80 0.86 0.82 0.87 0.85
128 0.87 0.75 0.71 0.78
160 0.88 0.79 0.69 0.79
(b)
Figure 5: Memory size: (a) The effect of memory size during training with DM. Small memory sizes tend to catastrophic forgetting, while large memory sizes lead to slow training of later tasks. (b) Classification accuracy on the validation set after training for varying memory sizes.

The influence of memory size during training is shown in Figure (a)a. For a small

of 16, adapting to new tasks is fast. However, such a limited memory can only store a limited data variability leading to catastrophic forgetting effects. Increasing memory size decreases catastrophic forgetting effects but slows adaption of new tasks. For the two largest investigated sizes 128 and 160 adaption, especially after the target-shift (Task C), is slower. The reason is, that more elements of task A are stored in the memory and more iterations are needed to fill the memory with samples of task C. This reduces the probability that elements from task C are drawn and in turn, slows down training for task C. For the same reason, with higher memory sizes there is an increase in task A accuracy, since examples from the first task are more often seen by the network. Results indicate that setting the memory size is a trade-off between faster adaption to novel data and more catastrophic forgetting. In our setting, training worked comparably well on

of 32, 64, and 80. Considering memory sizes between 16 and 160, setting resulted in the highest average accuracy of 0.87 on the validation set (Fig. (b)b). Therefore, we used this model for comparison to other continuous learning methods on the test set.

4 Conclusion

Here, we presented a continuous learning approach to deal with modality- and tasks-shifts induced by changes in protocols, parameter setting or different scanners in a clinical setting. We showed that maintaining a memory of diverse training samples mitigates catastrophic forgetting of a deep learning classification model. We proposed a memory update strategy that is able to automatically handle shifts in the data distribution without explicit information about domain membership and the moment such a shift occurs.

Acknowledgments

This work was supported by Austrian Science Fund (FWF) I 2714B31 and Novartis Pharmaceuticals Corporation.

References

  • [1] C. Baweja, B. Glocker, and K. Kamnitsas (2018) Towards continual learning in medical imaging. External Links: Link Cited by: §1.
  • [2] L. Gatys, A. Ecker, and M. Bethge (2016) A Neural Algorithm of Artistic Style. Journal of Vision 16 (12), pp. 326. External Links: Document, ISSN 1534-7362 Cited by: §2.
  • [3] K. He, X. Zhang, S. Ren, and J. Sun (2015) Deep Residual Learning for Image Recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    ,
    External Links: Link Cited by: §3.2.
  • [4] J. Hofmanninger, F. Prayer, J. Pan, S. Rohrich, H. Prosch, and G. Langs (2020) Automatic lung segmentation in routine imaging is a data diversity problem, not a methodology problem. External Links: Link Cited by: §3.1.
  • [5] N. Karani, K. Chaitanya, C. Baumgartner, and E. Konukoglu (2018) A lifelong learning approach to brain MR segmentation across scanners and protocols. In Medical Image Computing and Computer-Assisted Intervention (MICCAI), Vol. 11070 LNCS, pp. 476–484. External Links: ISBN 9783030009274, Document, ISSN 16113349 Cited by: §1.
  • [6] D. P. Kingma and J. L. Ba (2015) Adam: A method for stochastic optimization. In 3rd International Conference on Learning Representations, ICLR 2015, Cited by: §3.2.
  • [7] J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, D. Hassabis, C. Clopath, D. Kumaran, and R. Hadsell (2017)

    Overcoming catastrophic forgetting in neural networks

    .
    Proceedings of the National Academy of Sciences of the United States of America 114 (13), pp. 3521–3526. External Links: Document, ISSN 10916490 Cited by: §1, 2nd item.
  • [8] D. Lopez-Paz and M. Ranzato (2017) Gradient episodic memory for continual learning. Advances in Neural Information Processing Systems, pp. 6468–6477. External Links: ISSN 10495258 Cited by: §1, 2nd item.
  • [9] M. McCloskey and N. J. Cohen (1989) Catastrophic Interference in Connectionist Networks: The Sequential Learning Problem. Psychology of Learning and Motivation - Advances in Research and Theory 24 (C), pp. 109–165. External Links: Document, ISSN 00797421 Cited by: §1.
  • [10] H. Ravishankar, R. Venkataramani, S. Anamandra, P. Sudhakar, and P. Annangi (2019-10) Feature Transformers: Privacy Preserving Lifelong Learners for Medical Imaging. In Medical Image Computing and Computer Assisted Intervention (MICCAI), pp. 347–355. External Links: Document Cited by: §1.
  • [11] A. Robins (1995) Catastrophic Forgetting, Rehearsal and Pseudorehearsal. Connection Science 7 (2), pp. 123–146. External Links: Document, ISSN 0954-0091 Cited by: §1.