Recent advances in machine learning have come in the form of complex models that humans struggle to interpret. In response to the black-box nature of these models, a variety of recent work has focused on model interpretability[Guidotti:2018:SME:3271482.3236009]. One particular line of work that has gained much attention is that of feature attribution methods. Given a model and a specific prediction made by that model, these methods assign a numeric value to each input feature of the model, indicating how important that feature was for the given prediction. A variety of such methods have been proposed, and previous work has focused on how such methods can be used to gain insight into model behavior in applications where model trust is critical [zech2018variable, sayres2019using]
A natural question to ask is why a given feature was assigned a specific attribution value for a given prediction. In some settings it is easy for a human to evaluate the sensibility of attribution values; for example, in image classification problems we can overlay attribution values on the original image. However, in many other domains we do not have the ability to assess the validity of attribution values so easily. Recent work [ross2017right, AttributionPriors, rieger2019interpretations] has attempted to address this problem by regularizing model explanations to agree with prior human knowledge. For example, Enrion et al. [AttributionPriors] propose to regularize image classification models tasks by encouraging the attributions of neighboring pixels to be similar, thereby leading to smoother attribution maps. While such regularization methods do indeed lead to noticeably different feature attributions, they suffer from two shortcomings. First, they require a human expert with prior knowledge about the problem domain to construct a rule for the regularization. Second, after a set of feature attributions is produced, we still lack a method for obtaining insight into the reasons behind a given feature attribution beyond simply trusting in the regularization procedure.
Moreover, the complexity of deep learning models makes them especially prone to overfitting when trained on small datasets. This limitation has stalled the adoption of deep learning methods in domains where data acquisition is difficult, such as many areas of medicine [chen2019deep].
In certain problem domains we have access to second-order meta-features that characterize the input features used in the prediction task. Such meta-features can potentially encode information on how important a given feature is to the task at hand. By using such meta-features to bias models to focus only on “important” features, we should be able to achieve both greater model accuracy and interpretability. However, we may not know a priori how the meta-features correspond to prediction-level feature importance. As a first step to solving this problem, we propose the learned attribution prior framework, a novel method for model training that simultaneously learns a relationship between meta-features and feature attribution values, and biases the original model’s predictions to agree with this learned relationship. Empirically we find that models trained using the learned attribution prior framework achieve better performance on low-data tasks with both synthetic and real-world data. Furthermore, we demonstrate how such learned meta-feature to feature relationships can lead to more interpretable feature attributions.
2 Learned Attribution Priors
In this section we introduce a formal definition of a learned attribution prior by extending that of the attribution prior originally proposed in [AttributionPriors]. Let denote a training dataset consisting of samples each of which has features. Similarly, let denote the labels for the samples in our training dataset. In a standard model training setting, we wish to find the model in some model class
that minimizes the average prediction loss on our dataset, as determined by some loss function. In order to avoid overfitting on training data, a regularization function on the prediction model’s parameters is often included, giving the equation
For a given feature attribution method , an attribution prior is defined as some function that assigns scalar-valued penalties to ’s feature attributions for the input . This leads us to the objective
Now suppose that each feature in
is associated with some scalar valued meta-feature (e,g, the feature’s number of neighbors in some graph). We can represent these meta-feature values as a vector, where the ’th entry in represents the value of our meta-feature corresponding to the ’th feature in the prediction problem. If we believe that the values in should correspond with ’s feature attributions, a natural choice for would be a penalty on the difference between our model’s feature attributions and the values in , giving us
where is the norm. Extending this line of reasoning, suppose we have some meta-feature matrix , corresponding to each of the features in the prediction problem having values for meta-features. If we were to assume that there is a linear relationship between a feature’s meta-feature values and that feature’s attribution values, we can modify (1) to achieve
where refers to the ’th column of , and
is a vector of hyperparameters that we can optimize. Finding an optimal set of hyperparameters gives us a model that captures the meta-feature to feature-attribution relationship for a given problem, which we denote as alearned attribution prior. More generally, we define a learned attribution prior to be a function in some model class whose parameters are learned to predict an attribution value for a feature based on its meta-feature values. As either the number of meta-features increases (e.g. due to an increase in prior information) or becomes richer (e.g. so that it can capture more complex, nonlinear relationships), leading to more model parameters, learning an attribution prior ’s parameters by treating them as hyperparameters becomes computationally prohibitive. Instead, we jointly learn a prediction model and attribution prior pair by optimizing the following objective
where we define to be the vector in resulting from feeding each row of into (i.e., is the vector of predicted feature attribution values for all of our prediction features). In the case where and
are classes of multilayer perceptrons, we are able to approximate a solution to this problem by alternating between one step of optimizing, and steps of optimizing
via stochastic gradient descent. This procedure is presented formally in Algorithm1. So long as and change slowly enough, we find that Algorithm 1 accomplishes the following two goals:
Learn, without the input of human domain experts, a relationship between meta-features and feature relevance to a given prediction task.
Bias prediction models to rely on the features deemed important by this learned meta-feature to feature importance relationship.
We show in Section 3
that introducing such bias in the training procedure improves the performance of neural network models in low-data settings. Furthermore, we find in Section4 that we can exploit these learned meta-feature to feature importance relationships to achieve new insights into our prediction models.
3 Training with Learned Priors Improves Accuracy in Low-Data Settings
3.1 Two Moons Classification With Nuisance Features
We first evaluate the performance of multi-layer perceptrons (MLPs) with two hidden layers trained using the framework on the two moons with nuisance features task as introduced in [StochasticGates]. In particular we seek to understand whether the learned prior framework biases MLPs towards using informative features when making a prediction. In this experiment we construct a dataset based on two two-dimensional moon shape classes, concatenated with noisy features. For a given data point the first two coordinates are drawn by adding noise drawn from an distribution to a point from one of two nested half circles, as presented in Figure 1. The remaining coordinates are drawn from an distribution, and have no relation to the underlying classification task.
For each of our experiments, we use datasets consisting of 1000 samples. We use 20% of the dataset for model training, and divide the remaining points evenly into testing and validation sets, giving a final 20%-40%-40% train-test-validation split. We vary the number of nuisance features from 50 to 1000 in increments of 50. For a given number of features we construct a meta-feature matrix , where the ’th row contains the
’th feature’s mean and standard deviation. For our learned attribution prior we train a linear model that attempts to predict feature importance based on that feature’s mean and standard deviation. The number of units in the hidden layers of our MLPs depends on the number of features; for a given the first hidden layer has units, and the second has units.
For each number of nuisance features we generate five datasets. We use expected gradients [AttributionPriors] for our feature attribution method. We optimize our models using Adam [kingma2014adam] with early stopping, and all hyperparameters are chosen based on performance on our validation sets. We report our results in Figure 1. We find that our MLPs trained along with learned attribution priors using are far more robust to the addition of noisy features than MLPs that do not take advantage of meta-features. This result continues to hold even as the number of features in the dataset eclipses the number of training points. This phenomenon indicates that training an MLP with an appropriate learned prior model is indeed able to bias the MLP towards learning meaningful relationships rather than overfitting on noise in the data.
3.2 AML Drug Response Prediction
Acute myeloid leukemia (AML) is a form of blood cancer characterized by the rapid buildup of abnormal cells in the blood and bone marrow that interfere with the function of healthy blood cells. The disease has a very poor prognosis, leading to death in approximately 80% of patients, and it is the leading cause of leukemia-related deaths in the United States [AMLStats]. While the number of drugs approved for cancer treatment continues to expand, with an average of more than 10 new drugs introduced each year in the current decade alone [kinch2014analysis], advances in AML treatment have remained elusive. This phenomenon is likely due in large part to the heterogeneity of AML. Although AML genomes have fewer mutations than those of most other cancers, the relationships between mutations and phenotypes remains unclear [amlHeterogeneity]. As such, different AML patients have disparate responses to the same treatment regimens, and methods for betting matching patients to drugs are in high demand.
While deep learning methods, and their ability to discover complex nonlinear relationships, could potentially make progress towards this goal, such progress has been elusive. Due to the high cost of data acquisition in biological contexts, the size of datasets in e.g. the NIH’s Gene Expression Omnibus111https://www.ncbi.nlm.nih.gov/geo/ has remained static with median dataset size remaining under 20 [dutil2018towards]. Because of the high dimensionality of genetic data, the lack of large datasets has prevented deep learning models from achieving the same kinds of widespread success as they have in other fields.
In this experiment we investigate how training an MLP using the learned prior framework affects performance on an AML drug response prediction task. Our drug response data comes from the Beat AML dataset, comprised of RNA-seq gene expression data and ex vivo drug sensitivity measures for tumors from 572 patients [tyner2018functional]. In our analysis we focus specifically on responses to the drug Dasatinib, for which there was data from the most patients (). For prior information we use the publicly available MERGE driver features introduced in [MERGE]. The MERGE (Mutation, Expression hubness, Regulator, Genomic copy number variation, and mEthylation) features are a set of multi-omic prior information believed to correspond with how a given gene affects disease processes. This thus allows us to construct a meta-feature matrix , where is the number of genes for which we have both expression levels and MERGE feature values. For our dataset . We refer the reader to the Supplement for additional details on our RNA-seq data preprocessing and on the MERGE prior features.
To understand how training with the learned attribution prior can improve model performance, we compare against three baselines: LASSO regression[tibshirani1996regression], MLPs trained with no prior information, and an MLP trained using the learned prior framework but for which the prior information is composed solely of Gaussian noise. We evaluate each model on five splits of the data into training, validation, and test sets, for which we use 80%, 10%, and 10% of the data respectively. As in Section 3.1 we train our MLPs using Adam with early stopping, and we use expected gradients for our feature attribution method. Our results are reported in Figure 2. All prediction-model MLPs have two hidden layers with 512 and 256 hidden units respectively. For our learned attribution priors we use MLPs with two hidden layers, of five and three units, to capture the relationship between meta-features and feature importance.
We find that training MLP models with the learned prior framework provides a 15% reduction in mean squared error compared to a standard MLP approach. We also find that the standard error of the performance of our MLPs with learned MERGE prior is 40% less than that of standard MLPs. Taken together these improvements indicate that the learned MERGE prior is biasing MLPs to learn informative representations of the data, as opposed to just learning noise.
4 Learned Priors Admit New Insights into Deep Models
While feature attribution methods provide users with a sense of the “relevant" features for a given prediction, they lack a way to contextualize the attribution values in a human-interpretable way. This shortcoming can lead to a false sense of security when employing such methods, even though previous work has demonstrated their potential to produce explanations based on features known to be irrelevant to a problem based on prior knowledge [yang2019bim, kindermans2019reliability, goyal2019explaining]. However, with a learned prior model , we gain a new set of tools for understanding the why behind feature attributions in terms of our prior information. In the case where is a linear model, we can easily use the model’s weight coefficients to get an explanation for a given predicted attribution value. When is a more complex model, such as a neural network, uncovering such explanations is not as straightforward. However, in this section, using the AML drug response prediction problem from Section 3.2 as a case study, we demonstrate multiple methods for probing more complex learned attribution priors to obtain insights into predicted feature attribution values. Using such methods we find that our learned MERGE prior from Section 3.2 independently learns meta-feature to gene importance relationships that agree with prior biological knowledge.
4.1 Attribution Explanations For Understanding Meta-Feature to Gene Relationships
Given a (potentially complex) learned attribution prior model and meta feature values for the ’th feature in a prediction problem, we would like to understand how the values of relate to the predicted attribution value . To do so we can apply a feature attribution method to , thereby generating a second order attribution explanation , explaining the ’th feature’s predicted importance in terms of human-interpretable meta-features. We apply expected gradients as our to the learned MERGE prior from Section 3.2, and visualize our results in Figure 3.
Remarkably we find that our learned MERGE prior, without any input from domain experts, rediscovers relationships in line with prior biological knowledge. We can see Figure 3 that high predicted gene importance values are explained mostly by expression hubness. This trend is consistent with prior knowledge, as expression hubness has been suspected to drive events in cancer [logsdon2015sparse]. Furthermore, we observe that for a small number of genes mutation appears as a major factor in their importance explanations. This phenomenon also matches up with prior knowledge, as mutations in CEPBA, FLT3, and ELF4 are suspected to play a role in the heterogeneity of drug response in AML patients [lin2005characterization, fasan2014role, small2006flt3, daver2019targeting, suico2017roles]
4.2 Partial Dependence Plots Capture Nonlinear Meta-Feature to Importance Relationships
We can further explore the nature of these learned meta-feature to feature importance relationships by constructing partial dependence plots (PDPs) [friedman2001greedy] to visualize the marginal effect of particular meta-features on predicted attribution values. We display the results of doing so for expression hubness and mutation in Figure 4. For both meta-features we find that our learned attribution prior is capturing a non-linear relationship between the meta-features and predicted attribution values. The relationship broadly agrees with prior knowledge; after a certain point both hubness and mutation lead to higher predicted attribution values. Furthermore, we find that the relationships captured by our PDPs are nonlinear. This phenomenon indicates that more complex models may better capture the relationship between a gene’s MERGE features and its potential to drive events in AML than the original linear model used by Lee et. al in [MERGE].
4.2.1 MERGE Prior Captures Relevant Gene Pathways
To further confirm that the relationships captured by our learned MERGE prior match biological intuition, we perform Gene Set Enrichment Analysis [GSEA] to see if the top genes as ranked by our learned MERGE attribution prior were enriched for membership in any biological pathways. For comparison we use the number of pathways captured by a learned noise prior as a baseline. In our analysis we use the top 200 genes as ranked by both prior models, and we use the Enrichr [Enrichr] library to check for membership in the Kyoto Encyclopedia of Genes and Genomes [KEGG] 2019 pathways list. While the top genes as ranked by our learned noise prior are not significantly enriched for membership in any pathways after FDR correction, our MERGE prior captures many, as shown in Figure 5. Moreover, we find that among the pathways captured by our MERGE prior lie multiple pathways already known to be associated with AML [milella2001therapeutic, shao2016her2, park2010role, pomeroy2017targeting]. This result further indicates that our learned MERGE prior is capturing meaningful meta-feature to gene importance relationships.
In this work we introduce the learned attribution prior framework, a novel method for biasing the training of neural networks by incorporating prior information about the features used for a prediction task. Unlike other feature-attribution based penalty methods [AttributionPriors, ross2017right, rieger2019interpretations], our framework merely requires the presence of prior information in the form of meta-features, rather than rules hand-crafted by a domain expert. In our experiments we find that jointly learning prediction models and learned attribution priors leads to increased performance on prediction tasks in low data settings. Furthermore, we demonstrate that learned attribution priors admit new methods for model interpretability and for establishing trust in feature attribution methods. Using such methods we demonstrate that our learned attribution priors, without human intervention, independently learn meta-feature to feature relationships that agree with prior human knowledge. The learned prior framework provides a broadly applicable method for incorporating prior information into prediction tasks, and we believe that it is a valuable tool for learning in low-data, trust-critical domains.