Graph Transformation Policy Network for Chemical Reaction Prediction

12/22/2018 ∙ by Kien Do, et al. ∙ 6

We address a fundamental problem in chemistry known as chemical reaction product prediction. Our main insight is that the input reactant and reagent molecules can be jointly represented as a graph, and the process of generating product molecules from reactant molecules can be formulated as a sequence of graph transformations. To this end, we propose Graph Transformation Policy Network (GTPN) -- a novel generic method that combines the strengths of graph neural networks and reinforcement learning to learn the reactions directly from data with minimal chemical knowledge. Compared to previous methods, GTPN has some appealing properties such as: end-to-end learning, and making no assumption about the length or the order of graph transformations. In order to guide model search through the complex discrete space of sets of bond changes effectively, we extend the standard policy gradient loss by adding useful constraints. Evaluation results show that GTPN improves the top-1 accuracy over the current state-of-the-art method by about 3 model's performances and prediction errors are also analyzed carefully in the paper.



There are no comments yet.


page 19

page 20

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Chemical reaction product prediction is a fundamental problem in organic chemistry. It paves the way for planning syntheses of new substances (Chen & Baldi, 2009). For decades, huge effort has been spent to solve this problem. However, most methods still depend on the handcrafted reaction rules (Chen & Baldi, 2009; Kayala & Baldi, 2011; Wei et al., 2016)

or heuristically extracted reaction templates

(Segler & Waller, 2017; Coley et al., 2017), thus are not well generalizable to unseen reactions.

A reaction can be regarded as a set (or unordered sequence) of graph transformations in which reactants represented as molecular graphs are transformed into products by modifying the bonds between some atom pairs (Jochum et al., 1980; Ugi et al., 1979). See Fig. 1 for an illustration. We call an atom pair that changes its connectivity during reaction and its new bond a reaction triple . The reaction product prediction problem now becomes predicting a set of reaction triples given the input reactants and reagents. We argue that in order to solve this problem well, an intelligent system should have two key capabilities: (a) Understanding the molecular graph structure of the input reactants and reagents so that it can identify possible reactivity patterns (i.e., atom pairs with changing connectivity). (b) Knowing how to choose from these reactivity patterns a correct set of reaction triples to generate the desired products.

Recent state-of-the-art methods (Jin et al., 2017; Bradshaw et al., 2018) have built the first capability by leveraging graph neural networks (Duvenaud et al., 2015; Hamilton et al., 2017; Pham et al., 2017; Gilmer et al., 2017). However, these methods are either unaware of the valid sets of reaction triples (Jin et al., 2017) or limited to sequences of reaction triples with a predefined orders (Bradshaw et al., 2018). The main challenge is that the space of all possible configurations of reaction triples is extremely large and non-differentiable. Moreover, a small change in the predicted set of reaction triples can lead to very different reaction products and a little mistake can produce invalid prediction.

Figure 1: A sample reaction represented as a set of graph transformations from reactants (leftmost) to products (rightmost). Atoms are labeled with their type (Carbon, Oxygen,…) and their index (1, 2,…) in the molecular graph. The atom pairs that change connectivity and their new bonds (if existed) are highlighted in green. There are two bond changes in this case: 1) The double bond between O:1 and C:2 becomes single. 2) A new single bond between C:2 and C:10 is added.

In this paper, we propose a novel method called Graph Transformation Policy Network () that addresses the aforementioned challenges. Our model consists of three main components: a graph neural network (GNN), a node pair prediction network (NPPN) and a policy network (PN). Starting from the initial graph of reactant and reagent molecules, our model iteratively alternates between modeling an input graph using GNN and predicting a reaction triple using NPPN and PN to generate a new intermediate graph as input for the next step until it decides to stop. The final generated graph is considered as the predicted products of the reaction. Importantly, does not assume any fixed number or any order of bond changes but learn these properties itself. One can view

as a reinforcement learning (RL) agent that operates on a complex and non-differentiable space of sets of reaction triples. To guide our model towards learning a diverse yet robust-to-small-changes policy, we customize our loss function by adding some useful constraints to the standard policy gradient loss

(Mnih et al., 2016).

To the best of our knowledge, is the most generic approach for the reaction product prediction problem so far in the sense that: i) It combines graph neural networks and reinforcement learning into a unified framework and trains everything end-to-end; ii) It does not use any handcrafted or heuristically extracted reaction rules/templates to predict the products. Instead, it automatically learns various types of reactions from the training data and can generalize to unseen reactions; iii) It can interpret how the products are formed via the sequence of reaction triples it generates.

We evaluate on two large public datasets named USPTO-15k and USPTO. Our method significantly outperforms all baselines in the top-1 accuracy, achieving new state-of-the-art results of 82.39% and 83.20% on USPTO-15k and USPTO, respectively. In addition, we also provide comprehensive analyses about the performance of and about different types of errors our model could make.

2 Method

2.1 Chemical Reaction as Markov Decision Process of Graph Transformations

A reaction occurs when reactant molecules interact with each other in the presence (or absence) of reagent molecules to form new product molecules by breaking or adding some of their bonds. Our main insight is that reaction product prediction can be formulated as predicting a sequence of such bond changes given the reactant and reagent molecules as input. A bond change is characterized by the atom pair (where the change happens) and the new bond type (what is the change). We call this atom pair a reaction atom pair and call this atom pair with the new bond type a reaction triple.

More formally, we represent the entire system of input reactant and reagent molecules as a labeled graph with multiple connected components, each of which corresponds to a molecule. Nodes in are atoms labeled with their atomic numbers and edges in are bonds labeled with their bond types. Given as input, we predict a sequence of reaction triples that transforms into a graph of product molecules .

As reactions vary in number of transformation steps, we represent the sequence of reaction triples as or for short. Here is the maximum number of steps, is a pair of nodes, is the new edge type of , and is a binary signal that indicates the end of the sequence. If the sequence ends at , will be and will be . At every step , if , we apply the predicted edge change on the current graph to create a new intermediate graph as input for the next step

. This iterative process of graph transformation can be formulated as a Markov Decision Process (MDP) characterized by a tuple

, in which is a set of states, is a set of actions, is a state transition function, is a reward function, and is a discount factor. Since the process is finite and contains no loop, we set the discount factor to be . The rest of the MDP tuple are defined as follows:

  • State: A state is an intermediate graph generated at step . When , we denote .

  • Action: An action performed at step is the tuple . The action is composed of three consecutive sub-actions: , , and . If , our model will ignore the next sub-actions and , and all the future actions . Note that setting to be the first sub-action is useful in case a reaction does not happen, i.e.,

  • State Transition: If , the current graph is modified based on the reaction triple to generate a new intermediate graph . We do not incorporate chemical rules such as valency check during state transition because the current bond change may result in invalid intermediate molecules , but later, other bond changes may compensate it to create the valid final products .

  • Reward: We use both immediate rewards and delayed rewards to encourage our model to learn the optimal policy faster. At every step , if the model predicts , or correctly, it will receive a positive reward for each correct sub-action. Otherwise, a negative reward is given. After the prediction process has terminated, if the generated products are exactly the same as the groundtruth products, we give the model a positive reward, otherwise a negative reward. The concrete reward values are provided in Appendix A.3.

Figure 2: Workflow of a Graph Transformation Policy Network (

). At every step of the forward pass, our model performs 7 major functions: 1) Computing the atom representation vectors, 2) Computing the most possible

reaction atom pairs, 3) Predicting the continuation signal , 4) Predicting the reaction atom pair , 5) Predicting a new bond of this atom pair, 6) Updating the atom representation vectors, and 7) Updating the recurrent state.

2.2 Graph Transformation Policy Network

In this section, we describe the architecture of our model a Graph Transformation Policy Network (). has three main components namely a Graph Neural Network (GNN), a Node Pair Prediciton Network (NPPN), and a Policy Network (PN). Each component is responsible for one or several key functions shown in Fig. 2: GNN performs functions 1 and 6; NPPN performs function 2; and PN performs functions 3, 4 and 5. Apart from these components,

also has a Recurrent Neural Network (RNN) to keep track of the past transformations. The hidden state

of this RNN is used by NPPN and PN to make accurate prediction.

2.2.1 Graph Neural Network

To model the intermediate graph at step , we compute the node state vector of every node in by using a variant of the Message Passing Neural Networks (Gilmer et al., 2017):


where is the number of message passing steps; is the feature vector of node ; is the set of all neighbor nodes of node ; and is the state vector of node at the previous step. When , is initialized from using a neural network. Details about the function are provided in Appendix A.1.

2.2.2 Node Pair Prediction Network

In order to predict how likely an atom pair of the intermediate graph will change its bond, we assign with a score . If is high,

is more probably a reaction atom pair, otherwise, less probably. Similar to

(Jin et al., 2017), we use two different networks called “local” network and “global” network for this task. In case of the “local” network, is computed as:


where is a neural network;

is a nonlinear activation function (e.g., ReLU);

denotes vector concatenation; and are parameters; is the hidden state of the RNN at the previous step; and is the representation vector of the bond between . If there is no bond between we assume that its bond type is “NULL”. We consider as the representation vector for the atom pair .

The “global” network leverages self-attention (Vaswani et al., 2017; Wang et al., 2018) to detect compatibility between atom and all other atoms before computing the scores:


where is the attention score from node to every other node ; is the context vector of atom that summarizes the information from all other atoms.

During experiments, we tried both options mentioned above and saw that the “global” network clearly outperforms the “local” network so we set the “global” network as a default module in our model. In addition, since reagents never change their form during a reaction, we explicitly exclude all atom pairs that have either atoms belong to the reagents. This leads to better results than not using reagent information. Detailed analyses are provided in Appendix A.5.

Top- atom pairs

Because the number of atom pairs that actually participate in a reaction is very small (usually smaller than 10) compared to the total number of atom pairs of the input molecules (usually hundreds or thousands), it is much more efficient to identify reaction triples from a small subset of highly probable reaction atom pairs. For that reason, we extract atom pairs with the highest scores. Later, we will predict reaction triples taken from these atom pairs only. We denote the set of top- atom pairs, their corresponding scores, and representation vectors as , and , respectively.

2.2.3 Policy Network

Predicting continuation signal

To account for varying number of transformation steps, PN generates a continuation signal to indicate whether prediction should continue or terminate.

is drawn from a Bernoulli distribution:


where is the previous RNN state; is the set of representation vectors of the top atom pairs at the current step; is a neural network; is a function that maps an unordered set of inputs to an output vector. For simplicity, we use a mean function:

Predicting atom pair

At the next sub-step, PN predicts which atom pair changes its bond during the reaction by sampling from the top- atom pairs with probability:


where is the score of the atom pair computed in Eq. (5). After predicting the atom pair , we will mask it to ensure that it could not be in the top again at future steps.

Predicting bond type

Given an atom pair sampled from the previous sub-step, we predict a new bond type between and to get a complete reaction triple using the probability:


where is the total number of bond types; is the representation vector of computed in Eq. (4); is the old bond of ; and are the embedding vectors corresponding to the bond type and , respectively; and is a neural network.

2.3 Updating States

After predicting a complete reaction triple , our model updates: i) the new recurrent hidden state , and ii) the new node representation vectors of the new intermediate graph for . These updates are presented in Appendix A.2.

2.4 Training

Loss function plays a central role in achieving fast training and high performance. We design the following loss:

where is the Advantage Actor-Critic (A2C) loss (Mnih et al., 2016) to account for the correct sequence of reaction triples;

is the loss for estimating the value function used in A2C;

accounts for binary change in the bond of an atom pair; penalizes long predicted sequences; and is the rank loss to force a ground-truth reaction atom pair to appear in the top-; and are tunable coefficients. The component losses are explained in the following.

2.4.1 Reaction triple loss

The loss follows a policy gradient method known as Advantage Actor-Critic (A2C):


where is the first step that ; , and are called advantages

. To compute these advantages, we use the unbiased estimations called Temporal Different errors, defined as:


where , , are immediate rewards at step ; at the final step , the model receives additional delayed rewards; is the discount factor; and is the parametric value function. We train using the following mean square error loss:


where is the return at step .

Episode termination during training

Although the loss defined in Eq. (9) is correct, it is not good to use in practice because: i) If our model selects a wrong sub-action at any sub-step of the step (), the whole predicted sequence will be incorrect regardless of what will be predicted from to . Therefore, computing the loss for actions from to is redundant. ii) More importantly, the incorrect updates of the graph structure at subsequent steps from to will lead to cumulative prediction errors which make the training of our model much more difficult.

To resolve this issue, during training, we use a binary vector to keep track of the first wrong sub-action: where denotes the sub-step at which our model chooses a wrong sub-action the first time. The actor-critic loss in Eq. (9) now becomes:


where is the maximum number of steps. Similarly, we change the value loss into:

2.4.2 Reaction atom pair loss

To train our model to assign higher scores to reaction atom pairs and lower to non-reaction atom pairs, we use the following cross-entropy loss function:


where ; is a mask of the atom pair at step ; is the label indicating whether the atom pair is a reaction atom pair or not; (see Eq. (5)).

2.4.3 Constraint on the sequence length

One major difficulty of the chemical reaction prediction problem is to know exactly when to stop prediction so we can make accurate inference. By forcing the model to stop immediately when making wrong prediction, we can prevent cumulative error and significantly reduce variance during training. But it also comes with a cost: The model cannot learn (because it does not have to learn) when to stop. This phenomenon can be visualized easily as the model predicts

for the signal at every step during inference. In order to make the model aware of the correct sequence length during training, we define a loss that punishes the model if it produces a longer sequence than the ground truth sequence:


where is the end step of the ground-truth sequence. Note that the loss in Eq. (16) is not applied when . The reason is that forcing with is not theoretically correct because all the signals after are assumed to be . The incentive to force close to when it is smaller than has already been included in the advantages in Eq. (14).

2.4.4 Constraint on the top- atom pairs

Ideally, the loss from Eq. (15) pushes a reaction atom pair into the top- atom pairs at each step . However, this is not guaranteed, especially when comes close to . To encourage the ground-truth reaction atom pair with the highest score to appear in the top , we introduce an additional rank-based loss:

where is computed as:


3 Experiments

3.1 Dataset

We evaluate our model on two standard datasets USPTO-15k (15K reactions) and USPTO (480K reactions) which have been used in previous works (Jin et al., 2017; Schwaller et al., 2018; Bradshaw et al., 2018). Details about these datasets are given in Table 1. The USPTO dataset contains reactant, reagent and product molecules represented as SMILES strings. Using RDKit111, we convert the SMILES strings into molecule objects and store them as graphs. For each reaction, every atom in the reactant and reagent molecules is identified with a unique “atom map number”. This identity is the same in the products. Using this knowledge, we compare every atom pair in the input molecules with the correspondent in the product molecules to obtain a ground-truth set of reaction triples for training. In USPTO-15k, the ground-truth sets of reaction triples was precomputed by (Jin et al., 2017).

Dataset #reactions #changes #molecules #atoms #bonds
USPTO-15k train 10,500 1 | 11 | 2.3 1 | 20 | 3.6 4 | 100 | 34.9 3 | 110 | 34.7
valid 1,500 1 | 11 | 2.3 1 | 20 | 3.6 7 | 94 | 34.5 5 | 99 | 34.2
test 3,000 1 | 11 | 2.3 1 | 16 | 3.6 7 | 98 | 34.9 5 | 102 | 34.7
USPTO train 409,035 1 | 6 | 2.2 2 | 29 | 4.8 9 | 150 | 39.7 6 | 165 | 38.6
valid 30,000 1 | 6 | 2.2 2 | 25 | 4.8 9 | 150 | 39.6 7 | 158 | 38.5
test 40,000 1 | 6 | 2.2 2 | 22 | 4.8 9 | 150 | 39.8 7 | 162 | 38.7
Table 1: Statistics of USPTO-15k and USPTO datasets. “changes” means bond changes, “molecules” means reactants and reagents in a reaction; “atoms” and “bonds” are defined for a molecule. Apart from “#reactions”, other columns are presented in the format “min | max | mean”.

3.2 Reaction Atom Pair Prediction

In this section, we test our model’s ability to identify reaction atom pairs by formulating it as a ranking problem with the scores computed in Eq. (5). Similar to (Jin et al., 2017), we use Coverage@k

as the evaluation metric, which is the proportion of reactions that have

all groundtruth reaction atom pairs appear in the top predicted atom pairs.

We compare our proposed graph neural network (GNN) with Weisfeiler-Lehman Network (WLN) (Jin et al., 2017) and Column Network (CLN) (Pham et al., 2017). Since our GNN explicitly uses reagent information to compute the scores of atom pairs, we modify the implementation of WLN and CLN accordingly for fair comparison. From Table 2, we observe that our GNN clearly outperforms WLN and CLN in all cases. We attribute this improvement to the use of a separate node state vector (different from the node feature vector ) for updating the structural information of a node (see Eq. (21)). The other two models, on the other hand, only use a single vector to store both the node features and structure, hence, some information may be lost. In addition, using explicit reagent information boosts the prediction accuracy, which improves the WLN by 1-7% depending on the metrics. The presence of reagent information reduces the number of atom pairs to be searched on and contributes to the likelihood of reaction atom pairs. Further results are presented in Appendix A.5.

C@6 C@8 C@10 C@6 C@8 C@10
WLN (Jin et al., 2017) 81.6 86.1 89.1 89.8 92.0 93.3
WLN (Jin et al., 2017) 88.45 91.65 93.34 90.97 93.98 95.26
CLN (Pham et al., 2017) 88.68 91.63 93.07 90.72 93.57 94.80
Our GNN 88.92 92.00 93.57 91.24 94.17 95.33
Table 2: Results for reaction atom pair prediction. C@k is coverage at . Best results are highlighted in bold. WLN is the original model from (Jin et al., 2017) while WLN is our re-implemented version. Except for WLN, other models explicitly use reagent information.

3.3 Top- Atom Pair Extraction

Figure 3: Coverage@k and Recall@k with respect to for the USPTO dataset.

The performance of our model depends on the number of selected top atom pairs . The value of presents a trade-off between coverage and efficiency. In addition to the metric Coverage@k in Sec. 3.2, we use Recall@k which is the proportion of correct atom pairs that appear in top to find the good . Fig. 3 shows Coverage@k and Recall@k for the USPTO dataset with respect to . We see that both curves increase rapidly when and stablize when . We also ran experiments with , , and observed that their prediction results are quite similar. Hence, in what follows we select for efficiency.

3.4 Reaction Product Prediction

P@1 P@3 P@5 P@1 P@3 P@5
WLDN (Jin et al., 2017) 76.7 85.6 86.8 79.6 87.7 89.2
Seq2Seq (Schwaller et al., 2018) - - - 80.3 86.2 87.5
72.31 - - 71.26 - -
74.56 82.62 84.23 73.25 80.56 83.53
74.56 83.19 84.97 73.25 84.31 85.76
82.39 85.60 86.68 83.20 84.97 85.90
82.39 85.73 86.78 83.20 86.03 86.48
Table 3: Results for reaction prediction. P@k is precision at . State-of-the-art results from (Jin et al., 2017) are written in italic. Results from (Schwaller et al., 2018) are marked with and they are computed on a slightly different version of USPTO that contains only single-product reactions. Best results are highlighted in bold. : With beam search (beam width = 20), : Invalid product removal, : Duplicated product removal.

This experiment validates on full reaction product prediction against the recent state-of-the-art methods (Jin et al., 2017; Schwaller et al., 2018) using the accuracy metric. The recent method ELECTRO (Bradshaw et al., 2018) is not compatible here because it was only evaluated on a subset of USPTO limited to linear chain topology. Comparison against ELECTRO is reported separately in Appendix A.6. Table 3 shows the prediction results. We produce multiple reaction product candidates by using beam search decoding with beam width . Details about beam search and its behaviors are presented in Appendix A.4.

In brief, we compute the normalized-over-length log probabilities of predicted sequences of reaction triples and sort these values in descending order to get a rank list of possible reaction outcomes. Given a predicted sequence of reaction triples , we can generate reaction products from input reactants simply by replacing the old bond of with . However, these products are not guaranteed to be valid (e.g., maximum valence constraint violation or aromatic molecules cannot be kekulized) so we post-process the outputs by removing all invalid products. The removal increases the top-1 accuracy by about 8% and 10% on USPTO-15k and USPTO, respectively. Due to the permutation invariance of the predicted sequence of reaction triples, some product candidates are duplicate and will also be removed. This does not lead to any change in P@1 but slightly improves P@3 and P@5 by about 0.5-1% on the two datasets.

Overall, with beam search and post-processing outperforms both WLDN (Jin et al., 2017) and Seq2Seq (Schwaller et al., 2018) in the top-1 accuracy. For the top-3 and top-5, our model’s performance is comparable to WLDN’s on USPTO-15k and is worse than WLDN’s on USPTO. It is not surprising since our model is trained to accurately predict the top-1 outcomes instead of ranking the candidates directly like WLDN. It is important to emphasize that we did not tune the model hyper-parameters when training on USPTO but reused the optimal settings from USPTO-15k (which is 25 times smaller than USPTO) so the results may not be optimal (see Appendix A.3 for more training detail).

4 Related Work

4.1 Learning to Predict Chemical Reaction

In chemical reaction prediction, machine learning has replaced rule-based methods

(Chen & Baldi, 2009) for better generalizability and scalability. Existing machine learning-based techiques are either template-free (Kayala & Baldi, 2011; Jin et al., 2017; Fooshee et al., 2018) and template-based (Wei et al., 2016; Segler & Waller, 2017; Coley et al., 2017). Both groups share the same mechanism: running multiple stages with the aid of reaction templates or rules. For example, in (Wei et al., 2016)

the authors proposed a two-stage model that first classifies reactions into different types based on the neural fingerprint vectors

(Duvenaud et al., 2015) of reactant and reagent molecules. Then, it applies pre-designed SMARTS transformation on the reactants with respect to the most suitable predicted reaction type to generate the reaction products.

The work of (Jin et al., 2017) treats a reaction as a set of bond changes so in the first step, they predict which atom pairs are likely to be reactive using a variant of graph neural networks called Weisfeiler-Lehman Networks (WLNs). In the next step, they do almost the same as (Coley et al., 2017) by modifying the bond type between the selected atom pairs (with chemical rules satisfied) to create product candidates and rank them (with reactant molecules as addition input) using another kind of WLNs called Weifeiler-Lehman Different Networks (WLDNs).

To the best of our knowledge, (Jin et al., 2017) is the first work that achieves remarkable results (with the Precision@1 is about 79.6%) on the large USPTO dataset containing more than 480 thousands reactions. Works of (Nam & Kim, 2016) and (Schwaller et al., 2018) avoid multi-stage prediction by building a seq2seq model that generates the (canonical) SMILES string of the single product from the concatenated SMILES strings of the reactants and reagents in an end-to-end manner. However, their methods cannot deal with sets of reactants/reagents/products properly as well as cannot provide concrete reaction mechanism for every reaction.

The most recent work on this topic is (Bradshaw et al., 2018) which solves the reaction prediction problem by predicting a sequence of bond changes given input reactants and reagents represented as graphs. To handle ordering, they only select reactions with predefined topology. Our method, by contrast, is order-free and can be applied to almost any kind of reactions.

4.2 Graph Neural Networks for Modeling Molecules

In recent years, there has been a fast development of graph neural networks (GNNs) for modeling molecules. These models are proposed to solve different problems in chemistry including toxicity prediction (Duvenaud et al., 2015), drug activity classification (Shervashidze et al., 2011; Dai et al., 2016; Pham et al., 2018), protein interface prediction (Fout et al., 2017) and drug generation (Simonovsky & Komodakis, 2018; Jin et al., 2018). Most of them can be regarded as variants of message-passing graph neural networks (MPGNNs) (Gilmer et al., 2017).

4.3 Reinforcement Learning for Structural Reasoning

Reinforcement learning (RL) has become a standard approach to many structural reasoning problems222Structural reasoning is a problem of inferring or generating new structure (e.g. objects with relations) because it allows agents to perform discrete actions. A typical example of using RL for structural reasoning is drug generation (Li et al., 2018; You et al., 2018). Both (Li et al., 2018) and (You et al., 2018) learn the same generative policy whose action set including: i) adding a new atom or a molecular scaffold to the intermediate graph, ii) connecting existing pair of atoms with bonds, and iii) terminating generation. However, (You et al., 2018) uses an adversarial loss to enforce global chemical constraints on the generated molecules as a whole instead of using the common reconstruction loss as in (Li et al., 2018)

. Other examples are path-based relational reasoning in knowledge graphs

(Das et al., 2018)

and learning combinatorial optimization over graphs

(Khalil et al., 2017).

5 Discussion

We have introduced a novel method named Graph Transformation Policy Network () for predicting products of a chemical reaction. uses graph neural networks to represent input reactant and reagent molecules, and uses reinforcement learning to find an optimal sequence of bond changes that transforms the reactants into products. We train using the Advantage Actor-Critic (A2C) method with appropriate constraints to account for notable aspects of chemical reaction. Experiments on real datasets have demonstrated the competitiveness of our model.

Although the was proposed to solve the chemical reaction problem, it is indeed generic to solve the graph transformation problem, which can be useful in reasoning about relations (e.g., see (Zambaldi et al., 2018)) and changes in relation. Open rooms include addressing dynamic graphs over time, extending toward full chemical planning and structural reasoning using RL.


  • Battaglia et al. (2016) Peter Battaglia, Razvan Pascanu, Matthew Lai, Danilo Jimenez Rezende, et al. Interaction networks for learning about objects, relations and physics. In Advances in neural information processing systems, pp. 4502–4510, 2016.
  • Bradshaw et al. (2018) John Bradshaw, Matt J Kusner, Brooks Paige, Marwin HS Segler, and José Miguel Hernández-Lobato. Predicting electron paths. arXiv preprint arXiv:1805.10970, 2018.
  • Chen & Baldi (2009) Jonathan H Chen and Pierre Baldi. No electron left behind: a rule-based expert system to predict chemical reactions and reaction mechanisms. Journal of chemical information and modeling, 49(9):2034–2043, 2009.
  • Cho et al. (2014) Kyunghyun Cho, Bart Van Merriënboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. Learning phrase representations using RNN encoder-decoder for statistical machine translation. EMNLP, 2014.
  • Coley et al. (2017) Connor W Coley, Regina Barzilay, Tommi S Jaakkola, William H Green, and Klavs F Jensen. Prediction of organic reaction outcomes using machine learning. ACS central science, 3(5):434–443, 2017.
  • Dai et al. (2016) Hanjun Dai, Bo Dai, and Le Song. Discriminative embeddings of latent variable models for structured data. In International Conference on Machine Learning, pp. 2702–2711, 2016.
  • Das et al. (2018) Rajarshi Das, Shehzaad Dhuliawala, Manzil Zaheer, Luke Vilnis, Ishan Durugkar, Akshay Krishnamurthy, Alex Smola, and Andrew McCallum. Go for a walk and arrive at the answer: Reasoning over paths in knowledge bases using reinforcement learning. ICLR, 2018.
  • Duvenaud et al. (2015) David K Duvenaud, Dougal Maclaurin, Jorge Iparraguirre, Rafael Bombarell, Timothy Hirzel, Alán Aspuru-Guzik, and Ryan P Adams. Convolutional networks on graphs for learning molecular fingerprints. In Advances in Neural Information Processing Systems, pp. 2224–2232, 2015.
  • Fooshee et al. (2018) David Fooshee, Aaron Mood, Eugene Gutman, Mohammadamin Tavakoli, Gregor Urban, Frances Liu, Nancy Huynh, David Van Vranken, and Pierre Baldi. Deep learning for chemical reaction prediction. Molecular Systems Design & Engineering, 2018.
  • Fout et al. (2017) Alex Fout, Jonathon Byrd, Basir Shariat, and Asa Ben-Hur. Protein interface prediction using graph convolutional networks. In Advances in Neural Information Processing Systems, pp. 6530–6539, 2017.
  • Gilmer et al. (2017) Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. Neural message passing for quantum chemistry. In Proceedings of the International Conference on Machine Learning, 2017.
  • Hamilton et al. (2017) Will Hamilton, Zhitao Ying, and Jure Leskovec. Inductive representation learning on large graphs. In Proceedings of Advances in Neural Information Processing Systems, pp. 1025–1035, 2017.
  • He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    , pp. 770–778, 2016.
  • Jin et al. (2017) Wengong Jin, Connor Coley, Regina Barzilay, and Tommi Jaakkola. Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network. In Advances in Neural Information Processing Systems, pp. 2604–2613, 2017.
  • Jin et al. (2018) Wengong Jin, Regina Barzilay, and Tommi Jaakkola.

    Junction tree variational autoencoder for molecular graph generation.

    International Conference on Machine Learning (ICML), 2018.
  • Jochum et al. (1980) Clemens Jochum, Johann Gasteiger, and Ivar Ugi. The principle of minimum chemical distance (pmcd). Angewandte Chemie International Edition in English, 19(7):495–505, 1980.
  • Kayala & Baldi (2011) Matthew A Kayala and Pierre F Baldi. A machine learning approach to predict chemical reactions. In Advances in Neural Information Processing Systems, pp. 747–755, 2011.
  • Khalil et al. (2017) Elias Khalil, Hanjun Dai, Yuyu Zhang, Bistra Dilkina, and Le Song. Learning combinatorial optimization algorithms over graphs. In Advances in Neural Information Processing Systems, pp. 6348–6358, 2017.
  • Kingma & Ba (2015) Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. International Conference on Learning Representations (ICLR), 2015.
  • Li et al. (2018) Yibo Li, Liangren Zhang, and Zhenming Liu. Multi-objective de novo drug design with conditional graph generative model. Journal of Cheminformatics, 10, 2018.
  • Mnih et al. (2016) Volodymyr Mnih, Adria Puigdomenech Badia, Mehdi Mirza, Alex Graves, Timothy Lillicrap, Tim Harley, David Silver, and Koray Kavukcuoglu. Asynchronous methods for deep reinforcement learning. In International conference on machine learning, pp. 1928–1937, 2016.
  • Nam & Kim (2016) Juno Nam and Jurae Kim. Linking the neural machine translation and the prediction of organic chemistry reactions. arXiv preprint arXiv:1612.09529, 2016.
  • Pham et al. (2017) Trang Pham, Truyen Tran, Dinh Phung, and Svetha Venkatesh. Column networks for collective classification. In Proceedings of AAAI Conference on Artificial Intelligence, 2017.
  • Pham et al. (2018) Trang Pham, Truyen Tran, and Svetha Venkatesh. Graph memory networks for molecular activity prediction. ICPR, 2018.
  • Schlichtkrull et al. (2018) Michael Schlichtkrull, Thomas N Kipf, Peter Bloem, Rianne van den Berg, Ivan Titov, and Max Welling. Modeling relational data with graph convolutional networks. 15th European Semantic Web Conference (ESWC-18), 2018.
  • Schwaller et al. (2018) Philippe Schwaller, Theophile Gaudin, David Lanyi, Costas Bekas, and Teodoro Laino. “found in translation”: Predicting outcome of complex organic chemistry reactions using neural sequence-to-sequence models. Chemical Science, 9:6091–6098, 2018.
  • Segler & Waller (2017) Marwin HS Segler and Mark P Waller. Neural-symbolic machine learning for retrosynthesis and reaction prediction. Chemistry–A European Journal, 23(25):5966–5971, 2017.
  • Shervashidze et al. (2011) Nino Shervashidze, Pascal Schweitzer, Erik Jan van Leeuwen, Kurt Mehlhorn, and Karsten M Borgwardt. Weisfeiler-Lehman graph kernels. Journal of Machine Learning Research, 12(Sep):2539–2561, 2011.
  • Simonovsky & Komodakis (2018) Martin Simonovsky and Nikos Komodakis. GraphVAE: Towards Generation of Small Graphs Using Variational Autoencoders. arXiv preprint arXiv:1802.03480, 2018.
  • Srivastava et al. (2015) Rupesh K Srivastava, Klaus Greff, and Jürgen Schmidhuber. Training very deep networks. In Advances in neural information processing systems, pp. 2377–2385, 2015.
  • Ugi et al. (1979) Ivar Ugi, Johannes Bauer, Josef Brandt, Josef Friedrich, Johann Gasteiger, Clemens Jochum, and Wolfgang Schubert. New applications of computers in chemistry. Angewandte Chemie International Edition in English, 18(2):111–123, 1979.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in Neural Information Processing Systems, pp. 5998–6008, 2017.
  • Wang et al. (2018) Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. Non-local neural networks. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018.
  • Wei et al. (2016) Jennifer N Wei, David Duvenaud, and Alán Aspuru-Guzik. Neural networks for the prediction of organic chemistry reactions. ACS Central Science, 2(10):725–732, 2016.
  • You et al. (2018) Jiaxuan You, Bowen Liu, Rex Ying, Vijay Pande, and Jure Leskovec. Graph convolutional policy network for goal-directed molecular graph generation. NIPS, 2018.
  • Zambaldi et al. (2018) Vinicius Zambaldi, David Raposo, Adam Santoro, Victor Bapst, Yujia Li, Igor Babuschkin, Karl Tuyls, David Reichert, Timothy Lillicrap, Edward Lockhart, et al. Relational deep reinforcement learning. arXiv preprint arXiv:1806.01830, 2018.

Appendix A Appendix

a.1 Graph Neural Network

In this section, we describe our graph neural network (GNN) in detail. Since our GNN does not use the recurrent hidden state , we exclude the time step from our notations for clarity. Instead, we use to denote a message passing step.

Graph notations

Input to our GNN is a graph in which each node is represented by a node feature vector and each edge is represented by an edge feature vector . For example of molecular graph, the node feature vector may include chemical information about the atom such as its type, charge and degree. Similarly, captures the bond type between the two atoms and . We denote by the set of all neighbor nodes of node together with their links to node :

If we only care about the neighbor nodes of not their links, we use the notation defined as:

In addition to , node also has a state vector to store information about itself and the surrounding context. This state vector is updated recursively using the neural message passing method (Battaglia et al., 2016; Pham et al., 2017; Hamilton et al., 2017; Gilmer et al., 2017; Schlichtkrull et al., 2018). The initial state is the nonlinear mapping of :


Computing neighbor messages

At the message passing step , we compute the message from every neighbor node to node as:


where denotes concatenation; and is a nonlinear function.

Aggregating neighbor messages

Then, we aggregate all the messages sent to node into a single message vector by averaging:


where is the number of neighbor nodes of node .

Updating node state

Finally, we update the state of node as follows:


where is a Highway Network (Srivastava et al., 2015):


where is the nonlinear part which is computed as:

and is the gate controlling the flow of information:

By combining Eqs. (19,20,22) together, one step of message passing update for node can be written in a generic way as follows:


a.2 Updating States

Updating RNN state

We keep the old representation of the edge that have been modified in the hidden memory of the RNN as follows:


where GRU

stands for Gated Recurrent Units

(Cho et al., 2014); is the representation vector of the atom pair including its old bond (see Eq. 4). Eq. (25) allows the model to keep track of all the changes happening to the graph so far so it can make more accurate prediction later.

Updating graph structure and node states

After predicting a reaction triple at step , we update the graph structure and node states based on the new bond change. First, to update the graph structure, we simply update the neighbor set of and with information from the other atom and the new bond type as follows:


Next, to update the node states, our model performs one step of message passing for and with their new neighbor sets:


where the function is defined in Eq. (24). For other nodes in the graph to be aware of the new structures of and , we need to perform several message passing steps for all nodes in the graph after Eqs. (28, 29). However, it is very costly to run for every prediction step . Sometimes it is unnecessary since far-away bonds are less likely to be affected by the current bond change (unless the far-way bonds and the new bond are in an aromatic ring). Therefore, in our model, we limit the number of message passing updates for all nodes at step to be .

a.3 Model Configurations

We optimize our model’s hyper-parameters in two stages: First, we tune the hyper-parameters of the GNN and the NPPN for the reaction atom pair prediction task. Then, we fix the optimal settings of the first two components and optimize the hyper-parameters of the PN for the reaction product prediction task.

We provide details about the settings that give good results on the USPTO-15k dataset below. With these settings, we trained another model on the USPTO dataset from scratch. Because training on the large dataset such as the USPTO takes time, we did not tune hyper-parameters on the USPTO, eventhough it is possible to increase model sizes for better performance.

Unless explicitly stated, all neural networks in our model have 2 layers with the same number of hidden units, ReLU activation and residual connections

(He et al., 2016).

Graph Neural Network (GNN)

There are 72 different types of atom depending on their atomic numbers and 5 different types of bond including NULL, SINGLE, DOUBLE, TRIPLE and AROMATIC. The size of embedding vectors for atom and bond are 51 and 21, respectively. Apart from atom type, each atom has 5 more attributes listed in Table 4. These attributes are normalized to the range of [0, 1] and are concatenated to the atom embedding vector to form a final atom feature vector of size 56. The state vector and the neighbor message vector for an atom both have the size of 99. The number of message passing steps is 6.

Atom attribute Data type
Degree numeric
Explicit valence numeric
Explicit number of Hs numeric
Charge numeric
Part of a ring boolean
Table 4: Data types of atom attributes.
Node Pair Prediction Network (NPPN)

This component consists of two parts. The first part computes the representation vector of an atom pair using a neural network with hidden size of 71. The second part maps to an unnormalized score using the function (see Eqs. (3,5)). This function is also a neural network with hidden size of 51.

Policy Network (PN)

The recurrent network is a GRU (Cho et al., 2014) with 101 hidden units. The value function is a neural network with 99 hidden units. The two functions for computing signal scores (see Eq. (6)) and for computing scores over bond types (see Eq. (8)) are neural networks with 81 hidden units.


At each step, we set the reward to be 1.0 for correct prediction of signal/atom pair/bond type and -1.0 for incorrect prediction. After the prediction sequence is terminated (zero signal was emitted), we check whether the entire set of predicted reaction triples is correct or not. If it is correct, we give the model a reward value of 2.0, otherwise -2.0. From the rewards and estimated values for signal, atom pair and bond type, we define the Advantage Actor Critic loss (A2C) as in Eq. (14). The coefficients of components in the final loss are set empirically as follows:

We trained our model using Adam (Kingma & Ba, 2015) with the initial learning rate of 0.001 for both USPTO-15k and USPTO. For USPTO-15k, the learning rate will decrease by half if the Precision@1 does not improve on the validation set after 1,000 steps until it reaches the minimum value of . For USPTO, the decay rate is 0.8 after every 500 steps of no improvement until reaching the minimum learning rate is . The maximum number of training iterations is and the batch size is 20.

a.4 Decoding with Beam Search

For decoding, our model generates a sequence of reaction triples (including the stop signal) by taking the best and at every step until it outputs a zero signal (). In other words, it computes the argmax of at every step . However, this algorithm is not robust for the sequence generation task because just a single error at a step may destroy the entire sequence. To overcome this issue, we employ beam search for decoding.

During beam search, we keep track of best subsequences at every step . is called beam width. Instead of modeling the conditional distribution of generating an output at the current step

, we model the joint distribution of the whole subsequence that has been generated from

to :


Computing all configurations of jointly is very memory demanding, however. Thus, we decompose the first term as follows:

At step , we do beam search for the signal , then the atom pair and finally the bond type . Algorithm 1 describes beam search in detail. Some notable technicalities are:

  • We only do beam search for and if the prediction is ongoing, i.e., when . To keep track of this, we use a boolean vector of length with is initialized to be all true.

  • To avoid beam search favoring short sequences, we normalize the log probability scores over sequence lengths. This is shown in lines 14, 22, 38 and 54

1:A multi-graph consisting of reactant and reagent molecules, number of bond types , max prediction steps , beam width
2:The best subsequences of
3:The length-normalized log joint probabilities of the best subsequences
4:The continuation indicator of the best subsequences
6:Perform steps of message passing for all nodes using Eq. (1)
7:The initial states of all nodes before decoding
8:The initial neighbor set of all nodes before decoding
9: is loaded from the saved modelThe initial RNN hidden state before decoding
11:for  from to  do
12:     Find the top atom pairs using Eqs. (4,5)
14:     Superscript denotes the sub-step 0
15:     ;
17:     Beam search for continuation signals
19:     Stores the log joint probabilities for possible signals
20:     for  from to  do
21:         Compute using Eq. (6)
22:         Add to for
23:     end for
24:     Sort in descending order
26:      output signal of beams in
27:      indices of beams in
33:     Beam search for atom pairs
35:     Stores the log joint probabilities for possible atom pairs
36:     for  from to  do
37:         Compute using Eq. (7)
38:         Add to
39:     end for
40:     Sort in descending order
42:      output atom pair of beams in
43:      indices of beams in
49:     Beam search for bonds
51:     Stores the log joint probabilities for possible bonds
Algorithm 1 Reaction triple prediction using beam search.
52:     for  from to  do
53:         Compute using Eq. (8)
54:         Add to
55:     end for
56:     Sort in descending order
58:      output bond of beams in
59:      indices of beams in
66:     ;
67:      append
69:     for  from to  do
70:         Update the and for all