What Can Neural Networks Reason About?

05/30/2019 ∙ by Keyulu Xu, et al. ∙ MIT 0

Neural networks have successfully been applied to solving reasoning tasks, ranging from learning simple concepts like "close to", to intricate questions whose reasoning procedures resemble algorithms. Empirically, not all network structures work equally well for reasoning. For example, Graph Neural Networks have achieved impressive empirical results, while less structured neural networks may fail to learn to reason. Theoretically, there is currently limited understanding of the interplay between reasoning tasks and network learning. In this paper, we develop a framework to characterize which tasks a neural network can learn well, by studying how well its structure aligns with the algorithmic structure of the relevant reasoning procedure. This suggests that Graph Neural Networks can learn dynamic programming, a powerful algorithmic strategy that solves a broad class of reasoning problems, such as relational question answering, sorting, intuitive physics, and shortest paths. Our perspective also implies strategies to design neural architectures for complex reasoning. On several abstract reasoning tasks, we see empirically that our theory aligns well with practice.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Reasoning is about grasping the relations between objects in the world, and from that deriving sophisticated conclusions and predictions [47, 42, 31]. Recently, interest has been resurging to build neural networks that can learn to reason [43, 20, 42]. To that end, a broad class of reasoning tasks have been designed, including relational visual question answering [24, 22, 16, 2, 32], intuitive physics, i.e., predicting the time evolution of physical objects [7, 53, 54, 17, 12], more abstract mathematical reasoning [43, 11] and visual IQ tests [42, 59].

Curiously, neural networks that perform well in reasoning tasks usually possess specific, explicit structure in their computation graph [41, 8, 51], while less structured networks, e.g., fully connected architectures, often fail [42, 41]. Many empirically successful models for reasoning follow the Graph Neural Network (GNN) framework [8, 56, 45, 29]. These models consider pairwise relations, and recursively update each object’s representation by aggregating its relations with other objects [7, 34, 23, 37, 46, 40]. Other computational structures, e.g., symbolic program execution [25, 57], have been effective for other tasks.

However, there is currently limited understanding of how the reasoning ability and the structure of a neural network relate. What tasks can a neural network learn to reason about? When and why is a network structure more effective than the others? Answering these questions can be crucial for building better neural networks for complex reasoning.

This paper is an initial work towards answering these fundamental questions. We develop a theoretical framework to characterize what tasks a neural network can reason about well. Our framework is motivated by a seemingly simple observation: many reasoning procedures resemble algorithms. Hence, we study how well a reasoning algorithm “aligns” with the computation graph of the network. Intuitively, if the structures align, the network can easily learn to simulate the reasoning procedure.

We formalize this intuition of alignment, and show initial support for our hypothesis that alignment facilitates learning. First, we provide an example how the alignment can affect a theoretical analysis: in a sequential setting, when the structures of network and reasoning algorithm align, the network can learn more sample-efficiently, since the sub-modules it needs to learn are simpler.

Next, we study what algorithms a few example neural architectures structurally align with. In particular, we highlight that GNNs structurally match the broadly applicable algorithmic paradigm of dynamic programming (DP) [10]. We illustrate that we can solve a broad range of reasoning problems with DP, and thus, GNNs. Our results offer an explanation for the effectiveness (and hence popularity) of GNNs for reasoning, and are reflected empirically.

Our algorithmic structural condition differs from structural assumptions common in learning theory [49, 6, 5, 36, 19, 15, 4, 1] and specifically aligns with reasoning. Our main contribution is to introduce and formalize this structural alignment perspective, along with implications. This new perspective also implies strategies for designing architectures for complex reasoning. Our main results are summarized as follows:

  1. We introduce the perspective of algorithmic alignment to analyze learning for reasoning.

  2. Our initial theoretical results suggest that structural alignment is desirable for generalization.

  3. We apply our perspective to analyze what we expect some popular networks to learn well, and show that Graph Neural Networks align with dynamic programming.

  4. On a test suite of reasoning tasks, our empirical results support our theoretical findings.

2 Preliminaries: Abstract Reasoning

We begin by summarizing our setting and common models for reasoning and, along the way, introduce our notation. Let denote the universe, i.e.

, a configuration/set of objects to reason on. The object representation vectors,

for , could be state descriptions [57, 41] or high-level features learned from raw data [30, 41]. Information about the question can be included in the object representations. Given a set of universes and answer labels , we aim to learn a function , potentially parameterized by a neural network, that can answer questions about unseen universes, .

If we reason about a single-object universe, e.g.

, classification, applying a multilayer perceptron (MLP) on the object representation usually works well 

[30]. But, for multiple-object universes, simply applying an MLP to the concatenated object representations often leads to poor generalization [41]. To better learn functions of a set of objects, Zaheer et al. [58] propose Deep Sets. Deep Sets’ structure, encodes permutation-invariant functions.

Graph Neural Networks (GNNs). GNNs too are permutation invariant, but focus on pairwise relations. Their structure follows a message passing scheme [18, 55], where the representation of each object (in layer ) is recursively updated by aggregating its pairwise interactions with other objects:

(1)

where is the answer/output and is the number of GNN layers. We initialize . Instead of the sum, other aggregation functions are used, too.

Originally proposed for learning on graphs [45], GNNs have become a widely used model for reasoning too [8]. Relation Networks [41] and Interaction Networks [7] resemble one-layer GNNs, Recurrent Relational Networks [37] apply LSTMs [21] after aggregation, and Transformers [50] aggregate with an attention mechanism. While, in graphs, one aggregates over neighboring nodes, GNNs for reasoning typically use all pairs, corresponding to message passing on a complete graph.

3 Theoretical Framework: Algorithmic Structural Alignment

Figure 1: Algorithmic alignment. Our framework suggests that algorithmic alignment implies generalization. The algorithm (bottom left) and the neural network (bottom right) structurally align: the network can simulate the algorithm by filling in (learning) simple functions via the MLP modules (bottom purple blocks). An MLP, however, does not well align with the algorithm, because it needs to learn to simulate the entire for loop. The top row illustrates the structures: nodes are variables in an algorithm or representation vectors in a network; arrows refer to an algorithm step or an MLP taking in end nodes as input.

Curiously, many models for reasoning are structured neural networks. Next, we study how the network structure and task may interact. To do so, we observe that the answer to many reasoning tasks may be derived by following an algorithm; we further illustrate this in Section 4. For example, the answer to the shortest paths problem can be computed by the Bellman-Ford algorithm [9], shown in Fig. 1. Intuitively, if a network can learn the algorithm, it can learn to answer the task.

In principle, many neural networks can represent algorithms. For example, DeepSets can universally represent permutation-invariant set functions [58, 52]. This also holds for GNNs and MLPs (our setting differs from [44, 56], who study functions on graphs):

Let be any continuous functions over sets of bounded cardinality . If is permutation-invariant to the elements in and the elements are in a compact set in , then can be approximated arbitrarily closely by a GNN (of any depth).

For any GNN , there is an MLP that can represent all functions can represent.

But, empirically, not all network structures work equally well when learning these algorithms. Intuitively, a network may learn well if it can represent a function “more easily”. We formalize this idea by our notion of algorithmic alignment. Indeed, not only the reasoning procedure has an algorithmic structure: the neural network’s architecture induces a computational structure on the function it computes. This corresponds to an algorithm that prescribes how the network combines computations from modules. Fig. 1 illustrates this idea for a GNN, where the modules are the MLPs applied to pairs of objects. The GNN “matches” the algorithmic structure of the Bellman-Ford algorithm: it can simulate the algorithm if each of its modules learns a simple function (sum or min).

We could simulate the Bellman-Ford algorithm with an MLP too. But, a match of network modules to algorithmic steps would e.g. need to learn sparse MLPs that mimic the pairwise operations despite all objects as input, or an MLP that simulates the entire for loop as a whole, which is a much more complex function and, hence, would presumably need many more samples to learn. In short, without aligning well, the network may have to infer more of the algorithmic structure from data.

This perspective suggests that whether a neural network can reason about a task may depend on whether there exists an algorithmic solution that the network structurally aligns with.

3.1 Formalization

We formalize the above intuition about simple modules in a PAC learning framework [48]. PAC learnability formalizes simplicity

as sample complexity, i.e., the number of samples needed to ensure low test error with high probability. It refers to a learning algorithm

that, conditioned on training samples , outputs a function . The learning algorithm here is the training method for the neural network, which is different from the reasoning algorithm.

(Learnability). Fix an error parameter and failure probability . Suppose are i.i.d. samples from some distribution , and the data satisfies for some underlying function . Let be the function generated by a learning algorithm . Then is -learnable with if

(2)

The sample complexity is the minimum so that is -learnable with .

We then say that a network architecture aligns well with an algorithm if it can simulate the algorithm via a limited number of (types of) modules, and each module is simple, i.e., efficiently learnable. (Algorithmic structural alignment). Let be a reasoning function and a neural network with modules . The module functions generate for if, by replacing with , the network simulates . Then -aligns with if (1) generate and (2) there are learning algorithms for the ’s such that . The structures align well if the sample complexity is small. This implies all algorithm steps are easy to learn. The number of modules can be kept small via weight sharing.

Next, we show an initial result demonstrating that structural alignment is desirable for generalization. Theorem 3.1 states that, in a setting where we sequentially train modules of a network with auxiliary labels, alignment implies generalization: if the network -aligns with an algorithm, then the algorithm is -learnable by the network. In Section 5 we will see that, empirically, the same pattern holds for end-to-end learning. We prove Theorem 3.1 in the Appendix.

(Structural alignment implies learnability). Fix and . Suppose , where , and for some . Suppose are network ’s MLP modules in a sequential order. Under the following assumptions, is -learnable by .
a) Structural alignment. and -align via functions .
b) Algorithm stability. Let be the learning algorithm for the ’s. Suppose , and . For any , , for some .
c) Sequential training. We train ’s sequentially: has input samples , with obtained from . For , the input for are the outputs from the previous modules, but labels are generated by the correct functions on .
d) Lipschitzness. The learned functions satisfy , for some .

In our analysis, the Lipschitz constants, the universe size, and number of MLP modules are constants going into and . While a fine-grained analysis is possible, we leave it for future work.

3.2 MLPs and Sample Efficiency Gap

The generalization bound via alignment in Theorem 3.1 depends on the sample complexity of the MLP modules. Hence, next, we study learnability with MLPs. Recent work shows sample complexity bounds for overparameterized two or three-layer MLPs by analyzing their gradient descent trajectories [4, 1]. Theorem 3.2, proved in the Appendix, summarizes and extends Theorem of Arora et al. [4] to vector-valued functions. (Sample Complexity for MLPs). Let be an overparameterized and randomly initialized two-layer MLP trained with gradient descent for a sufficient number of iterations. Suppose with components , where , , and or . The sample complexity is

(3)

Theorem 3.2 says that if a function is “simple” when expressed as a polynomial, e.g., via a Taylor expansion, it is learnable by an MLP. By this notion, complex interactions that involve many objects may require many samples for MLPs to learn, since the number of polynomials or may increase in (3). Although Theorem 3.2 only gives an upper bound on the sample complexity, it might still provide a plausible explanation for why MLPs fail to learn.

Suppose the universe contains objects , and we have . In the sequential learning setting, the upper bound on sample complexity for MLP is times higher for GNN. The example in Corollary 3.2 illustrates how neural networks with a matching structure, e.g., GNNs, may get a polynomial improvement in sample complexity over MLPs. In practice, an efficiency gap can be serious with many objects to reason about.

Our framework is general and can work with other sample complexity bounds for MLPs too. Theorem 3.2 is an illustrative example. Next, we apply our algorithmic alignment perspective to foresee which tasks some popular networks can reason about.

4 Predicting What Neural Networks Can Reason About

To explore some implications of our perspective, we next study the alignment of popular architectures to reasoning tasks with increasing complexity. Simpler tasks may serve as modules for more complex tasks. If an architecture aligns well with a task, we expect it to learn the task well. We then introduce a neural architecture design strategy for complex reasoning. The experiments in Section 5 empirically validate the findings in this section.

Single-object feature extraction.

Given “disentangled” object representations , where each is a feature like color or shape, by Theorem 3.2, an MLP can provably learn to extract, and answer questions about, relevant features, e.g., “What is the color of the cat?”. Disentangled representations have also empirically shown good generalization [57, 41].

Feature extraction functions , are learnable with samples by an MLP.

Summary statistics. Deep Sets, i.e., , align with summary statistics, because the can extract object features. In particular, they can count, e.g., “How many white cats with blue eyes are there?”, and learn max or min statistics by using smooth approximations like the softmax

. For maxima, max pooling likely works better 

[38], because it aligns with the max even better than sum.

Pairwise relations. Deep Sets, however, do not align well with pairwise object relations, since those are not easily encoded by summing over all individual objects, i.e., via . Suppose if and only if , e.g., . There is no such that . Hence, the in Deep Sets must learn the pairwise relations. Corollary 3.2 suggests this may indeed be more difficult for MLPs, as no network structure can be exploited. In contrast, GNNs model pairwise relations with . By Corollary 4, can learn to extract relevant features in . Hence, each GNN iteration can reason about simple functions, e.g., sum/max/min, over pairwise relations, e.g., “Which two cats are the farthest apart?”. In Section 5, we will see that indeed, Deep Sets fail to learn answering such questions, but GNNs learn well.

Higher-order relations. For relations between more than two objects, one may use MLPs that take in multiple objects [8, 35], generalizing from graphs to (directed) hypergraphs. We refer to such networks as Hypergraph Relation Network (HRN). An one-layer HRN with triplet input is defined as

(4)

HRNs, however, are computationally expensive (e.g., for triplets). Often, this expense is not needed: many complex relations can be recursively reduced to pairwise relations as we see next.

4.1 GNNs Can Perform Complex Reasoning through Dynamic Programming

GNNs structurally align with the powerful algorithm paradigm dynamic programming (DP) [10, 13]. DP solves a complicated problem by recursively breaking it down into simpler sub-problems, i.e.,

(5)

where Answer denotes the answer to the sub-problem indexed by and , and DP is an algorithm-dependent update rule which obtains Answer by reducing it to ’s. can also depend on answers from step by remembering answers from previous steps. Often, the updates DP are fairly simple, e.g., min/max/sum.

The structure of GNNs naturally matches that of DP. The GNN representation vectors correspond to Answer, and the message passing update corresponds to the DP update. Theorem 3.1 suggests that a GNN can learn to simulate the underlying DP algorithm if it learns the DP update rules in each recursive step, and it has at least the same number of steps (depth) as the DP algorithm. Next, we show examples of reasoning tasks that can be solved by DP, and thus, by GNNs.

Relational question answering. A complex relational question can be answered through DP if we can recursively break it down to simpler relational questions. An example is “Which cat is the closest to the cat which is the closest to …?” [37]. A DP solution would reduce the question of finding the -hop closest cat to finding the cat that is closest to the -hop closest cat as follows:

(6)

where is cat ’s -hop closest cat. If a GNN maintains , then the DP update rules, i.e., finding the argmin of pairwise distances and conditionally copying another object’s features, are learnable by an iteration of GNN.

Shortest paths. Many known algorithms for the (single-source) shortest paths problem are DP in nature [14, 9]. An example is the Bellman-Ford algorithm [9], which recursively updates the minimum distance between each object and the source object within steps:

(7)

If a GNN maintains , where we need if cost needs to be learned too, then the sum is simple to be learned by MLP. Moreover, since min pooling can be approximated by sum pooling as we have discussed, GNNs can provably learn the DP update, and thus, can reason about shortest paths.

Divide and conquer, greedy, sorting, and intuitive physics. We can formulate many more reasoning problems as roughly DP and solve them with GNNs. Divide and conquer and greedy algorithms [13] fall under the general framework of DP. Divide and conquer also breaks a problem into sub-problems, but it only combines answers to non-overlapping sub-problems. This simpler algorithmic paradigm already solves many challenging reasoning problems, such as the classic Tower of Hanoi puzzle. Greedy algorithms maintain a single optimum, whereas DP/GNNs maintain answers to many sub-problems. GNNs can learn to sort [8], because we can count how many objects are “smaller than” an object through pairwise comparisons. Moreover, a GNN can predict the trajectories of physical objects if it learns a simple law of physics as the DP update rule [40].

Discrete mathematics and theory. In fact, we could even reason about many profound questions in discrete mathematics and theoretical computer science (TCS) via DP, and thus, GNNs. Graph minor theory via DP gives one of the deepest results in discrete mathematics [27, 39]. While in TCS, DP is one of the best tools to obtain a polynomial time approximation algorithm for problems known to be NP-hard. For example, DP for the traveling salesman problem (TSP) in Euclidean space gives one of the best known approximation algorithms [3, 33]. Thus, our theory also provides an explanation for the effectiveness (and hence popularity) of applying GNNs to solve NP-hard combinatorial problems [46, 28].

4.2 Neural Algorithms: a Neural Architecture Design Strategy

To reason about tasks that involve more complicated operations, e.g.

, computing the maximum flow or solving a linear program 

[13], our framework suggests we shall instantiate a similar structure in the network. We name neural networks that follow this architecture design strategy Neural Algorithms. As one example, we apply the neural algorithm strategy to an NP-hard reasoning problem.

Neural Exhaustive Search. Given a set of numbers, the subset sum problem asks whether there exists a subset that sums to . Subset sum is NP-hard [26]. Although there is a pseudo-polynomial time DP solution [13], it requires maintaining all possible sum so far at each step, so a GNN would need many more representation vectors to store these answers. This may be difficult if a GNN always retains the same number of representation vectors. Otherwise, we can consider a simple exhaustive search strategy, where we compute the sum for all possible subsets and decide whether any of them is . This leads to a neural algorithm we name Neural Exhaustive Search (NES).

(8)

NES structurally aligns with the exhaustive search algorithm if the LSTM learns to sum, learns to check whether its input is zero, and the max pooling layer summarizes this test for all subsets. In Section 5, we will see NES indeed solves subset sum, despite the exponential time complexity.

(a) Maximum value difference.
(b) Furthest pair.
(c) Dragon trainer.
(d) Subset sum. Random is .
Figure 2: Test accuracies on four abstract reasoning tasks with increasing complexity. GNN is GNN with iterations. (a) Computing a summary statistic of the universe. All models except MLP generalize. Sorted MLP refers to an MLP with input treasures sorted by value. (b) Simple relational reasoning. Deep Sets and MLP fail. (c) Dynamic programming. Only GNNs with at least three iterations generalize. (d) An NP-hard problem. Even GNNs fail, and only NES generalizes.

5 Experiments

We design four abstract reasoning tasks with increasing complexity. To separate reasoning from representation learning, we use disentangled object representations. Data and training details are in Appendix H. We test the following hypotheses derived from our framework:

MLPs do not align well with reasoning tasks (Corollary 3.2) and will not generalize well. Deep Sets can reason about summary statistics but not relations. GNNs generalize well on many reasoning tasks and can learn dynamic programming (§4.1), but may fail on NP-hard problems. Hypergraph Relation Network (HRN) (4) has similar generalization power as GNNs on many tasks, despite its extra time complexity. Neural Exhaustive Search (NES) (8) can solve NP-hard problems.

5.1 Fantastic Treasure: Relational Question Answering

The universe has fantastic treasures. We have the magic power to check the location , value , and color of each treasure . Treasures change over time, and we train networks on snapshots of the universe to answer two questions.

Maximum value difference. The first question asks the difference in value between the most and the least valuable treasure. Formally, the answer is .

This question asks a summary statistic of the universe. We expect MLP to fail (Corollary 3.2) and other models to succeed. Fig. 1(a) confirms our prediction. Interestingly, MLP achieves perfect test accuracy when the input treasures are sorted by value (Sorted MLP in Fig. 1(a)). This observation is in line with our theory—when the treasures are sorted, the answer is reduced to a simple subtraction: , which MLP can learn (Theorem 3.2).

Our theory also explains why Deep Sets have lower test accuracy than GNNs. We can rewrite the answer as . GNNs align with this equation and only need to learn two operations: max and subtraction. Deep Sets do not loop over pairs of objects and, therefore, must find the answer through more operations: max, min, and then subtraction. Finally, HRN is slightly behind GNNs, showing that ternary relations do not give additional gain for this task.

Furthest pair. Our second question asks the colors of the two treasures with the largest distance. The answer is a pair of colors, encoded as an integer category: Δy(S) = ( h_3(X_s_1), h_3(X_s_2))  s.t.     {X_s_1, X_s_2} = argmax_s_1, s_2 ∈S ∥h_1(X_s_1) - h_1(X_s_2)∥_ℓ_1

Unlike values, locations are not totally ordered, so the answer is not just a summary statistic of the universe and requires reasoning over pairwise relations. MLP and Deep Sets fail to generalize on this relational reasoning task (Fig. 1(b)), confirming Claim 4. HRN and GNNs work well with similar accuracies, suggesting again that ternary relations do not give additional gain for this task.

5.2 Dragon Trainer: Dynamic Programming

A dragon trainer lives in a world with dragons. Each dragon has a location and a unique combat level . In each game, the trainer starts at a random location with level zero, , and receives a quest to defeat the level- dragon. At each time step, the trainer can challenge any more powerful dragon , with a cost equal to the product of the travel distance and the level difference . After defeating dragon , the trainer’s level upgrades to , and the trainer moves to . We ask the minimum cost of completing the quest, i.e., defeating the level- dragon.

We can solve the game with a DP algorithm similar to shortest paths (7), where the source is the trainer’s starting location, and the target is the quest dragon. Our game is more challenging than vanilla shortest paths because the model also needs to learn the cost function .

We train models on games to predict minimum cost in . To make games challenging, we sample games whose optimal solution involves defeating three to seven non-quest dragons.

Figure 3: Accuracy breakdown on dragon trainer. Each dot indicates the accuracy of a model (y-axis) on games categorized by the number of defated dragons (x-axis) in the optimal strategy.
Figure 4: Accuracy breakdown on subset sum. Each dot indicates the accuracy of a model (y-axis) on questions where there exists a subset of (x-axis) elements which sum to zero.

Results. As expected, MLP, DeepSets, and one/two-iteration GNNs fail this complex game (Fig. 1(c)).

Surprisingly, a GNN with four iterations has almost the same test accuracy as a GNN with seven iterations. In our dataset, the optimal strategies for some games require defeating seven dragons, and the Bellman-Ford algorithm  (7) needs at least seven iterations to solve these games. Therefore, the four-iteration GNN must have discovered a solution to shortest-paths that requires fewer iterations.

One possible DP solution is the following. To compute a shortest-path from a source object to a target object with at most seven stops, we run the following updates for four iterations:

(9)
(10)

Update (9) is identical to the Bellman-Ford algorithm (7), and is the shortest distance from to with at most stops. Update (10) is a reverse Bellman-Ford algorithm, and is the shortest distance from to with at most stops. After running (9) and (10) for iterations, we can compute a shortest path with at most stops by enumerating a mid-point and aggregating the results of the two Bellman-Ford algorithms:

(11)

Thus, this alternative algorithm only needs half of the iterations of Bellman-Ford. Its structure also aligns with GNN—(9) and (10) are similar to GNN updates, and the final enumeration step (11) can be learned by GNN’s last pooling layer. The interesting empirical finding aligns with our theory: a neural network can reason about a task if there exists a structurally matching algorithmic solution.

In Fig. 3, we compare the performance of GNNs on games of different complexity. The accuracies of single/two-iteration GNNs drop dramatically as the games become more challenging, i.e., as the optimal strategy involves defeating more non-quest dragons. Indeed, as algorithmic alignment suggests, more than two GNN iterations are necessary to align with the iterations of a correct reasoning algorithm and find an optimal strategy for the game.

5.3 Subset Sum

Finally, we consider a classic NP-hard problem: our universe has six integers . We ask whether there is a subset of that sums up to .

Results. The neural algorithm design strategy from §4.2 proves useful in this challenging task: Only Neural Exhaustive Search (NES) achieves a nearly perfect test accuracy (Fig. 1(d)), confirming again that neural networks generalize better when the network structure matches the algorithmic structure of a correct reasoning procedure. MLP and Deep Sets barely outperform random guessing . GNNs and HRN generalize better, suggesting that they can learn to inspect a small number of subsets.

Fig. 4 shows the fine-grained accuracies for subset sum questions whose answers are yes, i.e., there exists a solution subset whose elements sum to zero. Indeed, if there exist solution subsets of two elements, single-iteration GNN always identifies them and can perfectly answer these questions (Fig. 4). In contrast, HRN fails to identify solution subsets of size two, but succeeds if the solution contains three elements instead (Fig. 4). This empirical finding can be explained via our theory: Single-iteration GNN considers pairwise relations, and HRN considers ternary relations. Thus, by algorithmic alignment, they can easily learn to inspect the sum of every pair (subsets of size two) and triplet (subsets of size three), respectively.

6 Conclusion

This paper is an initial step towards formally understanding how neural networks can learn abstract reasoning. We introduce an algorithmic alignment perspective that may inspire architecture design, and opens up theoretical avenues. An interesting future direction is to design, e.g. via algorithmic alignment, networks that can learn more general abstract reasoning than GNNs.

Acknowledgments

This research was supported by NSF CAREER award 1553284, DARPA DSO’s Lagrange program under grant FA86501827838 and a Chevron-MIT Energy Fellowship. This research was also supported by JST ERATO JPMJER1201 and JSPS Kakenhi JP18H05291. MZ was supported by DARPA award HR0011-15-C-0113 under subcontract to Raytheon BBN Technologies. The views, opinions, and/or findings contained in this article are those of the author and should not be interpreted as representing the official views or policies, either expressed or implied, of the Defense Advanced Research Projects Agency or the Department of Defense.

References

  • Allen-Zhu et al. [2018] Zeyuan Allen-Zhu, Yuanzhi Li, and Yingyu Liang. Learning and generalization in overparameterized neural networks, going beyond two layers. arXiv preprint arXiv:1811.04918, 2018.
  • Antol et al. [2015] Stanislaw Antol, Aishwarya Agrawal, Jiasen Lu, Margaret Mitchell, Dhruv Batra, C Lawrence Zitnick, and Devi Parikh. Vqa: Visual question answering. In

    Proceedings of the IEEE international conference on computer vision

    , pages 2425–2433, 2015.
  • Arora [1998] Sanjeev Arora. Polynomial time approximation schemes for euclidean traveling salesman and other geometric problems. Journal of the ACM (JACM), 45(5):753–782, 1998.
  • Arora et al. [2019] Sanjeev Arora, Simon S Du, Wei Hu, Zhiyuan Li, and Ruosong Wang. Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. In

    International Conference on Machine Learning

    , 2019.
  • Bartlett and Mendelson [2002] Peter L Bartlett and Shahar Mendelson. Rademacher and gaussian complexities: Risk bounds and structural results. Journal of Machine Learning Research, 3(Nov):463–482, 2002.
  • Bartlett et al. [2017] Peter L Bartlett, Dylan J Foster, and Matus J Telgarsky. Spectrally-normalized margin bounds for neural networks. In Advances in Neural Information Processing Systems, pages 6240–6249, 2017.
  • Battaglia et al. [2016] Peter Battaglia, Razvan Pascanu, Matthew Lai, Danilo Jimenez Rezende, et al. Interaction networks for learning about objects, relations and physics. In Advances in Neural Information Processing Systems, pages 4502–4510, 2016.
  • Battaglia et al. [2018] Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, et al. Relational inductive biases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261, 2018.
  • Bellman [1958] Richard Bellman. On a routing problem. Quarterly of applied mathematics, 16(1):87–90, 1958.
  • Bellman [1966] Richard Bellman. Dynamic programming. Science, 153(3731):34–37, 1966.
  • Chang et al. [2019] Michael Chang, Abhishek Gupta, Sergey Levine, and Thomas L. Griffiths. Automatically composing representation transformations as a means for generalization. In International Conference on Learning Representations, 2019.
  • Chang et al. [2017] Michael B Chang, Tomer Ullman, Antonio Torralba, and Joshua B Tenenbaum. A compositional object-based approach to learning physical dynamics. In International Conference on Learning Representations, 2017.
  • Cormen et al. [2009] Thomas H Cormen, Charles E Leiserson, Ronald L Rivest, and Clifford Stein. Introduction to algorithms. MIT press, 2009.
  • Dijkstra [1959] Edsger W Dijkstra. A note on two problems in connexion with graphs. Numerische mathematik, 1(1):269–271, 1959.
  • Dziugaite and Roy [2017] Gintare Karolina Dziugaite and Daniel M Roy. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. arXiv preprint arXiv:1703.11008, 2017.
  • Fleuret et al. [2011] François Fleuret, Ting Li, Charles Dubout, Emma K Wampler, Steven Yantis, and Donald Geman. Comparing machines and humans on a visual categorization test. Proceedings of the National Academy of Sciences, 108(43):17621–17625, 2011.
  • Fragkiadaki et al. [2016] Katerina Fragkiadaki, Pulkit Agrawal, Sergey Levine, and Jitendra Malik. Learning visual predictive models of physics for playing billiards. In International Conference on Learning Representations, 2016.
  • Gilmer et al. [2017] Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. Neural message passing for quantum chemistry. In International Conference on Machine Learning, pages 1273–1272, 2017.
  • Golowich et al. [2018] Noah Golowich, Alexander Rakhlin, and Ohad Shamir. Size-independent sample complexity of neural networks. In Conference On Learning Theory, pages 297–299, 2018.
  • Hill et al. [2019] Felix Hill, Adam Santoro, David Barrett, Ari Morcos, and Timothy Lillicrap. Learning to make analogies by contrasting abstract relational structure. In International Conference on Learning Representations, 2019.
  • Hochreiter and Schmidhuber [1997] Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural computation, 9(8):1735–1780, 1997.
  • Hu et al. [2017] Ronghang Hu, Jacob Andreas, Marcus Rohrbach, Trevor Darrell, and Kate Saenko. Learning to reason: End-to-end module networks for visual question answering. In Proceedings of the IEEE International Conference on Computer Vision, pages 804–813, 2017.
  • Janner et al. [2019] Michael Janner, Sergey Levine, William T. Freeman, Joshua B. Tenenbaum, Chelsea Finn, and Jiajun Wu. Reasoning about physical interactions with object-centric models. In International Conference on Learning Representations, 2019.
  • Johnson et al. [2017a] Justin Johnson, Bharath Hariharan, Laurens van der Maaten, Li Fei-Fei, C Lawrence Zitnick, and Ross Girshick. Clevr: A diagnostic dataset for compositional language and elementary visual reasoning. In

    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition

    , pages 2901–2910, 2017a.
  • Johnson et al. [2017b] Justin Johnson, Bharath Hariharan, Laurens van der Maaten, Judy Hoffman, Li Fei-Fei, C Lawrence Zitnick, and Ross Girshick. Inferring and executing programs for visual reasoning. In Proceedings of the IEEE International Conference on Computer Vision, pages 2989–2998, 2017b.
  • Karp [1972] Richard M Karp. Reducibility among combinatorial problems. In Complexity of computer computations, pages 85–103. Springer, 1972.
  • Kawarabayashi et al. [2012] Ken-ichi Kawarabayashi, Yusuke Kobayashi, and Bruce Reed. The disjoint paths problem in quadratic time. Journal of Combinatorial Theory, Series B, 102(2):424–435, 2012.
  • Khalil et al. [2017] Elias Khalil, Hanjun Dai, Yuyu Zhang, Bistra Dilkina, and Le Song.

    Learning combinatorial optimization algorithms over graphs.

    In Advances in Neural Information Processing Systems, pages 6348–6358, 2017.
  • Kipf and Welling [2017] Thomas N Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. In International Conference on Learning Representations, 2017.
  • Krizhevsky et al. [2012] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, pages 1097–1105, 2012.
  • Lake et al. [2017] Brenden M Lake, Tomer D Ullman, Joshua B Tenenbaum, and Samuel J Gershman. Building machines that learn and think like people. Behavioral and brain sciences, 40, 2017.
  • Mao et al. [2019] Jiayuan Mao, Chuang Gan, Pushmeet Kohli, Joshua B. Tenenbaum, and Jiajun Wu. The neuro-symbolic concept learner: Interpreting scenes, words, and sentences from natural supervision. In International Conference on Learning Representations, 2019.
  • Mitchell [1999] Joseph SB Mitchell. Guillotine subdivisions approximate polygonal subdivisions: A simple polynomial-time approximation scheme for geometric tsp, k-mst, and related problems. SIAM Journal on computing, 28(4):1298–1309, 1999.
  • Mrowca et al. [2018] Damian Mrowca, Chengxu Zhuang, Elias Wang, Nick Haber, Li F Fei-Fei, Josh Tenenbaum, and Daniel L Yamins. Flexible neural representation for physics prediction. In Advances in Neural Information Processing Systems, pages 8799–8810, 2018.
  • Murphy et al. [2019] Ryan L Murphy, Balasubramaniam Srinivasan, Vinayak Rao, and Bruno Ribeiro. Janossy pooling: Learning deep permutation-invariant functions for variable-size inputs. In International Conference on Learning Representations, 2019.
  • Neyshabur et al. [2015] Behnam Neyshabur, Ryota Tomioka, and Nathan Srebro. Norm-based capacity control in neural networks. In Conference on Learning Theory, pages 1376–1401, 2015.
  • Palm et al. [2018] Rasmus Palm, Ulrich Paquet, and Ole Winther. Recurrent relational networks. In Advances in Neural Information Processing Systems, pages 3368–3378, 2018.
  • Qi et al. [2017] Charles R Qi, Hao Su, Kaichun Mo, and Leonidas J Guibas.

    Pointnet: Deep learning on point sets for 3d classification and segmentation.

    Proc. Computer Vision and Pattern Recognition (CVPR), IEEE, 1(2):4, 2017.
  • Robertson and Seymour [1995] Neil Robertson and Paul D Seymour. Graph minors. xiii. the disjoint paths problem. Journal of combinatorial theory, Series B, 63(1):65–110, 1995.
  • Sanchez-Gonzalez et al. [2018] Alvaro Sanchez-Gonzalez, Nicolas Heess, Jost Tobias Springenberg, Josh Merel, Martin Riedmiller, Raia Hadsell, and Peter Battaglia. Graph networks as learnable physics engines for inference and control. In International Conference on Machine Learning, pages 4467–4476, 2018.
  • Santoro et al. [2017] Adam Santoro, David Raposo, David G Barrett, Mateusz Malinowski, Razvan Pascanu, Peter Battaglia, and Timothy Lillicrap. A simple neural network module for relational reasoning. In Advances in neural information processing systems, pages 4967–4976, 2017.
  • Santoro et al. [2018] Adam Santoro, Felix Hill, David Barrett, Ari Morcos, and Timothy Lillicrap. Measuring abstract reasoning in neural networks. In International Conference on Machine Learning, pages 4477–4486, 2018.
  • Saxton et al. [2019] David Saxton, Edward Grefenstette, Felix Hill, and Pushmeet Kohli. Analysing mathematical reasoning abilities of neural models. In International Conference on Learning Representations, 2019.
  • Scarselli et al. [2009a] Franco Scarselli, Marco Gori, Ah Chung Tsoi, Markus Hagenbuchner, and Gabriele Monfardini. Computational capabilities of graph neural networks. IEEE Transactions on Neural Networks, 20(1):81–102, 2009a.
  • Scarselli et al. [2009b] Franco Scarselli, Marco Gori, Ah Chung Tsoi, Markus Hagenbuchner, and Gabriele Monfardini. The graph neural network model. IEEE Transactions on Neural Networks, 20(1):61–80, 2009b.
  • Selsam et al. [2019] Daniel Selsam, Matthew Lamm, Benedikt Bunz, Percy Liang, Leonardo de Moura, and David L Dill. Learning a SAT solver from single-bit supervision. In International Conference on Learning Representations, 2019.
  • Spelke and Kinzler [2007] Elizabeth S Spelke and Katherine D Kinzler. Core knowledge. Developmental science, 10(1):89–96, 2007.
  • Valiant [1984] Leslie G Valiant. A theory of the learnable. In

    Proceedings of the sixteenth annual ACM symposium on Theory of computing

    , pages 436–445. ACM, 1984.
  • Vapnik [2013] Vladimir Vapnik.

    The nature of statistical learning theory

    .
    Springer science & business media, 2013.
  • Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in neural information processing systems, pages 5998–6008, 2017.
  • Vinyals et al. [2015] Oriol Vinyals, Samy Bengio, and Manjunath Kudlur. Order matters: Sequence to sequence for sets. In International Conference on Learning Representations, 2015.
  • Wagstaff et al. [2019] Edward Wagstaff, Fabian B Fuchs, Martin Engelcke, Ingmar Posner, and Michael Osborne. On the limitations of representing functions on sets. In International Conference on Machine Learning, 2019.
  • Watters et al. [2017] Nicholas Watters, Daniel Zoran, Theophane Weber, Peter Battaglia, Razvan Pascanu, and Andrea Tacchetti. Visual interaction networks: Learning a physics simulator from video. In Advances in neural information processing systems, pages 4539–4547, 2017.
  • Wu et al. [2017] Jiajun Wu, Erika Lu, Pushmeet Kohli, Bill Freeman, and Josh Tenenbaum. Learning to see physics via visual de-animation. In Advances in Neural Information Processing Systems, pages 153–164, 2017.
  • Xu et al. [2018] Keyulu Xu, Chengtao Li, Yonglong Tian, Tomohiro Sonobe, Ken-ichi Kawarabayashi, and Stefanie Jegelka. Representation learning on graphs with jumping knowledge networks. In International Conference on Machine Learning, pages 5453–5462, 2018.
  • Xu et al. [2019] Keyulu Xu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. How powerful are graph neural networks? In International Conference on Learning Representations, 2019.
  • Yi et al. [2018] Kexin Yi, Jiajun Wu, Chuang Gan, Antonio Torralba, Pushmeet Kohli, and Josh Tenenbaum. Neural-symbolic vqa: Disentangling reasoning from vision and language understanding. In Advances in Neural Information Processing Systems, pages 1031–1042, 2018.
  • Zaheer et al. [2017] Manzil Zaheer, Satwik Kottur, Siamak Ravanbakhsh, Barnabas Poczos, Ruslan R Salakhutdinov, and Alexander J Smola. Deep sets. In Advances in Neural Information Processing Systems, pages 3391–3401, 2017.
  • Zhang et al. [2019] Chi Zhang, Feng Gao, Baoxiong Jia, Yixin Zhu, and Song-Chun Zhu. Raven: A dataset for relational and analogical visual reasoning. arXiv preprint arXiv:1903.02741, 2019.

Appendix A Proof for Proposition 3

We will prove the universal approximation of GNNs by showing that GNNs have at least the same expressive power as Deep Sets, and then apply the universal approximation of Deep Sets for permutation invariant continuous functions.

Zaheer et al. [58] prove the universal approximation of Deep Sets under the restriction that the set size is fixed and the hidden dimension is equal to the set size plus one. Wagstaff et al. [52] extend the universal approximation result for Deep Sets by showing that the set size does not have to be fixed and the hidden dimension is only required to be at least as large as the set size. The results for our purposes can be summarized as follows.

Universal approximation of Deep Sets. Assume the elements are from a compact set in . Any continuous function on a set of size bounded by , i.e., , that is permutation invariant to the elements in can be approximated arbitrarily close by some Deep Sets model with sufficiently large width and output dimension for its MLPs.

Next we show any Deep Sets can be expressed by some GNN with one message passing iteration. The computation structure of one-layer GNNs is shown below.

(12)

where and are parameterized by MLPs. If is a function that ignores so that for some , e.g., by letting part of the weight matricies in be , then we essentially get a Deep Sets in the following form.

(13)

For any such , we can get the corresponding via the construction above. Hence for any Deep Sets, we can express it with an one-layer GNN. The same result applies to GNNs with multiple layers (message passing iterations), because we can express a function by the composition of multiple ’s, which we can express with a GNN layer via our construction above. It then follows that GNNs are universal approximators for permutation invariant continuous functions.

Appendix B Proof for Proposition 3

For any GNN , we construct an MLP that is able to do the exact same computation as . It will then follow that the MLP can represent any function can represent. Suppose the computation structure of is the following.

(14)

where and are parameterized by MLPs. Suppose the set size is bounded by (the expressive power of GNNs also depend on  [52]). We first show the result for a fixed size input, i.e., MLPs can simulate GNNs if the input set has a fixed size, and then apply an ensemble approach to deal with variable sized input.

Let the input to the MLP be a vector concatenated by ’s, in some arbitrary ordering. For each message passing iteration of , any can be represented by an MLP. Thus, for each pair of , we can set weights in the MLP so that the the concatenation of all become the hidden vector after some layers of the MLP. With the vector of as input, in the next few layers of the MLP we can construct weights so that we have the concatenation of as the result of the hidden dimension, because we can encode summation with weights in MLPs. So far, we can simulate an iteration of GNN with layers of MLP. We can repeat the process for times by stacking the similar layers. Finally, with a concatenation of as our hidden dimension in the MLP, similarly, we can simulate with layers of MLP. Stacking all layers together, we have obtained an MLP that can simulate .

To deal with variable sized inputs, we construct MLPs that can simulate the GNN for each input set size . Then we construct a meta-layer, whose weights represent (universally approximate) the summation of the output of MLPs multiplied by an indicator function of whether each MLPs has the same size as the set input (these need to be input information). The meta layer weights on top can then essentially select the output from of MLP that has the same size as the set input and then exactly simulate the GNN. Note that the MLP we construct here has the requirement for how we input the data and the information of set sizes etc. In practice, we can have MLPs and decide which MLP to use depending on the input set size.

Appendix C Proof for Theorem 3.1

We will show the learnability result by an inductive argument. Specifically, we will show that under our setting and assumptions, the error between the learned function and correct function on the test set will not blow up after the transform of another learned function , assuming learnability on previous by induction. Thus, we can essentially provably learn at all layers/iterations and eventually learn .

Suppose we have performed the sequential learning. Let us consider what happens at the test time. Let be the correct functions as defined in the match of structure assumption. Let be the functions learned by algorithm and MLP . We have input , and our goal is to bound with high probability. To show this, we bound the error of the intermediate representation vectors, i.e., the output of and , and thus, the input to and .

Let us first consider what happens for the first MLP . and have the same input distribution , where are obtained from , e.g., the pairwise object representations as in (1). Hence, by the learnability assumption on (match of structures assumption), with probability at least . The error for the input of is then with failure probability , because there are a constant number of terms of aggregation of ’s output, and we can apply union bound to upper bound the failure probability.

Next, we proceed by induction. Let us fix a . Let denote the input for , which are generated by the previous ’s, and let denote the input for , which are generated by the previous ’s. Assume with failure probability at most . We aim to show that this holds for . For the simplicity of notation, let denote the correct function and let denote the learned function . Since there are a constant number of terms for aggregation, our goal is then to bound . By triangle inequality, we have

(15)
(16)

We can bound the first term with the Lipschitzness assumption of as the following.

(17)

To bound the second term, our key insight is that is a learnale correct function, so by the learnability assumption (match of structures assumption), it is close to the function learned by the MLP learning algorithm on the correct samples, i.e., is close to . Moreover, is generated by the MLP learning algorithm on the perturbed samples, i.e., . By the algorithm stability assumption, and should be close if the input samples are only slightly perturbed. It then follows that

(18)
(19)
(20)

where and are the training samples at the same layer . Here, we apply the same induction condition as what we had for and : with failure probability at most . We can then apply union bound to bound the probability of any bad event happening. Here, we have 3 bad events each happening with probability at most . Thus, with probability at least , we have

(21)

This completes the proof.

Appendix D Proof for Theorem 3.2

Theorem 3.2 is a generalization of Theorem in [4], which addresses the scalar case. See [4] for a complete list of assumptions.

[4] Suppose we have , , where , , and or . Let be an overparameterized two-layer MLP that is randomly initialized and trained with gradient descent for a sufficient number of iterations. The sample complexity is .

To extend the sample complexity bound to vector-valued functions, we view each entry/component of the output vector as an independent scalar-valued output. We can then apply a union bound to bound the error rate and failure probability for the output vector, and thus, bound the overall sample complexity.

Let and be the given error rate and failure probability. Moreover, suppose we choose some error rate and failure probability for the output/function of each entry. Applying Theorem D to each component

(22)

yields a sample complexity bound of

(23)

for each . Now let us bound the overall error rate and failure probability given and for each entry. The probability that we fail to learn each of the is at most . Hence, by a union bound, the probability that we fail to learn any of the is at most . Thus, with probability at least , we successfully learn all for , so the error for every entry is bounded by . The error for the vector output is then at most .

Setting and gives us and . Thus, if we can successfully learn the function for each output entry independently with error and failure rate , we can successfully learn the entire vector-valued function with rate and . This yields the following overall sample complexity bound:

(24)

Regarding as a constant, we can further simplify the sample complexity to

(25)

Appendix E Proof for Corollary 3.2

Our main insight is that a giant MLP learns the same function for times and encode them in the weights. This leads to the extra sample complexity through Theorem 3.2, because the number of polynomial terms is of order .

First of all, the function can be expressed as the following polynomial.

(26)

We have , so . Hence, by Theorem 3.2, it takes samples for an MLP to learn . Under the sequential training setting, an one-layer GNN applies an MLP to learn , and then sums up the outcome of for all pairs . Here, we essentially get the aggregation error from pairs. However, we will see that applying an MLP to learn will also incur the same aggregation error. Hence, we do not need to consider the aggregation error effect when we compare the sample complexities.

Now we consider using MLP to learn the function . No matter in what order the objects are concatenated, we can express with the sum of polynomials as the following.

(27)

where has at the -th entry, at the -th entry and elsewhere. Hence . It then follows from Theorem 3.2 and union bound that it takes to learn , where and . Here, as we have discussed above, the same aggregation error occurs in the aggregation process of , so we can simply consider for both. Thus, comparing and gives us the difference.

Appendix F Proof for Corollary 4

The main proof idea is that, any object feature is embedded in a subspace indexed by a subset . Hence, the feature extractor function can be represented as a linear function of , whose coefficients depend on , i.e., which coordinates of encode the feature .

We can obtain each output coordinate of with the following function.

(28)

where has at the -th entry and otherwise, and is the -th element in . Thus, we have for any and . Suppose the length of each object feature is some constant. It then follows from Theorem 3.2 that

(29)

Appendix G Proof for Claim 4

We prove the claim by contradiction. Suppose there exists such that for any and . This implies that for any , we have . It follows that for any . Now consider some and so that . We must have . However, because . Hence, there exists and so that . We have reached a contradiction.

Appendix H Experiments: Data and Training Details

h.1 Fantastic Treasure: Maximum Value Difference

Dataset generation. In the dataset, we sample training data, validation data, and

test data. For each model, we report the test accuracy with the hyperparameter setting that achieves the best validation accuracy. In each training sample, the input universe consists of 25 treasures

. For each treasure , we have , where the location is sampled uniformly from , the value is sample uniformly form , and the color is sampled uniformly from . The task is to answer what the difference is in value between the most and least valuable treasure. We generate the answer label for a universe as follows: we find the the maximum difference in value among all treasures and set it to . Then we make the label

into one-hot encoding with

classes.

Hyperparameter setting. We train all models with the Adam optimizer, with learning rate from , and , and we decay the learning rate by every steps. We use cross-entropy loss. We train all models for epochs. We tune batch size of and . We apply weight decay of for all models. For the MLP model, we choose the number of of hidden layers from and . For models other than MLP, we set the number of hidden layers of the last MLP, i.e., , to . For models other than MLP, we choose the number of hidden layers of the MLPs prior to the last MLP, i.e., , from and . For the MLP model, we set the hidden dimension to . For all models, we choose the hidden dimension of all MLPs from and . Moreover, dropout with rate is applied before the last two hidden layers of , i.e., the last MLP module in all models. Dropout with rate is also applied before the last two hidden layers of the MLP model.

h.2 Fantastic Treasure: Furthest Pair

Dataset generation. In the dataset, we sample training data, validation data, and test data. For each model, we report the test accuracy with the hyperparameter setting that achieves the best validation accuracy. In each training sample, the input universe consists of 25 treasures . For each treasure , we have , where the location is sampled uniformly from , the value is sample uniformly form , and the color is sampled uniformly from . The task is to answer what are the colors of the two treasure that are the most distant from each other. We generate the answer label for a universe as follows: we find the pair of treasures that are the most distant from each other, say . Then we order the pair

to obtain an ordered pair

with (aka. and (), where denotes the color of . Then we compute the label from by counting how many valid pairs of colors are smaller than (a pair is smaller than iff i). or ii). and ). The label is one-hot encoding of the minimum cost with classes.

Hyperparameter setting. We train all models with the Adam optimizer, with learning rate from , and , and we decay the learning rate by every steps. We use cross-entropy loss. We train all models for epochs. We tune batch size of and . We apply weight decay of for all models. For the MLP model, we choose the number of of hidden layers from and . For models other than MLP, we set the number of hidden layers of the last MLP, i.e., , to . For models other than MLP, we choose the number of hidden layers of the MLPs prior to the last MLP, i.e., , from and . For the MLP model, we set the hidden dimension to . For all models, we choose the hidden dimension of all MLPs from and . Moreover, dropout with rate is applied before the last two hidden layers of , i.e., the last MLP module in all models. Dropout with rate is also applied before the last two hidden layers of the MLP model.

h.3 Dragon Trainer

Dataset generation. In the dataset, we sample training data, validation data, and test data. For each model, we report the test accuracy with the hyperparameter setting that achieves the best validation accuracy. In each training sample, the input universe consists of the trainer and dragons , and the request level , i.e., we need to challenge dragon . We have , where indicates the combat level, and the location is sampled uniformly from . We generate the answer label for a universe as follows. We implement a shortest path algorithm to compute the minimum cost from the trainer to dragon , where the cost is defined in Section 5. Then the label is a one-hot encoding of minimum cost with classes. Moreover, when we sample the data, we apply rejection sampling to ensure that the minimum cost’s shortest path is of length with equal probability. That is, we eliminate the trivial questions.

Hyperparameter setting. We train all models with the Adam optimizer, with learning rate from and , and we decay the learning rate by every steps. We use cross-entropy loss. We train all models for epochs. We tune batch size of and . We apply weight decay of for all models. For the MLP model, we choose the number of layers from and . For models other than MLP, we set the number of hidden layers of the last MLP, i.e., , to . For models other than MLP, we choose the number of hidden layers of the MLPs prior to the last MLP, i.e., , from and . For the MLP model, we set the hidden dimension to . For all models, we choose MLP hidden dimensions from and . Moreover, dropout with rate is applied before the last two hidden layers of , i.e., the last MLP module in all models. Dropout with rate is also applied before the last two hidden layers of the MLP model.

h.4 Subset Sum

Dataset generation. In the dataset, we sample training data, validation data, and test data. For each model, we report the test accuracy with the hyperparameter setting that achieves the best validation accuracy. In each training sample, the input universe consists of 6 numbers , where each is uniformly sampled from [-200..200]. The goal is to decide if there exists a subset that sums up to . In the data generation, we carefully decrease the number of questions that have trivial answers: 1)we control the number of samples where to be around 1% of the total training data; 2) we further control the number of samples where