Winner-Take-All as Basic Probabilistic Inference Unit of Neuronal Circuits

08/02/2018 ∙ by Zhaofei Yu, et al. ∙ University of Leicester Peking University 0

Experimental observations of neuroscience suggest that the brain is working a probabilistic way when computing information with uncertainty. This processing could be modeled as Bayesian inference. However, it remains unclear how Bayesian inference could be implemented at the level of neuronal circuits of the brain. In this study, we propose a novel general-purpose neural implementation of probabilistic inference based on a ubiquitous network of cortical microcircuits, termed winner-take-all (WTA) circuit. We show that each WTA circuit could encode the distribution of states defined on a variable. By connecting multiple WTA circuits together, the joint distribution can be represented for arbitrary probabilistic graphical models. Moreover, we prove that the neural dynamics of WTA circuit is able to implement one of the most powerful inference methods in probabilistic graphical models, mean-field inference. We show that the synaptic drive of each spiking neuron in the WTA circuit encodes the marginal probability of the variable in each state, and the firing probability (or firing rate) of each neuron is proportional to the marginal probability. Theoretical analysis and experimental results demonstrate that the WTA circuits can get comparable inference result as mean-field approximation. Taken together, our results suggest that the WTA circuit could be seen as the minimal inference unit of neuronal circuits.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Humans are able to process information in the face of uncertainty in the sensory, motor and cognitive domains [1]

. Confidence in any decision-making tasks does not come from the one-time judgment of all of the uncertainties. Instead, we estimate the uncertainty and inference the problem by cumulating our knowledge based on some rules of probabilistic inference. The processes like these can be understood as Bayesian inference. There is an increasing volume of behavioral and physiological evidence that human and monkeys (and other animals) can represent probabilities and implement probabilistic computations in a fashion of Bayesian inference with some types of neuronal and circuitry mechanisms in the brain

[2, 3, 4]. Despite this, it remains unclear how probabilistic inference is implemented by our neuronal system. For a perspective of theoretical consideration, the question is how to implement such a probabilistic inference with a network of spiking neurons.

A few models have been proposed to relate the dynamics of spiking neural networks to the inference equations of Belief Propagation (BP) algorithm, an inference algorithm commonly used in probabilistic graphical model, which is exact on acyclic graphical models and approximate on cyclic graphical models

[5, 6, 7, 8, 9, 10]. All these studies try to prove that the dynamics of spiking neural networks could implement basic computation of BP algorithm, thus the time-course of the dynamics of spiking neural networks can be understood as the inference process. A recent approach is implementing tree-based reparameterization algorithm [11]

, of which BP is a special case that only considers the reparameterization over just two eighboring nodes and their corresponding edge. In summary, most of the previous studies require each neuron and synapse conduct complicated computation, thus they are hard to be generalized to a general framework. In fact in the brain, a single neuron or a group of neurons should work in a relatively simple style while complex functions could be achieved when they are combined together. This could be achieved if there is a basic computation motif in the neuronal circuit, and then a combination of them can move towards to the complex functions. Therefore, it is worth considering what could be the basic inference motif in our neuronal system. If so, can the composition of these basic motifs implement inference for any Bayesian model with multiple layers and arbitrary scale?

Out of these possible motifs, there is one type of motif that has been intensively studied from the theoretical viewpoint in the last decades. It is named as winner-take-all (WTA) circuit that a microcircuit with an ensemble of excitatory cells with lateral inhibition as suggested by experimental observations [12, 13]

. Within this network motif, the competition between excitatory cells induced by the inhibition makes the WTA circuit suitable for implementation of many types of neuronal computations, such as feature selection, attention and decision making

[13, 14, 15].

However, the functional importance of a large scale of neuronal circuit with abundant WTA circuits remains unknown. Especially, from a computational perspective, it is unclear if probabilistic inference can be emerged with a combination of WTA circuits.

In this paper, we show that each WTA circuit can encode the state of a variable in a probabilistic graph model (PGM), the combination of which could represent the joint distribution defined on any PGM with synaptic weights and input current encoding the potential functions. Moreover, we prove that the neural dynamics of a network consisted of multiple WTA circuits is exactly equivalent to mean-field inference algorithm of probabilistic graphical models. We show the synaptic drive of each spiking neuron in the WTA circuit encodes marginal probability of the variable being in each state, and the firing probability (or firing rate) of each neuron is proportional to the marginal probability of the variable being in each state. Our results suggest that the WTA circuit can be seen as the minimal inference motif of neuronal system.

2 Inference by Mean-Field Approximation

In this section, we first briefly review probabilistic graphical model and variational inference algorithm with mean-field approximation, and then drive the differential equation which is equivalent to the mean-field inference algorithm.

2.1 Probabilistic Graphical Model

Probabilistic graphical model (PGM) provides a powerful formalism for multivariate statistical modeling by combining graph theory and probability theory

[16]

. It has been widely used in computer vision, signal processing, and computational neuroscience. In computational neuroscience, PGMs are used to model the inference process of the human brain. In this study, we only focus on the inference of undirected probabilistic graphical model, also known as Markov random fields (MRFs). The results can be easily generalized to directed probabilistic graphical models due to the fact that a directed probabilistic graphical model can be converted to an undirected probabilistic graphical model with moralization

[16, 17]. As shown in Fig. 1, the joint distribution over variables, defined on a MRF can be factorized into a product of potential functions according to the graph structure, that is,

(1)

where is the set of edges and is the set of nodes, and represent the potential functions of each edge and node respectively. is a normalized constant, which equals . If one defines that and , equation (1) can be reformulated as:

(2)
Figure 1: Illustration of MRF and WTA circuit. A variable of states in MRF can be presented by output excitatory neurons of WTA circuit. Competition mechanism of WTA is achieved by excitatory neurons (blue) and one inhibitory neuron (red). Each excitatory neuron receives the input current represented by afferent neurons, which encode the potential functions defined on each node of MRF.

2.2 Variational Inference with Mean-Field Approximation

When modeling the inference process of the human brain with graphical models, the inference problem includes two folds: (1) computing marginal distribution of each variable , that is , where represents all the variables in x expect variable

, (2) computing posterior probability

. In fact these two problems are naturally coupled as , thus we only need to consider marginal inference. As exact inference of MRF is an NP-complete problem, people often use efficient variational approximate inference algorithms, the idea of which is converting an inference problem to the optimization problem . Here the initial distribution is approximated by a distribution which belongs to a family of tractable distributions,

represents the Kullback-Leibler divergence. The mean-filed approximation is obtained by setting

to be a fully factorized distribution, that is, . By constraining and differentiating with respect to , one can get the following mean-field inference equations

(3)

where denotes the number of iterations, represents the information received by node (approximate marginal probability of variable ) in the th iteration. When all the messages converge to the fixed points, the marginal probability can be approximated by the steady-state . It is easy to prove that the following differential equation has the same fixed point as equation (3):

(4)

where is a time constant, and denote the number of all possible states of variables and respectively. Note that we have written as , as , and as for notational convenience.

3 Spiking Neural Network

3.1 Spiking Neuron Model

In this study, the spiking neuron is modeled by a standard stochastic variant of the spike response model [18], which is a generalization of the leaky integrate-and-fire neuron. Considering a network of spiking neurons , the output spike train of neuron is denoted by , which is defined as a sum of Dirac delta pulses positioned at the spike times , i.e., . It’s obvious to see if neuron fires at time and otherwise. In this model, the membrane potential of neuron at time is given by:

(5)

where denotes the set of pre-synaptic neurons for neuron , represents the input current from outside stimulus. denotes the synaptic weight between neuron and , describes the voltage response to a short current pulse. Here we use the standard exponential kernel as in [19]:

(6)

with the membrane time constant . In the standard stochastic variant of the spike response model, the strict firing threshold of membrane potential is replaced by a noisy threshold, which means that a neuron can fire stochastically at any membrane potential [18]. To be specific, neuron fires a spike at time with an instantaneous probability , which is often modeled by an exponential function of the membrane potential:

(7)

where decides the firing threshold and scales the firing rate of the neuron . One can find that the firing rate increases as the distance between membrane potential and firing threshold decreases. It also has been shown that this model is in good agreement with real neurons [20].

3.2 Winner-Take-All Circuit

Winner-take-all (WTA) circuit has been suggested as an ubiquitous motif of cortical microcircuits [13]. We consider a WTA circuit of output spiking neurons and an inhibitory neuron as in Fig. 1. The output spiking neurons mutually inhibit each other through the inhibitory neuron. Thus, all the neurons in the output layer are in competition against each other so that they cannot fire simultaneously.

In this study, we consider the WTA model used in [21, 22], where all neurons are allowed to fire with non-zero probability. Considering all the neurons in a WTA circuit are subject to the same lateral inhibition, the firing probability of neuron in the WTA circuit at time is determined by [21]:

(8)

where scales the firing rate of neurons. represents the divisive inhibition between the neurons in the WTA circuit, and is defined as . Thus, equation (8) can be rewritten as:

(9)

This WTA circuit works like a soft-max function. At each time, all neurons can fire with non-zero probability, but the neuron with the highest membrane potential has the highest firing probability.

4 Neural Implementation of Mean-Field Inference

Figure 2: MRF network represented by a combination of WTA circuits. (Left) MRF network with multiple variables, where and are connected. (Right) Corresponding neural network. The network consists of multiple WTA circuits, where each neuron encodes one state of a variable, and each WTA circuit encodes the distribution defined on the variable. When two nodes ( and ) are connected in the MRF, all the spiking output neurons in the corresponding WTA circuit are fully connected (connections between the th WTA circuit and the th circuit). Here the synaptic weights and input currents encode the potential functions defined on the edges and nodes respectively. The whole circuit is able to encode the joint distribution defined on the MRF.

In this section, we first show how a basic WTA circuit encode the distribution of a variable defined on a node and how a network consisted of WTA circuits represent the joint distribution defined on a MRF. Then we prove that there exists an exact equivalence between the mean-field inference equation of a MRF and the dynamic equation of a spiking neural network consisted of WTA circuits.

4.1 Representation of Distributions with WTA Circuits

In order to enable the combination of WTA circuits to implement arbitrary inference of MRFs, one needs to specify how the assignment of values to these variables defined on MRFs can be represented by the spiking activities of WTA circuits, where represents the number of states of variables respectively. In fact, each WTA circuit can represent the states of a variable in a MRF. To be specific, consider a WTA circuit with output spiking neurons , we say that a variable is represented by the firing activity of a WTA circuit at time if:

(10)

In this way, each neuron represents one of the possible values of variable . If the firing probability of each output spiking neuron equals the probability of the variable being in each state, that is, , then the firing activities of the WTA circuit encodes the distribution defined on the variable. One can read out the distribution by counting spikes from each neuron within a behaviorally relevant time window of a few hundred milliseconds, which is similar to the experimental results of monkey cortex[23, 24]. Similarly, a spiking neural networks consisted of WTA circuit can encode the joint distribution of over variables defined on a MRF.

4.2 Network Architecture

Here we illustrate the neural network architecture of WTA circuits to perform the inference of arbitrary MRF. Considering a MRF and its corresponding spiking neural network in Fig. 2. The neural network is composed of several WTA circuits, of which each WTA circuit represents (or encodes) a variable defined on each node. If there exists a connection between two nodes in the MRF, the output spiking neurons in the corresponding WTA circuits are fully connected. The connection weights are used to encode the potential functions defined on the edges between the adjacent nodes in MRF. Also, each neuron can receive the input current from output stimulus (not shown in Fig. 2), which encodes the potential functions defined on each node of MRF.

4.3 Spike-based Mean-field Inference in WTA circuits

In order to prove the combination of WTA circuits can implement inference for arbitrary MRFs, here we prove that there exists an equivalence between the dynamic equation of WTA circuits and the differential equation (4) of the MRF as in Fig. 2.

Considering the spiking neural network in Fig. 2, there are WTA circuits. The th WTA circuit consists of output spiking neurons, which encodes the variable . The th neuron in th WTA circuit is denoted as , which receives stimulus current (not shown in Fig. 2) and synaptic inputs from the neurons in the neighboring WTA circuits. We denote the output spike train of neuron by , which is defined as a sum of Dirac delta pulses positioned at the spike times , i.e., . We assume that , where is called synaptic drive [25], represents the firing probability of the th neuron in th WTA circuit at time , and has been defined in (8) to scale the firing rate of neurons. If we take the derivative of with respect to time , we obtain:

(11)

As neuron is in the th WTA circuit, the firing probability of neuron equals . According to equation (9), thus equation (11) can be rewritten as:

(12)

According to equation (5), the membrane potential of neuron at time equals

(13)

where represents all neighboring WTA circuits of the th WTA circuit, denotes the synaptic weight between neuron and . Note that here the output spike train of neuron at time is approximated by the firing probability function , which is also used in [25]

when driving the dynamic equation of recurrent neural networks. By substituting equation (

4.3) into equation (12), one can get:

(14)

Now one can find that a spiking neural network consisted of WTA circuits governed by (14) can implement equation (4) if the following equations hold.

(15)

It means if the synaptic weights and input current encodes the potential functions and respectively, then the synaptic drive of each spiking neuron in the WTA circuits equals marginal probability of the variable being in each state. Note that when equation (11) and (14) converges to the fix point, we have . Thus the firing probability (or firing rate) of each neuron is proportional to the marginal probability of the variable being in each state. Moreover, the time course of neural firing rate can implement marginal inference of MRFs. One can read out the inference result by counting spikes from each neuron within a behaviorally relevant time window of a few hundred milliseconds. If fact, the computation of WTA circuits in simulations can converge to the inference result very fast (see Supplementary Fig. 1).

The proposed theory is working in continuous time domain. It is easy to converge it into a version where the dynamics of WTA is discrete in time for numerical simulation. Based on the equivalence derived above, the changes of firing probability in the discrete time bins can be seen as one iteration of mean-field inference equation (3) (see Supplementary Materials).

5 Simulations

Figure 3: Mean-field inference simulated with WTA circuits. (Left) A chain MRF network with three nodes. (Middle) Spiking activity of 15 neurons in three WTA circuits, where each WTA circuit has 5 neurons. All neurons are firing through the competition of WTA mechanism. (Right) Tight match of three difference inference methods for marginal probability: belief propagation (brown), mean-field inference (light blue), WTA spiking neural network (red).

To validate our computational framework, we evaluate the performance of WTA circuits through two simulations. Firstly, we present the comparison of WTA circuits with mean-field algorithm and belief propagation algorithm on a chain MRF with 3 nodes (shown in Fig. 3). Note that belief propagation can conduct exact inference for this MRF. We suppose that each node has 5 states, the potential functions and

defined on each edge and node are created by randomly generating numbers from a uniform distribution on

. A spiking neural network composed of 3 WTA circuits is used to implement inference, of which each WTA circuit includes 5 neurons. The synaptic weights between neurons and input currents are set according to equation (15). In our simulation, the firing rate scaling factor is assumed to be 50Hz, thus we are able to map the firing rate Hz of each neuron to the probability of each state of a variable. The firing activity of 15 neurons is shown in Fig. 3, where one can find all the neurons can fire.

The performance of inference is shown as the histograms of the firing rate of 15 neurons in Fig. 3, where we compare the inference result of the MRF with different methods. As a result, the tight match between the three inference algorithms suggests that the WTA spiking neural network can perform mean-field variational inference with high accuracy. We also use the relative error to evaluate the divergence between the marginal probability () obtained by belief propagation and mean firing probability () of each neuron. The relative error decreases over time (Supplementary Fig. 1).

Figure 4: Inference performance matched with mean-field inference algorithm and belief propagation in different graph topologies. All graphs have 10 nodes with topology as chain (left), single loop (middle), and fully connected graph (right). The performance of WTA circuits (red) and mean-field approximation (blue) is comparable to belief propagation.

Then we investigate whether the inference framework can be scaled up to more complex MRFs. Here we infer the marginal probability of multiple MRFs with different graph topologies as chain, single loop, and fully connected graph. As in Figure 4 for the graphs with 10 nodes, one can find that the WTA circuits can get comparable results as belief propagation algorithm and mean-field approximation.

Note that for a fullyc connected graph, there is a tendency that mean-field approximation are moving to zero and one for marginal probability. Such a phenomena was observed previously, which shown the marginal probabilities obtained by mean-field approximation are overconfident than the one obtained by belief propagation [26]. However, our WTA inference is better in this sense. We further test this point with a full connected graph with 20 nodes (see Supplementary Fig. 2), indeed, our WTA inference is more close to the result of belief propagation than that of mean-field approximation.

6 Discussion

In this study, we prove that there exists an exact equivalence between the neural dynamics of a network consisted of WTA circuits and the mean-field inference of probabilistic graphical models. We show the WTA circuits are able to represent distribution and implement inference of arbitrary probabilistic graphical models. Our study suggests that the WTA circuit can be seen as the basic neural network motif for probabilistic inference at the level of neuronal circuits. This may offer a functional explanation for the existence of a large scale of WTA-like neuronal connectivities in cortical microcircuits.

Unlike many previous neural circuits proposed for probabilistic inference [5, 6, 7, 8, 9, 10], where each population of neurons has different network topology to implement the different and complex computation, our model is consist of a set of simple basic neural network motifs, and each motif works in a simple style. In such a way, our proposed neural implementation is plausible as most computations during the cognitive behaviors are very conserved in that different part of the brain and different modality of sensory processing seems to be shared for neural information computation [27, 28].

Difference approaches can be used to approximate Bayesian inference, among of which there are belief propagation and mean-field approximation working in a typical fashion. It was shown that the marginal probabilities obtained by mean-field approximation are overconfident than that obtained by belief propagation [26], which means the marginal probabilities of mean-field approximation are closer to zero and one than the truth marginal probabilities. This interesting observation suggests that our proposed model with WTA circuits is suitable for computation of probabilistic reasoning. In the end, the brain has to shift an attention, select an action, and make a decision in face of uncertainties [4, 15].

It remains unclear that how different setting-ups of network with more components included from neuroscience can affect the inference result under different methods. For instance, with a graph topology more similar to the neuronal network in some part of the human brain area, or some neuronal network from typical well studied animals [4, 29], the computation of WTA circuits proposed here could be explored. Perhaps, a more powerful utility of WTA circuits could be demonstrated for probabilistic reasoning and inference of the brain.

References

  • [1] F Meyniel, M Sigman, and ZF Mainen.

    Confidence as Bayesian probability: From neural origins to behavior.

    Neuron, 88(1):78–92, 2015.
  • [2] MO Ernst and MS Banks. Humans integrate visual and haptic information in a statistically optimal fashion. Nature, 415(6870):429–433, 2002.
  • [3] KP Körding and DM Wolpert. Bayesian integration in sensorimotor learning. Nature, 427(6971):244–247, 2004.
  • [4] A Pouget, JM Beck, WJ Ma, and PE Latham. Probabilistic brains: knowns and unknowns. Nat. Neurosci., 16(9):1170–8, 2013.
  • [5] RPN Rao. Bayesian computation in recurrent neural circuits. Neural Comput, 16(1):1–38, 2004.
  • [6] T Ott and R Stoop. The neurodynamics of belief propagation on binary markov random fields. In Advances in Neural Information Processing Systems, pages 1057–1064, 2007.
  • [7] A Steimer, W Maass, and R Douglas. Belief propagation in networks of spiking neurons. Neural Comput, 21(9):2502–2523, 2009.
  • [8] S Litvak and S Ullman. Cortical circuitry implementing graphical models. Neural Comput, 21(11):3010, 2009.
  • [9] D George and J Hawkins. Towards a mathematical theory of cortical micro-circuits. PLoS Comput. Biol., 5(10):e1000532, 2009.
  • [10] K Friston, T Fitzgerald, F Rigoli, P Schwartenbeck, and G Pezzulo. Active inference: A process theory. Neural Comput, 29(1):1–49, 2017.
  • [11] RV Raju and X Pitkow. Inference by reparameterization in neural population codes. In Advances in Neural Information Processing Systems, pages 2029–2037, 2016.
  • [12] M Okun and I Lampl. Instantaneous correlation of excitation and inhibition during ongoing and sensory-evoked activities. Nat. Neurosci., 11(5):535, 2008.
  • [13] RJ Douglas and KAC Martin. Neuronal circuits of the neocortex. Annu. Rev. Neurosci., 27:419–451, 2004.
  • [14] M Carandini and DJ Heeger. Normalization as a canonical neural computation. Nat. Rev. Neurosci., 13(1):51–62, 2012.
  • [15] L Itti and C Koch. Computational modelling of visual attention. Nat. Rev. Neurosci., 2(3):194, 2001.
  • [16] D Koller and N Friedman. Probabilistic Graphical Models: Principles and Techniques. MIT press, 2009.
  • [17] MI Jordan, Z Ghahramani, TS Jaakkola, and LK Saul. An introduction to variational methods for graphical models. Machine Learning, 37(2):183–233, 1999.
  • [18] W Gerstner, WM Kistler, R Naud, and L Paninski. Neuronal Dynamics: From Single Neurons to Networks and Models of Cognition. Cambridge University Press, 2014.
  • [19] N Frémaux, H Sprekeler, and W Gerstner. Functional requirements for reward-modulated spike-timing-dependent plasticity. J. Neurosci., 30(40):13326–13337, 2010.
  • [20] R Jolivet, A Rauch, HR Lüscher, and W Gerstner. Predicting spike timing of neocortical pyramidal neurons by simple threshold models. J. Comput. Neurosci, 21(1):35–49, 2006.
  • [21] B Nessler, M Pfeiffer, L Buesing, and W Maass. Bayesian computation emerges in generic cortical microcircuits through spike-timing-dependent plasticity. PLoS Comput. Biol., 9(4):e1003037, 2013.
  • [22] D Kappel, B Nessler, and W Maass.

    STDP installs in winner-take-all circuits an online approximation to hidden Markov model learning.

    PLoS Comput. Biol., 10(3):e1003511, 2014.
  • [23] T Yang and MN Shadlen. Probabilistic reasoning by neurons. Nature, 447(7148):1075–1080, 2007.
  • [24] JI Gold and MN Shadlen. The neural basis of decision making. Annual Review of Neuroscience, 30:535–574, 2007.
  • [25] P Dayan and LF Abbott. Theoretical Neuroscience. Cambridge, MA: MIT Press, 2001.
  • [26] Y Weiss. Comparing the mean field method and belief propagation for approximate inference in MRFs. Advanced Mean Field Methods: Theory and Practice, pages 229–240, 2001.
  • [27] D Shimaoka, KD Harris, and M Carandini. Effects of arousal on mouse sensory cortex depend on modality. Cell Reports, 22(12):3160–3167, 2018.
  • [28] AB Saleem, A Ayaz, KJ Jeffery, KD Harris, and M Carandini. Integration of visual motion and locomotion in mouse visual cortex. Nat. Neurosci., 16(12):1864, 2013.
  • [29] E Bullmore and O Sporns. Complex brain networks: graph theoretical analysis of structural and functional systems. Nat. Rev. Neurosci., 10(3):186–198, 2009.