pymdp: A Python library for active inference in discrete state spaces

01/11/2022
by   Conor Heins, et al.
69

Active inference is an account of cognition and behavior in complex systems which brings together action, perception, and learning under the theoretical mantle of Bayesian inference. Active inference has seen growing applications in academic research, especially in fields that seek to model human or animal behavior. While in recent years, some of the code arising from the active inference literature has been written in open source languages like Python and Julia, to-date, the most popular software for simulating active inference agents is the DEM toolbox of SPM, a MATLAB library originally developed for the statistical analysis and modelling of neuroimaging data. Increasing interest in active inference, manifested both in terms of sheer number as well as diversifying applications across scientific disciplines, has thus created a need for generic, widely-available, and user-friendly code for simulating active inference in open-source scientific computing languages like Python. The Python package we present here, pymdp (see https://github.com/infer-actively/pymdp), represents a significant step in this direction: namely, we provide the first open-source package for simulating active inference with partially-observable Markov Decision Processes or POMDPs. We review the package's structure and explain its advantages like modular design and customizability, while providing in-text code blocks along the way to demonstrate how it can be used to build and run active inference processes with ease. We developed pymdp to increase the accessibility and exposure of the active inference framework to researchers, engineers, and developers with diverse disciplinary backgrounds. In the spirit of open-source software, we also hope that it spurs new innovation, development, and collaboration in the growing active inference community.

READ FULL TEXT VIEW PDF
10/27/2019

Neural Network Distiller: A Python Package For DNN Compression Research

This paper presents the philosophy, design and feature-set of Neural Net...
05/06/2020

Introducing PyCross: PyCloudy Rendering Of Shape Software for pseudo 3D ionisation modelling of nebulae

Research into the processes of photoionised nebulae plays a significant ...
12/24/2017

EXONEST: The Bayesian Exoplanetary Explorer

The fields of astronomy and astrophysics are currently engaged in an unp...
11/24/2021

FCMpy: A Python Module for Constructing and Analyzing Fuzzy Cognitive Maps

FCMpy is an open source package in Python for building and analyzing Fuz...
10/12/2021

text2sdg: An open-source solution to monitoring sustainable development goals from text

Monitoring progress on the United Nations Sustainable Development Goals ...
10/14/2019

Some remarks on the performance of Matlab, Python and Octave in simulating dynamical systems

Matlab has been considered as a leader computational platform for many e...
05/02/2018

modAL: A modular active learning framework for Python

modAL is a modular active learning framework for Python, aimed to make a...

Code Repositories

pymdp

A Python implementation of active inference for Markov Decision Processes


view repo

Statement of need

Active inference is an account of cognition and behavior in complex systems which brings together action, perception, and learning under the theoretical mantle of Bayesian inference [friston_reinforcement_2009, friston_active_2012, friston2015active, friston2017process]. Active inference has seen growing applications in academic research, especially in fields that seek to model human or animal behavior [parr2020prefrontal, holmes2021active, adams2021everything]. The majority of applications have focused on cognitive neuroscience, with a particular focus on modelling decision-making under uncertainty. Nonetheless, the framework has broad applicability and has recently been applied to diverse disciplines, ranging from computational models of psychopathology [montague2012computational, schwartenbeck2015dopaminergic, smith2020imprecise, smith2021greater], control theory [baltieri2019pid, millidge2020relationship, baioumy2021towards]

and reinforcement learning

[tschantz2020reinforcement, tschantz2020scaling, sajid2021active, fountas2020deep, millidge2020deep], through to social cognition [adams2021everything, hipolito2021enactive, wirkuttis2021leading, tison2021communication] and even real-world engineering problems [martinez2021probabilistic, moreno2021pid, fox2021active]. While in recent years, some of the code arising from the active inference literature has been written in open source languages like Python and Julia [ueltzhoffer2018deep, van2019simulating, tschantz2020learning, ccatal2020learning, millidge2020deep], to-date, the most popular software for simulating active inference agents is the DEM toolbox of SPM [friston2008variational, smith2022step]. SPM is a MATLAB library originally developed for the statistical analysis and modelling of neuroimaging data, such as data collected by functional magnetic resonance imaging (fMRI) or magneto- and electro-encephalographic (MEG/EEG) [penny2011statistical] methods. The DEM toolbox, a sub-library of SPM

, was originally developed to simulate and perform Bayesian estimation of dynamical systems

[friston2008variational], but in the last decade it has been augmented with a series of demonstrative scripts and simulation routines related to active inference and the Free Energy Principle more broadly [friston2010free, friston2013life, friston2019free].

Simulations of active inference are commonly performed in discrete time and space [da2020active, friston2015active]

. This is partially motivated by the mathematical tractability of performing inference with discrete probability distributions, but also by the intuition of modelling choice behavior as a sequence of discrete, mutually-exclusive choices, in e.g. psychophysics or decision-making experiments. The most popular generative models – used to realize active inference in this context – are partially-observable Markov Decision Processes or

POMDPs [kaelbling1998planning]. POMDPs are state-space models that model the environment in terms of hidden states that stochastically change over time, as a function of both the current state of the environment as well as the behavioral output of an agent (control states or actions). Crucially, the environment is partially-observable, i.e. the hidden states are not directly observed by the agent, but can only be inferred through observations that relate to hidden states in a probabilistic manner, such that observations are modelled as being generated stochastically from the current hidden state.

In most POMDP problems, an agent is tasked with both inferring the hidden states and selecting a sequence of control states or actions to change the hidden states in a way that leads to desired outcomes (maximizing reward, or occupancy within some preferred set of states). DEM contains a reliable, reproducible set of functions for simulating active inference agents equipped with such generative models: they include spm_MDP_VB_X.m, spm_MDP_game.m, and – recently introduced for simulating ‘sophisticated’ active inference – spm_MDP_VB_XX.m [friston2021sophisticated]. Despite its robustness and widespread use among active inference researchers, spm_MDP_VB_X.m is a single function, meaning that any active inference simulation using the function has to comply with the constraints implied by its structure and control flow. Although options can be specified that initiate particular sub-routines or variants of active inference, it is still not straightforward to construct a custom active inference process from scratch. In practice, this means that novel, bespoke applications require researchers to manually adapt parts of spm_MDP_VB_X.m for their own purposes, which limits the general reproducibility and adaptability of academic active inference research, especially for new practitioners. In addition, since all of DEM is written in MATLAB, using the toolbox can be prohibitive due to the cost of a MATLAB license, especially for researchers who are unaffiliated with institutions. Increasing interest in active inference, manifested both in terms of sheer number as well as diversifying applications across scientific disciplines, has thus created a need for generic, widely-available, and user-friendly code for simulating active inference in open-source scientific computing languages like Python. The software we present here, pymdp, represents a significant step in this direction: namely, we provide the first open-source package for simulating active inference with discrete state-space generative models. The name pymdp derives from the fact that the package is written in the Python programming language and concerns discrete, Markovian generative models of decision-making, which take the form of Markov Decision Processes or MDPs.

We developed pymdp to increase the accessibility and exposure of the active inference framework to researchers, engineers, and developers with diverse disciplinary backgrounds. In the spirit of open-source software, we also hope that it spurs new innovation, development, and collaboration in the growing active inference community.

Summary

pymdp offers a suite of robust, tested, and modular routines for simulating active inference agents equipped with partially observable Markov Decision Process

(POMDP) generative models. Mathematically, a POMDP comprises a joint distribution over observations

, hidden states , control states

and hyperparameters

: . This joint distribution further factorizes into a set of categorical and Dirichlet distributions: the likelihoods and priors of the generative model. With pymdp, one can build a generative model using a set of prior and likelihood distributions, initialize an agent, and then link it to an external environment to run active inference processes - all in a few lines of code. The agent and environment API is built according to the standardized framework of OpenAIGym commonly used in reinforcement learning, where an Agent and Environment class recursively exchange observations and actions over time [brockman2016openai].

In order to enhance the user-friendliness of pymdp without sacrificing flexibility, we have built the library to be highly modular and customizable, such that agents in pymdp can be specified at a variety of levels of abstraction with desired parameterizations. In the next section, we provide an overview of the structure of the package.

Package structure

The Agent Class

The high-level API offered by pymdp is the Agent class. Instantiating an Agent allows the user to abstract away the various optimization routines and sub-operations that make up an active inference process, e.g. state estimation, action selection, and learning. The various sub-routines of active inference are themselves abstracted as user-friendly methods of Agent (such as self.infer_states(obs)), calls to which will run the corresponding function.

Modules

The methods of Agent themselves call functions from different sub-modules of pymdp. These submodules can be roughly divided into three sorts of operations: perception, action and learning. An attractive feature of active inference is that various cognitive processes naturally emerge as different variants of Bayesian inference. For instance, instantaneous inference about dynamically-changing hidden states is often analogized to perception (c.f. perception as inference [von1910treatise, gregory1980perceptions, hinton1983optimal, dayan1995helmholtz]), whereas inference about slower-changing variables (statistical regularities in the environment) is analogized to learning [friston2016active]. Moreover, action is treated as a process of inference, where agents select actions by inferring a distribution over control states or sequence of control states [friston2017process]. Each module of pymdp thus performs inference with respect to different components of a POMDP generative model – we summarize them briefly below.

The inference library of pymdp contains a set of functions for performing hidden state inference or state estimation. These are the core functions that allow agents to update their beliefs about the discrete hidden state of the environment, given observations. Functions from this library are called by the self.infer_states() method of Agent. Specific arguments can be passed into the Agent constructor to specify the type and parameterization of the algorithm used to perform hidden state inference.

The control library of pymdp contains functions for inferring policies and sampling actions from the posterior beliefs about control states.111

In active inference ‘control states’ refer to the random variables in the generative model and approximate posterior, and can be thought of as the agent’s representation of its actions. Actions themselves are

realizations of these random variables, sampled from the posterior over control states. These functions are called internally by the self.infer_policies() method of Agent.

Finally, the learning library of pymdp contains the functions necessary for the agent to update hyperparameters of its generative model, i.e. Dirichlet parameters over the categorical prior and likelihood distributions. These functions are called internally by methods like self.update_A(), self.update_B(), and self.update_D() of Agent.

For a more detailed overview of the functionality offered by each of pymdp’s modules, please see Appendix B: Modules and Theory.

1    from pymdp.agent import Agent
2
3    # here you would set up your generative model
4    my A = ...
5    my_B = ...
6
7    # instantiate your agent with a call to the Agent() constructor
8    my_agent = Agent(A=my_A, B=my_B, C=my_C, D=my_D)
9
10    # define an environment
11    my_env = Env()
12
13    # set an initial action
14    action = initial_action
15
16    for t in range(T):
17        o_t = my_env.step(action)
18        my_agent.infer_states(o_t)
19        my_agent.infer_policies()
20        action = my_agent.sample_action()
Example 1: Minimal example of running active inference with pymdp.

Usage

Specifying a generative model is central to active inference, and to Bayesian modelling in general. Intuitively, a generative model is a probabilistic specification of how data or sensory observations are generated. In the context of Bayesian agent-based models, the generative model represents an agent’s probabilistic internal model of its environment, comprising a set of structural assumptions about how the world generates observations and how action changes the world. In discrete-state and -time active inference models, we typically assume the generative model is a partially observable Markov Decision Process or POMDP, comprising observations the agent receives, hidden states of the world, and actions the agent can take to influence hidden states. Hidden states are called ‘hidden’ precisely because the agent can never directly access them, but can only infer

them via observations. The POMDP structure assumes that at each timestep, the observation is generated by the current hidden state, while the hidden state itself changes over time as a function of its current setting and some control state (i.e. action). Mathematically, a generative model is usually expressed as a joint probability distribution

over observations , hidden states , control states , and parameters . This joint distribution is called a generative model because it can be used to sample (or generate) sequences of potential observations according to the probabilistic structure encoded in the model.

Specifying the generative model in terms of discrete probability distributions is the first step to building an active inference agent in pymdp. Below, we overview the steps involved in building a generative model to provide intuition and illustrate its simplicity in the special case of POMDPs.

The POMDP generative model

The POMDP generative model assumed by pymdp is a discrete-time and space generative model that, like any probability distribution, can be factorized into a product of conditional distributions (likelihoods) and marginals (priors). The most important of these distributions – when writing down a generative model in pymdp – are 1) the observation likelihood , which represents the agent’s beliefs about how hidden states generate observations and 2) the transition model , which represents the agent’s beliefs about how hidden states at some time cause hidden states at the next time , conditioned on some control states (actions) . The agent also has a prior over initial hidden states , which represents the agent’s baseline belief, before gathering any observations, about the probability of the different hidden states at the first timestep. For the sake of mathematical convenience, we describe POMDP generative models with a finite time horizon , but note that in general pymdp does not require a finite time horizon, so active inference agents can be theoretically run indefinitely (e.g. in streaming applications). Finally, there is an additional prior distribution over observations, , which specifies an agent’s goals as a desired distribution over observations. In active inference, goals and desires are encoded as a prior in the generative model, such that the probability that the model assigns to some configuration of observations, hidden states, control states and prior parameters is only maximized when sampling preferred observations (e.g. if my model assigns high probability to observing body temperature to be 37 degrees, prior preferences are realized when these observations are sampled). As we will see in the following sections, active inference suggests that perception, action and learning all work to maximize the marginal likelihood of observations. Prior preferences only come into play during policy selection because only unobserved outcomes in the future are random variables (see the section on control.py in Appendix B for details).

The local dependence in time, captured by one-step conditional dependence between hidden states in the transition likelihood, is what renders POMDPs Markovian generative models. As such, a general expression for the joint distribution over , , and is as follows:

(1)

We have replaced here with to represent policies, or sequences of control states , i.e. . Control states can be formally related to policies by writing down an additional likelihood, a ‘policy-to-control’ mapping , that links a given policy to the control state it entails at time . Under active inference, agents also perform inference about policies, which naturally entails goal-directed and uncertainty-resolving behavior (see the section The Expected Free Energy

in Appendix B for details on policy inference). Finally, we capture any additional parameters as a single vector of hyperparameters

, which might correspond to the parameters of Dirichlet priors over the likelihood distributions and , for instance. In active inference, inference about these hyperparameters is often assumed to occur on a slower timescale than inference about hidden states and policies: therefore, this process is referred to in the active inference literature as learning [friston2016active, schwartenbeck2019computational] (see learning.py for more details on hyperparameter inference).

Equipped with a generative model, active inference (and Bayesian inference more generally) entails inference over latent variables , and , given some observations gathered over time. For a description of the mathematical basis of this inference – and how it is implemented algorithmically in pymdp – please refer to Appendix B: Modules and Theory. Now that we have expressed the POMDP generative model formally, highlighting its important components for active inference, we can move on to the representation of these distributions in pymdp.

Building blocks

pymdp considers generative models of discrete states that evolve in discrete time. This means that there are an integer number of discrete levels of both the states of the environment and of the observations. Because of this fundamental discreteness, a natural way to represent the distributions of the generative model is by using categorical distributions, which assign a probability value between and to each discrete outcome level of the distribution’s sample space, with the usual constraint that the sum of the probabilities over levels is . Mathematically, we refer to categorical distributions with the notation . This means that the distribution over the random variable is described by a categorical distribution with an -dimensional vector of parameters , where is the cardinality of the sample space of

. Numerically, these categorical distributions can be represented as multidimensional arrays (also known as NDarrays or tensors) that contain their parameters. These categorical distributions come in two flavors: vector-valued marginal distributions (e.g.

) usually playing the role of priors, and conditional categorical distributions (e.g. ) in the form of NDarrays (matrices and tensors), playing the role of likelihoods in the generative model.

Marginal categorical distributions are encoded in pymdp as simple 1-D vectors, which technically are instances of numpy.ndarrays, the core data structure for representing multi-dimensional arrays in the Python array programming library, numpy [harris2020array]. One can easily instantiate categorical distributions in numpy using calls to the array constructor, e.g.
prior_over_states = np.array([0.5, 0.5]). Conditional categorical distributions are just collections of 1-D categorical vectors, with as many 1-D vectors as there are levels of the conditioning variable. In pymdp

, we encode these collections as matrices (2-D NDarrays) and higher-order NDarrays. For instance, we would encode some discrete conditional distribution relating two categorical variables

as a matrix of size , where is the number of levels of the support random variable and is the number of levels of the conditioning random variable . We represent such conditional categorical distributions mathematically with the notation , where now is a matrix or tensor of parameters, whose columns have the properties of a single categorical distribution, i.e. .

The core distributions of the generative model – that are encoded in this way – are the observation likelihood (also known as the observation model or the sensory likelihood) and the transition likelihood (also known as the transition model or the dynamics model). For all the following descriptions, we borrow notation from the SPM and DEM standards, which are summarized in papers like [friston2015active, friston2017process, smith2022step, friston2017active]. In pymdp notation, the observation likelihood is constructed as the A array. In simple generative models222See the section on Factorized representations for the more general form., A will be a matrix, where is the number of outcomes or levels of the observations , and is the number of levels of hidden states . The entry A[i,j] encodes the probability of seeing observation given state . In other words, each column of this matrix A[:,j] stores a vector of categorical parameters that encodes the distribution . Similarly, the transition likelihood is represented by the B array which is a NDarray or tensor, where is the number of levels of the control states, and entry B[i,j,k] encodes the probability transitioning to state at time from state at time , when control state or action is taken by the agent. The general structure of numerical representations of conditional distributions in pymdp can be expressed as follows: the first dimensions (or rows) of the matrix or NDarray represent the support of the conditional distribution, while the lagging dimensions (columns, slices, etc.) represent the random variables being conditioned on. Thus, for the observation likelihood , the first dimension of A represents the support of (and will have length ) while the second dimension represents the support of (and will have length ).

Beyond the A and B arrays, one can also specify an initial prior over states – in pymdp this is called the D vector (of length ) and represents the agent’s beliefs about the distribution over hidden states at the first timestep of the time horizon (when ).

Finally, in order to achieve goal-directed behavior under active inference, it is necessary to build a representation of some desired state or goal into the generative model. In reinforcement learning this is handled using reward functions but in active inference we instead specify a prior distribution over observations, also known as the ‘prior preferences’ or ‘goal distribution’ [friston_reinforcement_2009]. Inference over control states is then biased by this preference distribution, leading agents to choose actions that bring them to states that (they expect) will lead to preferred observations. In pymdp, this is represented by the C array of length . The default C array is a vector that is time-independent (the same C is used for all timesteps), but it is also possible to specify a time-dependent array. This can be used to represent goals that change over time or the desire to reach a specific goal in a time-dependent manner .

After the generative model has been specified in terms of a set of likelihood and prior distributions, one can build an active inference agent in a single line using the Agent() constructor: e.g. my_agent = Agent(A=A, B=B, ...). The Agent() constructor requires A and B arrays as mandatory input, while C and D

vectors can be optionally included (the defaults are uniform distributions for each).

The various methods of the resulting Agent instance can then be used to perform active inference.

1    import numpy as np
2
3    import pymdp
4    from pymdp import utils, maths
5    from pymdp.agent import Agent
6
7    # create a simple model with one hidden state factor, and one observation modality
8
9    n_obs = 3
10    n_states = 3
11
12    A = utils.obj_array(1)
13    A[0] = np.array([[1.0, 0.0, 0.0],
14                     [0.0, 1.0, 0.0],
15                     [0.0, 0.0, 1.0]])
16
17    # introduce uncertainty into one of the hidden states
18    inv_temperature = 0.5
19    A[0][:,2] = maths.softmax(inv_temperature* A[0][:,2])
20
21    # create a simple transition model with two possible actions
22
23    B = utils.obj_array(1)
24    B[0] = np.zeros((3, 3, 2))
25
26    # first action leads to first two states with uncertainty
27    B[0][:,:,0] = np.array([[0.5, 0.5, 0.5],
28                            [0.5, 0.5, 0.5],
29                            [0.0, 0.0, 0.0]])
30
31    # second action leads to last state with certainty
32    B[0][:,:,1] = np.array([[0.0, 0.0, 0.0],
33                            [0.0, 0.0, 0.0],
34                            [1.0, 1.0, 1.0]])
35
36    # specify prior preferences (C vector)
37    C = utils.obj_array_uniform([n_obs])
38
39    # specify prior over hidden states (D vector)
40    D = utils.obj_array(1)
41    D[0] = utils.onehot(1, n_states)
42
43    # instantiate your agent with a call to the ‘Agent()‘ constructor
44    my_agent = Agent(A=A, B=B, C=C, D=D)
45
46    # write a simple environment class, where state depends on the action probabilistically, and observation is deterministic function of the state except for state 2, where it’s randomly sampled
47
48    from pymdp.envs import Env
49
50    # sub-class it from the base Env class
51    class custom_env(Env):
52
53        def __init__(self):
54            self.state = 0
55
56        def step(self, action):
57
58            if action == 0:
59                self.state = 0 if np.random.rand() > 0.5 else 1
60            if action == 1:
61                self.state = 2
62
63            if self.state == 0:
64                obs = 0
65            elif self.state == 1:
66                obs = 1
67            elif self.state == 2:
68                obs = np.random.randint(3)
69
70            return obs
71
72    env = custom_env()
73
74    action = 0
75
76    T = 10 # length of active inference loop in time
77    for t in range(T):
78
79        # sample an observation from the environment
80        o_t = env.step(action)
81
82        # do active inference
83        qs = my_agent.infer_states([o_t]) # get posterior over hidden states
84        my_agent.infer_policies()
85        action = my_agent.sample_action()
86
87        # convert action into int, for use with environment
88        action = int(action.squeeze())
Example 2: Detailed example of building and running an active inference process in pymdp.

Specifying an environment

For most use-cases of active inference, the agent will need to interface with some kind of environment or external world. The minimal definition of an environment is just a class or function that takes actions of the agent as input, updates the true hidden state of the environment (but does not convey this information to the agent) and returns observations generated by the updated hidden state. In the Bayesian modelling literature, this environment is also abstractly referred to as the ‘generative process’ or ‘data-generating process’. What is important to note is that this generative process does not have to be identical to the generative model – i.e. there is no requirement that an active inference agent with a POMDP generative model is operating in a world with discrete, POMDP-like dynamics [baltieri2019generative, tschantz2020learning]. All that matters is that the environment accepts the agent’s actions and returns observations that are discrete and are compatible with the support of the likelihood of the agent’s generative model.

pymdp contains a library of pre-built environments which can be imported using from pymdp import envs. Following the convention of OpenAI Gym [brockman2016openai], users can also write their own environment class. This class is traditionally written to have a step() method which takes an action from the agent as input and returns observations that will be processed by the agent at the next timestep. In many reinforcement learning and control problem contexts, the environment has its own internal state that is updated by the agent’s action, and which determines (either stochastically or deterministically) the next observation.

Closing the action-perception loop

The typical ‘active inference loop’ consists of three main steps: 1) sampling an observation from the environment; 2) updating the agent’s beliefs about states and policies using the observation; and 3) choosing an action, based on the agent’s posterior over policies (see Appendix B: Modules and Theory for more details on state and policy inference). In pymdp, 1) is implemented by calling the environment class env.step(); 2) is implemented using the pymdp functions agent.infer_states() and agent.infer_policies() and 3) is implemented using the pymdp function agent.sample_action(). Wrapping these three steps into a loop over time entails the entire active inference process; see the full example using the Agent class in Example 2.

Factorized representations

Although many simple POMDPs can be constructed with simple 2-D A matrices and 3-D B arrays, most of the interesting applications of active inference require what are referred to as ‘factorized representations’. This requires building additional structure into the generative model, such that observations are divided into separate modalities and hidden states into separate factors. A multi-modality observation and multi-factor hidden state can be expressed as follows:

where here the superscript refers to the index of the th observation modality or th hidden state factor, respectively. This means that at any given time the agent receives a collection of discrete observations, where each observation within the collection belongs to a distinct ‘modality’. The name modality is used to emphasize the analogy to different sensory channels (e.g. vision, audition, somatosensation) in biology that relay different sorts of information. Likewise, in a factorized hidden state representation, the environment’s structure is represented through several hidden state factors, that may encode distinct features of the world, each of which may have its own dimensionality, dynamics, and relationship to observations.

Importantly, with such factorized representations, the likelihood arrays A and B become more complex. In both pymdp and SPM, we encode a multi-modality A array as a collection of sub-arrays A[m], with one for each observation modality. Each modality-specific A array then represents the conditional probability of observations for modality , given the different configuration of hidden states, i.e., . Note that the ‘first’ index of the larger A array selects a particular modality from the collection, e.g. A_modality = A[m], whereas the subsequent multi-index into the modality-specific A array selects conditional probabilities or arrays of such probabilities, e.g., A_modality[0, 2, 3, ...]. Each A[m] thus encodes all probabilistic dependencies between the different hidden state factors hidden states and observations for the th modality: . These complex conditional relationships are encoded by accordingly higher-dimensional NDarrays in NumPy, with the number of lagging dimensions encoding the number of hidden state factors that the observations depend on. Such factorized generative models require more involved belief updating algorithms to achieve posterior inference, usually invoking a factorized approximate posterior (e.g. a mean-field factorization), where the full posterior over all hidden state factors is factorized into a product of marginals , where each is the posterior for hidden state factor . Fortunately, pymdp easily accommodates such higher-dimensional, factorized generative models and will automatically perform message passing with respect to such generative models with an arbitrary number of observation modalities and hidden state factors. See Appendix A: Factorized Generative Models for more details on multi-factor generative models.

Pedagogical materials & code

For more example code detailing how to use pymdp to simulate active inference in discrete state-space environments, we refer the reader to the tutorials found in the official documentation for the repository: https://pymdp-rtd.readthedocs.io/.

Customizability

pymdp offers a high degree of customizability in designing bespoke active inference processes, such that the methods of the Agent class can be called in any particular order, depending on the application, and furthermore they can be specified with various keyword arguments that entail choices of implementation details at lower levels.

For instance, if one wanted to model a purely ‘perceptual’ task, i.e., where the agent has no ability to act, but is only concerned with hidden state estimation, then one could write an active inference loop where the Agent class only uses the infer_states() function. This offers an advantage over the main function used to perform active inference in SPM, spm_MDP_VB_X.m where customization applications are limited and in practice are implemented by modifying parts of the function by hand to suit one’s needs (e.g. commenting out certain sections or adding in bespoke computations).

Moreover, by retaining a modular structure throughout the package’s dependency hierarchy, pymdp also affords the ability to flexibly compose different low level functions. This allows users to customize and integrate their active inference loops with desired inference algorithms and policy selection routines. For instance, one could sub-class the Agent class and write a customized step() function, that combines whichever components of active inference one is interested in.

Related software packages

The DEM toolbox within SPM in MATLAB is the current gold-standard in active inference modelling. In particular, simulating an active inference process in DEM consists of defining the generative model in terms of a fixed set of matrices and vectors, and then calling the spm_MDP_VB_X.m function to simulate a sequence of trials. pymdp, by contrast, provides a user-friendly and modular development experience, with core functionality split up into different libraries that separately perform the computations of active inference in a standalone fashion. Moreover, pymdp provides the user the ability to write an active inference process at different levels of abstraction depending on the user’s level of expertise or skill with the package – ranging from the high level Agent functionality, which allows the user to define and simulate an active inference agent in just a few lines of code, all the way to specifying a particular variational inference algorithm (e.g. marginal-message passing) for the agent to use during state estimation. In SPM, this would require setting undocumented flags or else manually editing the routines in spm_MDP_VB_X.m to enable or disable bespoke functionality.

pymdp has extensive, organized documentation and illustrative examples. While the DEM toolbox is also replete with interesting examples that result in beautiful visualizations of simulated behavior and synthetic neural responses, the available usage information for each function remains limited to doc-strings in the source code. The closest to documentation or an instruction manual for spm_MDP_VB_X.m is the comprehensive tutorial by Smith et al. 2021 [smith2022step], which features a series of MATLAB tutorial scripts that walk through the different aspects of active inference, with a focus on applications to modelling (behavioral and neurophysiological) empirical data.

A recent related, but largely non-overlapping project is ForneyLab, which provides a set of Julia libraries for performing approximate Bayesian inference via message passing on Forney Factor Graphs [ForneyLab2019]. Notably, this package has also seen several applications in simulating active inference processes, using ForneyLab as the backend for the inference algorithms employed by an active inference agent [van2019simulating, vanderbroeck2019active, ergul2020learning, van2021chance]. While ForneyLab focuses on including a rigorous set of message passing routines that can be used to simulate active inference agents, pymdp is specifically designed to help users quickly build agents (regardless of their underlying inference routines) and plug them into arbitrary environments to run active inference in a few easy steps.

Funding Statement

CH and IDC acknowledge support from the Office of Naval Research grant (ONR, N00014- 64019-1-2556), with IDC further acknowledging support from the European Union’s Horizon 2020 research and innovation programme under the Marie Skłodowska-Curie grant agreement (ID: 860949), the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany’s Excellence Strategy-EXC 2117- 422037984, and the Max Planck Society. KF is supported by funding for the Wellcome Centre for Human Neuroimaging (Ref: 205103/Z/16/Z) and the Canada-UK Artificial Intelligence Initiative (Ref: ES/T01279X/1). CH, DD, and BK acknowledge the support of a grant from the John Templeton Foundation (61780). The opinions expressed in this publication are those of the author(s) and do not necessarily reflect the views of the John Templeton Foundation.

Acknowledgements

The authors would like to thank Dimitrije Markovic, Arun Niranjan, Sivan Altinakar, Mahault Albarracin, Alex Kiefer, Magnus Koudahl, Ryan Smith, Casper Hesp, and Maxwell Ramstead for discussions and feedback that contributed to development of pymdp. We would also like to thank Thomas Parr for pointing out a technical error in an earlier version of the paper. Finally, we are grateful to the many users of pymdp whose feedback and usage of the package have contributed to its continued improvement and development.

[title=References]

Appendix A: Factorized Generative Models

In this appendix, we explain observation and state factorization by using an in-depth example. Let’s imagine a scenario where you have to infer two simultaneous states of the world, given some sensory data. The two facts you need to estimate are 1) what time of day it is (morning, midday, or evening), and 2) whether it rained recently (yes or no). We can represent this in a generative model as an environment characterized by two discrete random variables or hidden state factors. The first variable or factor we can call the

time-of-day state, which has three levels: Morning, Midday, or Evening; the second factor we can call the did-it-rain state, which has two levels: Rained and Did-not-rain. We use the following notation to denote these hidden state factors and their respective levels:

where we assign each state level a cardinal index, e.g. , , , and . Let’s now augment our simple hidden state representation with observations, which a hypothetical agent would use to infer both the time of day and whether it rained recently, i.e. to obtain a posterior distribution over : . Our observations will also be factorized, into two different modalities or information channels. Each of these modalities is also a discrete-valued random variable. Let’s imagine our two modalities are: 1) the ambient light level (dark, cloudy, or sunny) and 2) humidity (dry or humid). We denote the observations as follows:

Having specified a factorized representation of both states and observations, we can now consider how observations lend evidence for or against different states-of-affairs in the environment. For example, if you notice it’s dark outside (i.e., ), that provides evidence to suggest that it’s night time, rather than being morning or midday. At the same time, you might also notice that the air is humid through your humidity modality, i.e. . We can imagine that the humidity observation provides no evidence for the time of day it is, but it may suggest that it rained recently.

These probabilistic relationships between the observation modalities and the hidden state factors, which are used to perform inference, are encoded in the observation likelihood , represented in pymdp as a modality-specific sub-array A[m]. Recall that each likelihood array encodes the conditional dependencies between each setting of the hidden states and the observations within modality , i.e. A[m] = . Therefore the dimensionality of a given A[m] array is , where is the dimensionality of modality and is the dimensionality of factor . Following our simple example of inferring the time of day and whether it just rained, the likelihood for the first modality would be a 3D array that represents the likelihood distribution . Specifically, entry A[i,j,k] encodes the probability of observing level , given level and level . The second observation modality accordingly has its own likelihood NDarray, encoding the likelihood distribution .

These higher-dimensional likelihood arrays enable complex, conjunctive relationships to be encoded in the generative model. For instance, we might imagine that the observation depends on both the time of day and whether it just rained. For instance, all else being equal we might expect the ambient lighting to be sunny, if the time of day is midday. However, if it just rained then the probability of the ambient lighting being dark or cloudy might be higher, even if the time of day is midday. This statement already requires a nonlinear, conjunctive relationship between and . For example the probability distribution over the 3 levels of , given is and is would be encoded by the corresponding likelihood: . This would then be easily encoded in the corresponding ‘slice’ of the high-dimensional A[light] array: A[light][:,Mo,Rain], where light is the index of the A array that encodes the likelihood, and Mo and Rain are the cardinal indices for those factor-specific state levels.

In the same way that the observation likelihood for each modality is represented using a single likelihood NDarray A[m], the transition likelihood for each hidden state factor is represented using a single likelihood NDarray B[f], where the size of the th transition array is of size . So analogously to having a collection of A[m] arrays, one for each observation modality, we also have a collection of B[f] arrays, one for each state factor. Two important things to note are that: 1) constructing B matrices with this shape assumes that hidden state factors cannot influence each other dynamically, i.e. the next state within a factor only depends on the past state for that factor and control state , and 2) that control states are factorized just like hidden states, such that for each hidden state factor there is a corresponding control factor whose dimensionality is equal to the number of control state levels or actions that can be taken upon hidden state factor . It is of course allowable to have uncontrollable hidden state factors - in which case we simply set the dimensionality of the corresponding control factor to 1, i.e. . We often refer to these as ‘trivial’ control factors, since they don’t actually encode any kind of control.

The factorized representation described above offers several advantages. First of all, if a large hidden state space can be factorized into a collection of one-dimensional representations, then the memory cost of storing the relevant probability distributions over hidden states can be greatly finessed (e.g. , , etc.). For example, if you can represent the identity of an object and its location independently, without having to enumerate all the combinations of both its identity and its location together, then the amount of memory used to store the factorized representation will be linear in the dimensionality of the two hidden state factors, whereas the ‘enumerated’ representation will be polynomial. For example, if location is a 1000-dimensional vector, and identity is a 1000-dimensional vector, then storing two 1000-dimensional vectors is considerably cheaper than storing a single -dimensional vector.

Another advantage is the degree of interpretability and model transparency that factorized representations afford; a particular factorization is ideally explicitly designed, such that hidden state factors are directly mapped to intuitive features of the environment whose relationships are easy to reason about. If a multi-factor model of, for example, 3 hidden state factors (e.g. the location, identity, and time of some event) were fully enumerated into a single, 1-dimensional hidden state, then each level of the single hidden state would correspond to a unique combination of “what”, “where” and “when”. When it comes to encoding probabilistic relationships in the generative model (e.g. the observation and transition models), it becomes harder to visualize and reason about the relationships between such high-dimensional state combinations. Thus factorization also proves a useful tool when designing generative models based on prior domain or task knowledge.

Interestingly, when one optimizes the factorial structure of a generative model, using marginal likelihood or variational free energy, the best factorisation maximizes marginal likelihood (a.k.a., model evidence) by minimizing complexity: namely, the degrees of freedom used to provide an accurate account of observations. This is an important aspect of active inference; namely, that to provide the best account of observations – that precludes overfitting and ensures generalization – the (mean-field) factorisation should be as simple as possible but no simpler.

Finally, inference also may take advantage of the factorized structure of the generative model. In doing so, inference is not only more memory-efficient, but the belief-updating algorithms have features like functional specialization [mishkin1983object, zeki1991direct] and local message-passing that have been linked to features of computation in the brain [jardri2017experimental, leptourgos2020circular]. It has been argued that this affords factorized generative models a higher degree of biological plausibility [parr2019neuronal, parr2020modules].

Appendix B: Modules and Theory

inference.py

In this section we provide an overview of the inference module of pymdp and then briefly rehearse the mathematics of variational inference, both generally and as it is used in pymdp.

The inference.py file contains functions for performing variational inference about hidden states in discrete categorical generative models. Functions within this module are called by the infer_states() method of Agent. The core functions of this module are:

  • update_posterior_states(obs, A, prior=None, **kwargs): This function computes the variational (categorical) posterior over hidden states at the current timestep . This function by default calls the standard or ‘vanilla’ inference algorithm offered by pymdp, which estimates the marginal posteriors for each hidden state factor at the current timestep using fixed-point iteration. This function can be generically applied as the inference step in any discrete POMDP model, as all it requires are some observations obs, a likelihood array A and optionally a prior over hidden states prior, which will have the same structure as the resulting posterior. The additional arguments **kwargs contain parameters that will be passed to the run_vanilla_fpi() function in algos/fpi.py (see the documentation for more details).

  • update_posterior_states_full(A,B,prev_obs,policies, prev_actions=None,
    prior=None, policy_sep_prior=True, **kwargs
    ): This function computes the variational (categorical) posterior over hidden states under all policies: . The notation represents a trajectory of hidden states over time. This is inspired by the ‘full construct’ active inference process as implemented in spm_MDP_VB_X.m in DEM, where the full posterior over both hidden states and policies is computed: . This function itself calls the run_mmp.py function within the algos library, which estimates the marginal posteriors for hidden state factor , at time point under policy : , using marginal message passing [parr2019neuronal] (for more details, see the algos.py summary below). This function calls run_mmp.py once per policy, estimating the policy-conditioned posterior over all timepoints of the horizon and for all hidden state factors. The number of timepoints over which inference occurs is not necessarily identical to the total time horizon of the simulation: rather this time horizon is a function of the number of previous observations (len(prev_obs)) and the temporal depth of the policy under consideration (len(policies[j])). Thus the hidden state beliefs indexed by refer to a finite time horizon that is relative to the current timestep , where this horizon has a ‘lookback’ length and a ‘planning horizon’ . As arguments, this function requires the observation (A) and transition (B) likelihoods of a generative model, a list of previous (including current) observations (prev_obs), a list of policies (policies), an optional list of actions taken up until the current timepoint (prev_actions) an optional prior over hidden states at the start of the time horizon (prior), and a keyword argument policy_sep_prior, which determines whether the prior is itself conditioned on policies vs. unconditioned on policies . The additional **kwargs contain parameters that will be passed to the run_mmp() function in algos/mmp.py. For more details on run_mmp(), see the documentation.

Specific usage examples–in addition to descriptions of other specialized functions in inference.py–are more extensively covered in the official documentation.

Bayesian and Variational Inference

A central task in statistics is to perform inference, which can be mathematically represented as computing posterior distributions of one variable given another from a joint distribution. For example, suppose we are given some observation , and we then want to infer the likely state underlying that observation. Critical to achieving this task is the possession of a generative model, or joint distribution, that tells us how observations and hidden states are related. We can formulate the problem of inferring as finding a posterior distribution over the states, given the observation: . We can compute this posterior distribution using our generative model and Bayes Rule:

The generative model is often factorized into a likelihood and a prior , while is known as the marginal likelihood or model evidence and can be computed by solving the integral . This expresses the idea that the marginal probability of observations in is the sum (or integral) over all the ways that that depends on , for all possible settings of .

While Bayes rule provides a simple formula for relating the posterior distribution to the generative model, explicitly computing this distribution can often be difficult in practice due to the computational expense involved in performing the integral over all states necessary to compute the marginal likelihood . However, a number of approximate Bayesian inference methods have been developed which circumvent this computational difficulty at the expense of only returning approximately correct posterior distributions.

Variational inference [beal2003variational, wainwright2008graphical] is a widely used and well understood approach for performing approximate Bayesian inference. The central idea in variational inference is that instead of directly computing the posterior, we instead optimize the parameters of an arbitrary distribution so as to minimize the divergence between this distribution and the true posterior. This arbitrary distribution is often named the approximate posterior, because in the course of minimizing the divergence, the arbitrary distribution becomes an approximation to the true posterior, i.e. . In this way, variational inference converts a challenging inference problem (involving computing intractable integrals) into a relatively straightforward optimization problem, for which many powerful algorithms exist in the optimization literature.

Ideally, variational inference would directly minimize the Kullback-Leibler divergence between the approximate and true posteriors:

In the present form,, this objective is also intractable since it depends on the true posterior , whose approximation is our goal. However, by supplementing the KL divergence with the log marginal likelihood, which does not depend upon , we can convert the above KL divergence into an upper bound on the log marginal likelihood, called variational free energy (VFE) . Crucially, this can be rearranged into a computable form:

Thus, by minimizing , we minimize the divergence between the approximate and true posterior, thus forcing the approximate posterior to more closely resemble the true one. Moreover, if this optimization finds the exact solution, such that then value of provides the marginal likelihood () which can then be used for model selection and structure learning. More generally, is known as an evidence bound, because the KL divergence can never be less zero [mackay1992bayesian, penny2004modelling].

Variational Inference in pymdp

Active inference agents in pymdp perform inference over both hidden states and policies in a POMDP generative model. In this appendix, we only consider inference over states, while inference over policies is treated in the following section control.py.

Recall the POMDP generative model is a joint distribution over observations, states, policies, and parameters. For the purposes of hidden state inference, we will condition the whole generative model on some fixed policy , so that we can re-write it as follows:

(2)

For an active inference agent equipped with this generative model, instantaneous inference consists in optimizing an approximation to the posterior over the current hidden state: given the past and future states , and the observations collected up until the current timepoint . Mathematically, this inference can be described as minimizing the following free energy over trajectories with respect to the variational parameters :

Thus the goal of hidden state inference is the optimization of variational parameters which parameterise the approximate posterior over hidden states: . In the case of our discrete POMDP generative model, the variational parameters are the sufficient statistics of categorical distributions. Fortunately, these parameters are easy to interpret, since they are identical to the probabilities of sampling each outcome level in the distribution’s support: i.e. . For example, the variational parameters of some categorical distribution would simply be . In the equations to follow, we therefore exclude the variational parameters when writing the variational posterior, referring to it hereafter as simply . Below we describe the variational inference methods currently offered by pymdp as of this document’s writing (November 2021).

The run_vanilla_FPI function (within algos/fpi.py) implements an inference algorithm known as fixed-point iteration to optimize the posterior over hidden states at a given timestep . Central to this algorithm is the assumption of a factorized structure to the variational posterior, such that the posterior at time is independent of the posterior at any other timestep , where . In addition to this temporal factorization, we further assume the posterior at a given timestep is factorized across different hidden state factors: i.e. is independent of (see the section Factorized representations for more on multi-factor hidden states). This factorization is also known as a mean-field approximation in the statistics and physics literatures and can be expressed as follows:

Given this factorization, the full free energy over trajectories now also factorizes into a sum of free energies across time, which can be minimized independently of each other. Thus, for a given time , we can write the time-dependent free energy333For the remainder of the section we remove the hyperparameters from the generative model and approximate posterior, but inference over these are treated in the section on learning.py:

(3)

where we now use the bold notation and to express potentially multi-factor (or -modal) hidden states (or observations) in the generative model, in the same way that the variational posterior is factorized. Inference proceeds by optimizing in order to minimize the timestep-specific free energy .

Importantly, we can solve for the variational posterior analytically for a given timestep and factor by setting the derivative of the free energy to and solving for . We express this partial derivative as , and can express it as follows:

(4)

where the expectation denotes an expectation with respect to all posterior marginals besides the marginal currently being optimized, and is a normalized exponential or softmax function. The update equation for each marginal posterior offers an intuitive Bayesian interpretation, where the belief about the current state is the product of an observation likelihood term and a ‘prior’ term, , where the prior is dynamically determined by the previous state, previous action, and the transition likelihood. Note that in practice we set the prior at any given timestep equal to the posterior optimized at the previous timestep, i.e. , where is a Dirac delta function over the action actually taken (the agent has perfect knowledge of the action it just took). This means that the prior term in the last line of Equation (4) can be rewritten as . In the run_vanilla_fpi.py function of pymdp, the fixed point equation is solved iteratively for each marginal posterior , using the latest solution for the other marginals to compute the expected log-likelihood term for the marginal currently being updated: . Note that the prior only depends on the marginal currently being updated, because in pymdp the transition likelihoods are assumed to be independent across hidden states, i.e. hidden states from factor do not determine the dynamics of hidden states of another factor (see Appendix A: Factorized Generative Models for details). Given enough iterations,444The default number of iterations for run_vanilla_fpi() function is num_iter=10 but in many generative models with precise likelihood arrays (i.e. low entropy rows/columns), convergence is often achieved in many fewer iterations. This is further controlled by a tolerance parameter that tracks the change in the free energy across iterations. the fixed point equations converge to a unique solution of the variational posterior .

In pymdp, the approximate posterior is represented as a collection of 1-D NumPy arrays (e.g. qs), where individual elements of the collection (e.g. qs[f]) store the marginal posterior for a particular hidden state factor. The likelihood distributions are represented by A and B arrays. If we take the limiting case of a generative model and variational posterior with a single hidden state factor, the update equation for the variational posterior at a given timestep using fixed point iteration reduces to a single line of NumPy code:

    qs_current = softmax(np.log(A[o,:])+ np.log(B[:,:,u_last].dot(qs_last)))

where utility functions like softmax are available from the utils.py and maths.py modules of pymdp. In the default initialization of the Agent class, the infer_states() method will call the update_posterior_states() function of the inference module; this in turns calls upon fixed point iteration (via run_vanilla_fpi()) to update the variational posterior over hidden states. Therefore all that is required for inference at a given timestep is to provide some observation obs to infer_states(). The prior over hidden states is automatically updated within the Agent class, where prior will either A) equal the initial belief about hidden states (the D vector) in the case that ; or B) the dot product of the transition likelihood conditioned on the last action B[:,:,u_last], and the posterior at the last timestep qs_last in the case that : prior=B[:,:,u_last].dot(qs_last).

Another, more complex inference algorithm known as marginal message passing (MMP) is also implemented in the run_mmp() function, also found within the algos module. Marginal message passing makes weaker assumptions about the factorization of the variational posterior, and incorporates the computational advantages of two well-known message passing algorithms: belief propagation [yedidia2000generalized] and variational message passing [winn2005variational]. In practice, using marginal-message passing instead of standard fixed-point iteration enables more accurate inference due to its less restrictive assumptions as to the form of the variational posterior, at the expense of additional computational cost. For the purpose of brevity and since this algorithm has been discussed in detail elsewhere (specifically, see Appendix C of [friston2017process] as well a comprehensive treatment in [parr2019neuronal]

), we will not describe the mathematics behind marginal message passing here. It is worth noting that for beginning users, a standard active inference simulation will not require marginal-message passing to achieve the desired behavior; state inference achieved with instantaneous fixed-point iteration often suffices for practitioners interested in simulating a target behavior. However, cognitive neuroscientists are often interested in modelling neuronal responses based on estimated inferential dynamics. In this case, more sophisticated schemes like

run_mmp() may be required, where actual dynamics of belief updating might be used as a forward model of hypothesized electrophysiological processes (e.g. local field potentials or spiking activity). Finally, we also mention that in order to achieve identical behavior to active inference agents simulated using spm_MDP_VB_X.m, it is necessary to use run_mmp().

In practice, a desired inference algorithm can be specified by passing the name of the algorithm into the Agent() constructor, e.g. my_agent = Agent(...,inference_algo=‘MMP’). For more detailed instructions on how to initialize an Agent with different customization options, please see the documentation.

control.py

The core functions implementing policy inference and action selection (i.e. control) in pymdp can be found in the control.py file. As with inference, these functions are called by methods of Agent like infer_policies() and sample_action(), but can also be directly imported from the control library and used for custom applications. Below we briefly summarize the core functions of the control module:

  • update_posterior_policies(qs, A, B, C, policies, use_utility=True,
    use_states_info_gain=True, use_param_info_gain=False, pA=None, pB=None,
    E=None, gamma=16.0)
    : This function computes the posterior over policies using an initial posterior belief about hidden states at the current timestep qs

    . Specifically, it computes the expected free energy of each policy (discussed further in the next section) by summing the expected free energies over a future path in the case of multi-timestep or ‘temporally-deep’ policies. This function first loops over all policies, computes the expected states and observations under each policy, and then sums the expected free energies calculated from those predicted future states and observations. The expected free energy for each policy is combined with its prior probability under the generative model

    ) (in pymdp represented by the E vector) and softmaxed to determine the posterior over policies (in pymdp represented by q_pi). Optional Boolean parameters like use_utility and use_states_info_gain can be turned on and off to selectively enable (disable) computation of components of the expected free energy (see the section on The Expected Free Energy for information on these different expected free energy terms).

  • update_posterior_policies_full(qs_seq_pi, A,B,C, policies,use_utility=
    True, use_states_info_gain=True, use_param_info_gain=False, prior=None,
    pA=None, pB=None,
    F=None, E=None, gamma=16.0): This function computes the posterior over policies using a posterior belief over hidden states over multiple timesteps under all policies. This version differs from the standard function, update_posterior_policies(), in that the expected hidden states over future timepoints, under different policies, have already been computed in the input, the posterior beliefs qs_seq_pi. This function for policy inference should thus be used in tandem with the ‘advanced’ inference schemes (like marginal message passing) where posterior beliefs over multiple timesteps, under all policies, are computed during the inference step. As a consequence, this function only computes the expected observations under each policy for all future timesteps, and then uses the expected states (already part of the inputs) and expected observations under all policies to calculate the expected free energy for each policy. This is integrated with prior belief about policies E and the variational free energy of policies F (see the section on Policy Inference for more information on the variational free energy of policies) to finally determine the posterior over policies , often represented in pymdp as q_pi. This function’s remaining arguments (e.g. use_utility) are identical to how they are used in the standard update_posterior_policies() function.

  • get_expected_states(qs, B, policy): This function computes a posterior distribution over future states given a current state distribution (qs), a transition model (B) and a policy (policy). Specifically, this function projects the current beliefs about hidden states forward in time by iteratively taking the inner product of qs with the action-conditioned B matrix, where the actions are those entailed by the policy.

  • get_expected_obs(qs_pi, A): This function computes the observations expected under a (policy-conditioned) hidden state distribution qs_pi. In the case of a sequence of hidden states over time, qs_pi will be a list of hidden states distributions with one element per timestep e.g. qs_pi[t]. This function only requires an expected state distribution qs_pi and an observation model A.

  • calc_expected_utility(qo_pi, C): This function computes the extrinsic value or utility part of the expected free energy using the prior preferences or ‘goal distribution’ encoded by the C vector. The C is encoded in terms of relative log probabilities and thus need not be a proper probability distribution.

  • calc_states_info_gain(A, qs_pi): This function computes intrinsic value or information gain part of the expected free energy, in particular the information gain or epistemic value about hidden states .

  • calc_pA_info_gain(pA, qo_pi, qs_pi): This function computes the information gain about the Dirichlet prior parameters over the observation model (A) , also known as the ‘novelty’ term of the expected free energy [friston2017active]. It requires a Dirichlet prior over the observation model pA, an expected observation distribution qo_pi and an expected state distribution qs_pi. It is recommended to include this information gain term in the expected free energy calculation, when also simultaneously performing A array learning (i.e. inference over Dirichlet hyperparameters), since it leads to the agent exploring regions which lead to the largest updates of the parameters of A .

  • calc_pB_info_gain(pB, qs_pi, qs_prev, policy): This function computes the information gain about the Dirichlet prior parameters over the transition model (B), also known as the ‘novelty’ term of the expected free energy. It requires a Dirichlet prior over the transition model pB, an expected state distribution under a policy qs_pi, an initial state distribution qs_prev, and a policy policy. It is recommended to include this information gain term in the expected free energy calculation, when also simultaneously performing B array learning (i.e. inference over Dirichlet hyperparameters).

  • construct_policies(num_states, num_controls=None, policy_len=1,
    control_fac_idx=None): This is a utility function which builds an array of policies by combinatorially enumerating them from a set of actions and a time horizon. It can be used to construct a full set of policies based on the time horizon and action space of the environment, if the policy set is not explicitly stated by the user.

  • sample_action(q_pi, policies, num_controls,
    action_selection="deterministic", alpha=16.0)
    : This function samples an action, given the posterior distribution over policies and a desired sampling scheme. In particular, this function computes the posterior over control states by marginalising the posterior over policies with respect to each control state, i.e. , where is the mapping between policies and control states. To obtain an action, the most probable action is either A) selected deterministically (action_selection="deterministic") as the most probable control state or B) an action is sampled from the control posterior (action_selection="stochastic"), using a Boltzmann distribution with inverse temperature given by alpha.

Control in Active Inference

Policy inference consists in computing the ‘goodness’ or ‘quality’ of each policy, given the ability to compute the expected consequences of each policy and the agent’s goals. In active inference this is done by using a quasi-utility function known in the literature as the Expected Free Energy (EFE) (often denoted ). Under active inference, agents are equipped with a particular prior over policies that assumes policies are inversely proportional to the free energy expected under their pursuit, i.e.:

(5)

Equipped with this policy prior in the generative model, active inference agents perform policy inference by optimizing , the variational posterior over policies. As we shall see in the section Policy Inference, computing entails computing the expected free energy of each policy (the contribution from the prior) as well as the variational free energy of policies (analogous to the ‘evidence’ for each policy).

The Expected Free Energy

The expected free energy is the crucial component that determines the behavior of active inference agents. The EFE is designed to be similar to the VFE of standard variational inference but with two major modifications to enable its use as an objective which, when minimized, will perform goal seeking behavior rather than simply inference. Firstly, since the EFE ranks future performance, where future observations are not known, it contains an expectation over future observations. Secondly, as there needs to be a way to integrate the notion of goals or rewards into the inference procedure, the EFE alters the generative model of the agent to be ‘biased’ in such a way that it predicts the agent reaches rewarding or a priori preferred states [parr2019generalised]. Thus, performing inference to maximize the likelihood of visiting these rewarding states naturally leads to policies that help the agent achieve its goals. Moreover, an additional benefit is that minimizing the EFE also entails an exploratory, inherently uncertainty-reducing component to behavior. This endows behavior with an additional ‘epistemic drive’ which aids in computing the optimal long-term policies [friston2015active]. For in-depth discussion of the nature of the EFE and the exploratory drive it induces please see [friston2015active, friston2017process, millidge2021whence, millidge2021understanding].

The expected free energy is a function of observations, states, and policies, and is defined mathematically as:

(6)

where represents a generative model ‘biased’ towards the preferences of the agent. We can write this predictive generative model at a single timestep, under a given policy, as , where represents a ‘predictive prior’ over observations, represented in pymdp with the C array. Given the factorization of the approximate posterior over time, the EFE for a single policy and timestep can also be defined as follows:

(7)

where the first term, the epistemic value [friston2015active], encourages the pursuit of policies expected to yield high information gain about hidden states, expressed here as the divergence between the states predicted under a policy, with and without conditioning on observations. The second term represents the degree to which expected outcomes under a policy will align with prior preferences over observations. Since the prior over policies is inversely proportional to the expected free energy, policies will thus be more likely if they visit states that resolve uncertainty (maximize epistemic value) and satisfy prior preferences (maximize utility). The epistemic value terms give active inference agents a degree of superior exploration capacity compared to standard reinforcement learning agents. In pymdp, the EFE is computed using exactly this decomposition into epistemic value and utility, where the expected approximation error (penultimate line of Equation (7)) is implicitly assumed to be 0, so the bound becomes equality. The utility term is computed by the function calc_expected_utility() while the epistemic value term (also known as the information gain) is computed by the function calc_states_info_gain(). Both of these functions are found within control.py. The computation of the utility term is particularly straightforward for categorical distributions, since it reduces to the dot product of the expected observations under a policy with the log of the prior preferences or ‘goal vector’ , i.e. the C array.

Parameter Information Gain

In the case where the agent also maintains a variational posterior over parameters , the timestep- and policy-dependent EFE has an augmented form, since it needs to account for the expected information gain over both hidden states and parameters [friston2017active]:

(8)

So now the EFE is supplemented with an additional epistemic value, the so called ‘parameter’ epistemic value or ‘parameter information gain’. This additional term arises when the approximate posterior includes variational beliefs about model hyperparameters: . The optimization of the posterior over model parameters is handled in the next section on the learning.py

module. This presence of this term in the expected free energy mediates what’s also been referred to as ‘active learning’ or ‘model exploration’, i.e. the drive to resolve uncertainty about the parameters of one’s generative model

[schwartenbeck2019computational].

In the discrete state space case implemented in pymdp

, this parameter epistemic value is computed with respect to the Dirichlet parameters (conjugate priors over categorical distributions) that parameterise the prior and approximate posterior over the likelihoods and priors over the generative model, i.e. the

A, B, C and D arrays. This is implemented as of the time of writing (December 2021) for information gain about the parameters of the A array and B array, parameterised respectively by the Dirichlet conjugate priors pA and pB. The relevant functions for computing these information gains are calc_pA_info_gain(). and calc_pB_info_gain().

Policy Inference

Given the definition of the expected free energy in Equations (6) and (7), we now are equipped to describe posterior inference over policies, i.e., how to obtain .

We begin by expanding the variational free energy as defined in Equation (2), dropping parameters for simplicity:

(9)

where the variational free energy of a particular policy is defined as follows:

(10)

The optimal posterior that minimizes the full variational free energy is found by taking the derivative of with respect to and setting this gradient to , yielding the following free-energy-minimizing solution for :

(11)

where the prior over policies is the softmax of the negative expected free energy . Note that in the case of "temporally deep" or multi-timestep policies, the expected free energy of a given policy is the sum of the timestep-specific expected free energies:

(12)

In pymdp and the DEM toolbox of MATLAB, one has the option of augmenting the prior over policies with a ‘baseline policy’ or ‘habit vector’ , also referred to as the E vector. This means the full expression for the optimal posterior can be written as (expanding as ):

(13)

This means the inferred policy distribution combines influences from the expected free energy of each policy (), a baseline prior probability assigned to each policy () and the variational free energy of each policy (). Numerically, policy inference is achieved by computing the expected and variational free energies of each policy and then combining them with the policy prior (the E vector) before softmaxing them. The expected free energy is computed per policy as the integral of the timestep-specific expected free energies, as shown in Equation (7

). This is achieved by computing the ‘posterior predictive densities’ expected under each policy:

and using those densities to compute and add together the epistemic value and utility for each policy. These posterior predictive densities are simply the posterior beliefs at the current timestep ‘multiplied through’ the transition and observation models (in code: the A and B arrays) over the temporal horizon of the policy. By doing this iteratively across policies, the

vector ends up storing a ‘cost’ for each policy, which is then integrated with the policy prior and variational free energy of each policy to determine the posterior probability of each policy, stored in

. This boils down to the following single line of NumPy code:

    q_pi = softmax(-G + np.log(E) - F)

In pymdp, the functions update_posterior_policies() and update_posterior_policies
_full()
of the control module perform the calculations needed for policy inference, and themselves are called by the infer_policies method of Agent.

learning.py

In this section we will summarize the functions in the learning module and then derive the update equations for updating model parameters of the likelihood and prior distributions that comprise POMDP generative models.

The functions used to implement inference over model parameters can be found in the learning.py file. These functions are called by methods of Agent like update_A(), update_B(), and update_D(). We survey the most important functions of the learning module below:

  • update_obs_likelihood_dirichlet(pA, A, obs, qs, lr=1.0, modalities="all"):
    This function computes the posterior Dirichlet parameters over the A array or observation model . As input arguments this function requires the current Dirichlet prior pA over the parameters of the A array, the current value of the categorical A array (which is also the expected value of the Dirichlet prior pA), an observation obs, the current posterior beliefs about hidden states qs, a learning rate lr and a list of which observation modalities to update, modalities. The default setting is to update the A arrays associated with all observation modalities (modalities = "all"), but this extra argument allows one to only update specific sub-arrays of a larger multi-modality A array. For example, modalities = [0, 1] would only update sub-arrays A[0] and A[1]. The learning rate parameter scales the size of the update to the posterior over the A array.

  • update_state_likelihood_dirichlet(pB, B, actions, qs, qs_prev, lr=1.0,
    factors="all")
    : This function computes the posterior Dirichlet parameters over the B array or transition model . As input arguments this function requires the current Dirichlet prior pB over the parameters of the B array, the current value of the categorical B array, the posterior beliefs about hidden states at the current timestep qs, the posterior beliefs about hidden states at the previous timestep qs_prev, a learning rate lr and a list of which hidden state factors to update, factors. The default setting is to update the B arrays associated with all hidden state factors (factors = "all"), but this extra argument allows you to only update specific sub-arrays of a larger multi-factor B array. For example, factors = [0, 1] would only update sub-arrays B[0] and B[1]. The learning rate parameter scales the size of the update to the posterior over the B array.

  • update_state_prior_dirichlet(pD, qs, lr=1.0, factors="all"): This function computes the posterior Dirichlet parameters over the D array or prior over initial hidden states . As input this function requires the current Dirichlet prior pD over the parameters of the D array, the posterior beliefs about hidden states at the current timestep qs, a learning rate lr and a list of which hidden state factors to update, factors. The default setting is to update the D vectors associated with all hidden state factors (factors = "all"), but this extra argument allows you to only update specific sub-vectors of a larger multi-factor D array. For example, factors = [0, 1] would only update sub-arrays D[0] and D[1]. The learning rate parameter scales the size of the update to the posterior over the D array.

Inference of POMDP model parameters

Under active inference, learning is cast as inference about model parameters, and in the context of neuroscience is often analogized to slower-scale changes to inter-neuronal synaptic weights (e.g. Hebbian learning) [da2020active]. Parameter inference is referred to as ‘learning’ because it is often assumed to occur on a fundamentally slower timescale than hidden state and policy inference [friston2016active]. However, the update equations for model parameters follow the exact same principles as hidden state inference - namely, we optimize a variational posterior over model parameters by minimizing the variational free energy .

For the POMDP generative models used in pymdp, learning manifests as posterior inference over hyperparameters of the (categorical) likelihood and priors of the generative model. We use Dirichlet distributions as conjugate priors for the categorical distributions555a prior is called conjugate to a likelihood when the resulting posterior is the same distribution family as the prior, meaning that the hyperparameters become the parameters of Dirichlet distributions. This choice of parameterization results in remarkably simple and biologically-plausible updates for the posteriors over these parameters, wherein ‘fire-together-wire-together’-like Hebbian increments are used to learn the parameters as a function of observations. Below we derive the update rule for Dirichlet hyperparameters over the A, B, and D arrays.

To begin, we augment the POMDP generative model in (1) with the parameters of the likelihood and prior categorical distributions and Dirichlet priors over each of them. In order to do this, we divide the hyperparameters in the into subsets that correspond to the categorical and Dirichlet parameters over the A, B, and D arrays:

(14)

where the notation denotes the column of a matrix . Under this parameterisation, , , and are arrays of categorical parameters (i.e. probabilities) that ‘fill out’ the entries of the A, B, and D arrays respectively. The Dirichlet parameters , , and are similarly the parameters of Dirichlet priors over these categorical distributions, and have identical dimensionality to the distributions they parameterise. The Dirichlet parameters are constrained to be positive real numbers () that score the prior probability of each entry of the categorical distribution they parameterize. Dirichlet values, like the parameters of other common conjugate prior distributions, can be interpreted as ‘pseudo-counts’ measuring how often a particular outcome level is expected a priori (e.g. the prior probability assigned to a particular state-observation coincidence in the case of the distribution). Note that for notational convenience we assume the generative model is not factorized into multiple hidden state factors and observation modalities, but for generality one could add in additional indices to capture multiple hidden state factors and observation modalities. For instance, the most general form of a (potentially multi-modality and multi-factor) observation model would be:

Given the introduction of the new Dirichlet priors in Equation (14), we can now write down the augmented generative model, where the hyperparameters have been split into individual priors over , , and :

(15)

Given the new generative model with Dirichlet priors, we can now formulate learning as approximate inference about these parameters, i.e. optimizing variational posteriors over the likelihood and prior parameters. We begin by expanding our expression of the variational posterior to include beliefs over the values of the , , and distributions:

where now the variational parameters , , and are Dirichlet parameters of the approximate posteriors , , and , respectively. Performing inference with respect to , , and thus amounts to optimizing the variational Dirichlet parameters in order to minimize free energy. This is what is meant by ‘learning’ in active inference.

We will now step through the update rules for each of the Dirichlet posteriors over the , , and distributions. We begin by writing down the full variational free energy: