Modern deep neural networks have become highly capable in select applications of artificial intelligence. However, despite their effectiveness, their energy consumption is a limiting factor for application in many always-on application scenarios, like wearable intelligence devices, surveillance camera’s and smartwatch applications[Roy2019-rv]. Standard efficiency approaches reduce the bit-precision of weights [courbariaux2015binaryconnect, gong2019differentiable] and activations [rastegari2016xnor, darabi2018bnn+], or scale and prune models [tan2019efficientnet, yang2017designing]. These methods however still adhere to the standard model of computation in artificial neural networks (ANNs), where activations are exchanged between neurons in a synchronous frame-based manner at every iterative processing step.
Taking inspiration from the extremely efficient brain, spiking neural networks (SNNs) [maass1997networks] combine binary valued activations (spikes) with asynchronous and sparse communication. Such SNNs are arguably also more hardware friendly [davies2018loihi] and energy-efficient [panda2019towards]
. However, compared to ANNs, the development of SNNs is in its early phase. Training deep SNNs has remained a challenge as the spiking neurons’ activation function is typically non-differentiable and SNNs are thus not amenable to standard methods of error-backpropagation[bohte2011error, neftci2019surrogate]. In particular, as spiking neurons can be modeled as a class of self-recurrent neurons, learning algorithms have to account for the past, and the resultant statefulness of SNNs makes it hard to deal with simulating and training very large networks. Finally, many deep learning benchmarks are geared towards the synchronized and iterative processing paradigm of ANNs, exemplified by image classification tasks.
Recent work [neftci2019surrogate, bohte2000spikeprop] has demonstrated how the problem of a discontinuous gradient in spiking neural networks can be overcome effectively in a generic fashion through the use of surrogate gradients. This opens up new opportunities to leverage mature deep learning techniques in larger and more complex SNNs.
Here, we develop compact recurrent networks of spiking neurons (SRNNs) which we train using such surrogate gradients to directly apply back-propagation-through-time (BPTT) with auto-differentiation in a well-developed modern deep learning framework (PyTorch). Using this framework, we can also easily train the parameters of the spiking neurons themselves, also for complex spiking neuron models with multiple dynamical timescales. As we show, this approach makes it feasible to adapt such spiking neural networks to the particular temporal dynamics of the task.
We focus on sequential and streaming classification benchmarks with limited input dimensionality, including the well known sequential and permuted-sequential MNIST datasets (S-MNIST, PS-MNIST), the QTDB waveform classification of ECG, and the audio Spiking Heidelberg Digits dataset (SHD), tasks exemplary for various always-on edge computing devices that require low-power consumption. We demonstrate how our compact SRNNs can solve these complex tasks, exceeding SoTa in SNNs, and approaching or even exceeding SoTa compared to classical ANNs. On these tasks, the SRNNs demonstrate low to very low sparsity, and we show that this results in an 100x improvement in theoretical energy use over high-performing ANNs.
2 Related work
In standard deep learning, convolutional neural networks (CNNs) are widely used on visual tasks such as image classification and object recognition, while recurrent neural network (RNNs) are more generally applied to tasks that involve temporal patterns fed into the network as sequential input. In RNNs, recurrency in the network induces memory in the form of internal hidden stateswhich are updated while time-stepped input feeds in. For learning, because of the induced memory, RNNs are typically unrolled in time, for example using Backpropagation-Through-Time (BPTT) [werbos1990backpropagation, mozer1995focused] to account for the relationship between past inputs and current state. BPTT however is both computationally expensive and tends to suffer from stability problems when computing the gradients.
Several alternative RNN variants have been developed to ease and improve learning in standard RNNs. The LSTM (Long Short-Term Memory) unit[hochreiter1997long]
was designed specifically as an RNN for sequential machine learning tasks like speech recognition, language understanding and sequence to sequence translation[graves2013generating]. More recent innovations, like the IndRNN [li2019deep]
, borrow from the success of residual connections in CNNs to facilitate the gradient flow in the network and achieve state-of-the-art RNN performance. Alternatively, causal convolutional neural networks[oord2016wavenet] have also been applied successfully to sequential tasks [benidis2020neural] but have substantial network-size and data-history memory requirements.
Spiking neural networks [maass2001pulsed] comprise of a class of event-based neural networks inspired by more detailed models of biological neurons. Biological neurons differ from the standard neurons in ANNs in the sense that they have internal state and communicate via isomorphic electrical pulses - spikes. The low average firing rate in the brain [1-5Hz] [attwell2001energy] suggests that much effective and efficient computing can be done with stateful event-driven neurons that only sparingly exchange binary values [panda2019towards].
For SNNs, learning rules for both feedforward and recurrent spiking neural networks have been developed [bohte2011error, bohte2000spikeprop, bellec2018long, zenke2018superspike, Shrestha2018, huh2018gradient, ponulak2008analysis], applying different types of spike-coding paradigms and learning methodologies. Recent work has achieved high performance in tasks like image classification [zambrano2016fast, tavanaei2019deep]; still, it is unclear whether such SNNs are more efficient compared to conventional CNNs.
One direction where potentially a clear advantage for SNNs can be obtained is tasks that fit their inherently temporal mode of computation and that can be computed with relatively compact networks fitting low-power neuromorphic hardware. Recent work has shown substantial progress in LIDAR [Wang2020-uq] and speech recognition [bellec2018long, cramer2019heidelberg]. Still, in these tasks, a significant performance gap exists between SNNs and current deep learning solutions.
3 Spiking Recurrent Neural Networks
Here, we focus on SNNs that comprise of one or more recurrent layers, Spiking Recurrent Neural Networks (SRNNs), illustrated in Fig. 0(a). Within these networks, we use one of two types of spiking neurons: Leaky-Integrate-and-Fire (LIF) neurons and Adaptive spiking neurons. Spiking neurons are derived from models that capture the behavior of biological neurons [Gerstner2002-wd]. Complex models like the Hodgkin-Huxley model describe the detailed dynamics of biophysical quantities but are costly to compute; phenomenological models like the Leaky-Integrate-and-Fire (LIF) neuron model or the Spike Response Model trade-off levels of biological realism for interpretability and reduced computational cost.
The LIF spiking neuron integrates input current in a leaky fashion and fires an action potential when its membrane potential crosses a fixed threshold from below, at which time a spike is emitted, a process modeled as a nonlinear function , after which the membrane potential is reset to :
where is the input signal expressed as a spike-train , is the membrane potential decaying exponentially with time-constant , is the membrane resistance, and the emission of spike is expressed as a nonlinear function of threshold and potential: . The LIF neuron is cheap to compute [izhikevich2003simple], but lacks much of the more complex behavior of real neurons, including responses that exhibit longer history dependencies.
Bellec et al [bellec2018long] demonstrate that using more complex, adapting spiking neuron models improved performance in their SNNs. In the adaptive spiking neuron, the threshold is increased after each emitted spike and then decays exponentially with time-constant . Simulating the continuous neuron model in discrete time using the forward-Euler first-order exponential integrator method for , we compute:
where is a dynamical threshold comprised of a fixed minimal threshold and an adaptive contribution ; expresses the decay single-timestep decay of the threshold with time-constant . The parameter is a constant that controls the size of adaptation of the threshold (eq. (8)); we set to 1.8 for adaptive neurons as default. Similarly, expresses the single time-step decay of the membrane potential with time-constant .
Fig. 1 illustrates the behaviour of the two spiking neuron models in terms of evolution of the membrane potential, threshold and spiking behavior. Inspecting these neuron models, we see that the evolution of the membrane potential is determined by the self-decay term and similarly, for the adapting neuron, the threshold by the self-decay term : effectively, the behavior of spiking neurons can be modeled as being self-recurrent with weights and (Fig. 1). In our network implementation, we set these self-recurrent parameters, , , as trainable, as they directly relate to the effective duration of “memory” in the neuron, and we hypothesise that optimizing these to characteristics will increase performance.
To determine the effectiveness of the SNN approach, we can also turn the SNN network into a corresponding RNN network with RELU activation function by communicating to other neurons the membrane potential at every timestep rather than the occasional spike, replacing (8) with:
To train an SRNN, we apply Backpropagation-Through-Time (BPTT) [werbos1990backpropagation, mozer1995focused]. With BPTT, the difference between the output predictions and output targets is propagated back from outputs to inputs, including past inputs, to optimize the weights and parameters by gradient descent. Conceptually, BPTT unrolls the network for all input timesteps.
To compute the gradient through the discontinuous spiking mechanism, we apply the surrogate gradient approach as generalized in [neftci2019surrogate] from earlier specific instances [bohte2011error, bohte2000spikeprop]. To approximate the error-gradient through the discontinuous spike-generator of spiking neurons, the surrogate gradient approach substitutes this non-existing gradient with a derivative connecting the outgoing spike to the internal membrane potential. Multiple derivatives have been proposed [neftci2019surrogate]; here, we use a Gaussian: where the mean of the distribution is
– the threshold, with standard deviation
– the tolerance or variance – used to scale the membrane potential for error-backpropagation. Unless stated otherwise, we setto .
To define the loss-function that BPTT minimizes, we need to take into account the kind of task that is being performed and the kind of label that we have: for a sequential classification task, we receive a sequence of inputs and can make a decision at the end of the sequence. In a streaming task, we need to generate an output at every time step. The loss-function is further defined by the error-metric, for which we can define a number of different approaches for interpreting the behavior of the output neurons.
Decoding the output of an SNN directly relates to the interpretation of the spiking neuron’s behavior. Both the membrane potential trace and spike history, either in spike-counts and/or spike-timing, of the output neurons can be used to represent the belief of each class. We define a number of such output decoding methods and their associated loss-function.
Spike-based classification. When a neuron emits a spike, this is caused by the membrane potential (the hidden state of the neuron) reaching threshold. In a classification task, a simple classification method is to count the number of spikes in a certain time-window. While straightforward, this method has some shortcomings, that may result in misclassifications: (1) some output neurons may fire the same number of spikes; (2) the reset and refectory mechanism of the spiking neuron may reduce the firing rate of a strong stimulus; and (3) real-time readout from single neurons is not feasible. As an alternative, we can use direct measures of the output neurons available at each time-step.
Direct measures. The membrane potential of an output neuron can also be used for classification, as it represents a measure of output neuron’s stimulus history. We define several methods to decode the results from the membrane potential history:
Last-timestep membrane potential: we take the value of the output membrane potential at the last time step of a sample as the output. Using a softmax activation, we then scale the outputs similar to the softmax function in ANNs.
Max-over-time membrane potential as: we take as the value of the output neurons the maximum membrane potential reached during presentation of the sample.
The readout integrator: while the membrane potential can be interpreted as a moving average of a neuron’s activation, the resets caused by spiking do not fit in this notion. We define a non-spiking readout layer where the membrane potential is computed without the neuron spiking and resetting. This avoids the effect of reset mechanism of the spike neuron on classification performance. The readout integrator is defined as where is the output membrane potential and is the input spike train. where is a trainable time constant. We use the average value over time for the non-spiking readout neurons.
Variants of all three approaches were previously used in [cramer2019heidelberg]; for streaming tasks, classifications are needed for every time-step, and when using only single neurons to represent outputs, we can only use the direct measurements.
To train the network, as in [cramer2019heidelberg], we use the cross-entropy function as the error function. In a streaming task, the readout membrane potentials are used as output and compared to the corresponding targets at each timestep. In the classification tasks, the output after reading the whole sequence is compared with the correct label of the sequence. Note that in [cramer2019heidelberg], cross-entropy is computed either for the max-over-time and last-time-step decoding.
We implemented the various SRNNs in PyTorch, where the use of surrogate gradients allowed us to apply BPTT to minimize the loss efficiently, and also to leverage standard deep learning optimizations, including the training of spiking neuron parameters.
We apply the approach outlined above to a number of sequential classification and streaming classification tasks: waveform classification in ECG signals in the QTDB dataset, the sequential and permuted sequential MNIST problem (S-MNIST, PS-MNIST), and the Spoken Heidelberg Digits (SHD) dataset.
Encoding and decoding. SNNs as an event-driven neural network heavily rely on an encoding mechanism to convert external measures into spike-trains that feed into the network. Several approaches have been used to convert static or continuously changing values into spike-trains. In DVS-sensors [lichtsteiner2008128], a level-crossing scheme is used to encode a time-continuous signal into spikes; more generically, rate-based Poisson population encoding has been used [OConnor2013-un, ruckauer2019closing]. Emperically, we found that different decoding schemes were best for different tasks: for the ECG task, decoding directly at every timestep from the membrane potential worked best, for the S-MNIST and PS-MNIST, it was spike-counting and for the SHD task the average readout integrator was most effective.
The analysis of electrocardiograms (ECGs) is widely used for monitoring the functionality of the cardiovascular system. As a kind of time-series data, ECG signals can be used to detect and recognize different waveforms and morphologies in heart-function. For a human, the recognition task is time-consuming and relies heavily on experience.
Waveform classification. In an ECG signal, there are three meaningful parts of cardiac period including -wave, -wave and the -complex [hurst1998naming]. In detail, the part consists of a -wave, an -peak and an -wave. In a monitoring task, we aim to continuously recognize the present type of wave. The streaming task is thus to character-wise recognize all six patterns of the ECG wave – , , , , , and . An example of an ECG signal and the relative distribution of the 6 labels is shown in Fig. 0(b).
|ECG QTDB||Adaptive SRNN||46||84.4%||.32|
|LSNN (L2L) [bellec2018long]||80(I)+120+100||93.7%|
|LSNN (L2L+DeepR) [bellec2018long]||80(I)+120+100||96.4%|
|Dense IndRNN [li2019deep]||99.48%|
|Dense IndRNN [li2019deep]||97.2%|
|SHD||Adaptive SRNN (4ms)||128||79.4%||.071|
|Adaptive SRNN (4ms)||256||81.71%||.049|
|Adaptive SRNN (4ms)||128+128||84.4%||.103|
|LIF SRNN (4ms)||256||78.93%||.021|
|LIF RSNN [cramer2019heidelberg]||128*3||71.4%|
|RELU SRNN (4ms)||128+128||88.93%|
QTDB is the one of the most widely used ECG datasets for wave segmentation, where the data is well labeled in the temporal dimension. Each sample has two channels – ’a’ and ’b’, and this provide additional spatial information. The original data consist of float values for each timestep; to convert this signal into an event-based one, we applied level-crossing encoding on the ECG signal to convert the continuous values into spike-trains. Level-crossing is applied to the normalized ECG signal by converting each channel signal into two spike channels representing increasing and decreasing event respectively. A spike is generated when the amplitude increase or decrease is larger than a threshold – here, we use 0.3 as a threshold. The result is a compression of the ECG signal by about .
We apply several RNNs and SRNNs for ECG waveform classification, see Table 1. We find that the SRNNs with adaptive spiking neurons achieved the best performance of with the smallest size neural network of 46 neurons (36 hidden, 4 input and 4 output neurons). A same-sized SRNN comprised of LIF neurons achieved only . The LSTM and vanilla RNN with the same network structure achieved , and accuracy respectively; a birectional-LSTM with 290 units achieved . The best performance () was obtained by turning the adaptive SRNN into an ANN, the RELU SRNN.
The accuracy results are presented by evaluating the input that had been fed in at every time step, which is sampled at 250 Hz. No delay between input and output evaluation have been taken into account. In contrast to the spikes input generated by the level-crossing encoding that the SRNN receives as input, the LSTM and RNN networks receive floating point values at their inputs. These values represent the ADC sample values (12 bits precision).
S-MNIST and PS-MNIST
The MNIST dataset is the seminal computer vision classification task. The Sequential MNIST (S-MNIST) benchmark and Permuted MNIST (PS-MNIST) benchmark were introduced as corresponding problems for sequential data processing[le2015simple], presenting each pixel in an MNIST image pixel by pixel, resulting in a sequence of length .
For S-MNIST the state of the art accuracy is obtained with the Dense IndRNN [li2019deep]; the best reported performance of an LSTM is [arjovsky2016unitary]. For these RNNs, the analogue grey pixel value is directly fed as input into the network. With SNNs, Bellec et al. [bellec2018long] obtained on the S-MNIST task using eProp and population Poisson encoding to encode the grey pixel values, and when additionally using a learning-2-learn meta-learning framework. As can be seen in Table 1, our adaptive SRNNs using two recurrent layers obtained with the same encoding scheme; with the same network layout and size as [bellec2018long] the SRNN achieved . When using LIF neurons in various SRNN architectures, the networks failed to learn; turning the adaptive SRNN into an ANN again however increased performance to , approaching the Dense IndRNN accuracy.
PS-MNIST is a harder problem than S-MNIST, as first a permutation is applied to all images before sequentially reading the image pixel-by-pixel [le2015simple]. The permutation strongly distorts the temporal pattern in the input sequence, making the task more difficult than S-MNIST. The Dense IndRNN [li2019deep] here obtained accuracy; the LSTM [arjovsky2016unitary] achieved only . We are not aware of any SNNs benchmark data on this task; our adaptive SRNN achieved accuracy on the test dataset; the LIF SRNN again failed to learn while the RELU SRNN obtained .
The Spoken Heidelberg Digits spiking dataset was developed specifically for benchmarking spiking neural networks [cramer2019heidelberg]. It was created based on the Heidelberg Digits (HD) audio dataset which comprises of 20 classes of spoken digits from to in English and German, spoken by 12 individuals. For training and evaluation, the dataset (10420 samples) is split into a training set (8156 samples) and test set (2264 samples). An LSTM with 128 units achieved [cramer2019heidelberg], where the continuous time stream is binned into 10ms segments and the spike-count in each bin was used as input for the LSTM. For comparison, treating each sample as an image to train a deep CNN with over 1 million neurons achieves [cramer2019heidelberg]. Using a three layer spiking recurrent network comprised of LIF neurons with 128 neurons in each layer, [cramer2019heidelberg] obtained .
To apply our SRNNs, we converted all audio samples into 250-by-700 binary matrices. For this, all samples were fit within a 1 second window; shorter samples were padded with zeros and longer samples were cut by removing the tail (the latter applied to only 20 samples, with the longest sampling being 1.17s; visual inspection showed no significant data in the tail – an example is shown in Fig 1(a)). Spikes were then binned in time bins both of size 10ms and 4ms; for the SRNNs, the presence or non-presence of any spikes in the time-bin is noted as a single binary event; for the LSTMs, the spike-count in a bin is used as the (binary) input value. During training, a subset of the original training dataset was used for validation, with Adam [kingma2014adam] as the default optimizer. The initial learning rate is set to 0.01 with a
decay at epoch 10, 50, 120 and 200. For the non-spiking RELU SRNN, 50 training epochs sufficed. We trained SRNNs both with a single recurrent layer and with two recurrent layers, with LIF or adaptive spiking neurons. The membrane potential of each SRNN neuron was set to a random number between 0 and 1 at the start of each sample.
For an adaptive SRNN with two layers of 128 adaptive spiking neurons, trained on the 4ms binned data, we obtained accuracy, approaching the 128 unit LSTM in [cramer2019heidelberg]. An adaptive SRNN with a single layer of 256 spiking neurons achieved , demonstrating the utility of having multiple recurrent layers in an SRNN. Similarly, a LIF SRNN with a single recurrent layer of 256 neurons achieves only . The non-spiking RELU SRNN substantially outperforms the spiking SRNN, obtaining an accuracy of .
In Table 1, we also note the sparsity (Fr) of the trained SRNNs, where sparsity is defined as the percentage of active neurons at each step. We find that for the ECG task, neurons fire on average once every 3 timesteps (FR=0.32), where this relatively low sparsity is likely caused by the need to read out class labels at every timestep. For the MNIST tasks, sparsity is much higher, varying between 0.07 and 0.1, and for the SHD task, sparsity varies between 0.02 and 0.13, mostly as a function of accuracy.
For SHD, we investigated the relationship between network performance and sparsity for different size networks in more detail. In Fig 1(b), we see that the network performance can be increased and sparsity improved at the same time by increasing the size of the network. We also see that the performance advantage of adaptive neurons compared to LIF neuron comes at the expense of sparsity.
Complexity of Spiking Neurons
In general, we find that using adaptive spiking neurons with time-constants adjusted during training substantially outperform LIF neurons, as illustrated in Fig 3(a). As can be seen in Fig 3(b), training these time-constants substantially improved performance, illustrated for the SHD dataset, as successive ablation of training these parameters reduces performance. As illustrated in Fig 1(c), training also substantially modifies these parameters: shown is both the initial histogram of in the SHD task and the histogram after training.
A hotly debated topic is whether or not SNNs can achieve a meaningful power-reduction compared to ANNs [panda2019towards]. Here we derive theoretical energy values based on power numbers at the register transfer logic (RTL) level for 45nm CMOS technology from [panda2019towards].
We calculate the theoretical energy consumption of a recurrent network by counting the required operations per timestep. We count both multiply-and-accumulate operations (MACs) and accumulate (AC) operations. [panda2019towards]. A standard artificial neuron requires a MAC for each input; in contrast, a spiking neuron only requires an accumulate (AC) for each input spike, while it’s internal state dynamics require some MACs.
In a network, we thus need to consider the number fan-in connection into a neuron, the number of neurons in a layer, and the cost internal calculations. For example, consider a recurrent operation at layer that is defined as with input size and output size : this requires two multiply operations and one accumulate operation. The energy required for the RNN is then computed as , for every timestep. In the SNN however, the sparse spiking activity of the network (the average Firing Rate ) needs to be considered: with in SRNNs with sparse activity.
We compute the theoretical energy cost of a recurrent network as the sum over all layers and all time steps, : we computed the MACs/ACs and energy use for various recurrent networks in Table 2. For the network architectures used in this study, we then calculate the actual relative energy cost in Table 3, and we plot the Accuracy vs Energy ratio for the various networks in Fig 4. In Fig 4, we see that our SRNN solutions lie on the Pareto front of energy efficient and effective networks, with the spiking adaptive SRNN achieving close to the RELU SRNN performance while theoretically being 28–243x more energy efficient. Compared to more classical RNNs on the more complex SHD and (P)S-MNIST tasks, we calculate the SRNNs to be 100x more energy efficient as in these larger networks both the fan-in factor and sparsity increases.
|Task||Method||Accuracy||Energy/s||Energy ratio||Error ratio||Efficiency|
|SHD 256||Adapt. SRNN||81.71%||3,249||1x||1x||1x|
We demonstrated how competitive recurrent spiking neural networks can be trained using backpropagation-through-time (BPTT) and surrogate gradients on classical and novel time-continuous tasks, achieving novel state-of-the-art for spiking neural networks and approaching or exceeding state-of-the-art for classical RNNs on these tasks. Calculating the theoretical energy cost, we find that our spiking SRNNs are up to 243X more efficient than the slightly better performing analog RELU SRNN, and up to 1900x times more efficient than similarly performing classical RNNs like LSTMs.
We showed that using more complex adaptive spiking neurons was key to achieving these results, in particular by also training the individual time-constants of these spiking neurons, also using BPTT. Having two time-constants, the adaptive spiking neuron effectively maintains a multiple-timescale memory. We hypothesise that this approach is so effective because it allows the memory in the network to be adapted to the temporal dynamics of the task. Surprisingly, converting the SRNN to a non-spiking RELU RNN consistently increased performance, suggesting that the nested hierarchical recurrent network architecture is particularly effective.
Training these complex SRNNs including the various parameters was only feasible because, by using surrogate gradients, we were able to use a mature and advanced deep learning framework (PyTorch) and benefit from the automated differentiation to also train spiking neuron parameters111code is available at https://github.com/byin-cwi/SRNN-ICONs2020. We believe this approach opens up new opportunities for improving and scaling SNNs further.