Snowflake: Scaling GNNs to High-Dimensional Continuous Control via Parameter Freezing

03/01/2021 ∙ by Charlie Blake, et al. ∙ 8

Recent research has shown that Graph Neural Networks (GNNs) can learn policies for locomotion control that are as effective as a typical multi-layer perceptron (MLP), with superior transfer and multi-task performance (Wang et al., 2018; Huang et al., 2020). Results have so far been limited to training on small agents, with the performance of GNNs deteriorating rapidly as the number of sensors and actuators grows. A key motivation for the use of GNNs in the supervised learning setting is their applicability to large graphs, but this benefit has not yet been realised for locomotion control. We identify the weakness with a common GNN architecture that causes this poor scaling: overfitting in the MLPs within the network that encode, decode, and propagate messages. To combat this, we introduce Snowflake, a GNN training method for high-dimensional continuous control that freezes parameters in parts of the network that suffer from overfitting. Snowflake significantly boosts the performance of GNNs for locomotion control on large agents, now matching the performance of MLPs, and with superior transfer properties.



There are no comments yet.


page 3

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Whereas many traditional machine learning models operate on sequential or Euclidean (grid-like) data representations, gnn allow for graph-structured inputs. gnn have yielded breakthroughs in a variety of complex domains, including drug discovery (Lim et al., 2019; Stokes et al., 2020), fraud detection (Wang et al., 2020)

, computer vision

(Shen et al., 2018; Sarlin et al., 2020) and particle physics (Heintz et al., 2020).

gnn have also been successfully applied to rl, with promising results on locomotion control tasks with small state and action spaces. In this domain, gnn have shown comparable performance to standard mlps and superior performance in multi-task and transfer settings (Wang et al., 2018; Huang et al., 2020; Kurin et al., 2020b). A key factor in the success of gnn here is the capacity of a single gnn to operate over arbitrary graph topologies (patterns of connectivity between nodes) without modification.

Because gnn only model local relationships of neighbouring nodes (which in many graphs is independent of the size of the overall graph), they are expected to scale well to high-dimensional data comprising local structures, such as sparse graphs. However, so far gnn in rl have only shown competitive performance on lower dimensional locomotion control tasks while underperforming mlp in higher dimensions.

This paper investigates and address this phenomenon, identifying the factors underlying poor GNN scaling and introducing a method to combat them. We begin with an analysis of the nn architecture (Wang et al., 2018), which we choose for its gnn-based policy with promising performance and zero-shot transfer to larger agents. We show that a key limitation for training nn on larger tasks is the instability of policy updates as the size of the agent increases.

It is well known that for many continuous control tasks, great care must be taken to prevent destructive policy updates. Consequently, inspired by natural gradients (Amari, 1997; Kakade, 2001), all current state-of-the-art methods in deep rl for continuous control (Schulman et al., 2015, 2017; Abdolmaleki et al., 2018) employ trust region-like constraints that limit the change in policy for each update. We find that gnn like nn exhibit a much higher tendency to violate those constraints. Furthermore, tightening the constraint is not sufficient for good performance and also leads to slower learning.

Instead, in this paper, we investigate which structures in the gnn are responsible for this policy divergence. To do so, we apply different learning rates to different parts of the network. Surprisingly, the best performance is attained when training with a learning rate of zero in the parts of the GNN architecture that encode, decode, and propagate messages in the graph.

Based on this analysis, we derive sf, a simple training technique for gnn to enable them to scale to high-dimensional environments. sf freezes the parameters of particular operations within the gnn to their initialised values, keeping them fixed throughout training while updating the non-frozen parameters as before.

Experimentally, we show that applying sf to nn dramatically improves asymptotic performance on larger tasks, and gives reduced sample complexity across tasks of all sizes. This enables gnn performance to match that of mlps on the largest locomotion task available, indicating that gnn-based policies are an excellent choice for even the most challenging locomotion problems.

2 Background

2.1 Reinforcement Learning

We formalise an rl problem as a mdp. An mdp is a tuple . The first two elements define the state space and the action space . At every time step , the agent employs a policy to output a distribution over actions, selects action , and transitions from state to , as specified by the transition function

which defines a probability distribution over states. For the transition, the agent gets a reward

. The last element of an mdp specifies initial distribution over states, i.e., states an agent can be in at time step zero.

Solving an mdp means finding a policy that maximises an objective, in our case the expected discounted sum of rewards , where is a discount factor. pg find an optimal policy by doing gradient ascent on the objective: with parameterising the policy.

Often, to reduce the variance of the gradient estimate, one learns a value function

, and uses it as a critic of the policy. In the resulting actor-critic method, the policy gradient takes the form: , where is an estimate of the advantage function, e.g., TD error  (Schulman et al., 2016).

2.2 Proximal Policy Optimisation

ppo (Schulman et al., 2016) is an actor-critic method that has proved effective for a variety of domains including locomotion control (Heess et al., 2017). ppo approximates the natural gradient using a first order method, which has the effect of keeping policy updates within a “trust region”. This is done through the introduction of a surrogate objective to be optimised:



is a clipping hyperparameter that effectively limits how much a state-action pair can cause the overall policy to change at each update. This objective is computed over a number of optimisation epochs, each of which gives an update to the new policy

. If during this process a state-action pair with a positive advantage reaches the upper clipping boundary, the objective no longer provides an incentive for the policy to be improved with respect to that data point. This similarly applies to state-action pairs with a negative advantage if the lower clipping limit is reached.

2.3 Graph Neural Networks

gnns are a class of neural architecture designed to operate over graph-structured data. We define a graph as a tuple comprising a set of nodes and edges

. A labelled graph has corresponding feature vectors for each node and edge that form a pair of matrices

, where and . For gnns we often consider directed graphs, where the order of an edge defines as the sender and as the receiver.

A gnn takes a labelled graph and outputs a second graph with new labels. Most gnn architectures retain the same topology for as used in , in which case a gnn can be viewed as a mapping from input labels to output labels .

A common gnn framework is the mpnn (Gilmer et al., 2017), which generates this mapping using steps or ‘layers’ of computation. At each layer in the network, a hidden state and message is computed for every node in the graph.

An mpnn implementation calculates these through its choice of message functions and update functions, denoted and respectively. A message function computes representations from hidden states and edge features, which are then aggregated and passed into an update function to compute new hidden states:


for all nodes , where is the neighbourhood of all sender nodes connected to receiver by a directed edge.

The node input labels are used as the initial hidden states . mpnn assumes only node output labels are required, using each final hidden state as the output label111Gilmer et al. (2017) also specify a readout function, which aggregates the final node representations into a global representation for the entire graph. This is not necessary for the tasks considered in this paper, which only require node-level outputs. .

2.4 NerveNet

nn is a type of mpnn designed for locomotion control, based on the gated GNN architecture (Li et al., 2016). nn uses the morphology (physical structure) of the agent as the basis for the gnn’s input graph, with edges representing body parts and nodes representing the joints that connect them. Input labels encode positional information about body parts, which the gnn uses to compute output labels determining the force exerted at each joint. nn does not use input or output edge labels, and assumes a given task can be represented purely at the node level.

nn can be viewed in terms of the mpnn framework, using a message function consisting of a single mlp for all layers that takes as input only the state of the sender node . The update function is a single gru (Cho et al., 2014) that maintains an internal state equivalent to for node , and takes as input the aggregated message .

In addition, nn uses an encoder to generate the initial hidden states from the input labels, and a decoder

to turn final hidden states into scalar output labels. These output labels are used as means for normal distributions at each node from which actions are then sampled. The standard deviation of these distributions is a separate vector of parameters learned during training.

This gives the following set of equations to describe the operation of nn:


for layers , with


for all nodes .

nn uses ppo to train the policy, with parameter updates computed via the Adam optimisation algorithm (Kingma and Ba, 2015). With respect to the mdp of a given task, at timestep nn assumes that the state can be factored into the set of feature vectors . The gnn is then propagated through its layers, and the set of output features are returned. These are required to be scalar values representing the per-joint actions, which are concatenated to form the action vector . The transition dynamics are applied as specified by the mdp, returning a new state to be fed into the gnn, after which this process repeats.

3 Analysing GNN Scaling Challenges

In this section, we use nn to analyse the challenges that limit gnns’ ability to scale. We focus on nn as its architecture is more closely aligned with the gnn framework than alternative approaches to structured locomotion control (see Section 4). We use mostly the same experimental setup as Wang et al. (2018), with details of any differences and our choice of hyperparameters outlined in the appendix.

3.1 Tasks

We focus on environments derived from the Gym (Brockman et al., 2016) suite, using the MuJoCo (Todorov et al., 2012) physics engine. The main set of tasks we use to assess scaling is the selection of Centipede-n agents (Wang et al., 2018), chosen because of their relatively complex structure and ability to be scaled up to high-dimensional input-action spaces.

(a) MuJoCo rendering of Centipede-20.
(b) Graph representation of Centipede-n’s morphology.
Figure 1: Overview of the scalable Centipede-n benchmark from Wang et al. (2018).

The morphology of a Centipede-n agent consists of a line of n/2 body segments, each with a left and right leg attached (see Figure 1). The graph used as the basis for the gnn corresponds to the physical structure of the agent’s body. At each timestep in the environment, the MuJoCo engine sends a feature vector containing positional information regarding body parts and joints, expecting a scalar value to be returned specifying forces to be applied at each joint. The agent is rewarded for forward movement along the -axis as well as a small ‘survival’ bonus for keeping its body within certain bounds, and given negative rewards proportional to the size of its actions and the magnitude of force it exerts on the ground.

Figure 2: Comparison of the scaling of nn relative to an mlp-based policy. Performance is similar for the smaller agent sizes, but nn scales poorly to the larger agents.

Existing work applying gnn to locomotion control tasks avoid training directly on larger agents, i.e., those with many nodes in the underlying graph representation. For example, Wang et al. (2018) state that for nn, “training a CentipedeEight from scratch is already very difficult”. Huang et al. (2020) also limit training their smp architecture to small agent types.

3.2 Scaling Performance

To demonstrate the poor scaling of the nn architecture to larger agents, we compare its performance on a selection of Centipede-n tasks to that of an mlp-based policy. As in previous literature (e.g., Wang et al., 2018; Huang et al., 2020), our goal is not to outperform MLPs, but ideally to achieve similar performance so that the additional benefits of gnns can be realised. These benefits include the applicability of a single gnn-based policy to agents and graphs of any size, yielding improved multi-task and transfer performance, even in the zero-shot setting, which is not feasible for mlp-based policies. In this paper we aim to improve the performance of training gnns on larger agents so that these benefits can be gained in a wider range of settings. Consequently, the mlp-based policy is not a baseline to outperform, but an estimate of what should be achievable on a single task when training with gnns.

Figure 2 shows that for the smaller Centipede-n agents both policies are similarly effective, but as the size of the agent increases the performance of nn drops relative to the mlp. A visual inspection of the behaviour of these agents shows that for Centipede-20, nn barely makes forward progress at all, whereas the mlp moves effectively. Thus when training large agents one must choose between the strong performance of an mlp or the strong generalisation of a gnn; neither method attains both.

These nn results are also highly sensitive to the value of the term in the ppo surrogate objective (see Section 2.2). Figure 3 demonstrates the effect of changing for the Centipede-20 agent, showing that small changes significantly impair performance. Either the step in policy space is too great (large ) or over-clipping reduces the extent to which the policy can learn from the experience sampled (small ). This suggests that the difficulty of limiting the policy divergence as a result of updates is a factor in the poor performance of nn.

Figure 3: Final performance of nn on Centipede-20 after ten million timesteps, across a range of clipping hyperparameter values. As increases (i.e. clipping is reduced) the KL divergence from the old to new policy (blue) increases. This improves performance (orange) up to a point, after which it begins to deteriorate. Performance is particularly sensitive to the value of , indicating the importance of keeping the policy within the trust region for nn.

We can mitigate this sensitivity to some degree by using much lower learning rates or larger batch sizes. However, this significantly degrades sample efficiency, making training large agents with nn infeasible (see the appendix for supporting results). In the next section, we investigate a feature of the gnn design that contributes to this damaging instability across policy updates.

3.3 Overfitting in nn

A known deficiency of mpnn architectures is that message functions implemented as mlps are prone to overfitting (Hamilton, 2020, p.55). For on-policy rl algorithms like ppo, the policy being trained is used to generate a batch of training data for each update. Overfitting can occur if parameter updates for a given batch are not representative of those that would be learned given the whole policy distribution, or if updates do not generalise well to the new policy derived as a result of the update. In such cases, subsequent batch updates are liable to cause excessively large steps in policy space, harming overall performance.

We investigate the extent to which such overfitting is a factor in the policy instability exhibited by nn. A standard approach to reducing overfitting is parameter regularisation, which discourages the learning of large parameter values within the model. Figure 4 demonstrates the effect of using L2 regularisation on the parameters of nn’s message function MLP . This adds an extra term to the optimisation objective to be minimised. The value of determines the strength of the regularisation, and hence the trade-off between minimising the original objective and the regularisation term.

At the optimal value we see clear improvement in performance, indicating the presence of overfitting in the unregularised message mlps. However, the L2-regularised nn model is still substantially inferior to using an mlp to represent the entire policy, and does not sufficiently address the problem of scaling gnns to high-dimensional tasks.

We also investigate lowering the learning rate in parts of the gnn architecture that overfit the training batches. If parts of the network are particularly prone to damaging overfitting, training them more slowly may reduce their contribution to policy instability across updates. Results for this experiment can be seen in Figure 5.

Surprisingly, not only does lowering the learning rate in parts of the model improve performance, but the best performance is obtained when the encoder , message function and decoder each have their learning rate set to zero. Whereas the update function is implemented as a gru, the other three functions are implemented as mlps. This further supports our hypothesis that the mlp-based functions in the gnn are prone to overfitting.

Figure 4: L2 regularisation for nn’s message function across a range of values for the L2 penalty , trained on Centipede-20. Increasing this penalty reduces the L2 norm of the weights learned (left). Improved performance for higher values of (right) indicates the presence of overfitting for the message function.

3.4 Snowflake

Training with a learning rate of zero is equivalent to parameter freezing (e.g. Brock et al. (2017)), where parameters are fixed to their initialised values throughout training. nn can learn a policy with some of its functions frozen, as learning still takes place in other parts of the network. For instance, if we consider freezing the encoder, this results in an arbitrary mapping of input features to the initial hidden states. As the update function that processes this representation is still trained by the optimiser, so long as key information from the input features needed by the policy is not lost as a result of the encoding, the update function can still learn useful representations. Similarly, a frozen decoder may still result in an effective policy if the update function can learn to produce a final hidden state in a form that the decoder maps to the desired action.

Based on the effectiveness of parameter freezing within parts of the network in our analysis of nn, we propose a simple technique for improving the training of gnns via gradient-based optimisation, which we name sf (a naturally-occurring frozen graph structure).

sf assumes a gnn architecture made up internally of functions , where denotes the parameters of a given function. The sf algorithm selects a subset of these functions and places their parameters in the frozen set .

In practice, we found optimal performance for , i.e. when freezing the encoder, decoder and message function of the gnn. If not stated otherwise, this is the architecture we refer to as sf in subsequent sections.

During training, sf excludes parameters in

from being updated by the optimiser, instead fixing them to whatever values the gnn architecture uses as an initialisation. Gradients still flow through these operations during backpropagation, but their parameters are not updated.

Figure 5: Comparing the performance of nn on Centipede-20 across a range of learning rates for different functions within the gnn. The colour of each cell represents the final performance attained by the agent over five million timesteps, where the specified part of the network has its learning rate set to the given value throughout training. The rest of the functions within the gnn have their learning rates fixed at .

For our experiments, we initialise the values in the gnn using the orthogonal initialisation (Saxe et al., 2014)

. This procedure first individually samples a matrix of numbers from a Gaussian distribution, then calculates the QR decomposition of that matrix, and finally takes the orthogonal matrix Q as the initialised values for a given parameter matrix. We found this to be slightly more effective for frozen and unfrozen training than Gaussian-based initialisations such as the Xavier initialisation

(Glorot and Bengio, 2010)

For our message function, which has input and output dimensions of the same size, we find that performance with the frozen orthogonal initialisation is similar to that of simply using the identity function instead of an mlp. However, in the general case where the input and output dimensions of functions in the network differ (such as in the encoder and decoder, or in gnn architectures where layers use representations of different dimensionality), this simplification is not possible and freezing is required.

4 Related Work

4.1 Structured Locomotion Control

Several different gnn-like architectures (Scarselli et al., 2009; Battaglia et al., 2018) have been proposed to learn policies for locomotion control. Wang et al. (2018)

introduce nn, which trains a gnn based on the agent’s morphology, along with a selection of scalable locomotion benchmarks. nn achieves multi-task and transfer learning across morphologies, even in the zero-shot setting (i.e., without further training), which standard mlp-based policies fail to achieve.

Sanchez-Gonzalez et al. (2018) use a similar gnn-based architecture for learning a model of the agent’s environment, which can then be used for model-predictive control.

Huang et al. (2020) propose smp, which focuses on multi-task training and shows strong generalisation to out-of-distribution agent morphologies using a single policy. The architecture of smp has similarities with a gnn, but requires a tree-based description of the agent’s morphology rather than a graph, and replaces size- and permutation-invariant aggregation at each node with a fixed-cardinality mlp. smp is based on prior work by Pathak et al. (2019) who propose dgn, where a gnn is used to learn a policy enabling multiple small agents to cooperate by combining their physical structures.

A related approach (Kurin et al., 2020b) uses an architecture based on transformers (Vaswani et al., 2017) to represent locomotion policies. Transformers can be seen as gnns using attention for edge-to-vertex aggregation and operating on a fully connected graph, meaning their computational complexity scales quadratically with the graph size.

For all of these existing approaches to gnn-based locomotion control, training is restricted to small agents. In the case of nn and dgn, emphasis is placed on the ability to perform zero-shot transfer to larger agents, but this still incurs a significant drop in performance.

4.2 Graph-Based Reinforcement Learning

Figure 6: Comparison of the performance of sf training, regular nn and the mlp-based policy. sf enables effective scaling to the larger agents, significantly outperforming regular nn and comparable to using an mlp-based policy.

gnn have recently gained traction in rl due to their ability to support variable sized inputs and outputs, opening new avenues for rl applications as well as enhancing the capabilities of agents on existing benchmarks.

Khalil et al. (2017) apply DQN (Mnih et al., 2015) to combinatorial optimisation problems using Structure2Vec (Dai et al., 2016) for function approximation. Lederman et al. (2020)

use policy gradient methods to learn heuristics of a quantified Boolean formulae solver.

Kurin et al. (2020a) use DQN (Mnih et al., 2015) with graph networks (Battaglia et al., 2018) to learn the branching heuristic of a Boolean SAT solver. For the latter, variance increases with the size of the graph, making training on large graphs challenging.

Other approaches involve the construction of graphs based on factorisation of the environmental state into objects with associated attributes (Bapst et al., 2019; Loynd et al., 2020). In multi-agent rl, researchers have used a similar approach to model the relationship between agents, as well as environmental objects (Zambaldi et al., 2019; Iqbal et al., 2020; Li et al., 2020). In this setting, increasing the number of agents can result in additional problems, such as combinatorial explosion of the action space.

Our approach can be potentially useful to the above work in improving scaling properties across a variety of domains.

4.3 Parameter Freezing

Research on parameter freezing has mostly focused on the transfer learning or “fine tuning” use case. Here a neural network is pre-trained on a source task and only the final layers are then trained on the task, with earlier layers left frozen (e.g. Yosinski et al., 2014; Houlsby et al., 2019). Progressively freezing layers during training has also been shown in some cases to improve wall-clock training time, with minimal reduction in performance (Brock et al., 2017).

We are not aware of work that uses the same strategy as sf of freezing parameters to their initialisation values without pre-training. However, techniques such as dropout (Srivastava et al., 2014) have been used widely to improve generalisation, suggesting that good representations can still be learned without training every parameter at each step.

5 Experiments

The design of rewards for the standard Gym tasks is similar to that outlined for Centipede-n: the agent is punished for taking large actions and rewarded for forward movement and for ‘survival’, defined as the agent’s body staying within certain spatial bounds. An episode terminates when these bounds are exceeded, or after a fixed number of timesteps.

We now present experiments that assess the performance of sf when applied to nn.

5.1 Experimental Setup

We evaluate each model on a selection of MuJoCo tasks, including the Centipede-n agents from (Wang et al., 2018) and three standard tasks from the Gym (Brockman et al., 2016) suite.

All training statistics are calculated as the mean across six independent runs (unless specified otherwise), with the standard error across runs indicated by the shaded areas on each graph. The average reward per episode typically has high variance, so to smooth our results we plot the mean taken over a sliding window of 30 data points. Further experimental details are outlined in the appendix.

5.2 Scaling to High-Dimensional Tasks

Figure 6 compares the scaling properties of the regular nn model with sf. As the size of the agent increases, sf significantly outperforms nn, with comparable asymptotic performance to the mlp. This makes gnn a strong choice of policy representation, able to match mlps even on large graphs, while retaining their ability to transfer to tasks with different state and action dimensions without retraining (Wang et al., 2018).

5.3 Policy Stability and Sample Efficiency

By reducing overfitting in parts of the gnn, sf mitigates the effect of destructive policy updates seen with regular nn. One consequence is that the policy can train effectively on smaller batch sizes. This is demonstrated in Figure 7, which shows the performance of nn trained regularly versus using sf as the batch size decreases.

A potential benefit of training with smaller batch sizes is improved sample efficiency, as fewer timesteps are taken in the environment per update. However, smaller batch sizes also lead to increased policy divergence, as with less data to train on the policy may overfit, and gradient estimates have higher variance. When the policy divergence is too great, performance begins to decrease, limiting how small the batch can be. Our results show that because of the reduction in policy divergence, sf can use a smaller batch size than standard training before performance begins to degrade, and until this point the use of smaller batch sizes improves sample efficiency. This provides a wider motivation for the use of sf than just scaling to larger agents: it also improves sample efficiency across agents regardless of size.

The success of sf in scaling to larger agents can also be understood in this context. Without sf, for nn to attain strong performance on large agents an infeasibly large batch size is required, leading to poor sample efficiency. The more stable policy updates enabled by sf make solving these large tasks tractable.

(a) nn
(b) sf
Figure 7: Effectiveness of sf across smaller batch sizes relative to standard nn training. sf is able to use smaller batch sizes, leading to improved sample efficiency. This is due to sf reducing policy divergence across updates. Corresponding policy divergence plots can be found in the appendix.

5.4 PPO Clipping

sf’s improved policy stability also reduces the amount of clipping performed by the ppo surrogate objective (see Section 2.2) across each training batch. Figure 8 shows the percentage of state-action pairs that are affected by clipping for regular nn versus sf, on the Centipede-20 agent.

When nn is trained without using sf a larger percentage of state-action pairs are clipped during ppo updates—a consequence of the greater policy divergence caused by overfitting. Methods like ppo clipping are necessary to keep the policy within the trust region, particularly for standard nn due to its tendency towards large policy changes; however, such approaches involve a trade-off. For ppo if too many data points reach the clipping limit during optimisation, the algorithm is only able to learn on a small fraction of the experience collected, reducing the effectiveness of training.

One of sf’s strengths is that because it reduces policy divergence it requires less severe restrictions to keep the policy within the trust region. The combination of this effect and the ability to train well on smaller batch sizes enables sf’s strong performance on the largest agents.

6 Conclusion

We proposed sf, a method that enables gnn-based policies to be trained effectively on much larger locomotive agents than was previously possible. We no longer observe a substantial difference in performance between using gnns to represent a locomotion policy and the standard approach of using mlps, even on the most challenging morphologies. Based on these results, combined with the strong multi-task and transfer properties shown in previous work, we believe that in many cases gnn-based policies are the most useful representation for locomotion control problems. We have also provided insight into why poor scaling occurs for certain gnn architectures, and why parameter freezing is effective in addressing the overfitting problem we identify. We hope that our analysis and method can facilitate future work in training gnns on even larger graphs, for a wider range of learning problems.

Figure 8: The effect of sf on policy divergence and ppo clipping on Centipede-20. By freezing parts of the network that overfit, sf reduces the policy KL divergence leading to less clipping during training.


  • A. Abdolmaleki, J. T. Springenberg, Y. Tassa, R. Munos, N. Heess, and M. A. Riedmiller (2018) Maximum a posteriori policy optimisation. CoRR abs/1806.06920. External Links: 1806.06920 Cited by: §1.
  • S. Amari (1997) Neural learning in structured parameter spaces-natural riemannian gradient. NIPS, pp. 127–133. Cited by: §1.
  • V. Bapst, A. Sanchez-Gonzalez, C. Doersch, K. Stachenfeld, P. Kohli, P. Battaglia, and J. Hamrick (2019) Structured agents for physical construction. In ICML, pp. 464–474. Cited by: §4.2.
  • P. Battaglia, J. B. Hamrick, V. Bapst, A. Sanchez-Gonzalez, V. Zambaldi, M. Malinowski, A. Tacchetti, D. Raposo, A. Santoro, R. Faulkner, et al. (2018)

    Relational inductive biases, deep learning, and graph networks

    CoRR abs/1806.01261. External Links: 1806.01261 Cited by: §4.1, §4.2.
  • A. Brock, T. Lim, J. M. Ritchie, and N. Weston (2017) FreezeOut: accelerate training by progressively freezing layers. CoRR abs/1706.04983. External Links: 1706.04983 Cited by: §3.4, §4.3.
  • G. Brockman, V. Cheung, L. Pettersson, J. Schneider, J. Schulman, J. Tang, and W. Zaremba (2016) OpenAI gym. CoRR abs/1606.01540. External Links: 1606.01540 Cited by: §A.2, §3.1, §5.1.
  • K. Cho, B. van Merrienboer, Ç. Gülçehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio (2014) Learning phrase representations using RNN encoder-decoder for statistical machine translation. In EMNLP, pp. 1724–1734. External Links: Document Cited by: §2.4.
  • H. Dai, B. Dai, and L. Song (2016) Discriminative embeddings of latent variable models for structured data. In ICML, JMLR Workshop and Conference Proceedings, Vol. 48, pp. 2702–2711. Cited by: §4.2.
  • J. Gilmer, S. S. Schoenholz, P. F. Riley, O. Vinyals, and G. E. Dahl (2017) Neural message passing for quantum chemistry. In ICML, Vol. 70, pp. 1263–1272. Cited by: §2.3, footnote 1.
  • X. Glorot and Y. Bengio (2010) Understanding the difficulty of training deep feedforward neural networks. In AISTATS, JMLR Proceedings, Vol. 9, pp. 249–256. Cited by: §3.4.
  • W. L. Hamilton (2020) Graph representation learning. Synthesis Lectures on Artifical Intelligence and Machine Learning 14 (3), pp. p55. Cited by: §3.3.
  • N. Heess, D. TB, S. Sriram, J. Lemmon, J. Merel, G. Wayne, Y. Tassa, T. Erez, Z. Wang, S. Eslami, et al. (2017) Emergence of locomotion behaviours in rich environments. CoRR abs/1707.02286. External Links: 1707.02286 Cited by: §2.2.
  • A. Heintz, V. Razavimaleki, J. Duarte, G. DeZoort, I. Ojalvo, S. Thais, M. Atkinson, M. Neubauer, L. Gray, S. Jindariani, et al. (2020) Accelerated charged particle tracking with graph neural networks on fpgas. CoRR abs/2012.01563. External Links: 2012.01563 Cited by: §1.
  • N. Houlsby, A. Giurgiu, S. Jastrzebski, B. Morrone, Q. de Laroussilhe, A. Gesmundo, M. Attariyan, and S. Gelly (2019) Parameter-efficient transfer learning for NLP. In ICML, Proceedings of Machine Learning Research, Vol. 97, pp. 2790–2799. Cited by: §4.3.
  • W. Huang, I. Mordatch, and D. Pathak (2020) One policy to control them all: shared modular policies for agent-agnostic control. In International Conference on Machine Learning, pp. 4455–4464. Cited by: Snowflake: Scaling GNNs to High-Dimensional Continuous Control via Parameter Freezing, §1, §3.1, §3.2, §4.1.
  • S. Iqbal, C. A. S. de Witt, B. Peng, W. Böhmer, S. Whiteson, and F. Sha (2020)

    AI-QMIX: attention and imagination for dynamic multi-agent reinforcement learning

    CoRR abs/2006.04222. External Links: 2006.04222 Cited by: §4.2.
  • S. M. Kakade (2001) A natural policy gradient. In NIPS, pp. 1531–1538. Cited by: §1.
  • E. B. Khalil, H. Dai, Y. Zhang, B. Dilkina, and L. Song (2017)

    Learning combinatorial optimization algorithms over graphs

    In NIPS, pp. 6348–6358. Cited by: §4.2.
  • D. P. Kingma and J. Ba (2015) Adam: A method for stochastic optimization. In ICLR, Cited by: §2.4.
  • V. Kurin, S. Godil, S. Whiteson, and B. Catanzaro (2020a) Can q-learning with graph networks learn a generalizable branching heuristic for a SAT solver?. In NeurIPS, Cited by: §4.2.
  • V. Kurin, M. Igl, T. Rocktäschel, W. Boehmer, and S. Whiteson (2020b) My body is a cage: the role of morphology in graph-based incompatible control. CoRR abs/2010.01856. External Links: 2010.01856 Cited by: §1, §4.1.
  • G. Lederman, M. N. Rabe, S. Seshia, and E. A. Lee (2020) Learning heuristics for quantified boolean formulas through reinforcement learning. In ICLR, Cited by: §4.2.
  • S. Li, J. K. Gupta, P. Morales, R. E. Allen, and M. J. Kochenderfer (2020) Deep implicit coordination graphs for multi-agent reinforcement learning. CoRR abs/2006.11438. External Links: 2006.11438 Cited by: §4.2.
  • Y. Li, D. Tarlow, M. Brockschmidt, and R. S. Zemel (2016) Gated graph sequence neural networks. In ICLR, Cited by: §2.4.
  • J. Lim, S. Ryu, K. Park, Y. J. Choe, J. Ham, and W. Y. Kim (2019) Predicting drug-target interaction using a novel graph neural network with 3d structure-embedded graph representation. J. Chem. Inf. Model. 59 (9), pp. 3981–3988. External Links: Document Cited by: §1.
  • R. Loynd, R. Fernandez, A. Çelikyilmaz, A. Swaminathan, and M. J. Hausknecht (2020) Working memory graphs. In ICML, Proceedings of Machine Learning Research, Vol. 119, pp. 6404–6414. Cited by: §4.2.
  • V. Mnih, K. Kavukcuoglu, D. Silver, A. A. Rusu, J. Veness, M. G. Bellemare, A. Graves, M. Riedmiller, A. K. Fidjeland, G. Ostrovski, et al. (2015) Human-level control through deep reinforcement learning. Nat. 518 (7540), pp. 529–533. External Links: Document Cited by: §4.2.
  • D. Pathak, C. Lu, T. Darrell, P. Isola, and A. A. Efros (2019) Learning to control self-assembling morphologies: A study of generalization via modularity. In NeurIPS, pp. 2292–2302. Cited by: §4.1.
  • A. Sanchez-Gonzalez, N. Heess, J. T. Springenberg, J. Merel, M. A. Riedmiller, R. Hadsell, and P. W. Battaglia (2018) Graph networks as learnable physics engines for inference and control. In ICML, Proceedings of Machine Learning Research, Vol. 80, pp. 4467–4476. Cited by: §4.1.
  • P. Sarlin, D. DeTone, T. Malisiewicz, and A. Rabinovich (2020) SuperGlue: learning feature matching with graph neural networks. In

    IEEE/CVF Conference on Computer Vision and Pattern Recognition

    pp. 4937–4946. External Links: Document Cited by: §1.
  • A. M. Saxe, J. L. McClelland, and S. Ganguli (2014) Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. In ICLR, Cited by: §3.4.
  • F. Scarselli, M. Gori, A. C. Tsoi, M. Hagenbuchner, and G. Monfardini (2009) The graph neural network model. IEEE Trans. Neural Networks 20 (1), pp. 61–80. External Links: Document Cited by: §4.1.
  • J. Schulman, S. Levine, P. Abbeel, M. I. Jordan, and P. Moritz (2015) Trust region policy optimization. In ICML, JMLR Workshop and Conference Proceedings, Vol. 37, pp. 1889–1897. Cited by: §1.
  • J. Schulman, P. Moritz, S. Levine, M. I. Jordan, and P. Abbeel (2016) High-dimensional continuous control using generalized advantage estimation. In ICLR, Cited by: §2.1, §2.2.
  • J. Schulman, F. Wolski, P. Dhariwal, A. Radford, and O. Klimov (2017) Proximal policy optimization algorithms. CoRR abs/1707.06347. External Links: 1707.06347 Cited by: §A.1, §1.
  • Y. Shen, H. Li, S. Yi, D. Chen, and X. Wang (2018) Person re-identification with deep similarity-guided graph neural network. In ECCV, Lecture Notes in Computer Science, Vol. 11219, pp. 508–526. External Links: Document Cited by: §1.
  • N. Srivastava, G. E. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov (2014) Dropout: a simple way to prevent neural networks from overfitting. J. Mach. Learn. Res. 15 (1), pp. 1929–1958. Cited by: §4.3.
  • J. M. Stokes, K. Yang, K. Swanson, W. Jin, A. Cubillos-Ruiz, N. M. Donghia, C. R. MacNair, S. French, L. A. Carfrae, Z. Bloom-Ackerman, et al. (2020) A deep learning approach to antibiotic discovery. Cell 180 (4), pp. 688–702. Cited by: §1.
  • E. Todorov, T. Erez, and Y. Tassa (2012) MuJoCo: A physics engine for model-based control. In IEEE/RSJ International Conference on Intelligent Robots and Systems, pp. 5026–5033. External Links: Document Cited by: §A.2, §3.1.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In NIPS, Vol. 30, pp. 5998–6008. Cited by: §4.1.
  • T. Wang, R. Liao, J. Ba, and S. Fidler (2018) NerveNet: learning structured policy with graph neural networks. In ICLR, Cited by: Figure 12, §A.1, §A.1, Snowflake: Scaling GNNs to High-Dimensional Continuous Control via Parameter Freezing, §1, §1, Figure 1, §3.1, §3.1, §3.2, §3, §4.1, §5.1, §5.2.
  • X. Wang, D. Lyu, M. Li, Y. Xia, Q. Yang, X. Wang, X. Wang, P. Cui, Y. Yang, B. Sun, and Z. Guo (2020) APAN: asynchronous propagate attention network for real-time temporal graph embedding. CoRR abs/2011.11545. External Links: 2011.11545 Cited by: §1.
  • J. Yosinski, J. Clune, Y. Bengio, and H. Lipson (2014) How transferable are features in deep neural networks?. In NIPS, pp. 3320–3328. Cited by: §4.3.
  • V. F. Zambaldi, D. Raposo, A. Santoro, V. Bapst, Y. Li, I. Babuschkin, K. Tuyls, D. P. Reichert, T. P. Lillicrap, E. Lockhart, M. Shanahan, V. Langston, R. Pascanu, M. Botvinick, O. Vinyals, and P. W. Battaglia (2019) Deep reinforcement learning with relational inductive biases. In ICLR, Cited by: §4.2.

Appendix A Appendix

a.1 Experimental Details

Here we outline further details of our experimental approach to supplement those given in Section 5.1.

Data Generation

We train a policy by interleaving two processes. First, we perform repeated rollouts of the current policy in the environment to generate on-policy training data. Second, we optimise the policy with respect to the training data collected to generate a new policy, and then repeat.

To improve wall-clock training time, for larger agents we perform rollouts in parallel over multiple CPU threads, scaling from a single thread for Centipede-6 to five threads for Centipede-20. Rollouts terminate once the sum of timesteps experienced across all threads reaches the training batch size.

For optimisation we shuffle the training data randomly and split the batch into eight minibatches. We perform ten optimisation epochs over these minibatches, in the manner defined by the ppo algorithm (Schulman et al., 2017) (see Section 2.2).

Each experiment is performed six times and results are averaged across runs. The exception to this is Figure 5, where results are an average of three runs.

Hyperparameter Search

Our starting point for selecting hyperparameters is the hyperparameter search performed by Wang et al. (2018), whose codebase ours is derived from.

To ensure that we have the best set of hyperparameters for training on large agents, we ran our own hyperparameter search on Centipede-20 for sf, as seen in Table 1.

Hyperparameter Values
Batch size 512, 1024, 2048, 4096
Learning rate 1e-4, 3e-4, 1e-5
Learning rate scheduler adaptive, constant
clipping 0.02, 0.05, 0.1, 0.2
gnn layers 2, 4, 10
gru hidden state size 64, 128
Learned action std shared, separate
Table 1: Hyperparameter search for sf on Centipede-20. Values in bold resulted in the best performance.

Across the range of agents tested on, we conducted a secondary search over just the batch size, learning rate and clipping value for each model. For the latter two hyperparameters, we found that the values in Table 1 did not require adjusting.

For the batch size, we used the lowest value possible until training deteriorated. Using nn, a batch size of 2048 was required throughout, whereas using sf a batch size of 1024 was best for Walker, 512 was best for Centipede-8 and Centipede-6, and 2048 for all other agents.

Wang et al. (2018) provide experimental results for the nn model, which we use as a baseline for our experiments. Out of the Centipede-n models, they provide direct training results for Centipede-8 (see the non-pre-trained agents in their Figure 5). Our performance results are comparable, but taken over many more timesteps. Their final mlp results appear slightly different to ours at the same point (around 500 more reward), likely due to hyperparameter tuning for performance over a different time-frame.

They also provide performance metrics for trained Centipede-4 and Centipede-6 agents across the models compared (their Table 1). The results reported here are significantly less than the best performance we attain for both mlp and nn on Centipede-6. We suspect this discrepancy is due to running for fewer timesteps in their case, but precise stopping criteria is not provided.

Computing Infrastructure

Our experiments were run on four different machines during the project, depending on availability. These machines use variants of the Intel Xeon E5 processor (models 2630, 2699 and 2680), containing between 44 and 88 CPU cores. As running the agent in the MuJoCo environment is CPU-intensive, we observed little decrease in training time when using a GPU; hence the experiments reported here are only run on CPUs.

Runtimes for our results vary significantly depending on the number of threads allocated and batch size used. Our standard runtime for Centipede-6 (single thread) for ten million timesteps is around 24 hours, scaling up to 48 hours for our standard Centipede-20 configuration (five threads). Our experiments on the default MuJoCo agents also take approximately 24 hours for a single thread.

a.2 Sources

Our anonymised source code can be found in the provided zip archive, alongside documentation for building the software and its dependencies. We plan to open-source our code for publication.

Our code is an extension of the nn codebase: This repository contains the original code/schema defining the Centipede-n agents.

The other standard agents are taken from the Gym (Brockman et al., 2016): The specific hopper, walker and humanoid versions used are Hopper-v2, Walker2d-v2 and Humanoid-v2.

For our mlp results on the Gym agents, as state-of-the-art performance baselines have been well established in this case, we use the OpenAi Baselines codebase ( to generate results, to ensure the most rigorous and fair comparison possible.

The MuJoCo (Todorov et al., 2012) simulator can be found at: Note that a paid license is required to use MuJoCo. The use of free alternatives was not viable in our case as our key benchmarks are all defined for MuJoCo.

a.3 Supplementary Figures

Figure 9: The effect of increasing the batch size on the influence of the clipping hyperparameter (see Figure 3) after ten million timesteps. Increasing the batch size reduces the sensitivity to high values of (i.e. low clipping) as the underlying policy divergence is lower with the larger batch size. However, doubling the batch size leads to a drop in sample efficiency, reducing the maximum reward attained within this time-frame.
(a) nn
(b) sf
Figure 10: Accompanying KL divergence plots for Figure 7. As sf reduces the policy divergence between updates, smaller batch sizes can be used before the KL divergence becomes prohibitively large. This effect underlies the improved sample efficiency demonstrated.

Figure 11: Ablation demonstrating the effect of only training single parts of the network (freezing the rest). The configuration of sf we use for our experiments is equivalent to only training the update function, which is the most effective approach here, and all approaches are superior to training the entire gnn. For this experiment, we train on Centipede-6 using the small batch size of 256 in all cases. This setting was chosen as it demonstrates the difference in performance for these approaches most clearly.

Figure 12: Zero-shot transfer performance for sf, nn, and mlp models trained on Centipede-20, and evaluated across a range of sizes. The MLPAA approach is that presented in Wang et al. (2018) (the most effective method tried for transfer with mlps). sf is able to attain almost as strong performance on Centipede-18 as Centipede-20, despite never having been trained on the former task. In contrast there is a large drop in performance using MLPAA. This demonstrates the significant generalisation advantage gained by using the gnn-based policy, and confirms that the strong transfer performance shown in previous work for gnns on smaller agents persists when using sf on larger agents.