Approximate Inference in Discrete Distributions with Monte Carlo Tree Search and Value Functions

10/15/2019 ∙ by Lars Buesing, et al. ∙ 0

A plethora of problems in AI, engineering and the sciences are naturally formalized as inference in discrete probabilistic models. Exact inference is often prohibitively expensive, as it may require evaluating the (unnormalized) target density on its entire domain. Here we consider the setting where only a limited budget of calls to the unnormalized density oracle is available, raising the challenge of where in the domain to allocate these function calls in order to construct a good approximate solution. We formulate this problem as an instance of sequential decision-making under uncertainty and leverage methods from reinforcement learning for probabilistic inference with budget constraints. In particular, we propose the TreeSample algorithm, an adaptation of Monte Carlo Tree Search to approximate inference. This algorithm caches all previous queries to the density oracle in an explicit search tree, and dynamically allocates new queries based on a "best-first" heuristic for exploration, using existing upper confidence bound methods. Our non-parametric inference method can be effectively combined with neural networks that compile approximate conditionals of the target, which are then used to guide the inference search and enable generalization across multiple target distributions. We show empirically that TreeSample outperforms standard approximate inference methods on synthetic factor graphs.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Probabilistic (Bayesian) inference formalizes reasoning under uncertainty based on first principles

[cox1946probability, jaynes2003probability], with a wide range of applications in cryptography [turing1941applications], error-correcting codes [mceliece1998turbo], bio-statistics [robinson2010edger], particle physics [baydin2019etalumis], generative modelling [kingma2013auto], causal reasoning [pearl2000causality] and countless others. Inference problems are often easy to formulate, e.g. by multiplying non-negative functions that each reflect independent pieces of information, yielding an unnormalized target density (UTD). However, extracting, i.e. inferring, knowledge from this UTD representation, such as marginal distributions of variables, is notoriously difficult and essentially amounts to solving the SumProd problem [dechter1999bucket]:

where the UTD here is given by . For discrete distributions, inference is #P-complete [roth1996hardness], and thus at least as hard as (and suspected to be much harder than) NP-complete problems [stockmeyer1985approximation].

The hardness of exact inference, which often prevents its application in practice, has led to the development of numerous approximate methods, such as Markov Chain Monte Carlo (MCMC)

[hastings1970monte], Sequential Monte Carlo (SMC) methods [del2006sequential] and Variational Inference (VI) [jordan1999introduction]

. Whereas exact inference methods essentially need to evaluate and sum the UTD over its entire domain in the worst case, approximate methods attempt to reduce computation by concentrating evaluations of the UTD on regions of the domain that contribute most to the probability mass. The exact locations of high-probability regions are, however, often unknown a-priori, and different approaches use a variety of means to identify them efficiently. In continuous domains, Hamiltonian Monte Carlo and Langevin sampling, for instance, guide a set of particles towards high density regions by using gradients of the target density

[neal2011mcmc, roberts2002langevin]. In addition to a-priori knowledge about the target density (such as a gradient oracle), adaptive approximation methods use the outcome of previous evaluations of the UTD to dynamically allocate subsequent evaluations on promising parts of the domain [mansinghka2009exact, andrieu2008tutorial]. This can be formalized as an instance of decision-making under uncertainty, where acting corresponds to evaluating the UTD and the goal is to discover probability mass in the domain [lu2018exploration]. Form this viewpoint, approximate inference methods attempt to explore the target domain based on a-priori information about the target density as well as on partial feedback from previous evaluations of the UTD.

In this work, we propose a new approximate inference method for discrete distributions, termed TreeSample, that is motivated by the correspondence between probabilistic inference and decision-making highlighted previously in the literature, e.g. [dayan1997using, rawlik2013stochastic, weber2015reinforced, wingate2013automated, schulman2015gradient, weber2019credit]. TreeSample

 approximates a joint distribution over multiple discrete variables by the following

sequential decision-making approach: Variables are inferred / sampled one variable at a time based on all previous ones in an arbitrary, pre-specified ordering. An explicit tree-structured cache of all previous UTD evaluations is maintained, and a heuristic inspired by Upper Confidence Bounds on Trees (UTC) [kocsis2006bandit] for trading off exploration around configurations that were previously found to yield high values of UTD and configurations in regions that have not yet been explored, is applied. Algorithmically, TreeSample amounts to a variant of Monte Carlo Tree Search (MCTS) [browne2012survey], modified so that it performs integration rather than optimization. In contrast to other approximate methods, it leverages systematic, backtracking tree search with a "best-first" exploration heuristic.

Inspired by prior work on combining MCTS with function approximation [silver2016mastering], we proceed to augment TreeSample with neural networks that parametrically cache previously computed approximate solutions of inference sub-problems. These networks represent approximate conditional densities and correspond to state-action value function in decision-making and reinforcement learning. This caching mechanism (under suitable assumptions) allows to generalize search knowledge across branches of the search tree for a given target density as well as across inference problems for different target densities. In particular, we experimentally show that suitably structured neural networks such as Graph Neural Networks [battaglia2018relational] can efficiently guide the search even on new problem instances, therefore reducing the effective search space massively.

The paper is structured as follows. In sec. 2 we introduce notation and set up the basic inference problem. In sec. 3, this inference problem is cast into the language of sequential decision-making and the TreeSample algorithm is proposed. We show in sec. 4 empirically, that TreeSample outperforms closely related standard approximate inference algorithms. We conclude with a discussion of related work in sec. 5.

2 Discrete Inference with Computational Budget Constraints

2.1 Notation

Let

be a discrete random vector taking values

in , and let be its -prefix and define analogously. We assume the distribution is given by a factor graph. Denote with its density (probability mass function) and with the corresponding unnormalized density:

(1)

where is the normalization constant. We assume that all factors , defined in the log-domain, take values in . Furthermore, for all are assumed known, where is the index set of the variables that takes as input. We denote the densities of the conditionals as .

2.2 Problem Setting and Motivation

Consider the problem of constructing a tractable approximation to . In this context, we define tractable as being able to sample from (say in polynomial time in ). Such a

then allows Monte Carlo estimates of

for any function of interest in downstream tasks without having to touch the original again. This setup is an example of model compilation [darwiche2002logical]. We assume that the computational cost of inference in is dominated by evaluating any of the factors . Therefore, we are interested in compiling a good approximation using a fixed computational budget: A brute force approach would exhaustively compute all conditionals , up to and resort to ancestral sampling. This entails explicitly evaluating the factors everywhere, likely including "wasteful" evaluations in regions of with low density , i.e. parts of that do not significantly contribute to . Instead, it may be more efficient to construct an approximation that concentrates computational budget on those parts of the domain where the density , or equivalently , is suspected to be high. For small budgets , determining the points where to probe should ideally be done sequentially: Having evaluated on values with , the choice of should be informed by the previous results . If e.g. the target density is assumed to be "smooth", a point "close" to points with large might also have a high value under the target, making it a good candidate for future exploration (under appropriate definitions of "smooth" and "close"). In this view, inference presents itself as a structured exploration problem of the form studied in the literature on sequential decision-making under uncertainty and reinforcement learning, in which we decide where to evaluate next in order to reduce uncertainty about its exact values. As presented in detail in the following, borrowing from the RL literature, we will use a form of tree search that preferentially explores points that share a common prefix with previously found points with high .

3 Approximate Inference with Monte Carlo Tree Search

In the following, we cast sampling from

as a sequential decision-making problem in a suitable maximum-entropy Markov Decision Process (MDP). We show that the target distribution

is equal to the solution, i.e. the optimal policy, of this MDP. This representation of as optimal policy allows us to leverage standard methods from RL for approximating . Our definition of the MDP will capture the following intuitive procedure: At each step we decide how to sample based on the realization of that has already been sampled. The reward function of the MDP will be defined such that the return (sum of rewards) of an episode will equal the unnormalized target density , therefore "rewarding" samples that have high probability under the target.

3.1 Sequential Decision-Making Representation

We first fix an arbitrary ordering over the variables ; for now any ordering will do, but see the discussion in sec. 5. We then construct an episodic, maximum-entropy MDP consisting of episodes of length . The state space at time step is and the action space is for all . State transitions from to are deterministic: Executing action in state at step results in setting to , or equivalently the action is appended to the current state, i.e. . A stochastic policy in this MDP is defined by probability densities over actions conditioned on for . It induces a joint distribution over with the density . Therefore, the space of stochastic policies is equivalent to the space of distributions over .

We define the maximum-entropy reward function of based on the scopes of the factors as follows:

Definition 1 (Reward).

For , we define the reward function , as the sum over factors that can be computed from , but not already from , i.e. :

(2)

where . We further define the maximum-entropy reward:

(3)

To illustrate this definition, assume is only a function of ; then it will contribute to . If, however, it is has full support , then it will contribute to . Evaluating at any input incurs a cost of towards the budget . This completes the definition of . From the reward definition follows that we can write the logarithm of the unnormalized target density as the return, i.e. sum of rewards (without entropy terms):

(4)

We now establish that the MDP is equivalent to the initial inference problem by using the standard definition of the value of a policy as expected return conditioned on , i.e.  where the expectation is taken over The following straight-forward observation holds:

Observation 1 (Equivalence of inference and max-ent MDP).

The value of the initial state under in the maximum-entropy MDP is given by the negative KL-divergence between and the target up to the normalization constant :

(5)

The optimal policy is equal to the target conditionals :

Therefore, solving the maximum-entropy MDP is equal to finding all target conditionals , and running the optimal policy yields samples from . In order to convert the above MDP into a representation that facilitates finding a solution, we use the standard definition of the state-action values as . This definition together with observation 1 directly results in (see appendix for proof):

Observation 2 (Target conditionals as optimal state-action values).

The target conditional is proportional to the optimal state-action value function, i.e.  where the normalizer is given by the value . Furthermore, the optimal state-action values obey the soft Bellman equation:

(6)

3.2 TreeSample Algorithm

In principle, the soft-Bellman equation 6 can be solved by backwards dynamic programming in the following way. We can represent the problem as a -ary tree over nodes corresponding to all partial configurations , root and each node being the parent of children to . One can compute all -values by starting from all leafs for which we can compute the state-action values and solve eqn. 6 in reverse order. Furthermore, a simple softmax operation on each yields the target conditional . Unfortunately, this requires exhaustive evaluation of all factors.

1:procedure Sample(tree , default state-action values )
2:     
3:     for  do
4:         if  then
5:              
6:         else
7:              
8:         end if
9:         
10:     end for
11:     return
12:end procedure
Algorithm 1 TreeSample sampling procedure

As an alternative to exhaustive evaluation, we propose the TreeSample algorithm for approximate inference. The main idea is to construct an approximation consisting of a partial tree and approximate state-actions values with support on . A node in at depth corresponds to a prefix , with the attached vector of state-action values for its children to (which might not be in tree themselves). Sampling from is defined in algorithm 1: The tree is traversed from the root and at each node, a child is sampled from the softmax distribution defined by . If at any point, a node is reached that is not in , the algorithm falls back to a distribution defined by a user-specified, default state-action value function ; we will also refer to as prior state-action value function as it assigns a state-action value before / without any evaluation of the reward. Later, we will discuss using learned, parametric functions for . In the following we describe how the partial tree is constructed using a given, limited budget of of evaluations of the factors .

3.2.1 Tree Construction with Soft-Bellman MCTS

TreeSample leverages the correspondence of approximate inference and decision-making that we have discussed above. It consists of an MCTS-like algorithm to iteratively construct the tree underlying the approximation . Given a partially-built tree , the tree is expanded (if budget is still available) using a heuristic inspired by Upper Confidence Bound (UCB) methods [auer2002finite]. It aims to expand the tree at branches expected to have large contributions to the probability mass by taking into account how important a branch currently seems, given by its current -value estimates, as well as a measure of uncertainty of this estimate. The latter is approximated by a computationally cheap heuristic based on the visit counts of the branch, i.e. how many reward evaluations have been made in this branch. The procedure prefers to explore branches with high -values and high uncertainty (low visit counts); it is given in full in algorithm 2 in the appendix, but is briefly summarized here.

Each node in , in addition to , also keeps track of its visit count and the cached reward evaluation . For a single tree expansion, is traversed from the root by choosing at each intermediate node the next action in the following way:

(7)

Here, the hyperparameters

and determine the influence of the second term, which can be seen as a form of exploration bonus and which is computed from the inverse visit count of the action relative to the visit counts of the patent. This rule is inspired by the PUCT variant employed in [silver2016mastering], but using the default value for the exploration bonus. When a new node at depth is reached, the reward function is evaluated, decreasing our budget . The result is cached and the node is added using to initialize . Then the -values are updated: On the path of the tree-traversal that led to , the values are back-upped in reverse order using the soft-Bellman equation. This constitutes the main difference to standard MCTS methods, which employ max- or averaging backups. This reflects the difference of sampling / integration to the usual application of MCTS to maximization / optimization problems. Once the entire budget is spent, with its tree-structured is returned.

3.2.2 Consistency

As argued above, the exact conditionals can be computed by exhaustive search in exponential time. Therefore, a reasonable desideratum for any inference algorithm is that given a large enough budget the exact distribution is inferred. In the following we show that TreeSample passes this basic sanity check. The first important property of TreeSample is that a tree has the exact conditional if the unnormalized target density has been evaluated on all states with prefix during tree construction. To make this statement precise, we define as the sub-tree of consisting of node and all its descendants in . We call a sub-tree fully expanded, or complete, if all partial states with prefix are in . With this definition, we have the following lemma (proof in the appendix):

Lemma 1.

Let be a fully expanded sub-tree of . Then, for all nodes in , i.e.  and , the state-action values are exact and in particular the node has the correct value:

Furthermore, constructing the full tree with TreeSample incurs a cost of at most evaluations of any of the factors , as there are leaf node in and constructing the path from the root to each leaf requires at most oracle evaluations. Therefore, TreeSample with expands the entire tree and the following result holds:

Corollary 1 (Exhaustive budget consistency).

TreeSample outputs the correct target distribution for budgets .

3.3 Augmenting TreeSample with Learned Parametric Priors

TreeSample explicitly allows for a "prior" over state-action values with parameters . It functions as a parametric approximation to . In principle, an appropriate can guide the search towards regions in where probability mass is likely to be found a-priori by the following two mechanisms. It scales the exploration bonus in the PUCT-like decision rule eqn. 7, and it is used to initialize the state-action values for a newly expanded node in the search tree. In the following we discuss scenarios and potential benefits of learning the parameters .

In principle, if comes from an appropriate function class, it can transfer knowledge within the inference problem at hand. Assume we spent some of the available search budget on TreeSample to build an approximation . Due to the tree-structure, search budget spent in one branch of the tree does not benefit any other sibling branch. For many problems, there is however structure that would allow for generalizing knowledge across branches. This can be achieved via , e.g. one could train to approximate the -values of the current , and (under the right inductive bias) knowledge would transfer to newly expanded branches. A similar argument can be made for parametric generalization across problem instances. Assume a given a family of distributions for some index-set . If the different distributions share structure, it is possible to leverage search computations performed on for inference in to some degree. A natural example for this is posterior inference in the same underlying model conditioned on different evidence / observations, similar e.g. to amortized inference in variational auto-encoders [kingma2013auto]. Besides transfer, there is a purely computational reason for learning a parametric . The memory footprint of TreeSample grows linearly with the search budget . For large problems with large budgets , storing the entire search tree in memory might not be feasible. In this case, compiling the current tree periodically into and rebuilding it from scratch under prior and subsequent refinement using TreeSample may be preferable.

Concretely, we propose to train by regression on state-action values generated by TreeSample. For generalization across branches, approximates directly the distribution of interest, for transfer across distributions, approximates the source distribution, and we apply the trained for inference search in a different target distribution. We match to by minimizing the expected difference of the values:

In practice we optimize this loss by stochastic gradient descent in a distributed learner-worker architecture detailed in the experimental section.

4 Experiments

Figure 1: Comparison of TreeSample to SMC on inference in 1000 randomly generated Markov chains. Left: Approximation error as a function of the budget in log-scale, showing that SCM needs more than 30 times the budget of TreeSample to generate comparable approximations. Right: Energy and entropy contributions to the for all 1000 experiments for , showing that TreeSample finds approximations with both higher entropy and lower energy.

In the following, we empirically compare TreeSample

 to other baseline inference methods on different families of distributions. We quantify approximation error by the Kullback-Leibler divergence:

(8)

where we refer to the second term in eqn. 8

as negative expected energy, and the last term is the entropy of the approximation. We can get unbiased estimates of these using samples from

. For intractable target distributions, we compare different inference methods using , which is tractable to approximate and preserves ranking of different approximation methods.

As baselines we consider the following: Sequential Importance Sampling (SIS), Sequential Monte Carlo (SMC) and for a subset of the environments also Gibbs sampling (GIBBS) and sampling with loopy belief propagation (BP); details are given in the appendix. We use the baseline methods in the following way: We generate a set of particles of size such that we exhaust the budget , and then return the (potentially weighted) sum of atoms as the approximation density; here is the Kronecker delta, and are either set to for GIBBS, BP and to the self-normalized importance weights for SIS and SMC. Hyperparameters for all methods where tuned individually for different families of distributions on an initial set of experiments and then kept constant across all reported experiments. For further details, see the appendix. For SIS and SMC, the proposal distribution plays a comparable role to the state-action prior in TreeSample. Therefore, for all experiments we used the same parametric family for for TreeSample, SIS and SMC.

For the sake of simplicity, in the experiments we measured and constrained the inference budget in terms of reward evaluations, i.e. each pointwise evaluate of a incurs a cost of one, instead of factor evaluations.

4.1 TreeSample without Parametric Value Function

We first investigated inference without learned parametric . Instead, we used the simple heuristic of setting , which corresponds to the state-action values when all factors vanish everywhere.

4.1.1 Chain Distributions

We initially tested the algorithms on inference in chain-structured factor graphs (CHAINS). These allow for exact inference in linear time, and therefore we can get unbiased estimates of the true Kullback-Leibler divergences. We report results averaged over different chains of length with randomly generated unary and binary potential functions; for details, see appendix. The number of states per variable was set to , yielding states in total. The results, shown in fig. 1 as a function of the inference budget , show that TreeSample outperforms the SMC baseline (see also tab. 1). In particular, TreeSample generates approximations of similar quality compared to SMC with a roughly 30 times smaller budget. We further investigated the energy and entropy contributions to separately. We define energy (lower is better), and entropy (higher is better). Fig. 1 shows that TreeSample finds approximations that have lower energy as well as higher entropy compared to SMC.

or CHAIN PERMUTED CHAIN FACTOR GRAPHS 1 FACTOR GRAPHS 2
SIS 11.61 1.74 9.23 0.34 -21.97 2.47 -31.70 2.32
SMC 1.94 0.48 7.08 0.36 -24.09 2.85 -35.90 2.47
GIBBS -18.67 1.80 -25.12 1.48
BP exact exact -21.50 0.18 -31.48 0.48
TreeSample 0.53 0.17 3.41 0.41 -28.89 1.94 -38.70 2.29
Table 1:

Approximation error (lower is better) for different inference methods on four distribution classes. Results are averages and standard deviations over 1000 randomly generated distributions for each class. Budget was set to

.

A known limitation of tree search methods is that they tend to under-perform for shallow (here small ) decision-making problems with large action spaces (here large ). We performed experiments on chain distributions with varying and while keeping the state-space size approximately constant, i.e. . We confirmed that for very shallow, bushy problems with , SMC outperforms TreeSample, whereas TreeSample dominates SMC in all other problem configurations, see fig. 3 in the appendix.

Next, we considered chain-structured distributions where the indices of the variables do not correspond to the ordering in the chain; we call these PermutedChains. These are in general more difficult to solve as they exhibit "delayed" rewards, i.e. binary chain potentials depend on non-consecutive variables. This can create "dead-end" like situations, that SMC, not having the ability to backtrack, can get easily stuck in. Indeed, we find that SCM performs only somewhat better on this class of distributions than SIS, whereas TreeSample achieves better results by a wide margin. Results on both families of distributions are shown in tab. 1.

4.1.2 Factor Graphs

We also tested the inference algorithms on two classes of non-chain factor graphs, denoted as FactorGraphs1 and FactorGraphs2. Distributions in FactorGraphs1 are over variables with states each. Factors were randomly generated with maximum degree of 4 and their values where iid drawn from . Distributions in FactorGraphs2 are over binary variables, i.e. . These distributions are generated by two types of of factors: NOT (degree 2) and MAJORITY (max degree 4), both taking values in .

Results are shown in tab. 1. For both families of distributions, TreeSample outperforms all considered baselines by a wide margin. We found that GIBBS generally failed to find configurations with high energy due to slow mixing. BP-based sampling was observed to generate samples with high energy but small entropy, yielding results comparable to SIS.

4.2 TreeSample with Parametric Value Functions

value func. MLP GNN
single graph N/A Yes Yes No No No No
N/A 20 20 20 12 16 24
trained by SMC: +1.63 -0.19 -0.97 -1.00 -1.17 -0.64
[-1.60,+2.87] [-2.68,+1.49] [-2.41,+0.40] [-1.52, -0.58] [-1.42,-0.46] [-0.84,-0.32]
+ SMC: +2.72 +2.93 +1.64 +2.56 +2.00 +1.64 +2.10
[1.02, 4.54] [-1.54,+4.32] [-1.61,+3.46] [+0.58,+4.42] [+1.72, +2.19] [+1.38,+2.05] [+1.50,+2.68]
trained by TreeSample: -3.61 -3.86 -2.05 -2.12 -2.52 -1.83
[-5.73,-0.60] [-6.03, -0.85] [-3.58, -0.55] [-2.23,-1.99] [-2.63,-2.40] [-2.13,-1.76]
+ TreeSample: 0.00 -3.63 -3.87 -2.23 -2.22 -2.64 -2.35
[-1.47, 1.68] [-5.72,-0.64] [-6.05, -0.88] [-3.75, -0.73] [-2.30,-2.05] [-2.79,-2.55] [-2.46,-2.10]
Table 2: Approximation error for inference in factor graphs with TreeSample and SMC, for different types of value functions and training regimes. Results are relative to TreeSample w/o value function, lower is better. See main text for details.

Next, we investigated the performance of TreeSample, as well as SMC, with additional parametric state-action value functions (used as proposal for SMC). We focused on inference problems from FactorGraphs2. We implemented the inference algorithm as a distributed architecture consisting of a worker and a learner process, both running simultaneously. The worker requests an inference problem instance, and performs inference either with TreeSample or SMC with a small budget of using the current parametric . After building the approximation , 128 independent samples are drawn from it and the inferred -values for and are written into a replay buffer as training data; following this, the inference episode is terminated, the tree is flushed and a new episode starts. The learner process samples training data from the replay for updating the parametric with an SGD step on a minibatch of size 128; then the updated model parameters are sent to the worker. We tracked the error of the inference performed on the worker using the unnormalized as a function of the number of completed inference episodes. We expect to decrease, as adapts to the inference problem, and therefore becomes better at guiding the search. Separately, we also track the inference performance of only using the value function without additional search around it, denoted as . This is a purely parametric approximation to the inference problem, trained by samples from TreeSample and SMC respectively. We observed that as well as stabilized after roughly 1500 inference episodes for all experiments. Results were then averaged over episodes 2000-4000 and are shown in tab. 2. To facilitate comparison, all results in tab. 2 are reported relative to for TreeSample

 without value functions. In general, experimental results with learned value functions exhibited higher degrees of variability with some outliers. Results in tab. 

2 therefore report median results over 20 runs as well as 25% and 75% percentiles.

We first performed a simple set of "sanity-check" experiments on TreeSample with parametric value functions in a non-transfer setting, where the worker repeatedly solves the same inference problem arising from a single factor graph. As value function, we used a simple MLP with 4-layers and 256 hidden units each. As shown in the second column of tab. 2, approximation error decreases significantly compared to plain TreeSample without value functions. This corroborates that the value function can indeed cache part of the previous search trees and facilitate inference if training and testing factor graphs coincide. Furthermore, we observed that once is fully trained, the inference error obtaind using only is only marginally worse than using plus TreeSample-search on top of it; see row four and five in tab. 2 respectively. This indicates that the value function was powerful enough in this experiment to almost cache the entire search computation of TreeSample.

Next, we investigated graph neural networks (GNNs) [battaglia2018relational] as value functions . This function class can make explicit use of the structure of the factor graph instances. Details about the architecture can be found in [battaglia2018relational] and the appendix, but are briefly described in the following. GNNs consist of two types of networks, node blocks and edge blocks (we did not use global networks), that are connected according to the factor graph at hand, and executed multiple times mimicking a message-passing like procedure. We used three node block networks, one for each type of graph node, i.e. variable node (corresponding to a variable ), NOT-factors and MAJORITY-factors. We used four edge block networks, namely one for each combination of {incoming,outgoing}{NOT, MAJORITY}. Empirically, we found that GNNs slightly outperform MLPs in the non-transfer setting, see third column of tab. 2.

The real advantage of GNNs comes into play in a transfer setting, when the worker performs inference in a new factor graph for each episode. We keep the number of variables fixed () but vary the number and configuration of factors across problems. GNNs successfully generalize across graphs, see fourth column of tab. 2. This is due to their ability to make use of the graph topology of a new factor graph instance, by connecting its constituent node and edge networks accordingly. Furthermore, the node and edge networks evidently learned generic message passing computations for variable nodes as well as NOT/MAJORITY factor nodes. The results show that a suitable generalizes knowledge across inference problems, leading to less approximation error on new distributions. Furthermore, we investigated a transfer setting where the worker solves inference problems on factor graphs of sizes or 24, but performance is tested on graphs of size ; see columns five to seven in tab. 2. Strikingly, we find that the value functions generalize as well across problems of different sizes as they generalize across problems of the same size. This demonstrates that prior knowledge can successfully guide the search and greatly facilitate inference.

Finally, we investigated the performance of SMC with trained value functions ; see rows one and two in tab. 2. Overall, we found that performance was worse compared to TreeSample: Value functions trained by SMC were found to give worse results compared to those trained by TreeSample, and overall inference error was worse compared to TreeSample. Interestingly, we found that once is fully trained, performing additional SMC on top of it made results worse. Although initially counter-intuitive, these results are sensible in our problem setup. The entropy of SMC approximations is essentially given by the number of particles that SMC produces; this number is limited by the budget that can be used to compute importance weights. Once a parametric is trained, it does not need to make any further calls to the factors, and can therefore exhibit much higher entropy, therefore making smaller than .

5 Related Work

TreeSample is based on the connection between probabilistic inference and maximum-entropy decision-making problems established by previous work. This connection has mostly been used to solve RL problems with inference methods e.g. [dayan1997using, attias2003planning, hoffman2007trans, rawlik2013stochastic]. Closely related to our approach, this relationship has also been used in the reverse direction, i.e. to solve inference problems using tools from RL [mnih2014neural, weber2015reinforced, wingate2013automated, schulman2015gradient, weber2019credit], however without utilizing tree search and emphasizing the importance of exploration for inference. The latter has been recognized in [lu2018exploration], and applied to hierarchical partitioning for inference in continuous spaces, see also [rainforth2018inference]. In contrast to this, we focus on discrete domains with sequential decision-making utilizing MCTS and value functions. Soft-Bellman backups, as used here (also referred to as soft Q-learning) and their connection to entropy-regularized RL have been explored in e.g. [schulman2017equivalence, haarnoja2018soft].

For approximating general probabilistic inference problems, the class of Markov Chain Monte Carlo (MCMC) methods has proven very successful in practice. There, a transition operator is defined such that the target distribution is stationary under this operator. Concretely, MCMCs methods operate on a fully specified, approximate sample which is then perturbed iteratively. Transition operators are usually designed specifically for families of distributions in order to leverage problem structure for achieving fast mixing. However, mixing times are difficult to analyze theoretically and hard to monitor in practice [cowles1996markov]. TreeSample circumvents the mixing problem by generating a new sample "from scratch" when returning to the root node and then iteratively stepping through the dimensions of the random vector. Furthermore, TreeSample can make use of powerful neural networks for approximating conditionals of the target, thus caching computations for related inference problems. Although, adaptive MCMC methods exist, they usually only consider small sets of adaptive parameters [andrieu2008tutorial]. Recently, MCMC methods have been extended to transition operators generated by neural networks, which are trained either by adversarial training, meta learning or mixing time criteria [song2017nice, levy2017generalizing, neklyudov2018metropolis, wang2018meta]. However, these were formulated for continuous domains and rely on differentiability and thus do not carry over straight-forward to discrete domains.

Our proposed algorithm is closely related to Sequential Monte Carlo (SMC) methods [del2006sequential], another class of broadly applicable inference algorithms. Often, these methods are applied to generate approximate samples by sequentially sampling the dimensions of a random vector, e.g. in particle filtering for temporal inference problems [doucet2009tutorial]. Usually, these methods do not allow for backtracking, i.e. re-visiting previously discarded partial configurations, although few variants with some back-tracking heuristics do exist [klepal2008backtracking, grassberger2004sequential]. This is contrast to the TreeSample algorithm, which decides at every iteration where to expand the current tree based on a full tree-traversal from the root and therefore allows for backtracking an arbitrary number of steps. Furthermore, we propose to train value functions which approximately marginalize over the "future" (i.e. variables following the one in question in the ordering), thus taking into account relevant downstream effects. [gu2015neural, kempinska2017adversarial] introduce adaptive NN proposals, i.e. value functions in our formulation, but these are trained to match the "filtering" distribution, thus they do not marginalize over the future. In the decision-making formulation, this corresponds to learning proposals based on immediate rewards instead of total returns. However, recent work in continuous domains has begun to address this [guarniero2017iterated, heng2017controlled, lawson2018twisted, piche2018probabilistic], however, they do not make use of guided systematic search.

Recently, distilling inference computations into parametric functions as been extended to discrete distributions based on the framework of variational inference. [mnih2014neural, mnih2016variational] highlight connections to the REINFORCE gradient estimator [williams1992simple]

and propose various value function-like control variates for reducing its variance. Multiple studies propose to utilize continuous relaxation of discrete variables to make use of so-called reparametrization gradients for learning inference computations, e.g. 

[maddison2016concrete].

In addition to the approximate inference methods discussed above, there are numerous algorithms for exact inference in discrete models. One class of methods called Weighted Model Counting (WMC) algorithms, is based on representing the target probability distribution as Boolean formulas with associated weights, and convert inference into the problem of summing weights over satisfying assignments of the associated SAT problem

[chavira2008probabilistic]. In particular, it has been shown that DPLL-style SAT solvers [davis1962machine] can be extended to exactly solve general discrete inference problems [sang2005solving, bacchus2009solving], often outperforming other standard methods such as the junction tree algorithm [lauritzen1988local]. Similar to TreeSample, this DPLL-based approach performs inference by search, i.e. it recursively instantiates variables of the SAT problem. Efficiency is gained by chaching solved sub-problems [bacchus2003algorithms] and heuristics for adaptively choosing the search order of variables [sang2005solving]. We expect that similar techniques could be integrated into the TreeSample algorithm, potentially greatly improving its efficiency. In contrast to WMC methods, TreeSample dynamically chooses the most promising sub-problems to spend compute on via the UCT-like selection rule which is informed by all previous search tree expansions.

6 Discussion

Structured distributions

, such as factor graphs, Bayesian networks etc, allow for a very compact representation of an infinitely large set of beliefs, e.g. 

implies beliefs over for every test function

, including marginals, moments etc. This immediately raises the question: "What does it mean to ’know’" a

distribution? (paraphrased from [diaconis1988bayesian]). Obviously, we need to perform probabilistic inference to "convert the implicit knowledge" of (given by e.g. factors) into "explicit knowledge" in terms of the beliefs of interest (quoted from [gershman]). If the dimension of

is anything but very small, this inference process cannot be assumed to be "automatic", but ranks among the most complex computational problems known, and large amounts of computational resources have to be used to just approximate the solution. In other challenging computational problems such as optimization, integration or solving ordinary differential equations, it has been argued that the results of computations that have not yet been executed are to be treated as unobserved variables, and knowledge about them to be expressed as beliefs

[movckus1975bayesian, diaconis1988bayesian]. This would imply for the inference setting considered in this paper, that we should introduce second-order, or meta-beliefs over yet-to-be-computed first-order beliefs implied by . Approximate inference could then proceed analogously to Bayesian optimization: Evaluate the factors of at points that result in the largest reduction of second-order uncertainty over the beliefs of interest. However, it is unclear how such meta-beliefs can be treated in a tractable way. Instead of such a full Bayesian numerical treatment involving second-order beliefs, we adopted cheaper Upper Confidence Bound heuristics for quantifying uncertainty.

For sake of simplicity, we assumed in this paper that the computational cost of inference is dominated by evaluations of the factor oracles. This assumption is well justified e.g. in applications, where some factors represent large scale scientific simulators [baydin2019etalumis], or in modern deep latent variable models, where a subset of factors is given by deep neural networks that take potentially high-dimensional observations as inputs. If this assumption is violated, i.e. all factors can be evaluated cheaply, the comparison of TreeSample to SMC and other inference methods will become less favourable for the former. TreeSample incurs an overhead for traversing a search tree before expanding it, attempting to use the information of all previous oracle evaluations. If these are cheap, a less sequential and more parallel approach, such as SMC, might become more competitive.

We expect that TreeSample can be improved and extended in many ways. Currently, the topology of the factor graph is only partially used for the reward definition and potentially for graph net value functions. One obvious way to better leverage it would be to check if after conditioning on a prefix , corresponding to a search depth , the factor graph decomposes into independent components that can be solved independently. Furthermore, TreeSample uses a fixed ordering of the variables. However, a good variable ordering can potentially make the inference problem much easier. Leveraging existing or developing new heuristics for a problem-dependent and dynamic variable ordering could potentially increase the inference efficiency of TreeSample.

References

Appendix A Details for TreeSample  algorithm

We define a search tree in the following way. Nodes in at depth are indexed by the (partial) state , and the root is denoted by . Each node at depth keeps track of the corresponding reward evaluation and the following quantities for all its children:

  1. visit counts over the children,

  2. state-action values ,

  3. prior state-action values , and

  4. a boolean vector if its children are complete (i.e. fully expanded, see below).

Standard MCTS with (P)UCT-style tree traversals applied to the inference problem can in general visit any state-action pair multiple times; this is desirable behavior in general MDPs with stochastic rewards, where reliable reward estimates require multiple samples. However, the reward in our MDP is deterministic as defined in eqn. 2, and therefore there is no benefit in re-visiting fully-expanded sub-trees. To prevent the TreeSample algorithm from doing so, we explicitly keep track at each node if the sub-tree rooted in it is fully-expanded; such a node is called complete. Initially no internal node is complete, only leaf nodes at depth are tagged as complete. In the backup stage of the tree-traversal, we tag a visited node as complete if it is a node of depth (corresponding to a completed sample) or if all its children are complete. We modify the action selection eqn. 7 such that the is only taken over actions not leading to complete sub-trees. The TreeSample algorithm is given in full in algorithm 2.

1:globals reward function , prior state-action value function
2:procedure TreeSample(budget B)
3:     initialize empty tree
4:     available budget
5:     while  do
6:             TreeTravsersal ()
7:          
8:     end while
9:     return tree
10:end procedure
11:procedure TreeTraversal(tree )
12:      // traversal
13:     
14:     while  do
15:          
16:           Q-UCT
17:          
18:     end while
19:      // expansion
20:     if  then
21:            Expand
22:          used budget     // see def. 1
23:     end if
24:      // backup
25:     for  do
26:          
27:          
28:          
29:          
30:     end for
31:     return ,
32:end procedure
33:procedure Q-UCT(, , , )
34:     return of eqn. 7 over in-complete children
35:end procedure
36:procedure Expand(state )
37:     
38:     evaluate reward function
39:     initialize
40:     if  then    // is leaf
41:          initialize for all
42:          initialize
43:     else
44:          evaluate prior
45:          initialize
46:          initialize
47:     end if
48:     return node with , , , ,
49:end procedure
Algorithm 2 TreeSample  procedures

Appendix B Proofs

b.1 Observation 2

Proof.

This observation has been proven previously in the literature, but we will give a short proof here for completeness. We show the statement by determining the optimal policy and value function by backwards dynamic programming (DP). We anchor the DP induction by defining the optimal value function at step as zero, i.e. . Using the law of iterated expectations, we can decompose the optimal value function in the following way for any :

Therefore, assuming by induction that has been computed, we can find the optimal policy and value at step by solving:

(9a)
subject to (9b)

The solution to this optimization problem can be found by the calculus of variations (omitted here) and is given by:

where we used the definition of the optimal state-action value function. Furthermore, at the optimum, the objective eqn. 9a assumes the value:

This expression, together with the definition of establishes the soft-Bellman equation. The optimal value is also exactly the log-normalizer for . Therefore, we can write:

b.2 Proof of Lemma 1

Proof.

We will show this statement by induction on the depth of the sub-tree with root . For , i.e. , the state-action values are defined such that , which is the correct value. Consider now the general case . Let be the sub-tree before the last tree traversal that expanded the last missing node , ie ; for an illustration see fig. 2. The soft-Bellman backups of the last completing tree-traversal on the path leading to are by construction all of the following form: For any node on the path, all children except for one correspond to already completed sub-trees (before the last traversal). The sub-tree of the one remaining child is completed by the last traversal. All complete sub-trees on the backup path are of depth smaller than and therefore by induction their roots have the correct values . Hence evaluating the soft-Bellman backup eqn. 6 (with the true noiseless reward ) yields the correct value for . ∎

Appendix C Details for Experiments

c.1 Baseline Inference Methods

c.1.1 SIS and SMC

For each experiment we determined the number of SIS and SMC particles such that the entire budget was used. We implemented SMC with an resampling threshold , i.e. a resampling step was executed when the effective sample size (ESS) was smaller than . The threshold was treated as a hyperparameter; SMC with was used as SIS results.

c.1.2 Bp

We used the algorithm outline on p. 301 from [mezard2009information]. For generating a single approximate sample from the target, the following procedure was executed. Messages from variable to factor nodes were initialized as uniform; then message-passing steps, each consisting of updating factor-variable and variable-factor messages were performed. was then sampled form the resulting approximate marginal, and the messages from to its neighboring factors were set to the corresponding atom. This was repeated until all variables were sampled, generating one approximate sample from the joint .

In total, we generated multiple samples with the above algorithm such that the budget was exhausted. The number of message-passing steps before sampling each variable was treated as a hyperparameter.

c.1.3 Gibbs

We implemented standard Gibbs sampling. All variables were initially drawn uniformly from , and iterations, each consisting of updating all variables in the fixed order to , were executed. This generated a single approximate sample. We repeated this procedure to generate multiple samples such that the budget was exhausted. We treated as a hyperparameter.

c.2 Hyperparameter optimization

For each inference method (except for SIS) we optimized one hyperparameter on a initial set of experiments. For TreeSample, we fixed and optimized from eqn. 7. Different hyperparameter values were used for different families of distributions. Hyperparameters were chosen such as to yield lowest .

c.3 Details for Synthetic Distributions

c.3.1 Chains

The unary potentials for for the chain factor graphs where randomly generated in the following way. The values of for and where jointly drawn from a GP over the two dimensional domain with an RBF kernel with bandwidth 1 and scale 0.5. Binary potentials were set to , where is the distance between and on the 1-d torus generated by constraining and to be neighbors.

c.3.2 PermutedChains

We first uniformly drew random permutations . We then randomly generated conditional probability tables for by draws from a symmetric Dirichlet with concentration parameter . These were then used as binary factors.

c.3.3 FactorGraphs1

We generated factor graphs for this family in the following way. First, we constructed Erdős-Rényi random graphs with nodes with edge probability ; graphs with more than one connected component were rejected. For each clique in this graph we inserted a random factor and connected it to all nodes in the clique; graphs with cliques of size where rejected.

For applying the sequential inference algorithms TreeSample, SIS and SMC, variables in the graph were ordered by a simple heuristic. While iterating over factors in order of descending degree, all variables in the current factor were were added to the ordering until all were accounted for.

c.3.4 FactorGraphs2

We generated factor graphs for this family over binary random variables

in the following way. Variables and for were connected with a NOT factor, which carries out the computation . We then constructed Erdős-Rényi random graphs of size