Reliable uncertainty estimate for antibiotic resistance classification with Stochastic Gradient Langevin Dynamics

11/27/2018 ∙ by Md-Nafiz Hamid, et al. ∙ Iowa State University of Science and Technology 0

Antibiotic resistance monitoring is of paramount importance in the face of this on-going global epidemic. Deep learning models trained with traditional optimization algorithms (e.g. Adam, SGD) provide poor posterior estimates when tested against out-of-distribution (OoD) antibiotic resistant/non-resistant genes. In this paper, we introduce a deep learning model trained with Stochastic Gradient Langevin Dynamics (SGLD) to classify antibiotic resistant genes. The model provides better uncertainty estimates when tested against OoD data compared to traditional optimization methods such as Adam.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Antibiotic resistance is a global scourge that is taking an increasing toll in mortality and morbidity, in both nosocomial and community acquired infectionsNeu1992Crisis; who2104amr. A growing number of once easily treatable infectious diseases such as tuberculosis, gonorrhea, and pneumonia are becoming harder to treat as the scope of effective drugs is shrinking. The CDC estimates that 2,000,000 illnesses and 23,000 people die annually from antibiotic resistance in the US alonecdcamr. The overuse of antibiotics in health care and agriculture is exacerbating the problem to the point that the World Health Organization is considering antibiotic resistance “one of the biggest threats to global health, food security and human development today”. Identifying genes associated with antibiotic resistance is an important first step towards dealing with the problemBrown2016Antibacterial, and providing a narrow-spectrum treatment, targeted solely against the types of resistance displayed. This statement is especially true when dealing with genes acquired from human or environmental metagenomic samplesPerry2014Antibiotic. A rapid identification of the class of antibiotic resistance that may exist in a given environmental or clinical microbiome sample can provide immediate guidance to treatment and prevention.

In this study, we developed a deep neural network that can predict antibiotic resistance into 15 classes from protein sequences. It can be useful in identifying metagenomic sample resistance for the purpose of providing a focused drug treatment. Traditional methods

kleinheinz2014applying; davis2016antimicrobial; pal2016structure to identify antibiotic-resistant genes usually take a alignment based best-hit approach which causes the methods to produce many false negatives arango2018deeparg. Recently, a deep learning based approach was developed that used normalized bit scores as features that were acquired after aligning against known antibiotic resistant genes arango2018deeparg

. In contrast, our model only uses the raw protein sequence as its input. At the same time, neural networks are known for providing high confidence scores on inputs that are from a different probability distribution than the model was trained on

palacci2018scalable; choi2018generative. This can result in disastrous consequences in sensitive applications such as health care or self-driving systems. Here, we develop two deep learning models that were trained with ADAM and SGLD. Both models give significant accuracy on the test set in terms of predicting antibiotic resistance solely from the protein sequence. But we show that the model trained with SGLD is better equipped to predict OoD data i.e., it assigns a low probability to sequences from proteins that are not related to antibiotic-resistance or are from classes that were not included in training.

2 Dataset and Model

Dataset

We used the dataset curated in the DeepArg study arango2018deeparg. Briefly, The dataset was created from the CARD jia2016card, ARDB liu2008ardb and UNIPROT uniprot2018uniprot databases with a combination of computational and manual curation. The original dataset has 14974 protein sequences that are resistant to 34 different antibiotics (our classes in the multi-class classification task). There were 19 classes that had training samples of 11 sequences or less. We discarded these classes and were left with 15 classes with a total of 14907 protein sequences.

Model

We used a self-attention based sentence embedding model introduced in Lin_Feng_Santos_Yu_Xiang_Zhou_Bengio_2017

. For input, we represented each amino acid in a protein sequence as a size 10 embedding that was randomly initialized, and then trained end-to-end. We used one single layer of LSTM with 64 units and a dropout value of 0.7. Following that is the self-attention part which we can think of as a feed-forward neural network with one hidden layer of 600 units. This network takes the output from the LSTM layer as input, and produces an output of size 100. We weighted this output with a softmax layer which outputs our attentions. We multiplied the outputs of the LSTM layer with these attentions to get a weighted view of the LSTM hidden states. The result of this multiplication became our sentence embedding for that specific protein sequence.

Optimization

Typically, neural networks are trained with optimization methods such as Stochastic gradient descent (SGD)

robbins1985stochastic or its variants such as Adam kingma2014adam, Adagrad duchi2011adaptive

, RMSprop

tieleman2012lecture etc. In SGD, for each iteration a mini-batch from the dataset is used to update the parameters of the neural network. For each iteration , training data is provided, and for parameters , the update is:

(1)

At the same time, SGD or its variants do not capture parameter uncertainty. In contrast, Bayesian approaches such as Markov Chain Monte Carlo (MCMC)

robert2004monte techniques do capture uncertainty estimates. One such class of techniques are Langevin dynamics roberts2002langevin which inject Gaussian noise into Equation 1 so that the parameters do not collapse into the Maximum a posteriori (MAP) solution:

(2)

However, MCMC techniques require that the algorithm go over the entire dataset per iteration before making a parameter update. This slows down the model training process, and also requires huge computational costs. To remove this problem, Stochastic Gradient Langevin Dynamics (SGLD) was introduced welling2011bayesian

, which combined the best of both worlds i.e. inserting Gaussian noise into each mini-batch of training data. In SGLD, during each iteration for SGD, Gaussian noise is injected which has a variance of the step-size

:

(3)

This injection of Gaussian noise has an advantageous side-effect, as it also provides a better calibration of confidence scores of predictions on OoD data. For example, palacci2018scalable showed that an SGLD trained neural network provides low confidence scores when trained on the MNIST lecun2010mnist dataset but tested on the NotMNIST dataset bulatov2011notmnist; whereas an SGD trained neural network still naively provides high confidence scores. We used SGLD to train a neural network to classify protein sequences into their antibiotic resistance classes. In the experiment section, we show that an SGLD trained network provides low confidence scores when predicting on OoD protein sequences while an ADAM trained model still provides high confidence scores.

3 Experiment

The model that is trained with ADAM has the same self-attention architecture as the model used for SGLD training except it has 3 bi-directional LSTM layers. We used a learning rate of 0.001 with a weight decay value of 0.0001.

We divided our dataset into a 70/20/10% training, validation, and test set split. We trained our model with SGLD on the training dataset, and tuned the hyper-parameters by checking the performance on the validation dataset. Testing on the test dataset was done only once.

Table 1 shows the performance of both SGLD and ADAM trained models on the test set in terms of Precision, Recall and for each class and overall. We show that overall the ADAM trained model is performing better than the SGLD trained model.

SGLD trained model ADAM trained model
Antibiotics Precision Recall F1 Precision Recall F1
Number of data points
in Test set
Multidrug 0.68 0.81 0.74 0.84 0.92 0.88 109
Beta Lactam 0.97 0.93 0.95 0.99 0.96 0.98 519
Aminoglycoside 0.82 0.82 0.82 0.90 0.97 0.93 87
Rifampin 1.00 0.67 0.80 1.00 0.67 0.80 3
Tetracycline 0.68 0.70 0.69 0.86 0.70 0.78 27
Quinolone 0.75 0.92 0.83 0.80 0.92 0.86 13
Macrolide
lincosamide streptogramin
0.93 0.85 0.89 0.95 0.94 0.95 111
Fosfomycin 0.90 0.93 0.92 1.00 0.93 0.96 29
Polymyxin 0.97 0.97 0.97 1.00 0.99 0.99 90
Chloramphenicol 0.78 0.83 0.80 0.98 0.89 0.93 47
Bacitracin 0.99 0.96 0.98 0.99 0.98 0.99 421
Kasugamycin 1.00 1.00 1.00 1.00 1.00 1.00 3
Trimethoprim 0.83 0.62 0.71 0.78 0.88 0.82 8
Sulfonamide 1.00 1.00 1.00 1.00 1.00 1.00 2
Glycopeptide 0.49 0.82 0.61 0.53 0.91 0.67 22
Overall 0.92 0.91 0.91 0.96 0.95 0.96
Table 1: Comparison between SGLD and ADAM trained models for 15 different classes of antibiotic resistance.

In the next step, we tested both models on OoD samples. For this we used the 19 classes of antibiotic resistant classes we did not include in our dataset used for training and testing. These 19 classes have a total of 67 protein sequences. We also used about 19,000 human genes that we can confidently assume that these are not classified as antibiotic resistant.

Before testing on these sequences, our expectation is that an ideal model trained on a different set of classes (the 15 antibiotic classes used in training and testing in our case) should provide low probabilities for its prediction on out-of-class sequences. The model should also provide low probabilities for human genes that are not antibiotic resistant. Figure 1

shows the distribution of probability scores of predictions for both the SGLD and ADAM trained model on these sequences. We can see from the figure that in both cases the probability distribution for SGLD is centered around 0.5 whereas for ADAM the distribution is heavily right skewed. The ADAM trained model is still predicting these OoD sequences to be one of the 15 classes it was trained on with high confidence. In contrast, the SGLD model is conveying its uncertainty over its predictions.

(a) Probability of predictions on the 67 protein sequences that are from antibiotic resistant classes the models were not trained on
(b) Approximately 19,000 human genes both models did predictions upon. These are not antibiotic resistant genes. SGLD trained neural networks predict antibiotic resistance with a much lower probability.
Figure 1: Probabilities assigned to predictions by both SGLD and ADAM trained models. The SGLD trained method predicts low probability for antibiotic resistance, both for classes not trained on (a) , and for genes not associated with antibiotic resistance (b).

4 Discussion

In this study we applied a training optimization method for neural networks which calibrates the prediction probability scores such that OoD samples are assigned low probabilities. We used this SGLD trained neural network for a multi-class classification task of antibiotic resistance type classification from protein sequences. We trained our neural network on 15 classes of antibiotic resistant proteins. We also trained another ADAM trained neural network on these same 15 classes of antibiotic resistant proteins. The overall

score for the ADAM trained model (96%) was higher than the SGLD trained model (91%) model. Yet, when we tested both neural networks on two datasets of protein sequences that we know either belong to classes of antibiotic resistance that were not part of our training and testing or are not antibiotic resistance associated, the ADAM trained model still predicted them to be of the 15 classes with a high probability distribution. In contrast, for the SGLD trained model provided predictions with a lower probability distribution for the proteins not associated with antibiotic resistance. We hypothesize that the Gaussian Noise introduced in the SGLD training scheme impedes the neural networks to completely collapse on the Maximum Likelihood solution. That may also be the reason that training a neural network with SGLD towards convergence is difficult when compared with a neural network trained with ADAM and weight decay. However, SGLD lets a discriminative model detect OoD data points, and consequently provide lower probabilities in its predictions for them. This is an important property, especially when we consider the open world problem in biology where for any classification task it is hard to collect negative training samples for training the machine learning algorithm

Dessimoz2013CAFA. One avenue of future research is to investigate how to increase the accuracy of SGLD like training optimization methods. This might involve changing the structure of the noise we are introducing.

5 Funding

The research is based upon work supported, in part, by the Office of the Director of National Intelligence (ODNI), Intelligence Advanced Research Projects Activity (IARPA), via the Army Research Office (ARO) under cooperative Agreement Number W911NF-17-2-0105, and by the National Science Foundation (NSF) grant ABI-1458359. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of the ODNI, IARPA, ARO, NSF, or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Governmental purposes notwithstanding any copyright annotation thereon.

References