Log In Sign Up

Reducing Catastrophic Forgetting in Modular Neural Networks by Dynamic Information Balancing

Lifelong learning is a very important step toward realizing robust autonomous artificial agents. Neural networks are the main engine of deep learning, which is the current state-of-the-art technique in formulating adaptive artificial intelligent systems. However, neural networks suffer from catastrophic forgetting when stressed with the challenge of continual learning. We investigate how to exploit modular topology in neural networks in order to dynamically balance the information load between different modules by routing inputs based on the information content in each module so that information interference is minimized. Our dynamic information balancing (DIB) technique adapts a reinforcement learning technique to guide the routing of different inputs based on a reward signal derived from a measure of the information load in each module. Our empirical results show that DIB combined with elastic weight consolidation (EWC) regularization outperforms models with similar capacity and EWC regularization across different task formulations and datasets.


page 1

page 2

page 3

page 4


Modular-Relatedness for Continual Learning

In this paper, we propose a continual learning (CL) technique that is be...

Routing Networks with Co-training for Continual Learning

The core challenge with continual learning is catastrophic forgetting, t...

Dissecting Catastrophic Forgetting in Continual Learning by Deep Visualization

Interpreting the behaviors of Deep Neural Networks (usually considered a...

Diffusion-based neuromodulation can eliminate catastrophic forgetting in simple neural networks

A long-term goal of AI is to produce agents that can learn a diversity o...

Overcoming Catastrophic Forgetting in Graph Neural Networks

Catastrophic forgetting refers to the tendency that a neural network "fo...

Natural continual learning: success is a journey, not (just) a destination

Biological agents are known to learn many different tasks over the cours...

Are Neural Nets Modular? Inspecting Functional Modularity Through Differentiable Weight Masks

Neural networks (NNs) whose subnetworks implement reusable functions are...

1 Introduction

Lifelong learning is a key trait that characterises humans and many other animal species in nature and is considered to give a very powerful evolutionary advantage in an ever changing ecosystem which is constantly challenging autonomous learning agents with new survival situations. However, the realization of continual lifelong learning in machine learning systems is still to date an open problem, deep learning systems being no exception.

Artificial Neural Networks (ANNs), the current workhorse in deep learning systems, are parameterised nonlinear models that are learned by optimizing some objective function through iterative techniques, mostly by stochastic gradient descent (SGD) and its variants. In its most common formulation, lifelong learning in ANNs focuses on learning the model on a sequence of tasks, where each task is defined by its own dataset. Despite the fact that ANNs are able to learn new tasks, their performance on previous tasks tends to degrade. The main underlying cause is the drifting of weights from the optimal point discovered by the learning algorithm in earlier tasks. The term catastrophic forgetting was forged to refer to this phenomenon.

Catastrophic forgetting was described very early in ANN history (McCloskey and Cohen, 1989; Ratcliff, 1990). The phenomenon is strongly related to the synaptic stability-plasticity dilemma (Abraham and Robins, 2005), which remains an open problem to date in Neuroscience. According to Hebb (1949)

, the learning process in neural circuits is mediated by a change in synaptic strength between neurons through potentiation and depression. This raises the problem of how new experiences are consolidated into synapses without erasing previous experiences, or in other words, how the learning process in the brain strikes the delicate equilibrium point between being plastic enough to allow continual learning, while maintaining sufficient stability so that previous stored information is not damaged. Two main variants of catastrophic forgetting were identified in artificial models: spatial and temporal

(Jacobs et al., 1991). Spatial interference happens when a given weight in the network receives conflicting error signals from its outputs. On the other hand, temporal interference is the conflict in the error signals due to different data samples. As far as lifelong learning is concerned, the most impactful factor consists of the temporal interference happening due to different task data distributions.

Five main general approaches were developed for reducing catastrophic forgetting in ANNs, namely, regularization, ensemble, replay, dual-memory and sparse-coding (Kemker et al., 2017). The first approach is based on model regularization, like elastic weight consolidation (EWC) (Kirkpatrick et al., 2016) and synaptic intelligence (SI) (Zenke et al., 2017). Given the previous task weights, the main idea is regularizing the weights, based on their importance for the previous task, to stay close to their previous optimal point, while allowing some flexibility to learn the new task. EWC uses Fisher information as a measure of weight importance, while SI uses the path integral of the gradient.

The second approach is ensemble methods (Polikar et al., 2001; Dai et al., 2007; Fernando et al., )

. The motivation behind ensemble methods consists of training multiple classifiers, with a new classifier assigned to each new task, and then integrating their predictions into a final output. The technique in its naive implementation suffers from exploding memory usage as more tasks are learned and inability to transfer learning between tasks, which motivated more research to mitigate these problems.

The third main approach is memory replay (Shin et al., ; van de Ven and Tolias, 2018; Isele and Cosgun, 2018). Memory replay is biologically plausible due to the similarity to the suggested hypothesis that the memory consolidating effect of the sleep phase in animals is mediated by the replaying of spike sequences learned during wakefulness (Wei et al., 2018). While regularization focuses on applying direct restriction on the flexibility of weights, memory replay focuses on injecting data samples from previous tasks in order to counteract the tendency of the network to completely shift its learned distribution to the new task, thus forgetting previous experiences. Memory replay techniques differ in the way they acquire previous data samples and how they introduce them into the learning process.

Dual-memory is the fourth approach, which is biologically inspired from memory models in mammalian brains (McClelland et al., 1995; Kumaran et al., 2016). The model assumes that two different networks are used for new and old memories, respectively, where newly formed memories are stored in one network and then slowly consolidated into the other network (French, 1997; Kamra et al., 2017). The consolidation is usually done by simultaneous learning on mixed samples drawn from the new and old memory networks.

Finally come sparse-coding techniques, which are motivated by the idea that interference happens because of overlapping internal representations, hence, introducing carefully-engineered sparsity should in principle reduce this destructive overlapping (Kruschke, 1991; Coop et al., 2013; Murdock, 1983; Eich, 1982). The main limitation of the technique is that sparsity may hinder generalization and model capacity to learn new tasks.

van de Ven and Tolias (2018) identified three main scenarios of continual learning. Task-incremental learning (Task-IL) is a regime of learning sequential tasks, where the model is always provided with task identity. In the second scenario termed domain-incremental learning (Domain-IL), task identity is unknown to the model and the model is required to only solve the task at hand, without explicitly being required to identify the task identity. The last scenario is the class-incremental learning (Class-IL) where the task identity is unknown, however, the model is required to solve the task as well as identify the task identity.

In this paper, we introduce a new technique, that focuses on Task-IL, for reducing catastrophic forgetting by exploiting modular ANNs. The main idea is to distribute the information load between different network modules in order to reduce the ongoing catastrophic forgetting due to interference. In order to do that we need three main components: a way to route information between different modules, a measure to guide the routing based on the information load accumulating in each module and a way of memorizing the path distributions across multiple tasks. The memorization part is necessary for the functioning of the algorithm since the routing is mainly associating an input to a route. The target route will be trained on the routed input and, hence, at inference time similar inputs must be directed to the same route to attain a good performance. For the routing part, we rely on a class of deep models called routing ANNs (McGill and Perona, 2017; Rosenbaum et al., 2017; Cai et al., 2019). Routing ANNs are a class of neural network models where input is routed along different paths of the network based on some criteria. In order for the routing to reduce information destruction, we guide the routing process by an approximation of the empirical joint Fisher information using reinforcement learning (RL). The third component, i.e the memory, is realised using a dedicated per task network that is learned in a supervised manner guided by the routing component. Our main contributions in this paper are:

  • Introducing and investigating the idea of dynamic information balancing (DIB) between different modules in an ANN as a way of alleviating catastrophic forgetting.

  • Using joint Fisher information as a guiding measure for routing patterns by RL through modular ANNs in order to reduce information interference.

  • Approximating empirical joint Fisher information in order to make the routing RL reward tractable.

  • Introducing a component memory network as a way of preserving path distributions across tasks.

  • Achieving better generalization results than other models with similar capacity across different task formulations and datasets with diverse distributions.

2 Related Work

Due to the importance of catastrophic forgetting as a problem affecting our ability to implement continual learning in deep models, many research papers have tried to shed light on the phenomenon and develop techniques for overcoming it. The early work by Srivastava et al. (2013) showed that using local winner-take-all (LWTA) neurons enhances the test error on sequential tasks, suggesting a suppressing effect on catastrophic forgetting. Goodfellow et al. (2013)

investigated the effect of dropout and the choice of activation functions on catastrophic forgetting. Their empirical results suggest consistently that training using dropout is a way of reducing catastrophic forgetting across different datasets. On the other hand, different activation functions ranked differently under different conditions, which weakens the argument of a general advantage of the activation function choice on reducing catastrophic forgetting. Dropout is a member of the regularization family of catastrophic forgetting reduction techniques which will be elaborated on more below. On the other hand, activation function selection isn’t a mainstream technique for catastrophic forgetting reduction.

Different approaches were suggested for reducing catastrophic forgetting. They can be classified into five main categories: regularization, ensemble, replay, dual-memory and sparse-coding (Kemker et al., 2017; Parisi et al., 2019). One of the major contributions in regularization methods consists of the elastic weight consolidation (EWC) technique, introduced by Kirkpatrick et al. (2016). EWC regularizes the weights by restricting their flexibility so that they don’t drift so far from the local minimum discovered in previous tasks. In EWC, the regularization strength of each weight is determined using Fisher information. Synaptic intelligence (SI) (Zenke et al., 2017)

is another regularization method that uses the path integral of the gradient vector as a strength measure for restricting each weight. Incremental moment matching (IMM)

(Lee et al., 2017) takes a Bayesian approach and depends on regularizing the posterior distribution’s moment of the new task based on the previous task’s posterior distribution. Kaplanis et al. (2018) regularize the model through implementing a biologically inspired more complex synapse that takes into account previous modifications applied to the synaptic weight.

Ensemble methods assign an additional model for each new task. Learn++ Polikar et al. (2001); Dai et al. (2007) uses algorithms similar to AdaBoost, where a sequential set of classifiers is learned and their predictions are combined using weighted majority voting. PathNet (Fernando et al., )

is a form of implicit ensemble method, where genetic algorithms are used to select a pathway through a large network which is trained for the current task. After convergence, the pathway is frozen and another pathway is selected for the next task. A similar implicit technique based on paths is progressive network

(Rusu et al., 2016), where a new column is added to a multi-column network for each new task, while the columns of previous tasks are frozen. Each new column is connected to previous columns via lateral connections to promote information reuse. The main limitation for most of the ensemble techniques is the dependence of memory complexity on the number of tasks since a whole model (or a component module) needs to be stored for each task.

Replay methods rely on mixing samples from previous tasks into the learning process to balance the learning process. Shin et al. use a generative model, accompanying the main network, that is learned on the data distributions of the previous tasks. The generative model is used to sample inputs from previous tasks, which are mixed with the current task’s samples during the training process. van de Ven and Tolias (2018) integrate the generative model into the main network by introducing feedback connections that are trained to reconstruct inputs from hidden states, hence, removing the need for a separate generative model. Isele and Cosgun (2018) investigate the idea of selecting which experience is more likely to reduce the catastrophic forgetting effect. They investigate four different strategies for selection, namely, surprise, meaning which experience the model finds surprising as measured by the prediction error, reward, which is measured by how strong the reward assigned to the experience is, global distribution matching, which is motivated by capturing the joint strategy for all tasks combined, and coverage maximization, which favors a distribution that covers as much of the input state space as possible.

A cross technique that combines regularization and replay is gradient episodic memory (GEM) proposed by Lopez-Paz and Ranzato (2017). As in replay, an episodic memory is used for storing samples from previous tasks, but instead of injecting them as a learning input, they are used to regularize the subsequent tasks’ learning such that the loss on previous tasks doesn’t increase. Sodhani et al. (2019) combines GEM with a network expansion technique called Net2Net (Chen et al., 2015) so that when a target performance on a specific task is not achieved, the network is expanded.

Dual-memory approach is biologically inspired and depends on separating the learning of new memories across two different networks, the first is responsible for short-term memories, which are then consolidated into the long-term memory represented by the second network. French (1997) propose using a network composed of two parts, early-processing memory and final-storage memory. The early-processing memory is trained using real data samples and pseudo-samples drawn from the final-storage memory by presenting random inputs. After convergence, the weights of the early-processing memory are transferred to the final-storage memory. The motivation is for the early-processing memory to learn new data samples mixed with data samples drawn from final-storage memory, which already has the old experience consolidated into its data distribution. The approach shares some similarity with replay methods. Kamra et al. (2017) use deep generative models as a way of sampling previous task distributions in a way similar to replay methods. However, their method assigns a new deep generative model to each new task encountered, which are regarded as short-term memories (STMs). After training on multiple tasks, the STMs trained so far are consolidated into a larger long-term memory (LTM) generative network by unsupervised training on the samples from all of the STMs and samples from the LTM itself, which are representatives of old distributions consolidated into the LTM so far.

Sparse-coding methods are based on the assumption that catastrophic forgetting is mainly due to the interference of internal representations, hence, introducing carefully-crafted sparsity will in principle reduce representational overlap. ALCOVE (Kruschke, 1991) depends on attention gates applied to hidden nodes, such that nodes are activated based on a similarity measure with the input, which is considered as a proxy for task similarity. Another set of algorithms (Murdock, 1983; Eich, 1982) store the representations as a superposition between individual states using convolution and correlation as the operators for storage and retrieval, respectively. A fixed expansion layer (FEL) (Coop et al., 2013) is a hidden sparse layer initialized in a special way using a mix of fixed excitatory and inhibitory weights. The motivation is to activate different nodes of the layer for different inputs, hence, reducing destructive overlapping.

We rely on dynamic routing for input redirection in DIB. Routing is a way of dynamic module composition conditioned on some criteria, most commonly the input or some representation of it. A routing ANN is a neural network composed of different modules and a routing subnetwork. Given the input, the network is trained to compose a set of modules suitable for handling the input based on the routing subnetwork’s decisions, along with the usual weight optimization associated with any ANN. There is neuroscientific evidence that dynamic routing occurs in the primate visual cortex. Goodale et al. (1994) argued that spatial information is processed separately from object identities in the primate visual cortex. McGill and Perona (2017) use a pyramid of descriptors (Ke et al., 2016) as an input to each routing layer and a two way routing subnetwork decides whether to carry on to the next layer, or to stop the signal and produce the final output. The routing subnetwork is trained using different RL algorithms and the training criteria include two penalties for balancing accuracy, i.e more processing, and efficiency, i.e less depth. They regularize only the activated paths to prevent under/over-constraining for frequently/infrequently used paths, respectively. Rosenbaum et al. (2017, 2019) use a global router for doing the routing, which is provided with auxiliary information about the current depth. The training is done using a multi-agent RL (MARL) called weighted policy learner (WPL), which controls the learning rate based on the agent confidence. They experiment with combining two reward signals, a global final reward based on network performance, and a local reward after each action which encourages the agent to minimize the depth of the dynamic path. The recurrent model defined by Hafner et al. (2017) uses routing in a different implicit way. The recurrent model is inspired by the cortico-thalamo-cortico pathway and is composed of different modules each connected to a central routing center. At each time step, different modules read from the routing center using a reading mechanism and all the module outputs are integrated back into the routing center. They argue that the routing is done at the information level in a hierarchical fashion. Cai et al. (2019) propose a neural architecture search (NAS) related approach for routing. During training, different path outputs are combined by a weighted sum using the gumbel-softmax technique (Jang et al., 2016; Maddison et al., 2016), while only the top-k paths are selected at inference time.

Related routing mechanisms were introduced by Sabour et al. (2017); Hinton et al. (2018). The routing here is done in the context of capsule networks, which generalizes scalar CNN features to vector representations, termed capsules. The routing is done between different capsules using either routing by agreement (Sabour et al., 2017)

or expectation-maximization (EM)

(Hinton et al., 2018).

In this work, we introduce dynamic information balancing (DIB) as a new method for reducing catastrophic forgetting in modular neural networks (MNNs). DIB combines modular routing ANNs with an approximation of empirical joint Fisher information as a reward signal for routing in order to reduce information destruction. DIB doesn’t immediately fall in any of the previous categories mentioned for catastrophic forgetting methods since the core algorithm doesn’t use a form of regularization, combine different trained models (ensemble), inject past experience into the training process (replay), use memory consolidation into a global pool (dual-memory) or enforce sparsity (sparse-coding). However, the fact that DIB uses a task specific memory component makes it relatable to explicit ensemble models. However, the memory footprint of DIB is much smaller than ensemble models since only a memory module comprising a small fraction of the model’s total number of parameters is stored per task. DIB can also be related to PathNet and implicit ensembling since it exploits different paths through an MNN. However, DIB differs significantly in the routing mechanism and path selection. It is also vaguely related to dual-memory since a form of memory subnetwork is used for each task.

3 Dynamic Information Balancing

The aim of dynamic information balancing (DIB) is to reduce destructive information interference by balancing the information content across different modules in a modular neural network (MNN). The main components needed to realise a system like this are:

  • A modular neural network architecture.

  • A routing mechanism for routing inputs.

  • A measure to guide routing such that the information content is balanced.

These components need to interact in the following way to achieve information balancing: given an input, the routing mechanism will decide which modules have the least information load, using the information measure of each module, and hence, it will route the input through these modules. After the modules’ weights are updated, the information measure of the different modules is updated to reflect the new information load.

We use a modular architecture that is composed of sequentially stacked DIB cells fig. 1

. Each DIB cell has three main components. First of all, the set of modules that will be used as a learning substrate. The second component is the router, which is a subnetwork that is used for routing the input to the different modules based on the information measure. The third component is the memory network. The memory network (MemNet) is a subnetwork that is trained to shadow the router by supervised learning using the router’s output as a target signal. At inference time, the router is discarded and MemNet is used instead to route inputs.

The need for routing by RL arises from the infeasible combinatorial complexity of searching the space of every possible path and calculating some information measure for each one of these paths. Instead, we rely on a router leaned by RL, which is guided by a reward derived from an information measure.

Figure 1: DIB cell architecture (MemNet is omitted). White circle: active connection. Black circle: inactive connection. The router activates one module at a time. At inference time, decisions are made by a MemNet receiving raw input.

MemNet is essential for the functioning of the system in the continual learning paradigm. In continual learning, given a new task, we initialize a new set of MemNets (one per cell), that are trained throughout the task by shadowing router decisions. At inference time, the routers are discarded and the MemNets of the task at hand are loaded and used for making routing decisions. The need for MemNet arises from catastrophic interference occurring in the router itself. During task training, the router makes reasonable decisions about input routing based on the modules’ information load, but as the input distribution (from the previous DIB cell) shifts between tasks, the router forgets about its past decisions regarding previous distributions. MemNet shadows the router decisions at each task so that any task input can be routed correctly at inference time.

Despite the fact that the MemNet approach has a partial similarity to ensemble methods, its memory requirement is different. The decoupling of the router’s and MemNet’s architecture allows for using a network that is smaller than the router and much smaller than the total model size, which considerably reduces memory complexity.

Figure 2: DIB cells wiring diagram (modules are omitted). As depicted, routers receive the output of the previous DIB cell as an input (except in the first cell). In contrast, MemNets always receive the raw input.

The router receives the previous DIB cell’s output as a conditioning input for making routing decisions. In contrast, MemNet receives the raw input and not the previous DIB cell’s output fig. 2. The main reason for this is that the previous DIB cell’s output will change upon training on a sequence of tasks. While this change may carry very useful information for the router to make correct decisions, it is confusing to MemNet. MemNet’s main task is associating an input pattern with the routing decision at inference time, and, hence, it will not function properly when its inputs get changed.

The router is trained using deep Q-learning (DQN) (Mnih et al., 2013). Given a nonnormalized router’s output vector where is the number of possible actions in the action set (where each action corresponds to choosing the module with the corresponding index), the router’s loss is calculated as:


where is the chosen action, is the reward gained by taking the action and

is the logit of the router’s output for the same action. The action is chosen according to an

-greedy policy, where

and a random action is chosen with probability

and the router’s optimal action is chosen with probability .

The essence of the DIB is guiding the routing process to reduce information interference. In the assumed DQN RL routing, this translates to using a reward signal that reflects different information loads in each possible action or path. Intuitively, information load of a given module refers to how much information is packed into the specific values of the modules’ parameters. Fisher information is a measure of how much information a parameter holds about the distribution which it is modeling. Hence, we use empirical Fisher information with some modifications as a proxy for information load. Given a dataset of size , the usual empirical Fisher information is calculated as:


where is the set of model weights and is the index of the corresponding weight and the fact that we are parameterising by reflects that we are using only the diagonal of the Fisher matrix (Kirkpatrick et al., 2016). Due to the nonlinearity applied to the gradient, expressing empirical Fisher as a matrix-matrix operation and, hence, accelerating its calculation through GPUs is infeasible. For regularization purposes like EWC (Kirkpatrick et al., 2016)

, this is not a serious problem since the Fisher diagonal is only calculated in-between tasks. However, for our purpose of continuously guiding the routing process, we need a continuous efficient way of calculating the empirical Fisher diagonal, otherwise, the router’s estimation of information loads will quickly become inaccurate and the routing decisions will start to cause information destruction.

We have done two approximations to allow for continuously using the Fisher information as a routing signal. The first one is approximating the empirical Fisher information by the joint empirical Fisher information over a minibatch. The term joint Fisher information refers to the fact that Fisher information is calculated for the joint probability distribution of multiple samples, instead of the probability of a single sample. Given a minibatch set

, we calculate the joint Fisher information as:


where is the batch index. Note that the sum of log probability in the numerator corresponds to the log of the joint probability of the samples in the minibatch under the assumptions of i.i.d. For our purpose, such an approximation is justified by the fact that our routing aims mainly at balancing the information content of the current probability distribution, which can be approximated by the joint probability over a random sample of a minibatch. The second approximation is due to the fact that we are routing over different paths where each path contains many weights, so, we average across all the parameters in the activated modules,


where is the set of activated modules, is the set of parameters in the activated modules, is the number of these parameters and the sum rolls over , which is the index of a parameter in .

Since the joint Fisher information we calculated so far is intuitively a measure of the information load in each path, then, the reward should be the negative in this quantity in order to encourage the router to avoid paths with information congestion,



is a weighting hyperparameter that controls the strength of the reward. Since this reward is calculated in a minibatch setting, it is actually the reward that will be used for every routing decision in the minibatch. Hence, given the set of routers’ actions for minibatch inputs

, the minibatch total routing loss for the given router is calculated as,


Algorithm 1

details how the overall algorithm works. In summary, we loop over the given tasks’ datasets. For each epoch in a task’s training loop, we calculate the DIB model’s outputs on a given minibatch, which includes the classification outputs and router’s and MemNet’s outputs for each DIB cell. We calculate the classification loss from the classification outputs and the targets. We calculate the memory loss from the MemNets’ outputs and the corresponding routers’ outputs. Then, based on the activated modules decided by the routers’ decisions, we calculate the joint Fisher information using the classification outputs and the targets. From the minibatch’s joint Fisher information, we calculate the reward, which is then used to reward the routers’ decisions. Finally, we do the backpropagation and update the relevant parameters for each loss. Note that we have three different losses, one is a classification loss related to the classification outputs, a memory loss related to the supervised training of the MemNets and a routing loss related to the RL of the routers. Classification loss should affect only the modules’ weights, memory loss should affect only MemNets’ weights and routing loss should affect only the routers’ weights.

Input: A set of datasets corresponding to different tasks where is the number of tasks. A DIB model having a set of modules, with the associated routers and MemNets. A reward weight hyperparameter . Training epochs .
for  to  do
       initialize a new MemNet for  to  do
             foreach  do
             end foreach
       end for
end for
Output: Trained DIB model and a MemNet for each task.
Algorithm 1 DIB

4 Experiments

We apply DIB to three different Task-IL datasets, two common benchmarks for lifelong learning which are PermutedMNIST and SplitMNIST, and a third more complex camera trap dataset called iWildCam2019, which we preprocess and use in a way similar to SplitMNIST. For each dataset we compare the DIB model against an MLP with a similar capacity. The DIB model we use for all of our experiments is made of two stacked DIB cells, one referred to as hidden cell, the other as output cell. The hidden cell has 10 similar modules, each of which is composed of 2 fully-connected (FC) layers, each of which in turn has 445 neurons. The output cell is composed of 10 modules, each of which is an FC layer with dimensionality matching the task output and a Softmax nonlinearity. The router associated with each cell is an MLP with two hidden layers, each with 256 nodes, and an output layer with dimensionality matching the number of modules (i.e 10 nodes) and Softmax nonlinearity. The MemNet in each cell has two hidden layers, each with 128 nodes and an output layer exactly similar to the router’s. We use ReLU activations for all of the hidden nodes. We use an

-greedy policy for the DQN training, with initialized to and updated each training step according to the following formula:


where is the step index, , and .

The MLP model we use for comparison under different conditions is designed to have the same depth and almost the same total number of parameters as the DIB model. It is composed of 2 hidden layers, each with 2000 nodes with ReLU activations and an output layer with a Softmax nonlinearity. Since the DIB model depends on a task-specific MemNet and to make a fair comparison with MLP, we add a comparison condition that involves a multi-head MLP (MHMLP). MHMLP has the same architecture mentioned above, except that it has a task-specific output layer. On each new task, we initialize a new output layer, which is trained with the rest of the network. The output layer is stored after the given task’s training and reloaded at inference time depending on the task identity.

We use the mean test error at the final task as a measure of the model performance. This is done by evaluating the model, after training on the final task, on all of the tasks, and then taking the mean. In one of the comparison conditions, we train a vanilla MLP on the whole set of tasks simultaneously. Since such a model is trained on all tasks jointly, it is an estimation of the lowest error attainable on the given tasks.

We compare the following conditions for all of the three datasets:

  • MLP: a vanilla MLP with the above mentioned architecture.

  • MLP+EWC: an MLP with EWC applied.

  • MHMLP: a vanilla MHMLP with the above mentioned architecture.

  • MHMLP+EWC: an MHMLP with EWC applied.

  • DIB: the DIB model described above.

  • DIB+EWC: the DIB model with EWC applied to the modules (i.e not applied to the router or MemNet)

  • lower-bound: an estimation of the test error lower bound, as mentioned above.

EWC depends on a weight hyperparameter, which defines the strength of regularization. DIB also depends on a similar hyperparameter that scales the reward signal used for the DQN RL eq. 5. Hence, any condition that involves EWC or DIB is run using 5 different hyperparameter values: , and when the condition contains both DIB and EWC, the same values will be used for both DIB and EWC. All our test results are based on the average of 3 trials, and in the case of any condition involving DIB or EWC, the best test performance across all the hyperparameter values is reported. Any single task is trained for 20 epochs and the lower bound condition is trained for 200 epochs. Adam optimizer is used for all of the experiments with default hyperparameter values and a training batch size of 128.

One additional comparison condition, called random information routing (RIR), is done for the SplitMNIST dataset for the sake of analysis and shedding more light on the dynamics of DIB. In RIR, we replace the RL router with random uniform routing of the inputs, i.e a given input is assigned to a module sampled randomly from a uniform distribution over the available modules.

To asses the information content of a specific module, we use conditional-entropy (cond-entropy) as a measure. We calculate the cond-entropy of the final model over all samples, but we accumulate the cond-entropy segregated by which module was activated for each pattern, and then we calculate module-mean cond-entropy as the average cond-entropy of that module over the samples which it was activated for.

PermutedMNIST (lower-bound=)
Model Test error(%)
DIB+EWC 2.32
SplitMNIST (lower-bound=)
Model Test error(%)
DIB+EWC 4.32
iWildCam2019 (lower-bound=)
Model Test error(%)
DIB+EWC 22.74
Table 1: Mean test errors across all tasks.
(a) PermutedMNIST (all)
(b) SplitMNIST
(c) PermutedMNIST (zoomed on the best performing)
(d) iWildCam2019
Figure 3: Test error on previous tasks after training on each task.
Figure 4: Mean conditional entropy per path.

4.1 Datasets

4.1.1 PermutedMNIST

PermutedMNIST is a classic benchmark for continual learning assessment. The MNIST dataset is used to generate k-number of tasks by shuffling the input pixels by a shuffle order that is applied to all of the inputs in a given task and is different from task to task. The output labels are kept the same.

We generate 10 tasks, where the first task is just the original dataset, while the remaining 9 tasks are shuffled randomly. The random seed used for shuffling is kept the same for all experiments, to reduce the bias that may be introduced due to the shuffling order. From the training set, we use 90% for training and 10% for validation. The test set is used as it is.

4.1.2 SplitMNIST

SplitMNIST is based also on the MNIST dataset, however, the tasks are generated by splitting the dataset by combining samples for each two sequential digits together in a disjoint way to generate 5 tasks. We again use 90% of the task’s training data for training and 10% for validation and we use the partitioned test set as it is.

4.1.3 iWildCam2019

We generate tasks from the iWildCam2019 training dataset using a similar way to the SplitMNIST dataset. Before splitting, we have selected 10 of the available classes to generate 5 tasks. We preprocessed the images by gray-scaling and resizing to 88x64. We, then, paired each two classes together and balanced them by discarding the excess samples in the larger class. Because there is no complete overlap between the iWildCam2019 train and test dataset, we used our own test set by dividing each task into 70% training, 20% validation and 10% test. The paired classes are: .

4.2 Results

We benchmark the performance of three main models, namely MLP, MHMLP and DIB, with and without EWC regularization. All the models’ performances are enhanced by applying EWC regularization relative to their nonregularized performances across all the datasets fig. 3. DIB+EWC has the best performance across all the datasets table 1. In PermutedMNIST fig. 2(a), EWC enhances the different models’ performances by a large margin relative to the other two datasets. The performance levels of the EWC conditions have the same following ascending order, from lower to higher performance, across all the datasets: MLP+EWC, MHMLP+EWC then DIB+EWC. The non-EWC conditions don’t seem to have a consistent order, however, the condition with the worst performance is MLP across all the datasets table 1. The RIR+EWC condition in the SplitMNIST dataset fig. 2(b) has lower performance than both the DIB and DIB+EWC. The conditional entropy per path fig. 4 is higher in RIR+EWC as compared to DIB+EWC.

5 Discussion

The main idea behind DIB is minimizing information interference by routing different patterns through RL rewarded by joint Fisher information. Since the router is guided by an information measure from the current task, the balancing is mainly affecting intra-task interference. This explains the performance gain from EWC regularization, which, despite being a more expensive operation, accounts for inter-task interference. This may also explain the relative performance of DIB+EWC when compared to other EWC conditions. Since other EWC enhanced techniques lack measures for minimizing intra-task interference, they can only adapt to inter-task interference.

The RIR+EWC condition confirms the effectiveness of RL guided by an information measure. Besides having lower performance than DIB and DIB+EWC, the average entropy per path in RIR+EWC is higher than DIB+EWC. Since EWC is applied to both conditions, i.e RIR+EWC and DIB+EWC, and since EWC regularizes the inter-task interference, then intra-task interference is very likely to contribute significantly to the degraded performance of RIR+EWC.

Another evidence for the contribution of inter-task interference as compared to intra-task interference is the large margin gain in performance in PermutedMNIST compared to the other two datasets. The PermutedMNIST has different input distributions, however, it shares the output distribution. On the other hand, SplitMNIST and iWildCam2019 don’t share neither the input nor the output distribution. This shared output distribution between PermutedMNIST tasks means that there is some overlap between tasks, which provides a fertile substrate for inter-task techniques like EWC to significantly reduce the inter-task interference.

One important detail to explain is the choice of the DIB architecture. It may sound more reasonable to diversify the different modules in a DIB cell, instead of choosing homogeneous modules, in order to allow for more differential learning that the routing algorithm can exploit. However, implementing heterogeneous modules isn’t straightforward since it can’t be readily reduced to a single matrix-matrix operation that can benefit from GPU-acceleration on a minibatch. On the other hand, assuming homogeneity, we could reduce routing different inputs to different paths as a pooled matrix-matrix operation that can be GPU-accelerated using any deep learning library. Despite the fact that we applied the homogeneity assumption to the fully-connected layer, expanding it to other layers like convolutional layers is trivial.

6 Conclusion

We have introduced dynamic information balancing (DIB), a method for reducing catastrophic forgetting by dynamically balancing information content across different modules in a modular neural network through routing inputs based on an information theoretic measure. DIB, combined with EWC, achieved better performance than models with similar capacity combined with EWC across different lifelong learning datasets and tasks. We used a computationally cheap approximation of the joint empirical Fisher information as a proxy for information load, which allowed for efficient continual update of the reward needed for guiding the routing by reinforcement learning. MemNet was introduced as a task-specific component that is learned to shadow the router’s decisions and take over its routing role at inference time.

We believe there are several potential directions for improving the DIB methodology in the future. The homogeneity assumption, despite the fact that it allowed for an efficient implementation, may be limiting the router’s capability of exploiting diversity in the available modules. Finding an efficient generic methodology for the practical realisation of routing efficiently to modules with heterogeneous arbitrary architectures may open the door for a lot of potential enhancements. While relying on task-specific information is a common practice in lifelong learning systems, which is represented by MemNet in our DIB model, finding more task-agnostic ways of reducing catastrophic forgetting is unavoidable for generalizing and extending lifelong learning. We consider extending the information balancing algorithm across the task boundary as a natural generalization for enhancing DIB and reducing the cross-task component of information interference.

7 Acknowledgement

This work was partially supported by a grant from Microsoft’s AI for Earth program.