A benchmark study on reliable molecular supervised learning via Bayesian learning

06/12/2020 ∙ by Doyeong Hwnag, et al. ∙ 0

Virtual screening aims to find desirable compounds from chemical library by using computational methods. For this purpose with machine learning, model outputs that can be interpreted as predictive probability will be beneficial, in that a high prediction score corresponds to high probability of correctness. In this work, we present a study on the prediction performance and reliability of graph neural networks trained with the recently proposed Bayesian learning algorithms. Our work shows that Bayesian learning algorithms allow well-calibrated predictions for various GNN architectures and classification tasks. Also, we show the implications of reliable predictions on virtual screening, where Bayesian learning may lead to higher success in finding hit compounds.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 10

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Predicting molecular properties based on structural input is placed at the heart of computational chemistry. Recently, graph neural networks (GNNs), which deal with molecular structure graphs, have been widely used for molecular machine learning. However, due to the vast size of chemical space and insufficient amount of labeled examples, prediction models often make incorrect predictions on out-of-distribution (OOD) samples. A simple way to improve model performance is to acquire more labeled examples, but it demands expensive and time-consuming assay experiments.

Since it is uncertain to determine true label of OOD samples, it is more desirable for models to give predictions with low confidence (predictive probability). However, neural networks are prone to the problem of over-confident prediction(guo2017calibration), meaning that their predictive probability (confidence) is usually higher than true correctness. For example, in virtual screening to find COVID-19 medications, a compound with 0.9 predictive probability will be considered as having 90% probability of being active, and thus it will be taken into account for experimental validation. However, over-confident predictions will entail unexpectedly large number of false positive/negative predictions, which discourages the reliability of neural networks. Thus, evaluating and improving prediction reliability would be essential for successful virtual screening with ML models.

In that sense, Bayesian learning is an essential choice for virtual screening, which enables to yield reliable neural network models. Recent advances in Bayesian inference

(welling2011bayesian; gal2016dropout; lakshminarayanan2017simple; maddox2019simple)

allow practical approximation for computing posterior and Bayesian marginalization, which is long-standing challenges in applying Bayesian learning to neural networks. To this end, the works in computer vision tasks with well-established benchmark studies

(thulasidasan2019mixup; snoek2019can) has shown that Bayesian learning is beneficial for better generalization to OOD and corrupted samples. Nevertheless, to the best of our knowledge, there is no benchmark study on Bayesian learning for molecular property prediction tasks concerning prediction reliability.

In this work, we present a benchmark study on reliable molecular supervised learning with graph neural networks and Bayesian deep learning. Our work investigate on the effectiveness of the recent Bayesian learning methods on various GNNs and binary classification tasks. We observe that most of the methods are helpful, in particular, stochastic weight averaging (SWA) and its variant (SWAG)

(izmailov2018averaging; maddox2019simple) show consistently good prediction results. Also, we explore models’ prediction behaviors by using the histogram of predictive probability in order to help understanding implications of Bayesian inference on virtual screening. We have released our codes at https://github.com/AITRICS/mol_reliable_gnn to assist urgent needs in molecular machine learning, such as discovery of COVID-19 medications.

2 Backgrounds and Methods

2.1 Evaluating prediction reliability

Many problems in chemistry applications are given by (binary) classification tasks, e.g. whether the molecule is toxic or not, and whether the molecule is biologically active or not. Our goal is to develop classification systems whose output can be interpreted as probability (or confidence

) of correct prediction. To do so, we evaluate prediction reliability by estimating

expected calibration error (ECE) (guo2017calibration), given by

(1)

Low ECE means that predictive probability value corresponds to true probability of correctness. We refer to guo2017calibration for more precise definition of ECE.

2.2 Bayesian learning

A primary goal of Bayesian learning is to infer the posterior distribution of model parameters given dataset . Then, the predictive distribution of output given new input can be computed by Bayesian marginalization:

(2)

On the other hand, maximum-a-posteriori (MAP) estimation gives the model parameter as the mode of posterior distribution:

(3)

Eq. 2

can mitigate over-confident predictions, which are frequently observed in MAP-estimated models, via marginalizing over all possible model parameters drawn from the posterior. Also, predictive uncertainty can be estimated by computing the variance of predictive distribution. Note that the predictive uncertainty is given by

, where , for binary classification problems.

Since exact computation of posterior is mostly intractable for neural networks, a variety of approximate Bayesian inference methods have been proposed. We consider MAP-estimation, Deep Ensemble (lakshminarayanan2017simple), Monte Carlo dropout (MC-DO) (gal2016dropout), Stochastic Gradient Langevin Dynamics (SGLD) (welling2011bayesian), Stochastic Weight Averaging (SWA) (izmailov2018averaging), and Stochastic Weight Averaging Gaussian (SWAG) (maddox2019simple) as our baselines to demonstrate effectiveness of Bayesian learning in reliable molecular property predictions. Further details on our Bayesian learning implementations are provided in Appendix A.

2.3 Graph neural networks

We utilize graph neural networks (GNNs) to handle molecular graph structure inputs. Various GNN architectures have been proposed, and the recent work by dwivedi2020benchmarking has performed ablation studies on node, edge and graph prediction tasks. We have modified their released code111https://github.com/graphdeeplearning/benchmarking-gnns and implemented Bayesian learning algorithms. For our study, we utilize Graph Convolutional Network (GCN) (kipf2016semi), GraphSAGE (hamilton2017inductive), Graph Isomorphism Network (GIN) (xu2018powerful), Graph Attention Network (velivckovic2017graph), and Gated Graph Convolutional Network (GatedGCN) (bresson2017residual). Further elaboration on implementation is provided in Appendix B.

3 Experiments

BBBP BACE HIV Tox21
Single Ensemble Single Ensemble Single Ensemble Single Ensemble
None 17.9 4.8 15.7 4.8 19.3 7.0 15.4 5.7 1.5 0.4 1.5 0.3 9.6 1.5 8.0 1.3
MC-DO 14.9 4.8 15.5 3.9 13.5 4.8 15.5 5.3 0.9 0.2 1.6 0.3 9.7 1.5 8.6 1.4
BBB 14.3 3.7 12.9 2.8 12.9 3.3 12.6 3.6 3.0 0.4 2.4 0.4 9.4 1.4 8.4 1.3
SGLD 14.9 4.3 14.2 4.7 13.1 3.6 12.3 2.8 3.0 0.4 2.5 0.3 9.5 1.3 8.4 1.3
SWA 7.1 2.8 7.0 2.8 8.6 1.4 8.5 2.7 0.9 0.2 1.2 0.3 3.8 1.1 3.7 1.0
SWAG 6.9 2.5 7.0 3.1 8.2 2.0 8.8 2.7 1.0 0.2 0.9 0.3 3.7 1.0 3.6 1.0
Table 1: ECE(%,

) of various Bayesian approaches on BBBP, BACE, HIV, and Tox21 prediction tasks. We report mean and standard deviation of results from eight different experiments with scaffold-splitting.

BBBP BACE HIV Tox21
Single Ensemble Single Ensemble Single Ensemble Single Ensemble
None 82.7 6.1 85.0 5.5 79.3 6.4 81.7 5.2 75.2 3.3 76.2 3.0 73.5 4.0 75.7 3.9
MC-DO 83.6 5.7 85.1 5.5 80.4 6.0 81.8 5.0 74.6 2.9 76.4 2.9 74.0 3.9 75.5 4.0
BBB 86.9 3.7 88.2 3.7 81.1 5.1 82.1 4.7 72.9 3.3 74.9 2.8 74.0 4.2 75.2 4.0
SGLD 85.0 4.7 86.7 5.3 81.2 5.1 81.2 5.1 72.7 3.4 75.0 2.9 73.9 3.8 75.3 3.9
SWA 91.2 4.1 91.5 3.4 81.4 3.6 81.2 3.6 74.1 3.0 76.2 2.7 78.7 3.6 79.1 3.5
SWAG 91.2 4.2 91.5 3.4 81.5 3.6 81.2 3.6 73.7 3.1 75.1 3.0 78.8 3.7 79.0 3.6
Table 2: AUROC(%, ) of various Bayesian approaches on BBBP, BACE, HIV, and Tox21 prediction tasks. We report mean and standard deviation of results from eight different experiments with scaffold-splitting.

In this section, we present the experimental results of using Bayesian algorithms for molecular property prediction tasks with several GNN models. We elaborate on the details on dataset information – number of training examples, ratio between positive and negative examples – in Appendix C, and training configurations in Appendix D. Since the molecular datasets in this work are highly sparse and imbalanced, we ran experiments with eight different random seeds and scaffold-splitting (wu2018moleculenet) of each dataset for training, validation, and test.

3.1 Comparison of Bayesian learning algorithms

In Table 1 and 2, we show the prediction results for the four different prediction tasks – BACE, BBBP, HIV, and Tox21 predictions – where GIN is set as the baseline model architecture and various Bayesian learning methods (Ensemble, MC-DO, BBB, SGLD, SWA, and SWAG) are used. We observe that all Bayesian methods are helpful for improving both prediction reliability (lower ECE) and performance (higher AUROC). These results indicate that Bayesian learning approaches are beneficial for improving generalization ability in molecular property prediction tasks. In particular, SWA and SWAG show superior performance when compared to the other Bayesian approaches for the most cases. We provide additional prediction results (Accuracy, Precision, Recall, and F1-score) in Figure 3 (see Appendix E).

Also, we attempted to check whether the Bayesian learning methods are effective for the other GNN architectures as well as GIN. In Figure 4 (see Appendix E

), we show the prediction performance and reliability of the five GNN models (i.e. GCN, GIN, GraphSAGE, GAT, and GatedGCN) trained with different Bayesian learning methods on the BACE prediction task. We observe that SWA and SWAG consistently show better performance and reliability results than MAP. On the other hand, other methods show worse results for some GNN models than MAP, for example, Ensemble, MC-DO, and BBB show higher ECE than MAP. Specifically, BBB give poor results for GCN and GatedGCN – showing significantly deteriorated accuracy, recall, and F1-score, despite the fact that we adopted scaling factor to the Kullback-Leibler divergence term in the learning objective of BBB, which can be interpreted cold-posterior, as described in

wenzel2020good. (see Appendix A

for more details) We conjecture the reason for such results from the training sensitivity of BBB according to the choice of hyperparameters (e.g. prior length scale). We leave deeper investigation of BBB on molecular prediction as future work.

3.2 Using Deep Ensemble additionally improves Bayesian learning

wilson2020bayesian proposed that the ensemble of SWAG (Multi-SWAG), which uses the ensemble of variational posterior in order to model multi-modal posterior, can improve the single SWAG. Motivated by the Multi-SWAG, we compare ECE and AUROC results of using single Bayesian models and the Ensemble of Bayesian models on the four prediction tasks, shown in Table 1 and 2. Using ensemble additionally improve both prediction reliability and performance for all Bayesian approaches in most cases, but its amount is relatively smaller than the improvement gain from using SWA/SWAG. Thus, we conclude that using the ensemble of SWA/SWAG would be the best choice to accomplish both high prediction performance and reliability as long as enough computing resource is secured.

3.3 Prediction behavior of Bayesian models - implications on virtual screening

Figure 1: The histograms of true positive (TP), false positive (FP), true negative (TN), and false negative (FN) predictions from the GIN trained with MAP (top) and SWAG (bottom).
Figure 2: The histograms of predictive probability of BACE-activity from the GIN models trained with MAP (left) and SWAG (right), respectively.

Over-confident prediction behavior is frequently observed in neural networks, especially in MAP-estimated models. As shown in Figure 1, we can confirm that most prediction results from the MAP-estimated models are positioned near zero or one. On the other hand, SWAG models effectively mitigate over-confident predictions, which were quantitatively evaluated by using ECE as shown in Table 1 – much smaller number of predictions and higher ratio between true positive/negative and false positive/negative near zero or one. We show the results obtained with other Bayesian learning methods in Figure 5, 6, and 7 (see Appendix E).

In addition, we imitated the virtual screening experiment – using the model trained with the BACE dataset to screen BACE-active compounds from the lead-like subset of ZINC database (irwin2005zinc). Though we do not have true label of the compounds in the ZINC dataset, most of them might be out-of-distribution against the samples in the BACE dataset. Figure 2 shows the histogram of predictive probability for the ZINC compounds inferred by the models trained with BACE dataset. We observe that most of prediction results from the MAP model are positioned near zero or one, which might be unnatural for predictions on OOD samples. On the other hand, the SWAG model shows small number of predictions in which probability is less than or greater than over entire negative or positive predictions, respectively. In specific, only 238 compounds have predictive probability value greater than out of total samples. Considering the situation of selecting the experimental candidates using the output probability values, for example selecting compounds with output probability higher than , our demonstration implies that Bayesian approaches may be able to give higher success rate than non-Bayesian approach in virtual screening.

4 Conclusion

In this work, we have presented the benchmark study on the reliable molecular prediction models developed by using Bayesian learning methods. Our demonstrations show that the recent Bayesian learning methods are notably beneficial for obtaining well-calibrated prediction results, which would be essential for virtual screening with the final output of neural networks. We expect our study to be utilized for other applications which can be benefited by the Bayesian principles, such as active learning and continual learning in molecular tasks.

Acknowledgements

This work was supported by the National Research Foundation of Korea (NRF) grant funded by the project NRF-2019M3E5D4065965.

References

Appendix A Backgrounds and implementations of Bayesian learning

This section describes brief backgrounds on Bayesian learning methods. Also, we provide notes on the implementation of our Bayesian approaches in the following subsections. We will use the following notations:

  • : likelihood function

  • : prior on model weights

  • : (true) posterior distribution

  • : variational distribution, where is the variational parameter.

a.1 Deep Ensemble (lakshminarayanan2017simple)

Deep Ensemble combines the outputs from multiple models, where each of which is trained with different random initialization seed. fort2019deep analyzed Deep Ensemble as Bayesian approach, showing that each single model in Deep Ensemble corresponds to the different modes of multi-modal posterior distribution and combining the predictive outputs from the different models can be interpreted as Bayesian marginalization. Furthermore, wilson2020bayesian proposed to combine Deep Ensemble with variational methods in order to enhance expressive power on posterior, since variational distribution approximates a single modality of posterior distribution.

We used 10 different random initialization seeds for ensembling in practice.

a.2 Monte Carlo Dropout (MC-DO) (gal2016dropout)

MC-DO enables efficient practice of approximate Bayesian learning, whose training and inference procedures do not requires significant modification of standard dropout models. The predictive probability of MC-DO for input is obtained by MC-sampling the final outputs computed by different model weights generated by using stochastic dropout masks:

(4)

where is the number of MC-sampling.

For our MC-DO implementation, we used residual dropout in every -th GNN layer:

(5)

where is the -th node updating layer (more details will be described in Appendix B) and is dropout rate. We used and for our experiments.

a.3 Bayes By Backprop (BBB) (blundell2015weight)

Variational Bayes aims to minimize the Kullback-Leibler (KL) divergence between the two distributions to model the true but intractable posterior with variational posterior:

(6)

where the R.H.S is also referred to as evidence lower bound (ELBO). blundell2015weight

assumed Gaussian variational posterior and derived the analytic expression of KL-divergence term, leading to obtain variational distribution by using backpropagation.

We minimized the analytic expression of negative ELBO term in blundell2015weight, but multiplied a factor of to KL divergence between the variational posterior and the prior, which was significantly helpful for training various GNNs with BBB. We refer to blundell2015weight and wenzel2020good

for more details on re-weighting of the KL-term. We used the open-source python library ‘Blitz - Bayesian Layers in Torch Zoo’

222https://github.com/piEsposito/blitz-bayesian-deep-learning for implementing BBB. Also, we used single Gaussian prior, while blundell2015weight proposed Gaussian mixture prior.

a.4 Stochastic Gradient Langevin Dynamics (SGLD) (welling2011bayesian)

SGLD can be thought as connecting Monte-Carlo Markov Chain (MCMC) and stochastic gradient descent – weight transition in MCMC is modeled by the stochastic gradient of posterior. At time step

, the updating rule for weights drawn from posterior is given by

(7)

where is step size (learning rate), is the number of training examples, is the size of mini-batches, is the mini-batch of training examples chosen from the dataset , and is the Gaussian noise.

For our implementation, we followed the pre-conditioned SGLD(li2016preconditioned)

, which is known for more stabilized training with SGLD. We did not sample model weights for the first 100 epoches (Burn-in steps), and sampled weight for every two epoch in the sampling steps of 100 epochs after the Burn-in steps.

a.5 Stochastic Weight Averaging (SWA) (izmailov2018averaging) and Stochastic Weight Averaging Gaussian (SWA) (maddox2019simple)

SWA and SWAG sample model weights updated by stochastic gradient descent (SGD) for Bayesian learning, whose theoretical foundation connects the dynamics of weight parameters on loss surface and Bayesian posterior. Both algorithms consist of two steps: i) preconditioning step for updating model parameters to the (sub-)optimal point of loss surface , and ii) sampling step for sampling the weight parameters near the (sub)-optimal point generated by SGD optimizer.

Those two approaches followed the same two steps, but SWA used Polyak-Ruppert weight averaging(polyak1992acceleration) for obtaining the final model weight and SWAG approximates the variational distribution with the mean and covariance of the sampled weights.

Appendix B Backgrounds and Implementations of graph neural networks

This section describes brief introduction to graph neural networks (GNNs) and our implementations of the GNNs studied in this work. We consider molecular graph whose node features are for and edge features are for .

The -th GNN layer updates the -th node features from to , where , and its updating rule is given by

(8)

where is the set of nodes adjacent to the -th node. is the -th node updating layer whose formalism will be described in the following subsections. In this work, we used same dimension for the all GNN layers’ outputs, i.e. for all .

After applying total node updating layers, the readout layer aggregates the node features to produce the graph feature :

(9)

where

is a weight parameter for linear transformation.

Finally, the predictive label is given by

(10)

where and

are the weight and bias parameters of the linear classifier.

b.1 Graph Convolutional Network (GCN) (kipf2016semi)

GCN aggregates adjacent nodes’ features and multiplies a weight parameter for updating node features:

(11)

where is a weight parameter.

b.2 Graph Isomorphism Network (GIN) (xu2018powerful)

The original node updating formalism of GIN- is given by

(12)

where is the learnable parameter or fixed number, and are weight parameters. We let , then the eqn. 12 is reduced to

(13)

We can see that the difference between the second line of eqn. 11 and eqn. 13

is whether using one-layer perceptrons or two-layer perceptrons. Note that we did not use batch normalization

(ioffe2015batch) for our GIN, in contrast to the implementation in dwivedi2020benchmarking.

b.3 Graph Sample and Aggregate (GraphSAGE) (hamilton2017inductive)

Our implementations of GraphSAGE updates the node representations with the following equation

(14)

where is a weight parameter. We note that the hamilton2017inductive proposed mean, sum, max and LSTM aggregation, and we adopted the sum aggregation among them. Also, we did not normalize the node features by dividing with their L2-norm, since it can lead to fail graph isomorphism test.

BACE BBBP HIV Tox21
Task type Binary classification
Number of samples 1,513 2,050 41,127 7,831
Positives:Negatives 822:691 483:1,567 39,684:1,443 -
Number of tasks 1 1 1 12
Table 3: Specifications of the datasets used in this work

b.4 Graph Attention Network (GAT) (velivckovic2017graph)

The former GNNs, i.e. GCN, GIN, and GraphSAGE, can be categorized as isotropic node updating methods, in that neighbor nodes’ features are aggregated with equal importance. On the other hand, GAT adopts multi-head attention mechanism(bahdanau2014neural; vaswani2017attention) for anisotropic node updating, in which neighbor nodes’ features are aggregated with learned attention coefficient. The updating formalism of GAT is given by

(15)

where is the number of attention heads, , and the attention coefficient is defined as

(16)
(17)

where is a weight parameter. Note that we set the number of attention heads as 4.

b.5 Gated Graph Convolutional Network (GatedGCN) (bresson2017residual)

GatedGCN is another anisotropic node updating method, which utilizes edge features for updating the edge gates and multiplies them by the neighbor nodes’ feature.

(18)

where denotes the element-wise multiplication, and are weight parameters. The gate coefficient is given by

(19)
(20)

where are weight parameters and . We again notify that we did not use batch normalization for our implementation of GatedGCN.

Appendix C Details on molecular datasets

We describe the specifications of the datasets used for training the models in Table 3. We downloaded all the four datasets at the MoleculeNet homepage333http://moleculenet.ai/datasets-1. Scaffold-splitting to training, validation and test sets by 80:10:10 ratio is applied for each dataset. Note that the Tox21 dataset consists of 12 different binary classification tasks, and each task has different number of positive and negative samples.

Appendix D Details on model training

In this section, we describe the hyperparameter settings used for the implementation of GNNs and Bayesian learning methods.

For all GNN models, we used the dimension of node features ( in eq. 8) as 128, and the dimension of graph feature ( in eq. 9) as 256, and the number of node updating layers as 4.

Since SGLD, SWA, and SWAG utilize gradient descent update for sampling weights from the posteriors, we used different optimizer and learning rate scheduling for different Bayesian learning methods. For MAP, Ensemble, MC-DO, and BBB, we used Adam optimizer and trained models for 200 epochs with initial learning rate of , which is decayed by the factor of at the 80- and 160-th epoch. For SWA and SWAG, we used SGD optimizer and trained models for 250 epochs with initial learning rate as 0.1. Preconditioning step is set to 150 epochs. the learning rate is constantly dropped to 0.01 from 75 epoch to 150 epoch, in which before cyclic learning rate is applied. Then, as sampling step starts, cyclic learning rate is applied in between 0.01 and 0.001, following the method proposed in (garipov2018loss). Model weights were collected for every 4 epochs during the sampling step. For SWAG, scaling factor applied on SWAG posterior covariance is set to 1.0.

For the setting of prior distribution, we adopted explicit Gaussian prior for BBB, and weight decay coefficient of for the others.

Lastly, we sampled 30 model weights and averaged the output from them for Bayesian marginalization in MC-Dropout and SWAG. For BBB, we sampled 5 and 100 model weights for training and evaluating the model.

Appendix E Additional experimental results

In this section, we show the following additional results supporting the main text:

  • Figure 3 shows the prediction reliability and performance of the GIN model for the four molecular property prediction tasks – BACE, BBBP, HIV, and Tox21 prediction tasks.

  • Figure 4 shows the prediction reliability and performance of various GNN models for the BACE prediction tasks.

  • Figure 5, 6, and 7 show the histogram of predictive probability categorized by true positive, false positive, true negative, and false negative predictions for the BACE, BBBP, and HIV prediction tasks.

Figure 3: The prediction reliability (ECE; ) and performance (accuracy, AUROC, precision, recall and F1-score; ) of the GIN model for the BACE, BBBP, HIV, and Tox21 prediction tasks. We report mean and standard deviation of results from eight different experiments with scaffold-splitting of the datasets.
Figure 4: The prediction reliability (ECE; ) and performance (accuracy, AUROC, precision, recall and F1-score; ) of the five different GNN models on the BACE prediction task. We report mean and standard deviation of results from eight different experiments with scaffold-splitting.
Figure 5: The histogram of true positive (TP), false positive (FP), true negative (TN), and false negative (FN) results for the BACE prediction task.
Figure 6: The histogram of true positive (TP), false positive (FP), true negative (TN), and false negative (FN) results for the BBBP prediction task.
Figure 7: The histogram of true positive (TP), false positive (FP), true negative (TN), and false negative (FN) results for the HIV prediction task.