Learned Feature Attribution Priors

by   Ethan Weinberger, et al.

Deep learning models have achieved breakthrough successes in domains where data is plentiful. However, such models are prone to overfitting when trained on high-dimensional, low sample size datasets. Furthermore, the black-box nature of such models has limited their application in domains where model trust is critical. As a result, deep learning has struggled to make inroads in domains such as precision medicine, where small sample sizes are the norm and model trust is paramount. Oftentimes, even in low data settings we have some set of prior information on each input feature to our prediction task, which may be related to that feature's relevance to the prediction problem. In this work we propose the learned attribution prior framework to take advantage of such information and alleviate the issues mentioned previously. For a given prediction task, our framework jointly learns a relationship between prior information about a feature and that feature's importance to the task, while also biasing the prediction model to focus on the features with high predicted importance. We find that training models using our framework improves model accuracy in low-data settings. Furthermore, we find that the resulting learned meta-feature to feature relationships open up new avenues for model interpretation.


page 5

page 6


Learning Explainable Models Using Attribution Priors

Two important topics in deep learning both involve incorporating humans ...

Towards Semantic Interpretation of Thoracic Disease and COVID-19 Diagnosis Models

Convolutional neural networks are showing promise in the automatic diagn...

Fine-Grained Neural Network Explanation by Identifying Input Features with Predictive Information

One principal approach for illuminating a black-box neural network is fe...

Explaining COVID-19 and Thoracic Pathology Model Predictions by Identifying Informative Input Features

Neural networks have demonstrated remarkable performance in classificati...

Deep Prior

The recent literature on deep learning offers new tools to learn a rich ...

When is Memorization of Irrelevant Training Data Necessary for High-Accuracy Learning?

Modern machine learning models are complex and frequently encode surpris...

Incorporating Priors with Feature Attribution on Text Classification

Feature attribution methods, proposed recently, help users interpret the...

1 Introduction

Recent advances in machine learning have come in the form of complex models that humans struggle to interpret. In response to the black-box nature of these models, a variety of recent work has focused on model interpretability

[Guidotti:2018:SME:3271482.3236009]. One particular line of work that has gained much attention is that of feature attribution methods. Given a model and a specific prediction made by that model, these methods assign a numeric value to each input feature of the model, indicating how important that feature was for the given prediction. A variety of such methods have been proposed, and previous work has focused on how such methods can be used to gain insight into model behavior in applications where model trust is critical [zech2018variable, sayres2019using]

A natural question to ask is why a given feature was assigned a specific attribution value for a given prediction. In some settings it is easy for a human to evaluate the sensibility of attribution values; for example, in image classification problems we can overlay attribution values on the original image. However, in many other domains we do not have the ability to assess the validity of attribution values so easily. Recent work [ross2017right, AttributionPriors, rieger2019interpretations] has attempted to address this problem by regularizing model explanations to agree with prior human knowledge. For example, Enrion et al. [AttributionPriors] propose to regularize image classification models tasks by encouraging the attributions of neighboring pixels to be similar, thereby leading to smoother attribution maps. While such regularization methods do indeed lead to noticeably different feature attributions, they suffer from two shortcomings. First, they require a human expert with prior knowledge about the problem domain to construct a rule for the regularization. Second, after a set of feature attributions is produced, we still lack a method for obtaining insight into the reasons behind a given feature attribution beyond simply trusting in the regularization procedure.

Moreover, the complexity of deep learning models makes them especially prone to overfitting when trained on small datasets. This limitation has stalled the adoption of deep learning methods in domains where data acquisition is difficult, such as many areas of medicine [chen2019deep].

In certain problem domains we have access to second-order meta-features that characterize the input features used in the prediction task. Such meta-features can potentially encode information on how important a given feature is to the task at hand. By using such meta-features to bias models to focus only on “important” features, we should be able to achieve both greater model accuracy and interpretability. However, we may not know a priori how the meta-features correspond to prediction-level feature importance. As a first step to solving this problem, we propose the learned attribution prior framework, a novel method for model training that simultaneously learns a relationship between meta-features and feature attribution values, and biases the original model’s predictions to agree with this learned relationship. Empirically we find that models trained using the learned attribution prior framework achieve better performance on low-data tasks with both synthetic and real-world data. Furthermore, we demonstrate how such learned meta-feature to feature relationships can lead to more interpretable feature attributions.

2 Learned Attribution Priors

In this section we introduce a formal definition of a learned attribution prior by extending that of the attribution prior originally proposed in [AttributionPriors]. Let denote a training dataset consisting of samples each of which has features. Similarly, let denote the labels for the samples in our training dataset. In a standard model training setting, we wish to find the model in some model class

that minimizes the average prediction loss on our dataset, as determined by some loss function

. In order to avoid overfitting on training data, a regularization function on the prediction model’s parameters is often included, giving the equation

For a given feature attribution method , an attribution prior is defined as some function that assigns scalar-valued penalties to ’s feature attributions for the input . This leads us to the objective

Now suppose that each feature in

is associated with some scalar valued meta-feature (e,g, the feature’s number of neighbors in some graph). We can represent these meta-feature values as a vector

, where the ’th entry in represents the value of our meta-feature corresponding to the ’th feature in the prediction problem. If we believe that the values in should correspond with ’s feature attributions, a natural choice for would be a penalty on the difference between our model’s feature attributions and the values in , giving us


where is the norm. Extending this line of reasoning, suppose we have some meta-feature matrix , corresponding to each of the features in the prediction problem having values for meta-features. If we were to assume that there is a linear relationship between a feature’s meta-feature values and that feature’s attribution values, we can modify (1) to achieve


where refers to the ’th column of , and

is a vector of hyperparameters that we can optimize. Finding an optimal set of hyperparameters gives us a model that captures the meta-feature to feature-attribution relationship for a given problem, which we denote as a

learned attribution prior. More generally, we define a learned attribution prior to be a function in some model class whose parameters are learned to predict an attribution value for a feature based on its meta-feature values. As either the number of meta-features increases (e.g. due to an increase in prior information) or becomes richer (e.g. so that it can capture more complex, nonlinear relationships), leading to more model parameters, learning an attribution prior ’s parameters by treating them as hyperparameters becomes computationally prohibitive. Instead, we jointly learn a prediction model and attribution prior pair by optimizing the following objective


where we define to be the vector in resulting from feeding each row of into (i.e., is the vector of predicted feature attribution values for all of our prediction features). In the case where and

are classes of multilayer perceptrons, we are able to approximate a solution to this problem by alternating between one step of optimizing

, and steps of optimizing

via stochastic gradient descent. This procedure is presented formally in Algorithm

1. So long as and change slowly enough, we find that Algorithm 1 accomplishes the following two goals:

  • Learn, without the input of human domain experts, a relationship between meta-features and feature relevance to a given prediction task.

  • Bias prediction models to rely on the features deemed important by this learned meta-feature to feature importance relationship.

We show in Section 3

that introducing such bias in the training procedure improves the performance of neural network models in low-data settings. Furthermore, we find in Section

4 that we can exploit these learned meta-feature to feature importance relationships to achieve new insights into our prediction models.


 epoch in training epochs 

     for minibatch, batchLabels in  do
        predictionLoss = (minibatch), batchLabels
        attributionLoss =
        totalLoss = predictionLoss + attributionLoss
        Update via SGD based on totalLoss
        for  steps do
           attributionLoss =
           Update via SGD based on attributionLoss
        end for
     end for
  end for
Algorithm 1 Gradient descent training of prediction model with a learned attribution prior. The number of optimization steps is a hyperparameter. We found in practice that the least expensive option worked well to achieve consistent learned prior models.

3 Training with Learned Priors Improves Accuracy in Low-Data Settings

3.1 Two Moons Classification With Nuisance Features

We first evaluate the performance of multi-layer perceptrons (MLPs) with two hidden layers trained using the framework on the two moons with nuisance features task as introduced in [StochasticGates]. In particular we seek to understand whether the learned prior framework biases MLPs towards using informative features when making a prediction. In this experiment we construct a dataset based on two two-dimensional moon shape classes, concatenated with noisy features. For a given data point the first two coordinates are drawn by adding noise drawn from an distribution to a point from one of two nested half circles, as presented in Figure 1. The remaining coordinates are drawn from an distribution, and have no relation to the underlying classification task.

Figure 1: two_moons_points Points from the two moons shape classes. and are the informative features for our task, while

are pure noise. two_moons_accuracy Classification accuracy (mean and standard error) of MLPs trained (trained with and without learned priors) vs. number of noise dimensions.

For each of our experiments, we use datasets consisting of 1000 samples. We use 20% of the dataset for model training, and divide the remaining points evenly into testing and validation sets, giving a final 20%-40%-40% train-test-validation split. We vary the number of nuisance features from 50 to 1000 in increments of 50. For a given number of features we construct a meta-feature matrix , where the ’th row contains the

’th feature’s mean and standard deviation. For our learned attribution prior we train a linear model that attempts to predict feature importance based on that feature’s mean and standard deviation. The number of units in the hidden layers of our MLPs depends on the number of features

; for a given the first hidden layer has units, and the second has units.

For each number of nuisance features we generate five datasets. We use expected gradients [AttributionPriors] for our feature attribution method. We optimize our models using Adam [kingma2014adam] with early stopping, and all hyperparameters are chosen based on performance on our validation sets. We report our results in Figure 1. We find that our MLPs trained along with learned attribution priors using are far more robust to the addition of noisy features than MLPs that do not take advantage of meta-features. This result continues to hold even as the number of features in the dataset eclipses the number of training points. This phenomenon indicates that training an MLP with an appropriate learned prior model is indeed able to bias the MLP towards learning meaningful relationships rather than overfitting on noise in the data.

3.2 AML Drug Response Prediction

Acute myeloid leukemia (AML) is a form of blood cancer characterized by the rapid buildup of abnormal cells in the blood and bone marrow that interfere with the function of healthy blood cells. The disease has a very poor prognosis, leading to death in approximately 80% of patients, and it is the leading cause of leukemia-related deaths in the United States [AMLStats]. While the number of drugs approved for cancer treatment continues to expand, with an average of more than 10 new drugs introduced each year in the current decade alone [kinch2014analysis], advances in AML treatment have remained elusive. This phenomenon is likely due in large part to the heterogeneity of AML. Although AML genomes have fewer mutations than those of most other cancers, the relationships between mutations and phenotypes remains unclear [amlHeterogeneity]. As such, different AML patients have disparate responses to the same treatment regimens, and methods for betting matching patients to drugs are in high demand.

While deep learning methods, and their ability to discover complex nonlinear relationships, could potentially make progress towards this goal, such progress has been elusive. Due to the high cost of data acquisition in biological contexts, the size of datasets in e.g. the NIH’s Gene Expression Omnibus111https://www.ncbi.nlm.nih.gov/geo/ has remained static with median dataset size remaining under 20 [dutil2018towards]. Because of the high dimensionality of genetic data, the lack of large datasets has prevented deep learning models from achieving the same kinds of widespread success as they have in other fields.

In this experiment we investigate how training an MLP using the learned prior framework affects performance on an AML drug response prediction task. Our drug response data comes from the Beat AML dataset, comprised of RNA-seq gene expression data and ex vivo drug sensitivity measures for tumors from 572 patients [tyner2018functional]. In our analysis we focus specifically on responses to the drug Dasatinib, for which there was data from the most patients (). For prior information we use the publicly available MERGE driver features introduced in [MERGE]. The MERGE (Mutation, Expression hubness, Regulator, Genomic copy number variation, and mEthylation) features are a set of multi-omic prior information believed to correspond with how a given gene affects disease processes. This thus allows us to construct a meta-feature matrix , where is the number of genes for which we have both expression levels and MERGE feature values. For our dataset . We refer the reader to the Supplement for additional details on our RNA-seq data preprocessing and on the MERGE prior features.

To understand how training with the learned attribution prior can improve model performance, we compare against three baselines: LASSO regression

[tibshirani1996regression], MLPs trained with no prior information, and an MLP trained using the learned prior framework but for which the prior information is composed solely of Gaussian noise. We evaluate each model on five splits of the data into training, validation, and test sets, for which we use 80%, 10%, and 10% of the data respectively. As in Section 3.1 we train our MLPs using Adam with early stopping, and we use expected gradients for our feature attribution method. Our results are reported in Figure 2. All prediction-model MLPs have two hidden layers with 512 and 256 hidden units respectively. For our learned attribution priors we use MLPs with two hidden layers, of five and three units, to capture the relationship between meta-features and feature importance.

Figure 2: Performance comparison between MLP with learned MERGE prior vs. baseline models.

We find that training MLP models with the learned prior framework provides a 15% reduction in mean squared error compared to a standard MLP approach. We also find that the standard error of the performance of our MLPs with learned MERGE prior is 40% less than that of standard MLPs. Taken together these improvements indicate that the learned MERGE prior is biasing MLPs to learn informative representations of the data, as opposed to just learning noise.

4 Learned Priors Admit New Insights into Deep Models

While feature attribution methods provide users with a sense of the “relevant" features for a given prediction, they lack a way to contextualize the attribution values in a human-interpretable way. This shortcoming can lead to a false sense of security when employing such methods, even though previous work has demonstrated their potential to produce explanations based on features known to be irrelevant to a problem based on prior knowledge [yang2019bim, kindermans2019reliability, goyal2019explaining]. However, with a learned prior model , we gain a new set of tools for understanding the why behind feature attributions in terms of our prior information. In the case where is a linear model, we can easily use the model’s weight coefficients to get an explanation for a given predicted attribution value. When is a more complex model, such as a neural network, uncovering such explanations is not as straightforward. However, in this section, using the AML drug response prediction problem from Section 3.2 as a case study, we demonstrate multiple methods for probing more complex learned attribution priors to obtain insights into predicted feature attribution values. Using such methods we find that our learned MERGE prior from Section 3.2 independently learns meta-feature to gene importance relationships that agree with prior biological knowledge.

4.1 Attribution Explanations For Understanding Meta-Feature to Gene Relationships

Figure 3: Explanations of feature attributions for the top 10 most important genes as ranked by absolute predicted feature importance. Bar color is proportional to absolute value of that meta-feature’s contribution to the predicted importance value as determined by some attribution method .

Given a (potentially complex) learned attribution prior model and meta feature values for the ’th feature in a prediction problem, we would like to understand how the values of relate to the predicted attribution value . To do so we can apply a feature attribution method to , thereby generating a second order attribution explanation , explaining the ’th feature’s predicted importance in terms of human-interpretable meta-features. We apply expected gradients as our to the learned MERGE prior from Section 3.2, and visualize our results in Figure 3.

Remarkably we find that our learned MERGE prior, without any input from domain experts, rediscovers relationships in line with prior biological knowledge. We can see Figure 3 that high predicted gene importance values are explained mostly by expression hubness. This trend is consistent with prior knowledge, as expression hubness has been suspected to drive events in cancer [logsdon2015sparse]. Furthermore, we observe that for a small number of genes mutation appears as a major factor in their importance explanations. This phenomenon also matches up with prior knowledge, as mutations in CEPBA, FLT3, and ELF4 are suspected to play a role in the heterogeneity of drug response in AML patients [lin2005characterization, fasan2014role, small2006flt3, daver2019targeting, suico2017roles]

Figure 4: Learned attribution prior partial dependence plots for hubness_pdp expression hubness and mutation_pdp mutation. Both plots indicate that our learned prior is capturing a nonlinear relationship between meta-feature value and feature importance.

4.2 Partial Dependence Plots Capture Nonlinear Meta-Feature to Importance Relationships

We can further explore the nature of these learned meta-feature to feature importance relationships by constructing partial dependence plots (PDPs) [friedman2001greedy] to visualize the marginal effect of particular meta-features on predicted attribution values. We display the results of doing so for expression hubness and mutation in Figure 4. For both meta-features we find that our learned attribution prior is capturing a non-linear relationship between the meta-features and predicted attribution values. The relationship broadly agrees with prior knowledge; after a certain point both hubness and mutation lead to higher predicted attribution values. Furthermore, we find that the relationships captured by our PDPs are nonlinear. This phenomenon indicates that more complex models may better capture the relationship between a gene’s MERGE features and its potential to drive events in AML than the original linear model used by Lee et. al in [MERGE].

4.2.1 MERGE Prior Captures Relevant Gene Pathways

To further confirm that the relationships captured by our learned MERGE prior match biological intuition, we perform Gene Set Enrichment Analysis [GSEA] to see if the top genes as ranked by our learned MERGE attribution prior were enriched for membership in any biological pathways. For comparison we use the number of pathways captured by a learned noise prior as a baseline. In our analysis we use the top 200 genes as ranked by both prior models, and we use the Enrichr [Enrichr] library to check for membership in the Kyoto Encyclopedia of Genes and Genomes [KEGG] 2019 pathways list. While the top genes as ranked by our learned noise prior are not significantly enriched for membership in any pathways after FDR correction, our MERGE prior captures many, as shown in Figure 5. Moreover, we find that among the pathways captured by our MERGE prior lie multiple pathways already known to be associated with AML [milella2001therapeutic, shao2016her2, park2010role, pomeroy2017targeting]. This result further indicates that our learned MERGE prior is capturing meaningful meta-feature to gene importance relationships.

Pathway FDR q-value Acute myeloid leukemia Transcriptional misregulation in cancer Pathways in cancer MAPK signaling pathway MicroRNAs in cancer ErbB signaling pathway PI3K-Akt signaling pathway Ras signaling pathway
Figure 5: num_pathways Number of pathways captured by a learned MERGE prior vs. by a learned noise prior. aml_pathways Sample of enriched pathways previously linked to AML phenotypes.

5 Discussion

In this work we introduce the learned attribution prior framework, a novel method for biasing the training of neural networks by incorporating prior information about the features used for a prediction task. Unlike other feature-attribution based penalty methods [AttributionPriors, ross2017right, rieger2019interpretations], our framework merely requires the presence of prior information in the form of meta-features, rather than rules hand-crafted by a domain expert. In our experiments we find that jointly learning prediction models and learned attribution priors leads to increased performance on prediction tasks in low data settings. Furthermore, we demonstrate that learned attribution priors admit new methods for model interpretability and for establishing trust in feature attribution methods. Using such methods we demonstrate that our learned attribution priors, without human intervention, independently learn meta-feature to feature relationships that agree with prior human knowledge. The learned prior framework provides a broadly applicable method for incorporating prior information into prediction tasks, and we believe that it is a valuable tool for learning in low-data, trust-critical domains.