Differentiable Generalised Predictive Coding

12/02/2021
by   André Ofner, et al.
0

This paper deals with differentiable dynamical models congruent with neural process theories that cast brain function as the hierarchical refinement of an internal generative model explaining observations. Our work extends existing implementations of gradient-based predictive coding with automatic differentiation and allows to integrate deep neural networks for non-linear state parameterization. Gradient-based predictive coding optimises inferred states and weights locally in for each layer by optimising precision-weighted prediction errors that propagate from stimuli towards latent states. Predictions flow backwards, from latent states towards lower layers. The model suggested here optimises hierarchical and dynamical predictions of latent states. Hierarchical predictions encode expected content and hierarchical structure. Dynamical predictions capture changes in the encoded content along with higher order derivatives. Hierarchical and dynamical predictions interact and address different aspects of the same latent states. We apply the model to various perception and planning tasks on sequential data and show their mutual dependence. In particular, we demonstrate how learning sampling distances in parallel address meaningful locations data sampled at discrete time steps. We discuss possibilities to relax the assumption of linear hierarchies in favor of more flexible graph structure with emergent properties. We compare the granular structure of the model with canonical microcircuits describing predictive coding in biological networks and review the connection to Markov Blankets as a tool to characterize modularity. A final section sketches out ideas for efficient perception and planning in nested spatio-temporal hierarchies.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

11/16/2021

PredProp: Bidirectional Stochastic Optimization with Precision Weighted Predictive Coding

We present PredProp, a method for bidirectional, parallel and local opti...
11/12/2021

Predictive coding, precision and natural gradients

There is an increasing convergence between biologically plausible comput...
02/20/2019

Meaningful representations emerge from Sparse Deep Predictive Coding

Convolutional Neural Networks (CNNs) are the state-of-the-art algorithms...
06/04/2021

Predify: Augmenting deep neural networks with brain-inspired predictive coding dynamics

Deep neural networks excel at image classification, but their performanc...
06/04/2022

Developing hierarchical anticipations via neural network-based event segmentation

Humans can make predictions on various time scales and hierarchical leve...
06/07/2020

Predictive Coding Approximates Backprop along Arbitrary Computation Graphs

Backpropagation of error (backprop) is a powerful algorithm for training...
06/18/2021

Dynamical Deep Generative Latent Modeling of 3D Skeletal Motion

In this paper, we propose a Bayesian switching dynamical model for segme...

1 Introduction

1.1 Generalized Filtering and Dynamical Generative Models

The hierarchical dynamical generative models (HDMs) discussed by Friston and Kiebel [friston2008hierarchical]

in the context of explaining brain function describe a powerful class of Bayesian generative models that are hierarchical in structure and dynamics and allow to perform Bayesian inference on complex observations. HDMs capture latent states, encoding estimated causes for observations and their associated motion in generalised coordinates, i.e. the derivatives of each encoded state’s trajectory. In the context of HDMS,

perception is referred to as the process of inverting a particular model and refining its parameters, in order to more accurately explain the observations. As we will see later in this article, planning and actions are closely intertwined with perception when applying HDMs to sensory data prediction.

Interestingly, HDMs generalize most established statistical models, e.g. (extended) Kalman filters can be interpreted as a particular first order HDM. More generally speaking, it is possible to characterize most existing Bayesian model optimisation schemes as a variant of

Generalized Filtering [friston2010generalised]

. Generalised filtering describes a generic approach to compute posterior densities for latent states and model parameters (i.e. addressing inference and learning) based on a gradient descent on the variational Free Energy (also known as the Evidence Lower bound or ELBO in the machine learning domain).

Without any additional assumptions about the underlying probability distributions, the variational Free Energy can be expressed as:

(1)

with data , latent state and a generative model that generates observed data . In variational inference, the true psosterior of latent state x is inferred indirectly using an approximate posterior , with known parameters . To do so, fitting the approximate posterior to the true posterior is achieved by minimizing the divergence between the true and approximate posterior with respect to the parameters

(2)

Existing models that implement Generalized Gradient Descent on the Free Energy take different forms and with that, different degrees of biological plausibility - at different levels of analysis. Important variants implementing generalized filtering are variational filtering [friston2008variational]

, dynamic expectation maximization

[friston2008variational] and generalized predictive coding [friston2009predictive]. Of these, generalized (Bayesian) predictive coding is a particularly interesting candidate, since it allows relatively straightforward mapping to neural mechanism.

We will focus on generalized predictive coding in this article, but keep the similarities to related, more generic filtering schemes in mind. Finally, HDMs typically rest on functions operating in continuous time, which requires additional mechanisms for updates with respect to discrete time intervals, such as they are present in many machine learning problems of the kind that we want to look at here.

Dynamical generative models describe the likelihood of observing data given causes , and priors on causes . With non-linear functions and , parameterized and , these models generate outputs (responses) characterised by the first derivatives of their trajectory, the generalized coordinates of motion:

(3)

with stochastic observation noise z. Similarly, the generalized coordinates of the motion of cause states are

(4)

with stochastic transition noise w. Here, the cause states split into input states at time and hidden states that are part of a transition function that projects input states to output states. From a machine learning perspective, we can directly associate

with the hidden states or memory in Recurrent Neural Networks (RNNs) and

with the corresponding transition function in RNNs. Differently to the single-order transition function in classical RNNs, the function couples multiple orders of dynamics.

Next to a hierarchy of the cause state dynamics, the models described in [friston2009predictive] model a hierarchical structure of the cause states themselves:

(5)

where input and hidden states and at layer link states and dynamics between layers . In such hierarchical models, the outputs of a layer are the inputs to the next lower layer and the fluctuations and influence the fluctuations of states at the next higher level.

1.2 Inverting hierarchical dynamical models with Generalized Predictive Coding

The described general structure of hierarchical dynamical models does not yet entail the specific process that actually learns parameters and states inside the model from observations. This process, called model inversion, can be described at different levels of description, each of which leads to different assumptions at the implementations level. A general approach for model inversion is based on variational Bayesian inference, where the conditional density for causes , model and data is approximated using a recognition density with respect to a lower bound on the evidence of the model. More detailed descriptions of hierarchical model inversion can be found, for example, in [friston2008hierarchical]

. When this Bayesian inversion is done with respect to (explicitly represented) prediction errors, this process can be implemented with predictive coding. (Bayesian) predictive coding is a algorithmic motif that is plausibly represented in biological neurons and thus provides an interesting candidate for improving machine learning systems.

Generalized predictive coding can be described as a gradient descent on precision weighted prediction errors [friston2008hierarchical, bastos2012canonical]:

(6)

where is the current layer and

is the inverse variance (the precision) of the prediction errors

in generalized coordinates. The precision weighted prediction errors and are generated with respect to the difference between current expectations about hidden causes and states and and their predicted values. Predicted values of causes and states are computed with non-linear feedback functions and . Crucially, the dynamics of the hidden states in a particular layer are computed with respect to the states in the same layer. In contrast, cause states are updated with respect to predictions that map from states of the next higher layer. Effectively, this means that hidden states model the dynamics within layers, while cause states link different layers. 111This restriction to roles in intrinsic and extrinsic connectivity for cause and hidden states comes from assumptions about the underlying graphical model, where the dynamics of different nodes can strictly be separated. In biological and artificial implementations, these boundaries might be much looser and under constant change. As we’ll see in the discussion of this article, the interpretation of neural (cortical) function in terms of canonical predictive coding microcircuits might be substantially less straightforward than it appears from the math [bastos2012canonical].

Existing models that implement a particular form of generalized filtering need to make specific assumptions, particularly about the nature of modelled functions and the computation of their generalized coordinates. A popular approach is to resort to numerical differentiation with the finite difference method. Numerical differentiation, however, often struggles with rounding and discretisation problems and the exact computation of higher order derivatives in general. Automatic differentiation, which drives a substantial amount of state-of-the art machine learning solves these problems by operating on functions with known exact gradients. It is this technique that we want to adapt here in the context of HDMs.

While several implementations of generalized filtering variants exist, they do not quite yet scale to complex applications such as perception of complex data, or reinforcement learning problems such as dealt with in the deep learning domain. Here, we want to focus on implementing hierarchical dynamical models with update rules described by gradient based predictive coding. Specifically, we look at Stochastic Gradient Descent (SGD) on prediction errors with exact gradients in differentiable deep neural network (DNN) models.

1.3 Differentiable predictive coding and deep neural networks

Much effort has recently been spent on developing scaled up versions of gradient based predictive coding networks and comparing their (relatively) biologically plausible updates with the exact gradient computation involved in the backpropagation algorithm that drives a substantial amount of current state-of-the art machine learning

[millidge2021predictivereview, rosenbaum2021relationship, bogacz2017tutorial, ofner2021predprop, rumelhart1986learning]. In particular, it has been shown that, under specific assumptions 222These assumptions include fixing predictions during weights updates or the inversion of data and target inputs to the model - so called ”discriminative” predictive coding. the weight updates in predictive coding networks use approximations of the exact gradients in backpropagation [millidge2020predictive, rosenbaum2021relationship, whittington2017approximation]. Similar experiments for the gradients in HDMs have been, so far, not been conducted. Related efforts have resulted in optimization strategies that allow to turn (deep) neural networks into gradient based predictive coding networks [millidge2020predictive, rosenbaum2021relationship]. PredProp, a DNN optimiser proposed in [ofner2021predprop] jointly optimizes each layer’s learnable parameters in parallel using exact gradients. This approach removes the need to wait for individual variables to converge (such as the fixed prediction assumption when training weight parameters) and allows for constant interaction of the estimated precision errors between layers. Another aspect of PredProp is the possibility to include complex DNN architectures within the backward prediction weights of a single predictive layer, while layer-wise predictive coding 333Which is arguably more biologically plausible due to the significantly reduced computational complexity in the connections between state variables. is still an option. The model presented here is an extension to the PredProp optimiser that focuses on adding a dynamical pathway in models with simple feedback and transition weights. Outside of gradient based predictive coding, there are various models that include predictive coding inspired mechanisms into state-of-the-art deep neural architectures, although many times with substantial differences to the mechanisms described in the predictive coding theory [lotter2016deep, millidge2021predictivereview, rane2020prednet]

. Detailed reviews on predictive coding variants and their connection to established methods in machine learning, such as variational autoencoders and normalizing flows can be found in

[millidge2021predictivereview, marino2020predictive].

2 Inferring content and dynamics with generalised predictive coding

We introduce an optimisation method based on Stochastic Gradient Descent, called GPC, that implements Generalised Predictive Coding with exact gradients along the hierarchical and dynamical dimension, while still offering the possibility to estimate the true dynamics underlying raw sensory data. GPC augments an existing predictive coding based optimisation method called PredProp [ofner2021predprop] and employs exact gradients to optimize inferred states, weights and prediction error precision in a layer-wise fashion and in parallel (i.e. without having to wait for other variables to converge) for arbitrary neural network architectures. Figure 1 summarizes the main components of the proposed optimisation scheme in a predictive coding model with two layers.

a) b)

Figure 1: a) Hierarchical and dynamical predictions in Differentiable GPC. Blue and green colors indicate dynamical and hierarchical pathways respectively. Dotted arrows indicate precision-weighted prediction errors that drive local updates of learnable parameters. Here, cause states and hidden states and are summarized as states in layer l. b) Different levels of causes can be decoupled from their dynamics, when higher layers predict inputs (green), but do not predict their dynamics (blue). For example, the motion () of birds () flying might be associated with the relatively static concept of nature () as a cause.

For the sake of completeness, we summarize the core functions of PredProp here. For a more detailed description of PredProp, the reader is referred to [ofner2021predprop]. PredProp implements stochastic gradient descent on the negative Free Energy. Prediction errors at each layer, weighted by their precision (inverse covariance) are minimized for all layers individually.

More precisely, under the Laplace assumption, the variational Free Energy simplifies to a Free Energy term based on prediction errors, which can be noted as:

(7)

with prediction errors between an observed activity and predicted mean activity at layer :

(8)

More details on the variational Free Energy and the Laplace assumption in the context of predictive coding are discussed in [friston2006free, millidge2021predictivereview, bogacz2017tutorial].

Effectively, a cost function is optimized for each layer in parallel. As predictive coding are linked via shared cause states, the summed cost function of the entire network is optimized simultaneously. PredProp optimises networks bidirectionally, so propagated values constantly iterate between inferring the most likely cause (i.e. target) of observed data and the most likely data given the currently estimated cause. Each predictive coding layer has at least one or multiple weights parameterizing the backwards connection

, followed by a (non-)linear activation function

. Theoretically, these multi-layer backward connections can incorporate entire DNN structures, since exact gradients can be computed with automatic differentiation.

In a single-layer predictive coding network, the observed activity simply is the input data. Layers in multi-layer predictive coding networks predict the activity of the next lower layer. The backward predictions of the next lower layer are computed using a (non-)linear activation function and learnable weights . The resulting update at time for the backward weights at layer is

(9)

Next to updating the weights (i.e. parameter learning), PredProp performs a simultaneous Gradient Descent on activity (i.e. state inference) and precision in each layer directly. The corresponding update rules for activity and precision at layer are:

(10)

As discussed in more detail in [millidge2021predictivereview, ofner2021predictive], the estimated precision of the prediction error plays an important role in predictive coding, since it estimates second-order changes in the partial objective functions and weights parameter updates accordingly.

2.1 From static to dynamic predictive coding models

In contrast to PredProp, which so far has been applied to static perception tasks, GPC addresses changes in observed inputs. Like many deep neural architectures, GPC operates on discrete time steps. However, the transitions inside the model have adaptive interval sizes, which are modulated only by the respective layer individually, in order to improve its own predictions as well as its top-down predictability. Since deeper dynamical layers in GPC encode known analytic functions via their feedback weights, discrete changes over such adaptive intervals can efficiently be encoded. Importantly, at each hierarchical layer, multiple instance operate in parallel and randomly sample new intervals modulated by the precision of prediction errors. Figure 2 shows the learnable variables that are relevant to the transition function at an individual predictive coding layer.

Each layer estimates a state at discrete time for time increments of arbitrary size. Additionally, each layer learns transition weights that paramterize the state transition function

(11)

where

is the stochastic noise driving the selection of sampling intervals (the step size or stride in the context of DNN models with discrete timesteps).

Weights parameterize the hierarchical prediction function :

(12)
Figure 2: Temporal transitions and associated dynamical prediction errors occurring in the transition function at each dynamical layer of the predictive coding network.

Functions parameterize a single connection along the dynamical and hierarchical pathway respectively. Layers along the hierarchical dimension are connected as each layer’s hierarchical prediction is the expected state at thse next lower layer . Neighboring layers along the dynamical dimension are connected as each layer predicts the derivative of the transition function of the next lower layer . The discrete time increment plays an important role, since it influences the estimated derivative of the observed activity, by paying attention to changes across different temporal scales.

This estimated derivative is approximated by transition weights . Since the learned transition function is known (in contrast to the function generating the observed data), so is its exact gradient. In the dynamical pathway, neighboring layers predict this gradient of the next lower layer. Alternatively to computing the gradient, it is possible to simply predict the change that the transition function applies to known state inputs:

(13)

.

Here, we use that , i.e. we require the transition weights to be estimated with respect to a known time increment . Initially is selected randomly based on each layer’s prior on the sampling interval fluctuation . Since multiple instances of a layer can predict the same input, different sampling intervals are covered at any point in time. The interval is known in the local context of a layer, but is not propagated to distant layers. As each layer’s states are predicted top-down, the optimal inferred sampling interval balances a low prediction error for the outgoing prediction with the incoming top-down prediction.

Dynamical and hierarchical pathways are mutually interacting:

(14)

At discrete time step , the total prediction error in each layer is the sum of dynamical and hierarchical prediction errors:

(15)

where and refer to the hierarchically predicted state, the transitioned state and the predicted state change of the last state transition respectively.

Given this objective function for each predictive coding layer, we can apply the optimised network to various learning and inference tasks. Optimising multiple sample intervals in parallel is thought to reduce the need for more complex autoregressive modelling, by trading computational complexity within single layers with increased parallelism. To test this hypothesis, a first step is to apply the network to the prediction of sequential data. It should be noted that, although we focus on sampling temporal inputs, the same method can be applied to spatial dimensions. The following section evaluate dynamics learning from sequential data, with and without sensory noise and non-linear activation functions.

3 Learning dynamics from sequential data

3.1 Learning dynamics with linear and non-linear activation

With linear activations, GPC infers observed and the first two derivatives

relatively exactly. When a ReLU non-linearity is used, the represented function has less complex generalised coordinates, but still matches the observed

well. Figure 3 show examples for dynamical state estimation on sequential inputs with and without non-linear activation functions.

Figure 3: Approximated representation of a sine function in a GPC model with three hidden dynamical layers (black and green lines). The feedback predictions are linearly activated (top) or ReLU activated (bottom). Prediction errors from the transition functions and as well as their top-down predictions and are indicated in red and yellow respectively.

3.2 Learning dynamics with parallel sampling

With changing interval size (stride), the inferred generalised coordinates change significantly. In the example shown in Figure 4, a stride of 10 leads to a relatively constant stationary input, where the second represented derivative remains nearly constant. For strides other than 5 or 10, the prediction error is substantially higher.

Figure 4: Dynamical state estimation on a modulated sine wave sampled at three different strides (1,9,10) in the lowest predictive coding layer.

4 Relation to the canonical microcircuit for predictive coding in biological networks

Figure 5: Comparison of Differentiable Generalized Predictive Coding with the canonical microcircuit for predictive coding that is hypothesized to drive Bayesian inference in the human brain.

A lot of effort has been spent on describing canonical microcircuits that capture the core functionality of neuronal processing in human brains. In the context of interpreting neuronal computations as performing Bayesian inference, multiple similar models of canonical computations have been suggested that show close correspondence to the connectivity described by generalised predictive coding [douglas1991functional, douglas2004neuronal, douglas1989canonical, bastos2012canonical].

Figure 5 shows a comparison between GPC, the Differentiable Generalised Predictive Coding model suggested here and the canonical microcircuit (CMC) described by Bastos et al. [bastos2012canonical]. When comparing the structure for a single predictive coding layer , there is a clear correspondence between cause states, hidden states and associated prediction errors in biological and artificial networks. In the context of cortical function, a layer in CMC sends forward prediction errors to higher cortical areas and backward predictions towards lower areas. While this structure is visible in GPC, it lacks the explicit differentiation of states into superficial pyramidal cells and deep pyramidal cells. In the CMC, this differentiation can be connected to different sampling frequencies associated with incoming and outgoing information [bastos2012canonical].

The comparison with canonical circuitry and concrete neurophysiological evidence provides a useful starting point for more in depth analysis of the biological plausibility of computations in deep predictive coding networks (and for DNNs in general). One striking aspect in CMC is that feedforward information typically is propagated via linear connections, while backward predictions feature non-linear connectivity. GPC and PredProp consider this differentiation by including DNNs into the backward weights, while keeping precision and transition pathways linear.444PredProp, however, theoretically allows to to include non-linearities and complex autoregressive DNN models into the remaining connections between state variables. Finally, comparing with the biological interpretation in CMC allows to address inhibitory and excitatory jointly with associated sampling frequencies in deep predictive coding explicitly.

The hypothesized canonical circuitry provides an important bridge between analysis of predictive coding from a functional and computational level of analysis to concrete biological implementation. Similarly, GPC provides a concrete implementation of generalised predictive coding in the context of deep neural networks.

One interesting line of research could address more exact modelling of the CMC and its implications on performance in machine learning tasks. Another, quite elegant way of characterising aspects of functional modularity in concrete implementations is to define functional modules in terms of statistical independence, i.e. by their Markov Blankets. We will use this approach in the following section in order to analyse differences between probabilistic graphical models describing generalized predictive coding and its concrete DNN-based implementation in GPC.

4.1 From neurons to graphical models via Markov blankets

Figure 6: Left: Neuronal Markov blankets describe that separate individual neurons and their surrounding. Regional Markov Blankets characterize the effective connectivity between cortical layers based on canonical microcircuits for predictive coding. Right: Similarly, regional Markov blankets can be mapped to the connectivity between predictive coding layers in Differentiable Generalised Predictive Coding.

A relatively recent line of research addresses the formalisation of dynamical coupling in biological systems across (spatial) scales via Markov Blankets [friston2019free, ramstead2018answering, palacios2020markov, kirchhoff2018markov, pearl1998graphical]. Markov blankets describe a partitioning of complex dynamical systems in terms of statistical boundaries. When applied to characterising functional modules in biological brains, this entails partitions appearing at multiple scales, such as at the level of individual neurons, at the level of canonical micro-circuitry or at the level of brain regions and larger networks [hipolito2021markov]. Markov blankets are an effective way to characterise statistical modularity in complex systems and can directly be mapped to the effective connectivity described by the model suggested here. Figure 6 shows a simplified overview of Markov blankets at the neuronal and micro-circuit level, next to a mapping to the connectivity structure of GPC. While we characterize Markov Blankets at the regional level here and focus on relatively shallow GPC models, future work might characterize Markov Blankets appearing at larger spatio-temporal scales in more complex implementations.

Crucially, Markov Blankets divide the states of a systems into external, sensory, active and internal states [hipolito2021markov]. Of these, sensory and active states can be summarized into blanket states, which shield internal and external partitions. Internal states are updated based on internal and blanket states, but are independent of external states. Similarly, active states do not depend upon external states. As shown in Figure 6 the level of canonical micro-circuits, one can separate neighboring layers in GPC into internal and external states, where information is propagated exclusively through active and sensory state, the state predictions and prediction errors respectively. An important aspect Markov blankets is that the described modularity is characterised exclusively by statistical dependencies and is not directly dependent on the physical structure of the systems.

This differentiation between a system’s structure in terms of Markov Blankets and the implementation can be applied in straightforward fashion to artificial systems modelled to mimic particular biological function or are tailored to a particular graphical model. The computations in the model presented here is based on the graphical model described by Generalised Predictive Coding. As visible in Figure 1 the implementation of the graph described by Generalised Predictive Coding seems to require exact connectivity to maintain a valid structure. For example, interconnecting variables within dynamical and hierarchical pathways intuitively seems to change the implemented graphical model. Future work, however, could relax the assumption of such strict mapping between connectivity in the graphical model and the implementation of GPC by allowing more flexible (i.e. random up to a certain degree) connectivity between variables. We can then characterize particular nodes in the model via their Markov blankets instead of a one-to-one mapping from functional to implementational level.

5 Planning with Generalised Predictive Coding

The hierarchical-dynamical transition model learned with GPC can be used for planning task in relatively straightforward matter. While arbitrarily complex algorithms could be employed on top of the learned model-based predictive model. Here, we focus on planning similar to recent formulations of Active Inference agents that optimise the Expected Free Energy of with respect to policies over discrete actions [ueltzhoffer2018deep, ccatal2020learning, van2020deep, fountas2020deep]. Policies are sequences of actions at discrete time step . In many reinforcement learning algorithms, one would optimise on sequences of actions and observed rewards :

up to a planning horizon T. Typically, reward based methods require additional regularization that enables exploration in order to improve generalization. Such exploration often refers to testing new policies that might not be connected to (immediate) optimal reward and is typically modelled by including a random sampling process or some form of expected information gain.

The approach in active inference is slightly different and does not directly operate on a known reward function.555External reward, however, can be integrated by defining a prior preference on maximising observed reward. Active inference models optimises the Expected Free Energy (EFE), i.e. the Free Energy of expected observations for planned policies. Instead of resolving to external rewards, we can define actions over adaptive time intervals and the associated Expected Free Energy:

(16)

in N parallel policies with discrete actions over planning horizons with adaptive size. The adaptive sampling intervals in GPC implies that the temporal666Alternatively, the spatial interval size, when planning is done with respect to spatial dimensions. In hidden layers, the adaptive sampling size refers to the predicted motion of causes represented at the respective lower layer. interval size for each action depends on the variational Free Energy of the previously inferred state.

We can interpret this EFE objective as a sort of performance measure and simply optimise it with Stochastic Gradient Descent, similar to existing work on active inference interpreted as a policy gradient method [millidge2020deep].

6 Nested hierarchies and perceptual actions

(a) A predictive coding network with four local and one distant nodes with nested hierarchical-dynamical structure.
(b) Example trajectory with redundancy that can be reduced during perceptual model reduction in mental simulation.
Figure 7: Perception and planning in generalized predictive coding networks with learned connectivity.