Training state-of-the-art Artificial Neural Network (ANN) models requires distributed computing on large mixed CPU-GPU clusters, typically over many days or weeks, at the expense of massive memory, time, and energy resources [DBLP:journals/corr/abs-1906-02243]
, and potentially of privacy violations. Alternative solutions for low-power machine learning on resource-constrained devices have been recently the focus of intense research, and this paper proposes and explores the convergence of two such lines of inquiry, namely Spiking Neural Networks (SNNs) for low-power on-device learning[loihi, Fouda2019SpikingNN, Nandakumar2019SupervisedLI, intro_snn], and Federated Learning (FL) [FL_paper, federated_learning] for collaborative inter-device training.
SNNs are biologically inspired neural networks in which neurons are dynamic elements processing and communicating via sparse spiking signals over time, rather than via real numbers, enabling the native processing of time-encoded data, e.g., from DVS cameras[dvs_camera]. They can be implemented on dedicated hardware [loihi], offering energy consumptions as low as a few picojoules per spike [intro_snn].
With FL, distributed devices can carry out collaborative learning without exchanging local data. This makes it possible to train more effective machine learning models by benefiting from data at multiple devices with limited privacy concerns. FL requires devices to periodically exchange information about their local model parameters through a parameter server [federated_learning].
This paper explores for the first time the implementation of collaborative FL for on-device SNNs. As shown in Fig. 1, we consider a number of mobile devices, each training a local SNN on the basis of local data, which can communicate through a base station (BS) to implement FL. The proposed FL solution enables learning and inference on time-encoded data, and enables a flexible trade-off between communication load and accuracy. It generalizes the single-SNN learning rule introduced in [intro_snn] for probabilistic SNNs, which in turn recovers as special cases many known training algorithms for SNNs such as Spike-Timing-Dependent Plasticity (STDP) and INST/FILT [gardner16INST].
Ii System Model
As illustrated in Fig. 1, we consider a distributed edge computing architecture in which mobile devices communicate through a BS in order to perform the collaborative training of local SNN models via FL. Each device holds a different local data set that contains a collection of data points . Each sample is given by a pair
of covariate vectorand desired output . The goal of FL is to train a common SNN-based model without direct exchange of the data from the local data sets. In the rest of this section, we first review conventional FL and then separate online training of individual SNNs [intro_snn].
Ii-a Conventional Federated Learning
the loss function for any examplewhen the model parameter vector is , the local empirical loss at a device is defined as
The global learning objective of conventional FL is to solve the problem
where is the global empirical loss over the collective data set . FL proceeds in an iterative fashion across iterations, with one communication update through the BS every iterations (see, e.g., [FL_paper]). To elaborate, at each iteration , each device
computes a local Stochastic Gradient Descent (SGD) update
where is the learning rate; represents the local value of the model parameter at the previous iteration for device ; and is a randomly selected example (with replacement) from local data set . If is not a multiple of , we set for all devices . Otherwise, when , the devices communicate the updated parameter (3) to the BS, which computes the centralized averaged parameter
The updated parameter (4) is sent back to all devices via multicast downlink transmission. This is used in the next iteration as initial value for all .
Ii-B Probabilistic SNN Model and Training
An SNN is a network of spiking neurons connected via an arbitrary directed graph, possibly with cycles (see Fig. 2). Each neuron receives the signals emitted by the subset of neurons connected to it through directed links, known as synapses. At any local algorithmic time , each neuron outputs a binary signal , with “” representing a spike. As seen in Fig. 2, neurons in the SNN can be partitioned into three disjoint subsets as , with being the subset of input neurons, the subset of output neurons, and the subset of hidden, or latent, neurons.
Following the probabilistic Generalized Linear Model (GLM), the spiking behavior of any neuron at local algorithmic time is determined by its membrane potential
: The instantaneous spiking probability of neuronat time , conditioned on the value of the membrane potential, is
being the sigmoid function[intro_snn].
The membrane potential depends on the previous spikes of pre-synaptic neurons and on the past spiking behavior of the neuron itself, according to the formula
where is a bias parameter; the contribution of each pre-synaptic neuron is through the filtered synaptic trace , and that of the neuron itself through its feedback trace , where denotes the convolution operator; and and are the synaptic and feedback impulse responses. Each impulse response is given by the linear combination with fixed basis functions and learnable synaptic feedforward weights , whereas the feedback filter is fixed with a learnable feedback weight .
During the (individual) training of an SNN, for each selected example , the covariate vector is first encoded into input binary time-sequences and the target into desired output binary time-sequences of duration samples [intro_snn]. At each local algorithmic time , the spike signal of each input neuron is assigned value ; while each output neurons is assigned value . The negative log-probability of obtaining the desired output sequence given input is computed by summing over all possible values of the latent spike signals as
where the vector is the model parameters of an SNN. When interpreted as a function of , the quantity (7) is known as negative log-likelihood or log-loss. When summed over the available data set, the log-loss can be optimized as a learning criterion via SGD [intro_snn].
Iii FL-SNN: FL with Distributed SNNs
In this section, we propose a novel online FL algorithm, termed FL-SNN, that jointly trains on-device SNNs. FL-SNN aims at minimizing the global loss (2), where the local loss function is given by the log-loss (7) as
To start, each device selects a sequence of examples from the local training set , uniformly at random (with replacement). Concatenating the examples yields time sequences of binary samples, with . We note that, between any two examples, one could include an interval of time, also known in neuroscience as inter-stimulus interval. FL-SNN tackles problem (2) via the minimization of the log-loss through online gradient descent [intro_snn] and periodic averaging (4).
The resulting FL-SNN algorithm is summarized in Algorithm 1. As illustrated in Table I, each global algorithmic iteration corresponds to local SNN time steps, which we define as the interval . Note that the total number of SNN local algorithmic time steps and the number of global algorithmic time steps during the training procedure are hence related as . Therefore, unlike the conventional FL, the number of examples is generally different from the number of local updates.
At each global algorithmic iteration , upon generation of spiking signals (5) by the hidden neurons, the local update rule of the SNN at each neuron in the SNN of device is given by the online gradient steps for loss [intro_snn]
which respectively correspond to the learning signal and eligibility trace, i.e., the running average of the gradients of the log-loss. The global update at the BS is then given by (4).
As summarized in Algorithm 1, FL-SNN is based on local and global feedback signals, rather than backpropagation. As in [intro_snn], the local learning signal , computed every , indicates to the hidden neurons within the SNN of each device how effective their current signaling is in maximizing the probability of the desired input-output behavior defined by the selected data . In contrast, the global feedback signal is given by the global averaged parameter (4), which aims at enabling cooperative training via FL.
We consider a handwritten digit classification task based on the MNIST-DVS dataset [mnist-dvs]. The latter was obtained by displaying slowly moving handwritten digits from the MNIST dataset on an LCD monitor and recording the output of a pixel DVS (Dynamic Vision Sensor) camera [mnist, dvs_camera]. The camera uses send-on-delta encoding, whereby for each pixel, positive () or negative () events are recorded if the pixel’s luminosity respectively increases or decreases by more than a given amount, and no event is recorded otherwise. Following [6933869, DBLP:journals/corr/HendersonGW15], images are cropped to pixels, which capture the active part of the image, and to seconds. Uniform downsampling over time is then carried out to obtain samples. The training dataset is composed of examples per class and the test dataset is composed of samples per class.
We consider devices which have access to disjoint subsets of the training dataset. In order to validate the advantages of FL, we assume that the first device has only samples from class ‘’ and the second only from class ‘’. As seen in Fig. 2, each device is equipped with an SNN with directed synaptic links existing from all neurons in the input layer, consisting of neurons, to all other neurons, consisting of hidden and output neurons. Hidden and output neurons are fully connected. We choose other network parameters as: synaptic basis functions, selected as raised cosine functions with a synaptic duration of time-steps [pillow]; and training parameters , , and . We train over randomly selected examples from the local data sets, which results in local time-steps.
As a baseline, we consider the test loss at convergence for the separate training of the two SNNs. In Fig. 3, we plot the local test loss normalized by the mentioned baseline as a function of the global algorithmic time for . A larger communication period is seen to impair the learning capabilities of the SNNs, yielding a larger final value of the loss. In fact, e.g, for , after a number of local iterations without communication, the individual devices are not able to make use of their data to improve performance.
As noted in recent literature [federated_learning, sparse_ternary, signsgd, Aji_2017], one of the major flaws of FL is the communication load incurred by the need to regularly transmit large model parameters. To partially explore this aspect, we now consider exchanging only a subset of synaptic weights during global iterations. Define the communication rate as the number of synaptic weights exchanged per global iteration, i.e., . We assume that, for a given rate and period , each device communicates only the weights with the largest gradients (10). The BS averages the weights sent by both devices; set the weights transmitted by one device only to the given transmitted value; and set unsent weights to zero. We note that the overhead of byte to communicate the position of the weights sent is small with respect to the overall communication load. In Fig. 4, we plot the final test accuracy as a function of for fixed rates and with . We observe that it is generally preferable to transmit a subset of weights more frequently in order to enhance cooperative training.
This paper introduced a novel protocol for communication and cooperative training of on-device Spiking Neural Networks (SNNs). We demonstrated significant advantages over separate training, and highlighted new opportunities to trade communication load and accuracy by exchanging subsets of synaptic weights.