PredProp: Bidirectional Stochastic Optimization with Precision Weighted Predictive Coding

11/16/2021
by   André Ofner, et al.
0

We present PredProp, a method for bidirectional, parallel and local optimisation of weights, activities and precision in neural networks. PredProp jointly addresses inference and learning, scales learning rates dynamically and weights gradients by the curvature of the loss function by optimizing prediction error precision. PredProp optimizes network parameters with Stochastic Gradient Descent and error forward propagation based strictly on prediction errors and variables locally available to each layer. Neighboring layers optimise shared activity variables so that prediction errors can propagate forward in the network, while predictions propagate backwards. This process minimises the negative Free Energy, or evidence lower bound of the entire network. We show that networks trained with PredProp resemble gradient based predictive coding when the number of weights between neighboring activity variables is one. In contrast to related work, PredProp generalizes towards backward connections of arbitrary depth and optimizes precision for any deep network architecture. Due to the analogy between prediction error precision and the Fisher information for each layer, PredProp implements a form of Natural Gradient Descent. When optimizing DNN models, layer-wise PredProp renders the model a bidirectional predictive coding network. Alternatively DNNs can parameterize the weights between two activity variables. We evaluate PredProp for dense DNNs on simple inference, learning and combined tasks. We show that, without an explicit sampling step in the network, PredProp implements a form of variational inference that allows to learn disentangled embeddings from low amounts of data and leave evaluation on more complex tasks and datasets to future work.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

11/12/2021

Predictive coding, precision and natural gradients

There is an increasing convergence between biologically plausible comput...
12/02/2021

Differentiable Generalised Predictive Coding

This paper deals with differentiable dynamical models congruent with neu...
07/03/2018

Stochastic Layer-Wise Precision in Deep Neural Networks

Low precision weights, activations, and gradients have been proposed as ...
10/02/2020

Relaxing the Constraints on Predictive Coding Models

Predictive coding is an influential theory of cortical function which po...
05/19/2018

Sampling-Free Variational Inference of Bayesian Neural Nets

We propose a new Bayesian Neural Net (BNN) formulation that affords vari...
02/03/2020

Effect of top-down connections in Hierarchical Sparse Coding

Hierarchical Sparse Coding (HSC) is a powerful model to efficiently repr...
01/18/2021

Faster Convergence in Deep-Predictive-Coding Networks to Learn Deeper Representations

Deep-predictive-coding networks (DPCNs) are hierarchical, generative mod...

1 Introduction

In the context of machine learning, neural networks are often trained by updating a set of parameters p in order to optimize a objective function L(x). A popular method for this is the Stochastic Gradient Descent (SGD) algorithm, an iterative procedure that applies changes

to the parameters based on randomly sampled batches from a dataset [robbins1951stochastic]

. SGD optimizes the objective function by following the direction of steepest descent indicated by the negative of the gradient, a local estimation of the direction that maximally minimizes the cost.

Under the assumption that derivatives for the parameters can be computed, this results in simple update rules for the parameters at discrete steps in the form of:

(1)

where is the gradient of the parameters at step and is a learning rate that modules the size of the step.

While gradient computation and parameter updating usually is simple, finding the correct hyperparameters, particularly when training deep neural networks requires manual tuning. With respect to the ability to generalize to new data, complex regularization strategies have been developed. These, however, typically require manual tuning as well. Gradient Descent in combination with backpropagation of errors in deep neural networks (DNNs) is efficient, however often struggles when noise is present the data or in the temporal evolution of computed gradients. Furthermore, SGD based training of NNs typically focuses on optimising only the weight parameters, thereby requiring successfully learned weights to perform inference. This renders SGD in combination with (deep) neural networks particularly suited for learning from entire datasets, but less so for inference on few trials.

Predictive coding is a theory that originates from cognitive neuroscience and aims at explaining brain function. It offers a description of neural operations that solve many of these issues using a relatively simple algorithmic motif based on prediction error minimisation and bidirectional processing [friston2009predictive].

While reviewed in more detail elsewhere in the literature, under the Laplace assumption, the updates in predictive coding networks depend on a Free Energy cost function F, also known as the negative evidence lower bound (ELBO) [millidge2021predictivereview]:

(2)

that depends on prediction errors between a observed activity and predicted mean activity at layer :

(3)

In contrast to activations, neural activities here refer to inferred states in the general sense, which might by modified directly through the propagated error or from a backward activation. Using gradient descent on this negative Free Energy, a multi-layer NN is gets optimised when the prediction errors

at each layer, weighted by their precision (or inverse variance)

are minimized.

In a single layer predictive coding networks, the observed activity simply is the data. Layers in multi-layer predictive coding networks predict the activity

of the next lower layer. The backward predictions of the next lower layer are computed using a (non-)linear activation function

and learnable weights .

In contrast to feedforward NNs, the layers of such a predictive coding network thus have an individual cost function for each layer. The resulting update at time for the backward weights at layer is

(4)

which depend on the prediction error and prediction error precision at the predicted layer as well as the activity at the layer that generates the prediction.

However, the negative Free Energy can be minimized not only by updating the backwards weights, i.e. through learning, but also through updating the activity and precision units in each layer directly, i.e. through inference and uncertainty estimation.

The corresponding update rule for activities at layer l is:

(5)

Similarly, the update rule for each layer’s precision is:

(6)

where are the precision weighted prediction errors at layer .

By explicitly tracking the precision (or inverse covariance) of the activities through these simple updates, predictive coding networks constantly estimate second order changes in the partial objective functions, or prediction error, across the layers. As explained in more detail in [millidge2021predictivereview], the precision matrix for the activities directly encodes the Fisher information matrix of the layer’s error:

(7)

The Fisher information is used, for example, in Natural Gradient Descent as an adaptive learning rate that scales the computed gradient by the curvature of the objective function. The Fisher information with respect to the weights can be expressed as the layer’s estimated precision of prediction error, scaled by the expected value of the activities at the next higher layer:

(8)

In summary, predictive coding can be described as a gradient descent scheme that computes local second order estimates separately for each layer. The locally informed gradients from this update scheme (when exchanging data and target inputs, so-called discriminative predictive coding) have previously been shown to approximate those of the backpropagation algorithm, where partial derivatives are computed for entire NNs. This means that we now have an optimisation method that jointly addresses inference and learning, scales learning rates dynamically and weights gradients by the curvature of the loss function.

We argue that, in order to make this optimisation scheme accessible to a larger audience and scaling predictive coding to the complexity of state-of-the art DNN systems, there is need for a efficient predictive coding optimiser that computes analytical derivatives locally for each layer and allows to easily implement complex neural networks with arbitrary connectivity and nonlinearities. The next sections will describe the proposed ”Predictive Coding Optimizer For Stochastic Gradient Descent with Local Error Propagation”, called PredProp and evaluate learning and inference in comparison to related approximate second order methods.

Figure 1: The PredProp optimization method based on automatic differentiation and convex optimisation from stochastic samples. Variables are optimised in parallel using only the information that are locally available between layers. The gradients of the error backpropagation algorithm can be obtained by reversing the network (targets become inputs) and computing only the gradients with respect to the weights. While the backward prediction pass through the network’s weights is entirely deterministic, the predictive coding updates of all optimised parameters implement variational inference.

2 Algorithm

See algorithm 1 for pseudo-code of the proposed algorithm. Our optimization method computes exact local gradients with respect to a single predictive coding layer . When the deepest layer is provided with a label input input instead of , the prediction error in this layer is a label prediction error, i.e. . Similarly the prediction error in the lowest layer turns into a data prediction error with respect to the observed data . The described update scheme operates bidirectionally, so that the predictive coding network constantly iterates between inferring the most likely cause (c.f. target) of observed data and the most likely data given the currently estimated cause. Each predictive coding layer has at least a one or multiple stacked weights parameterizing the backwards connection , followed by a (non-)linear activation function . In some cases, these multi-layer backward connections might incorporate entire DNN structures without any modifications necessary to the algorithm.

The algorithm proceeds by iteratively computing a single backward prediction pass followed by forward passes of weighted prediction errors. When compared to the backpropagation algorithm, errors and the exact gradients are computed in the same form, but are propagated in reverse direction, i.e. propagating from data (lower layers) towards higher layers near the target input.

In the backwards prediction pass, each layer projects its currently inferred activity towards the corresponding lower layer. The backward connections project activities from the next higher predictive coding layer to the current layer . Both current and next layer’s activities and are optimised with respect to the precision weighted prediction error . These three parameters, jointly with the estimated inverse covariance, or precision, at the current layer are optimised iteratively, using separate learning rates for activities and , backward weights and precision . Setting weights learning rate

to zero effectively forces the optimised network to rely only on inference to minimise the objective function. This special case and will be discussed in the experiments section in more detail. For all following experiments, we used the RelU activation function in the hidden layers and linear activations in output layer

[nair2010rectified] with a fixed size of 256 hidden units and initial learning rates set to . For all unsupervised experiments, we set the cause units in the deepest predictive coding layer to 10.

3 Discussion

3.1 Relation to other optimizers

The Natural Gradient Descent algorithm computes the Fisher information in its analytical form [amari1997neural]. It is a highly attractive optimization method due to the speed and stability of convergence [rattray1998natural, kunstner2019limitations]

. However, computing the partial derivatives for all combinations of network parameters often is not feasible in practise. PredProp circumvents the computation of the exact fisher matrix with respect to all parameters in the model to significantly smaller by resorting to estimations of the Fisher estimation for each layer. Other optimization methods for neural networks typically approximate second order statistics with simpler and less computationally expensive approaches. Adam and RMSProp approximate the second order derivatives by computing the variance of the gradient

[kingma2014adam]. Other approaches compute approximate or exact second order derivatives directly, e.g. [zeiler2012adadelta] and [yao2020adahessian]. All of these methods share the adaptive learning rate with PredProp but do not address bidirectional or distributed learning.

3.2 Relation to variational inference

As has been shown before, the presented optimization scheme underlying predictive coding can directly be interpreted as performing a particular form of variational inference, when certain assumptions, such as the Laplace approximation hold [buckley2017free]

. In such cases, for Gaussian distributions, only updating the mean of activities

is required. This direct connection to variational inference gets particularly obvious when one looks at the definition of the variational Free Energy

that was introduced in Equation 2). Without any additional assumptions about the underlying probability distributions, the variational Free Energy has an additional entropy term:

(9)

with data , latent state and a generative model that generates observed data . In variational inference, the true posterior of latent state x is inferred indirectly using an approximate posterior , with known parameters . To do so, fitting the approximate posterior to the true posterior is achieved by minimizing the divergence between the true and approximate posterior with respect to the parameters

. As a result, when certain approximations, such as the Laplace approximation holds, Gradient Descent on the negative Free Energy directly implements variational inference. While this is a core advantage of PredProp over other second-order optimisers, future work could explore variants of the employed objective function, such as KL divergence minimization between more complex probability distributions.

3.3 Relation to training deep generative models

Figure 2:

Overview of selected generative model architectures, each implementing a variant of variational Bayesian inference. While sometimes making clear distinctions is difficult, arrows indicate that a certain generative model architecture is a specific variant of a more general model and associated learning procedure.

The presented predictive coding updates have close connection to a variety of related update schemes used for (deep) generative models and NNs in general. In terms of inference, closely related algorithms are the wake-sleep, expectation propagation and expectation maximisation algorithms, all of which propagate information bidirectionally inside models. In contrast to the parallel optimisation here, variants of the EM algorithm typically optimize variables iteratively and keep the remaining variables fixed [dempster1977maximum]. Like PredProp, the wake-sleep algorithm was initially inspired by human brain function and shows similar bidirectional optimization used for a variety of (deep) neural network models, such as the Restricted Boltzmann or Helmholtz machine. Figure 2 shows connections between these related model classes with examples that are by no means exhaustive. We leave it to future work to explore in detail where and how more generalised gradient based predictive coding schemes like PredProp supersede these learning procedures.

4 Future work

Several studies have demonstrated the close connections between the gradient-based optimisation methods that play a key role in training (deep) neural networks to more general models, such as predictive coding networks, that implement variants of variational Bayesian inference. We suggest that future work should aim to deepen our understanding about bi-directional information propagation, such as it ubiquitous in the human brain. We suggest that treating model parameters and update schemes not as conceptually separate, but as mutually dependent and closely intertwined, will lead to models that learn and infer autonomously and need less supervision or manual tuning.