In many real-world scenarios, the underlying process of data stream is non-stationary. The performance of neural networks may be decreased when the source distribution of the input changes. Even worse, these changes may lead to catastrophic forgetting problem. An important issue for neural networks is the ability to preserve previously learned information and to adapt faster to the changes in dynamic environment. We aim to solve the problem of forgetting previously learned information under data drifts. According to the Bayesian Decision Theory duda2001pattern
a classification problem can be defined as maximising the posterior probability of, where represents the classes of data (). We can view the drifts into two types tsymbal2004problem ; vzliobaite2010learning : real drift which refers to changes in which means the data distribution remains the same, but the class of the data changes; virtual drift which refers to the changes in in which the class of the data remains the same, but the distribution of the data changes. Real drift and virtual drift need replacement learning and supplemental learning respectivelyelwell2011incremental . In real world, we cannot choose the learning environment, a robust model is needed to adapt to these drifts. We propose a model with a Kalman Filter modifier to adjust the learning parameters of the neural network models. Our experiments show that our purposed solutions adapts better and faster compared to the conventional neural network models in drifts environment.
2 Kalman Filter Modifier
The performance of an online learning model would decrease if the training data or its distribution changes. From the gradient point of view, the location of optima point will change. For example, if we take a single parameter as shown in Figure (1), the parameters in the pre-trained model have a minimum loss. However, the loss will be relatively high when the distribution of the data changes. In this case, the parameters of the online learning model will change to the new optima value. However, this changes could be a significant problem in training a consistent model when the data periodically goes back to the previous state or distributionkirkpatrick2017overcoming .
We address this problem by finding an optimal estimation between the data and the changes caused by the drifts. For this purpose we train a Kalman filter which acts as an optimal estimatorrhudy2017kalman . This method can infer parameters from uncertain an inaccurate observations.
In our approach, we use a mini-batch method to train the model. In Equation (1) is the state of model, the initial state represents the pre-trained model. Each state means model is training on batch of the new data. refers to the output model. From the gradient descent algorithm view, in the Equation (1), , is the learning rate, is the gradient of model on which refers to the batch of data, . Because we omit the process noise, the are 0. This is how the neural networks perform with gradient descent algorithm in a linear state perspective.
However, the model trained this way is not an optimal or accurate model. Based on the gradient algorithm, we can obtain the gradient of the pre-trained model which will indicate the measurement error of the current model. Using a Kalman Filter, we have the following:
We omit all the process noise, and assumes that the system is stable (the dynamic matrix is 0). In Equation (2), the first formula is state predict process, the second one is state error predict process. Where is the predicted model parameters at state, which we assumes it is stable with no additional information given, is the state error, which assumes it is stable as well.
3 Experiment Results
We train a fully-connected multi-layer neural networkschmidhuber2015deep
on three datasets sequentially. Within each task, the model is trained at fixed epochs and the training data will be no longer available to the model. We constructed a set of classification tasks based on the MNIST datasetlecun2010mnist (Figure 2). The data in the first task is the original MNIST dataset. In the second task we permute all the pixels of the images. This will require a completely different solution. The final task is related to the real drift problem. For this purpose, we change all the labels by adding 1 to the value of the label (e.g. if the image is 3, the label will be changed to 4). The results show that, no matter what the training dataset is, the Kalman filter will allow the online learning model to respond more efficiently to the changes and to maintain an overall better performance compared to a conventional model without any modifiers (see Figure 3).
We present a novel online learning method that responds to data drifts by using a Kalman Filter modifier. This addresses the forgetting problem for neural networks in non-stationary environments. Our proposed method does not require any changes in the architecture of neural network. We use the Kalman filter to adjust the learning parameters in changing environments. The Kalman Filter modifier takes the weights as the measurement value and the gradient as the measurement error. We demonstrate our approach using both virtual and real drifts and show that the proposed model will remember the previously learned information to adjust the online learning parameters. The method is characterised by an intrinsic recursive algorithm; so it does not need to access the previously seen data. Our evaluation results show that our approach performs better in responding to changes and has lower learning error compared with a conventional model. Our future work focus on improving the Kalman Filter and compare it with some advanced catastrophic forgetting methodologies in non-stationary environment.
This work is partially supported by the EU H2020 IoTCrawler project under contract number: 779852.
- (1) R. O. Duda, P. E. Hart, D. G. Stork et al., “Pattern classification,” J. of Computational Intelligence and Applications, vol. 1, pp. 335–339, 2001.
- (2) A. Tsymbal, “The problem of concept drift: definitions and related work,” Computer Science Department, Trinity College Dublin, vol. 106, no. 2, 2004.
- (3) I. Žliobaitė, “Learning under concept drift: an overview,” arXiv preprint arXiv:1010.4784, 2010.
- (4) R. Elwell and R. Polikar, “Incremental learning of concept drift in nonstationary environments,” IEEE Transactions on Neural Networks, vol. 22, no. 10, pp. 1517–1531, 2011.
- (5) J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska et al., “Overcoming catastrophic forgetting in neural networks,” Proceedings of the national academy of sciences, p. 201611835, 2017.
- (6) M. B. Rhudy, R. A. Salguero, and K. Holappa, “A kalman filtering tutorial for undergraduate students,” International Journal of Computer Science & Engineering Survey (IJCSES), vol. 8, pp. 1–18, 2017.
J. Schmidhuber, “Deep learning in neural networks: An overview,”Neural networks, vol. 61, pp. 85–117, 2015.
- (8) Y. LeCun, C. Cortes, and C. Burges, “Mnist handwritten digit database,” AT&T Labs [Online]. Available: http://yann. lecun. com/exdb/mnist, vol. 2, 2010.