Efficient and robust multi-task learning in the brain with modular task primitives

05/28/2021 ∙ by Christian David Marton, et al. ∙ Université de Montréal Icahn School of Medicine at Mount Sinai 0

In a real-world setting biological agents do not have infinite resources to learn new things. It is thus useful to recycle previously acquired knowledge in a way that allows for faster, less resource-intensive acquisition of multiple new skills. Neural networks in the brain are likely not entirely re-trained with new tasks, but how they leverage existing computations to learn new tasks is not well understood. In this work, we study this question in artificial neural networks trained on commonly used neuroscience paradigms. Building on recent work from the multi-task learning literature, we propose two ingredients: (1) network modularity, and (2) learning task primitives. Together, these ingredients form inductive biases we call structural and functional, respectively. Using a corpus of nine different tasks, we show that a modular network endowed with task primitives allows for learning multiple tasks well while keeping parameter counts, and updates, low. We also show that the skills acquired with our approach are more robust to a broad range of perturbations compared to those acquired with other multi-task learning strategies. This work offers a new perspective on achieving efficient multi-task learning in the brain, and makes predictions for novel neuroscience experiments in which targeted perturbations are employed to explore solution spaces.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 5

page 10

page 11

page 12

page 13

page 14

page 15

page 16

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Task knowledge can be stored in the collective dynamics of large populations of neurons in the brain

marton_learning_2020, richards_deep_2019, yang_task_2018, remington_dynamical_2018, kell_task-optimized_2018, zeng_continuous_2018, chaisangmongkon_computing_2017, rajan_recurrent_2016, sussillo_neural_2015, mante_context-dependent_2013, sussillo_opening_2013, barak_fixed_2013, buonomano_state-dependent_2009, sussillo_generating_2009. This body of work has offered insight into the computational mechanisms underlying the execution of multiple different tasks across various domains. In these studies, however, animals and models are over-trained on specific tasks, at the expense of long training duration and high energy expenditure. Models are optimized for one specific task, and unlike networks of neurons in the brain, are not meant to be re-used across tasks.

In addition to being resource-efficient, it is useful that acquired task representations be resistant to failures and perturbations. Biological organisms are remarkably resistant to lesions affecting large extents of the neural circuit basile_preserved_2020, curtis_effects_2004, and to neural cell death with aging fu_aging-dependent_2015, westlye_life-span_2010. Current approaches do not take this into account; often, representations acquired naively are highly susceptible to perturbations vincent-lamarre_driving_2016.

In the multi-task setting, where multiple tasks are learnt simultaneously yang_task_2018 or sequentially duncker_organizing_2020, zenke_continual_2017, kirkpatrick_overcoming_2017, the amount of weight updates performed grows quickly with the number of tasks as the entire model is essentially being re-trained with each new task. Training time and energy cost increase rapidly and can become prohibitive.

To overcome this, it is conceivable that the brain has adopted particular inductive biases to speed up learning over an evolutionary timescale richards_deep_2019. Inductive biases can help guide and accelerate future learning by providing useful prior knowledge. This becomes particularly useful when resources (energy, time, neurons) to acquire new skills are limited, and the space of tasks that need to be acquired is not infinite. In fact, biological agents often inhabit particular ecological niches carscadden_niche_2020 which offer very particular constraints on the kinds of tasks that are useful and need to be acquired over a lifetime. Under these circumstances, useful priors or inductive biases can help decrease the amount of necessary weight updates.

One way in which such biases may be expressed in the brain is in terms of neural architecture or structure richards_deep_2019. Neural networks in the brain may come with a pre-configured wiring diagram, encoded in the genome zador_critique_2019. Such structural pre-configurations can put useful constraints on the learning problem. One way in which this is expressed in the brain is modularity. The brain exhibits specializations over various scales sporns_modular_2016, wang_brain_2016, goulas_strength_2015, murray_hierarchy_2014, meunier_modular_2010, with visual information being primarily represented in the visual cortices, auditory information in the auditory cortices, and decision-making tasks encoded in circuits in the prefrontal cortex marton_learning_2020, chaisangmongkon_computing_2017, mante_context-dependent_2013. Further subdivisions within those specialized modules may arise in the form of particular cell clusters, or groupings in a graph-theoretic sense, processing particular kinds of information (such as cells mapping particular visual features or aspects of a task). From a multi-task perspective, modularity can offer an efficient solution to the problem of task interference and catastrophic forgetting in continual learning hadsell_embracing_2020, parisi_continual_2019, ruder_overview_2017, while increasing robustness to perturbations.

Another way in which inductive biases may be expressed in the brain, beyond structure, is in terms of function: the types of neural dynamics instantiated in particular groups of neurons, or neural ensembles. Rather than learn separate models for separate tasks, as many of the studies cited above do, or re-train the same circuit many times on different tasks yang_task_2018, duncker_organizing_2020, core computational principles could be learned once and re-used across tasks. There is evidence that the brain implements stereotyped dynamics which can be used across tasks domenech_executive_2015, sakai_task_2008. Ensembles of neurons may be divided into distinct groups based on the nature of the computations they perform, forming functional modules perich_rethinking_2020, sporns_modular_2016, billeh_revealing_2014, power_functional_2011, meunier_modular_2010. In the motor domain, for instance, neural activity was found to exhibit shared structure across tasks gallego_cortical_2018. Different neural ensembles can be found to implement particular aspects of a task koay_sequential_2019, yang_task_2018, chaisangmongkon_computing_2017 and theoretical work suggests the number of different neural sub-populations influences the set of dynamics a neural network is able to perform beiran_shaping_2020. Thus, a network that implements a form of task primitives, or basis-computations, could offer dynamics useful for training multiple tasks.

Putting these different strands together, in this work we investigate the advantage of combining structural and functional inductive biases in a recurrent neural network (RNN) model of modular task primitives. Structural specialization in the form of

modules can help avoid task interference and increase robustness to failure of particular neurons, while core functional specialization in the form of task primitives can help substantially lower the costs of training. We show how our model can be used to facilitate resource-efficient learning in a multi-task setting. We also compare our approach to other training paradigms, and demonstrate higher robustness to perturbations.

Altogether, our work shows that combining structural and functional inductive biases in a recurrent network allows an agent to learn many tasks well, at lower cost and higher robustness compared to other approaches. This draws a path towards cheaper multi-task solutions, and offers new hypotheses for how multi-task learning may be achieved in the brain.

2 Background

The brain is highly recurrently connected kveraga_top-down_2007, and recurrent neural networks have been successfully used to model dynamics in the brain marton_learning_2020, richards_deep_2019, yang_task_2018, remington_dynamical_2018, kell_task-optimized_2018, zeng_continuous_2018, chaisangmongkon_computing_2017, rajan_recurrent_2016, sussillo_neural_2015, mante_context-dependent_2013, sussillo_opening_2013, barak_fixed_2013, buonomano_state-dependent_2009, sussillo_generating_2009. Reservoir computing approaches, such as those proposed in the context of echo-state networks (ESNs) and liquid state machines (LSMs) pathak_model-free_2018, jaeger_harnessing_2004, maass_real-time_2002, avoid learning

the recurrent weights of a network (only training output weights) and thereby achieve reduced training time and lower computational cost compared to recurrent neural networks (RNNs) whose weights are all trained in an end-to-end fashion. This can be advantageous: when full state dynamics are available for training, reservoir networks exhibit higher predictive performance and lower generalisation error at significantly lower training cost, compared to recurrent networks fully trained with gradient descent using backpropagation through time (BPTT); in the face of partial observability, however, RNNs trained with BPTT perform better  

vlachas_backpropagation_2020. Also, naive implementations of reservoir networks show high susceptibility to local perturbations, even when implemented in a modular way vincent-lamarre_driving_2016.

Fig 1: Schematic of fully disconnected recurrent network with modular task primitives. Left: inputs for modules M1-M3. Middle: recurrent weight matrix of the modular network, with separate partitions for modules M1-M3 along the diagonal. Right: different tasks (T1-T9) trained into different output weights, leveraging the same modular network.

Meanwhile, previous work in continual learning has largely focused on feedforward architectures hadsell_embracing_2020, yu_gradient_2020, adel_continual_2020, parisi_continual_2019, golkar_continual_2019, yoon_lifelong_2018, serra_overcoming_2018, schwarz_progress_2018, zeng_continuous_2018, kirkpatrick_overcoming_2017, zenke_continual_2017, lopez-paz_gradient_2017, draelos_neurogenesis_2017, rusu_progressive_2016. The results obtained in feedforward systems, however, are not readily extensible to continual learning in dynamical systems, where recurrent interactions may lead to unexpected behavior in the evolution of neural responses over time sodhani_toward_2020, duncker_organizing_2020. Previously, modular approaches have also largely been explored in feedforward systems wilson_knowledge_2015, ellefsen_neural_2015, happel_design_1994.

Recently, a method for continual learning has been proposed for dynamical systems duncker_organizing_2020. The authors have shown that their method, dynamical non-interference, outperforms other gradient-based approaches zenke_continual_2017, kirkpatrick_overcoming_2017 in key applications. All of these approaches, however, require the entire network to be re-trained with each new task and thus come with a high training cost. A network with functionally specialized modules could provide dynamics that are useful across tasks and thus help mitigate this problem.

In the use of modules, our approach is similar to memory-based approaches to continual learning which have also been applied in recurrently connected architectures graves_neural_2014, cossu_continual_2020. In these approaches, context-triggered information retrieval is central to computations: information can be stored in memory and re-used when it becomes relevant to new tasks. These approaches, however, rely on adding new modules or memory components as learning proceeds, which may often grow unboundedly hadsell_embracing_2020.

Ongoing work in transfer, few-shot and one-shot learning has focused on leveraging previously trained networks to achieve a reduction in training time on new tasks. This work, largely carried out in the computer vision domain, has focused on feedforward systems

rezaei_zero-shot_2020. Recently, a new framework for achieving good time-series forecasts with limited data has been proposed orozco_zero-shot_2020, but it has not been done in a multi-task setting. Initial training in this framework is expensive also, as it requires training on a multitude of different time-series datasets.

Learning from this wide-ranging body of work, our key contribution is the following. Instead of pre-training an entire network on a full set of tasks, as in transfer learning, we hypothesize that pre-training functional modules on task primitives offers benefits for multi-task learning while keeping training costs low. This procedure aims at a tradeoff between end-to-end trained RNNs and the reservoir computing paradigm, replacing randomly connected ’reservoir’ networks with task-primitive modules. Modules are trained once at the beginning, and can be re-used to perform many different tasks. We also hypothesize that functional modularity guards against failure, as functionality in other modules is preserved even if dynamics in one of the modules are perturbed.

Fig 2: Task training overview. Left: the pre-configured network receives task-specific inputs. A multitude of different tasks (T1-T9) can be trained leveraging the dynamics in the pre-configured network. Separate readout weights are trained for tasks T1-T9 while recurrent weights remain frozen. Right: task targets (strong colors) are plotted together with trained outputs (pale colors).

3 Model details

3.1 Modularity

Our network is based on a simple RNN architecture with hidden state evolving according to

(1)

with activation function

, external input , uncorrelated Gaussian noise , and linear readouts

. Starting from an initial state, with initial weights drawn from a Gaussian distribution, each module is trained to produce a module-specific target output timeseries

by minimizing the squared loss using the Adam optimizer (see A for more details). Three modules (M1-M3) are embedded within the larger network. We consider two conditions, (1) modular disconnected, in which the modules are entirely disconnected (Figure 1), and (2) modular connected, in which connection weights are trained (Figure S1). Inputs are fed in through shared weights. During training, each module (M1-M3) receives one of three inputs (Fixation, Input1, Input2), respectively. Each module is trained separately such that when a particular module is being trained, only weights within a module are allowed to change, while others remain frozen. The gradient of the error, however, is computed across the whole network.

3.2 Task primitives

Neuroscience tasks yang_task_2018, like the ones we employ here (explained in greater detail later, see 4), are structured as trials, analogous to separate training batches. In each trial, stimuli are randomly drawn, e.g. from a distribution of possible orientations. In a given trial, there is an initial period without a stimulus in which animals are required to maintain fixation. Then, in the stimulus period, a stimulus appears for a given duration and the animal again is required to maintain fixation without responding. Finally, in the response period, the animal is required to make a response. In models, a context signal indicates when to maintain fixation, separate inputs are fed in through different channels, and targets indicating the required choice are read out from the network.

Taking inspiration from this task structure, and core functionality in most neuroscience tasks maheswaranathan_universality_2019, gallego_cortical_2018, mante_context-dependent_2013, we derive simple task primitives

that provide core functionality for solving multiple tasks. The goal is to explore how such primitives can be pre-learned and recombined by networks for multi-task learning. Chosen primitives are (i) a low-dimensional latent representation of inputs, and (ii) the memorization of transient input sequences. In our model, M1 is assigned to (i) and learns to autoencode the fixation signal. M2 and M3 are assigned to (ii) and are trained to hold Inputs 1-2 in memory, respectively. More specifically, M2 receives short positive input pulses (

) and is required to hold the signal in memory for the remainder of the trial period, while M3 is required to do the same with negative input pulses. More details can be found in section Appendix A.

After training, recurrent weights are frozen. Separate linear readouts can now be trained for different tasks (T1-T9) by minimizing the task-specific loss function and updating readout weights (

) only.

(2)

4 Tasks

We show that we can leverage our pre-configured modular network to learn a set of nine different neuroscience tasks, previously used to study multi-task learning yang_task_2018, duncker_organizing_2020. Trained simultaneously into one and the same network, these tasks produce a rich and varied population structure and are mapped into different parts of state space yang_task_2018. Simultaneous training, however, does not guarantee good performance on all tasks; this can be improved by using a multi-task training paradigm whereby new tasks are projected into non-interfering subspacesduncker_organizing_2020. This approach is still dependent on task order, however: how well a new task is learnt often depends on its placement in the training curriculum.

We employ a set of nine different tasks: Delay Pro, Delay Anti, Mem Pro, Mem Anti, Mem Dm 1, Mem Dm 2, Context Mem Dm 1, Context Mem Dm 2 and Multi Mem (Figure 2). In Delay Pro and Delay Anti, the network receives three inputs (Fixation, cos(), sin(), with varying from 0-180 degrees). The angle inputs turned on at the Go-signal and remained on for the duration of the trial. The network was required to produce three outputs, yielding the correct location during the output period (when the fixation signal switched to zero). In the anti-versions, the network is required to respond in the direction opposite to the inputs. In Mem Pro and Mem Anti, the input stimulus turns off after a short, variable duration (Figure 2,inputs T3-T4). As before, the network is required to produce three outputs encoding the correct location during the output period.

In Mem Dm 1 and Mem Dm 2, the network receives two inputs (Fixation, and a set of stimulus pulses separate by 100msec) (Figure 2, inputs T5-6). The network is required to reproduce the higher valued pulse during the output period. Depending on the task, the network receives one or the other input stimulus pulse set.

In Context Mem Dm 1 and Context Mem Dm 2, the network receives three inputs (Fixation, and two sets of stimulus pulses) (Figure 2, inputs T7-9). The network is required to reproduce the higher value pulse during the output period, ignoring the other input stimulus set. The two tasks differ in which input stimulus to pay attention to. In Multi Mem, the network again receives three inputs (Fixation, and both sets of stimulus pulses). The network needs to yield the summed value of the stimulus train with the higher total magnitude of the two pulses.

Fig 3: Performance. Training loss across tasks, task performance after training, and exemplary targets and outputs (Delay Pro task) are shown for separate training (A-C), dynamical non-interference (E-G) and for our modular network approach (I-K). The modular approach achieves high performance on all tasks.

5 Results

All of the tasks are learned well by leveraging the dynamics in the network (Figure 2). Each task is trained into separate output weights, leaving the recurrent weights frozen. We also compared our approach to networks trained separately and with recent multi-task learning approaches (Figure 3). Networks trained separately reach perfect performance on each task, while networks trained with dynamical non-interference duncker_organizing_2020 reach performance on all task at the end of training. Our approach (Figure 3I-K) also reaches high performance on all tasks, more akin to separate training (Figure 3B) than to multi-task training (Figure 3F).

Importantly, in comparison to all other approaches considered, our modular network achieves high performance while using much less total parameter updates (Table 1). Naive, separate network training requires a separate network for each task. Thus, assuming a network size of 300 units and two readout units, it requires parameters in total, and all of them need to be trained separately on one task per network. In contrast, using dynamical non-interference (Figure 4B,duncker_organizing_2020), only one network is needed to perform all tasks; however, the whole network still needs to be re-trained separately on each task. Thus it requires parameters to be trained in total across the nine tasks, just like the naive approach, despite using fewer parameters in total. Our approach aims to exploit the best of both of these methods: (i) keeping the total number of parameters low, and (2) re-training the same set of parameters as little as possible.

Indeed, similarly to multi-task learning in the overall parameter count, our approach uses one network to perform all tasks. However, the number of trained parameters in our approach is order of magnitude lower than with the other approaches (Table 1). This arises from using the network with modular task primitives as a reservoir of useful dynamics.

Counts Separate training Dynamical non-interference Modular task-primitves
Overall count 8.15e5 9.06e4 9.54e4
Number trained 8.15e5 8.15e5 9.54e4
Per task 9.06e4 9.06e4 6e2
Table 1: Parameter count

The recurrent weights of the modular network only need to be trained once upfront, and subsequently, each task can be learnt training separate readout weights only. This allows the parameter count in our approach to increase much more slowly as a function of tasks trained (Figure 4A) compared to other approaches, which require the entire recurrent weight matrix to be re-trained on each task. The overall parameter count increases slowly with tasks, but also stays low overall using our approach (Figure 4B).

Fig 4: Parameter count comparison. Total number of parameters trained (A), and total number of parameters needed overall (B) for networks trained separately, networks trained with dynamical non-interference, and those trained with our approach, modular task primitives.

5.1 Perturbations

We also compare the performance under different types of perturbations for the three approaches. We consider four different types of perturbations: lesioned weights, silenced activity, global noise and input noise (see A). Weights are lesioned by randomly picking an entry in the connectivity matrix (Equation 1), and setting it to zero. Activity is silenced by clamping the activity of a particular unit in the network at zero. In the case of global noise, Gaussian noise is added to the entire connectivity matrix, while in the case of input noise, uncorrelated Gaussian noise is added to all the input units.

In Figure 5, we show the results of comparing the performance of our approach after perturbations. Here, we pick connections and units from across the entire network randomly and repeat this procedure for 25 draws with different seeds. In addition to fully disconnected networks with block diagonal modules (Figure 1), we also consider fully connected modular networks (Figure S1).

Modular networks show consistently high resilience to all types of perturbations, performing better than networks trained with a high-performing multi-task learning approach such as dynamical non-interference (Figures 5 & S2). Performance decays more slowly using our approach after weight and activity perturbations, as well as when adding global and input noise to the network. In the case of the Delay Anti task, for instance, the performance of networks trained with dynamical non-interference reaches chance-level at weights lesioned, while our approach still gets 50% of the decisions correct at that level. We found networks trained separately and those trained with dynamical non-interference to be similar in performance, with a few exceptions for particular tasks (such as DelayPro and DelayAnti, for example, where separately trained networks showed higher robustness). We found no particular advantage in training off-diagonal connections (Figure 1, cyan vs dark green). Overall, we found that our multi-task learning technique achieved highest robustness to perturbations across tasks.

We also performed perturbations in which we specifically targeted units within a module in our network. In Figures 6 & S3, we show performance after perturbing units within module M2 of our network (see Appendix A). We observed the same pattern across approaches, with our technique achieving highest robustness. We found that fully connected modular networks performed worse than fully disconnected networks (6 & S3 light blue vs. green, respectively). When perturbations were performed in a modular way, targeting sets of locally inter-connected units, we found that networks trained with dynamical non-interference performed better than those trained separately. Overall, even when targeting a module directly, our approach consistently achieved higher robustness to perturbations than other approaches across tasks.

Fig 5: Performance after perturbations. Performance (% decisions correct on a 250 trial test set) for tasks T1-T4 (DelayPro, DelayAnti, MemPro, MemAnti

) as a function of % weights lesioned (A, E, I, J), % units silenced (B, F, J, N), global Gaussian noise as a function of standard deviation (

) added to the entire population (C, G, K, O), and independent Gaussian noise added to all inputs (D, H, L, P). Error bars show the standard deviation across repeating perturbations for 25 draws and 2 networks, all with different random seeds.

6 Discussion

We presented a new approach for multi-task learning leveraging modular task primitives. Inspired by the organization of brain activity and the need to minimize training cost, we developed a network made up of functional modules that provide useful dynamics for solving a number of different tasks.

We showed that our approach is able to achieve high performance across all tasks, while not suffering from the effects of task order or task interference that might arise with other multi-task learning approaches duncker_organizing_2020, zenke_continual_2017. What is more, our approach achieves high performance with around an order of magnitude less parameters compared to other approaches. It scales better with tasks, thus reducing overall training cost.

Our results also show that networks composed of task modules are more robust to perturbations, even when perturbations are specifically targeted at a particular module. As readouts may drawn upon dynamics in different modules, task performance may not be drastically affected by the failure of any one particular module. The lower performance of the fully connected modular network may be due to the trained connection weights, which increase the inter-connectivity of different cell clusters and thus the susceptibility of local cell clusters to perturbations in distal parts of the network.

This generates useful hypotheses for multi-task learning in the brain. It suggests decision-making related circuits in the brain, such as in the prefrontal cortex, need not re-adjust all weights during learning. Rather, pre-existing dynamics may be leveraged to solve new tasks. This could be tested by using recurrent networks fitted to population activity recorded from the brain perich_inferring_2020, pandarinath_inferring_2018 as reservoirs of useful dynamics. We hypothesize that the fitted networks contain dynamics useful to new tasks; if so, it should be possible to train readout weights to perform new tasks. Moreover, targeted lesions may be performed to study the effect of perturbations on network dynamics in the brain; our work suggests networks with functional task modules are less prone to being disrupted by perturbations. Based on our work we would also expect functional modules to encode a set of basis-tasks - commonly occurring dynamics - for the purpose of solving future tasks.

In a recent study, flesch_rich_2021

examined the effects of ’rich’ and ’lazy’ learning on acquired task representations and compared their findings to network representations in the brain of macaques and humans. The lazy learning paradigm is similar to ours in as far as learning is confined to readout weights; but it lacks our modular approach and does not show performance on multiple tasks. Conversely, our networks are not initialized with high variance weights as in the lazy regime. Ultimately, we observe improved robustness using our approach, unlike lazy learning which appears to confer lower robustness compared with rich learning.

In the more general sense, the challenge of how to choose task primitives for a given distribution of tasks remains. In the future, we aim to develop a way to derive the initialization scheme for the modular network directly from the set of tasks to be learnt. This could be done by segregating common input features derived from a set of tasks into separate modules, for example by obtaining a factorised representation of task dynamics higgins_beta-vae_2017. Also, instead of keeping recurrent weights frozen, they could be allowed to change slowly over the course of learning multiple tasks botvinick_reinforcement_2019, thereby incorporating useful task information without disrupting previous dynamics. Leveraging Meta-Learning approaches to optimize task primitives themselves is another possible avenue to explore this question. Future work will also explore how the observed susceptibility to perturbations is more precisely related to the dynamics in the network. Higher robustness to perturbations could be achieved by optimizing the training of readout weights further.

In as far as this work facilitates multi-task learning at a lower overall training and thus energy cost, it has the potential to displace other more energy-intensive methods and thereby reduce the overall impact on the environment. The focus of this work is mainly on neuroscience tasks and thus devoid of any potentially harmful application. However, in as far as this work facilitates efficient multi-task learning, it has the potential to find wider application in modeling sequential data. The dynamics created in the network may be useful to training a wide range of tasks; to preclude any harmful effects, task primitives can be chosen in such a way as to limit the performance of the network.

Fig 6: Performance after modular perturbations (see A for details). Performance (% decisions correct on a 250 trial test set) for tasks T1-T4 (DelayPro, DelayAnti, MemPro, MemAnti) as a function of % weights lesioned (A, D, G, J), % units silenced (B, E, H, K), and global Gaussian noise as a function of standard deviation () (C, F, I, L). Error bars show standard deviation across repeating perturbations for 25 draws with different random seeds for 2 networks initialized with different random seeds.

7 Broader Impact

In this work we developed a novel approach to continual learning which achieves high performance and robustness at lower overall training cost compared to other approaches. As energy usage soars with increased model size, lowering training cost without sacrificing performance is becoming increasingly more important in fields such as artificial intelligence and robotics. Our approach may also find applications in the domain of low-power biological implants. Furthermore, our work explores new avenues for multi-task learning in the brain, and thus has the potential to catalyze new experiments and analyses in the field of neuroscience.

8 Acknowledgments & Disclosure of funding

We would like to thank Daniel Kepple for his insightful comments on this work. This work was funded by the NSF Foundations grant (1926800) and the McDonnell UHC Scholar Award, as well as by the Canada CIFAR AI Chair, the NSERC Discovery Grant (RGPIN-2018-04821), Samsung Research Support and FRQS (Research Scholar Award Junior 1 LAJGU0401-253188).

References

Appendix A Appendix

a.1 Training details

All networks were trained on a single laptop with a 2.7 GHz Quad-Core Intel Core i7 processor and 16GB of working memory using a custom implementation of recurrent neural networks based on JAX jax2018github, which is available at https://github.com/dashirn/FAST_RNN_JAX. We used a recurrent neural network (RNN) model with the hyperbolic tangent as the activation function throughout (, Equation 1), in accordance with previous approaches to modeling neuroscience tasks marton_learning_2020, chaisangmongkon_computing_2017, mante_context-dependent_2013. Previous work suggests the dynamics obtained with different activation functions, such as or , and different network models such as RNNs, GRUs and LSTMs are similar maheswaranathan_universality_2019. All networks were trained by minimizing the mean squared error using Adam as an optimizer. We also compared our results with those obtained from the native implementation of dynamical non-interference duncker_organizing_2020, available at https://github.com/LDlabs/seqMultiTaskRNN under a MIT license, and obtained similar results.

We performed a gridsearch over various parameter ranges (Table 2), and settled on the parameter combination that worked best across all training paradigms. We used a network of units across all tasks and training paradigms. This ensured that individual modules trained well, and achieved the highest task performance across all training paradigms. The performance we obtained agrees with previously obtained results duncker_organizing_2020.

Parameter Parameter range considered
Network size 100 – 500
Learning rate 1e-4 – 3e-3
Batch size 20 – 400
Scaling factor 1.1 – 1.8
-norm weight regularization parameter 5e-6 – 5e-5
Activity regularization parameter 1e-7
Number iterations before decreasing learning rate 100 – 300
Table 2: Parameter settings

We set the learning rate at for all networks, the batch size at 200 trials, the scaling factor at , the -norm weight regularization parameter at 1e-5, the activity regularization parameter at 1e-7, and the number of iterations at 200. The weights in the recurrent weight matrix, (Equation 1

), were initially drawn from the standard normal distribution and scaled by a factor of

. The input weights, , were also initially drawn from the standard normal distribution and scaled by a factor of with

varying by task. Each input unit also received an independent white noise input,

, with zero mean and a standard deviation of . Output weights, , were also initially drawn from the standard normal distribution and multiplied by a factor of with varying by task.

a.1.1 Modular network with task primitives

The modular network was initialized and trained with the parameters described above. Each module, M1-3, was trained separately, with , , and . When a particular module was being trained, all other weights in the network remained frozen (they were not updated with information from the error gradient). For the perturbation studies, we considered two structural variants: a modular, fully disconnected network (Figure 1), and a modular, fully connected network (Figure S1). After the modules were trained, tasks T1-9 were trained using separate readout weights, .

a.2 Task performance

In order to measure task performance, a test set of 250 trials was randomly generated from a separate seed for each task. Performance was calculated as the percentage of correct decisions over this test set. Decisions, rather than other error metrics such as mean squared error, were chosen as a performance metric as they reflect the ultimate goal of biological agents. A decision was considered correct if the mean output over the response period did not deviate more than a fixed amount () from the target direction, following previous approaches duncker_organizing_2020. Performance was averaged across all output units.

Fig S1: Schematic of fully connected recurrent network with modular task primitives. Left: inputs for modules M1-M3. Middle: recurrent weight matrix of the modular network, with separate partitions for modules M1-M3 along the diagonal. Right: different tasks (T1-T9) trained into different output weights, leveraging the same modular network.

a.3 Perturbation studies

We considered 4 types of perturbations: lesioned weights, silenced activity, global Gaussian noise, and independent Gaussian input noise. We performed all these types of perturbations in 2 conditions, (1) whole network (Figures 5 and S2) and (2) modular (Figures 6 and S3). In the whole network condition (1), perturbations were performed across all connections or units in the network: e.g. connections to be lesioned were picked from among the entire recurrent weight matrix, units to be silenced from among all units, and global Gaussian noise was added to the whole network. In the modular condition (2), perturbations were performed locally, restricted to a module. For the modular network with task primitives, perturbations were performed within module 2 (M2) in this condition, while perturbations in the networks trained sequentially and with dynamical non-interference were averaged across 5 modules of the same size picked randomly from the entire connectivity matrix. All perturbations were performed for 2 networks initialized with different random seeds and 25 draws with different random seeds.

a.4 Supplementary results

Fig S2: Performance after perturbations. Performance (% decisions correct on a 250 trial test set) for tasks T1-T4 (MemDm1, MemDm2, ContextMemDm1, ContextMemDm2, MultiMem) as a function of % weights lesioned (A, E, I, M, Q), % units silenced (B, F, J, N, R), global Gaussian noise as a function of standard deviation () (C, G, K, O, S), and uncorrelated Gaussian noise added to all inputs (D, H, L, P, T). Error bars show standard deviation across repeating perturbations for 25 draws with different random seeds for 2 networks initialized with different random seeds.
Fig S3: Performance after modular perturbations (see A for details). Performance (% decisions correct on a 250 trial test set) for tasks T1-T4 (MemDm1, MemDm2, ContextMemDm1, ContextMemDm2, MultiMem) as a function of % weights lesioned (A, D, G, J, M), % units silenced (B, E, H, K, N), and global Gaussian noise as a function of standard deviation () (C, F, I, L, O). Error bars show standard deviation across repeating perturbations for 25 draws with different random seeds for 2 networks initialized with different random seeds.