Log In Sign Up

Granger-causal Attentive Mixtures of Experts

by   Patrick Schwab, et al.
ETH Zurich

Several methods have recently been proposed to detect salient input features for outputs of neural networks. Those methods offer a qualitative glimpse at feature importance, but they fall short of providing quantifiable attributions that can be compared across decisions and measures of the expected quality of their explanations. To address these shortcomings, we present an attentive mixture of experts (AME) that couples attentive gating with a Granger-causal objective to jointly produce accurate predictions as well as measures of feature importance. We demonstrate the utility of AMEs by determining factors driving demand for medical prescriptions, comparing predictive features for Parkinson's disease and pinpointing discriminatory genes across cancer types.


page 1

page 2

page 3

page 4


CXPlain: Causal Explanations for Model Interpretation under Uncertainty

Feature importance estimates that inform users about the degree to which...

Model-Attentive Ensemble Learning for Sequence Modeling

Medical time-series datasets have unique characteristics that make predi...

Decomposition of Global Feature Importance into Direct and Associative Components (DEDACT)

Global model-agnostic feature importance measures either quantify whethe...

A Mixture of h-1 Heads is Better than h Heads

Multi-head attentive neural architectures have achieved state-of-the-art...

A Unified Approach to Interpreting Model Predictions

Understanding why a model makes a certain prediction can be as crucial a...

Marginal Contribution Feature Importance – an Axiomatic Approach for The Natural Case

When training a predictive model over medical data, the goal is sometime...


Neural networks are often criticised for being black-box models [Castelvecchi2016]. Researchers have addressed this criticism by developing tools that provide visualisations and explanations for the decisions of neural networks [Baehrens et al.2010, Simonyan, Vedaldi, and Zisserman2014, Zeiler and Fergus2014, Xu et al.2015, Shrikumar et al.2017, Krause, Perer, and Ng2016, Montavon et al.2017, Koh and Liang2017]. These explanations are desirable for the many machine-learning use cases in which both predictive performance and interpretability are of paramount importance [Kindermans et al.2017, Smilkov et al.2017, Doshi-Velez and Kim2017]. They enable us to argue for the decisions of machine-learning models, show when algorithmic decisions might be biased or discriminating [Hardt et al.2016], help uncover the basis of decisions when there are legal or ethical circumstances that call for explanations [Goodman and Flaxman2016], and may facilitate the discovery of patterns that could advance our understanding of the underlying phenomena [Shrikumar et al.2017].

Figure 1: An overview of attentive mixtures of experts (AMEs). The attentive gating networks (red) attend to the combined hidden state (blue) of the AME. Each expert’s assigns an attentive factor to opportunistically control its contribution to the AME’s final prediction .

Estimating the relative contribution of individual input features towards outputs of a deep neural network is hard because the input features undergo multiple hierarchical, interdependent and non-linear transformations as they pass through the network

[Montavon et al.2017]. We propose a new approach to feature importance estimation that optimises jointly for predictive performance and accurate assignment of feature importance in a single end-to-end trained neural network. Structurally, our approach builds on the idea of distributing the features of interest among experts in a mixture of experts [Jordan and Jacobs1994]. The mixture of experts uses attentive gating networks to assign importance weights to individual experts (Figure 1). However, when trained naïvely, this structure alone does not generally learn to accurately assign weights that correspond to the importance of the experts’ input features. We therefore draw upon a previously unreported connection between Granger-causality and feature importance estimation to define a secondary Granger-causal objective. Using the Granger-causal objective, we ensure that the weights given to individual experts correlate strongly and measurably with their ability to contribute to the decision at hand. Our experiments demonstrate that this optimisation-based approach towards learning to estimate feature importance leads to improvements of several orders of magnitude in computational performance over state-of-the-art methods. In addition, we show that the Granger-causal objective correlates with the expected quality of importance estimates, that AMEs compare favourably to the best existing methods in terms of feature importance estimation accuracy, and that AMEs discover associations that are consistent with those reported by domain experts111Source code at

Measure of expected quality
Computation Time (x) 1x 1x 2x 100-1000x 1000x
Table 1: Comparison of AMEs to several representative methods for feature importance estimation.

Contributions. We present the following contributions:

  • [noitemsep]

  • We delineate an end-to-end trained AME that uses attentive gating to assign weights to individual experts.

  • We introduce a Granger-causal objective that measures the degree to which assigned feature importances correlate with the predictive value of features towards individual decisions.

  • We compare AMEs to state-of-the-art importance estimation methods on three datasets. The experiments show that AMEs are significantly faster than existing methods, that AMEs compare favourably to existing methods in terms of attribution accuracy, and that the associations discovered by AMEs are consistent with human experts.

Related Work

There are four main categories of approaches to assessing feature importance in neural networks:

Perturbation-based Approaches.

Perturbation-based approaches attempt to explain the sensitivity of a machine-learning model to changes in its inputs by modelling the impact of local perturbations [Ribeiro, Singh, and Guestrin2016b, Adler et al.2016, Fong and Vedaldi2017, Lundberg and Lee2017]. Examples of perturbation-based approaches are LIME [Ribeiro, Singh, and Guestrin2016b] and SHAP [Lundberg and Lee2017]. The main drawbacks of perturbation-based approaches are (a) that perturbed samples might not be part of the original data distribution [Ribeiro, Singh, and Guestrin2018], and (b) that they are computationally inefficient, as hundreds to thousands of model evaluations are required to sample the space of local perturbations. Perturbation-based approaches are applicable to any machine-learning model [Ribeiro, Singh, and Guestrin2016b, Lundberg and Lee2017].

Gradient-based Methods.

Gradient-based approaches are built on the idea of following the gradient from the output nodes of a neural network to the input nodes to obtain the features that the output was most sensitive to [Baehrens et al.2010, Simonyan, Vedaldi, and Zisserman2014]. Gradient-based approaches are therefore only applicable to differentiable models. Several improvements to this technique have since been proposed [Zeiler and Fergus2014, Smilkov et al.2017, Sundararajan, Taly, and Yan2017]. In particular, [Selvaraju et al.2016] introduced the DeepLIFT method that addresses the issue of saturating gradients.

Attentive Models.

Attentive models have been used in various domains to improve both interpretability [Xu et al.2015, Choi et al.2016, Schwab et al.2017, Schwab and Karlen2019] and performance [Bahdanau, Cho, and Bengio2015, Yang et al.2016]

. In computer vision, related works used attention in convolutional neural networks (CNNs) that selectively focus on input data

[Ba, Mnih, and Kavukcuoglu2014] and internal convolutional filters [Stollenga et al.2014]

. However, fundamentally, naïve soft attention mechanisms do not provide any incentive for a neural network to yield attention factors that correlate with feature importance. When used on top of recurrent neural networks (RNNs), attention mechanisms may propagate information across time steps through the recurrent state, and are therefore not guaranteed to accurately represent importance

[Sundararajan, Taly, and Yan2017].

Mimic Models.

Mimic models are interpretable models trained to match the output of a black-box model. Rule-based models [Andrews, Diederich, and Tickle1995] and tree models [Schwab and Hlavacs2015, Che et al.2016] are examples of models that have been used as mimic models. Mimic models are not guaranteed to match the original model and may thus not be truthful to the original model.


Irrespective of the category of the approach, we would ideally want a feature importance estimation method to be model-agnostic, i.e. independent of the choice of predictive model [Ribeiro, Singh, and Guestrin2016a]. The main arguments for model-agnosticism are flexibility to choose the predictive model and feature representation as necessary [Ribeiro, Singh, and Guestrin2016a]. However, in practice, the generality of model-agnostic approaches comes at a considerable cost in computational performance and scalability (Table 1). With datasets continuously growing in size and neural networks becoming the preferred choice of model in many domains, model-specific feature importance estimation methods are often the only viable choice. This is evidenced by the recent surge in works applying model-specific approaches to analyse predictive relationships in large-scale datasets [Esteva et al.2017, Ilse, Tomczak, and Welling2018]. In addition, it would be desirable to have a measure of the expected quality of the provided importance estimates. Against the backdrop of estimates that are potentially not truthful to the underlying data, such a measure would enable us to assess the expected estimation accuracy and inform us when accurate estimates can not be expected. To the best of our knowledge, the presented Granger-causal objective is both the first tool that quantifies the expected quality of importance estimates, and the first objective that enables neural networks to learn to estimate feature importance.

Attentive Mixtures of Experts

We consider the setting in which we are given a dataset containing training samples . Each consists of input features with where is the number of input features per sample. A ground truth label is available for each training sample. Using the labelled training dataset, we wish to train a model that produces (1) accurate output predictions for new samples for which we do not have labels, and (2) feature importance scores that correspond to the importance of each respective input feature towards the output . To model this problem setting, we introduce AMEs, a mixture of experts model that consist of experts and their corresponding attentive gating networks (Figure 1). At prediction time, the attentive gating networks output an attention factor for each expert to control its respective contribution to the AME’s final prediction . All of the experts and the attentive gating networks are neural networks with their own parameters and architectures222The experts are, however, not separate models because all parts of AMEs are connected, differentiable, and trained end-to-end. An AME is therefore a single model and not an ensemble of models.. AMEs do not impose any restrictions on the experts other than that they need to expose their topmost feature representation and their contribution for a given .

As input to the gating networks, the hidden states and local contributions of each expert are concatenated to form the combined hidden state of the whole AME:


We denote as the output of the th expert for the given feature of the input data . represents the output of the th attentive gating network with respect to the combined hidden state of the AME. The output of an AME is then given by:


The attention factors modulate the contribution of each expert to the final prediction based on the AME’s combined hidden state . The attention factors therefore represent the importance of each expert’s contribution towards the output . The motivation behind structuring AMEs as a mixture of experts with input features distributed across experts is to ensure (1) that each expert’s contribution can only be based on their respective input feature , and (2) that the importance of towards the final prediction can only be increased by increasing its respective attention factor. Splitting the features across experts guarantees that there can not be any information leakage across features, and that the attention factors can in turn safely be interpreted as the importance of the input feature towards upon model convergence. We calculate the attention factors using:




corresponds to a single-hidden-layer multi-layer perceptron (MLP) with an activation function, a weight matrix

and bias . To compute the attention factors, we first feed the combined state of the AME into the MLP to get

as a projected hidden representation of

[Xu et al.2015, Rocktäschel et al.2016, Yang et al.2016]. We then compute the similarity of the projected hidden representation

to a per-expert context vector

. The context vector can be seen as a fixed high-level representation that answers the question: ”What projected hidden representation would be the most informative for this expert?”. The per-expert context vector is initialised randomly and has to be learned with the other network parameters during training. We obtain normalised importance scores from the similarities through a softmax function (Eq. 3). The attention factors are used to weight the contributions of each expert towards the final decision of the AME (Eq. 2). The soft attention mechanism formulated in Equations 3 and 4 closely follows the definitions used in related works [Xu et al.2015, Rocktäschel et al.2016, Yang et al.2016] with two notable exceptions: Firstly, we use rather than just the hidden representation of a single expert as input to the soft attention mechanism. This enables the AME to simultaneously take into account the information from all available experts when producing its attention factors and its prediction , despite not sharing features in the experts themselves. Secondly, we use a separate attentive gating network for each expert to produce the attention factors. This is in contrast to existing works that use a shared representation either over feature maps in a CNN for image data [Xu et al.2015] or over the hidden states of a RNN for sequence data [Xu et al.2015, Choi et al.2016, Yang et al.2016]. Using a shared or overlapping attention mechanism is problematic for importance estimation, as information from features could potentially leak across features. This is best exemplified by attention mechanisms on top of RNNs, where information can propagate across time steps through the recurrent state, and therefore influence the model output through means other than the attention factor [Sundararajan, Taly, and Yan2017]. The use of separate attention mechanisms prevents information leakage entirely, and at the same time enables us to apply soft attention to non-sequential and non-spatial input data.

Granger-causal Objective

A fundamental issue of naïvely-trained soft attention mechanisms is that they provide no incentive to learn feature representations that yield accurate attributions [Sundararajan, Taly, and Yan2017]. Naïvely-trained attentive gating networks may therefore not accurately represent feature importance or even collapse towards attractive minima, such as assigning an attention weight of to a single expert and to all others [Bengio et al.2015, Shazeer et al.2017]. To ensure the assigned attention weights correspond to feature importance, we introduce a secondary objective function that measures the mean Granger-causal error (MGE). Granger-causality follows the Humean definition of causality that, under certain assumptions, declares a causal relationship

between random variables

and if we are better able to predict using all available information than if the information apart from had been used [Granger1969]. Given input sample , we denote as the AME’s prediction error without including any information from the th expert and as the AME’s prediction error when considering all available information. To estimate and , we use differentiable auxiliary predictors and that receive as input the concatenated hidden representations of all experts excluding the th expert’s hidden representation and the concatenated hidden representations of all experts , respectively. The auxiliary predictors are trained jointly with the AME.


We then calculate and by comparing the auxiliary predictions and with the ground truth labels

using the auxiliary loss function

. We use the mean absolute error as for regression problems and categorical cross-entropy for classification problems.


Following [Granger1969], we define the degree to which the th expert is able to contribute to the final output as the decrease in error associated with adding that expert’s information to the set of available information sources:


This definition of naturally resolves cases where combinations of features enable improvements in the prediction error - both experts would be attributed equally for the decrease. We normalise the desired attribution corresponding to the th experts’ attention weights for a given input as:


Where equation (10) normalises the attributions across all experts to ensure that they are on the same scale across decisions. We calculate the Granger-causal objective by computing the average probabilistic distance over samples between the target distribution , with , and the actual distribution , with

, of attention values using a distance measure D. The Kullback-Leibler divergence

[Kullback1997] is a suitable differentiable D for attention distributions [Itti and Baldi2006].


Because the Granger-causal loss measures the average probabilistic distance of the actual attributions to the desired Granger-causal attributions, it is valid to use it as a proxy for the expected quality of explanations. A Granger-causal loss of 0 indicates a perfect match with the Granger-causal attributions. We can therefore apply the familiar framework of cross-validation and held-out test data to estimate the expected quality of the importance estimates on unseen data. Finally, the total loss

is the sum of the main loss and the Granger-causal loss weighted by a hyperparameter



The core idea of the Granger-causal objective is to train predictors on distinct subsets of the input data to measure how much the exclusion of individual features reduces model performance. This approach to importance estimation is not new [Štrumbelj, Kononenko, and Šikonja2009] and is commonly practiced in ablation studies. In addition, a similar approach, called Shapley value analysis [Lipovetsky and Conklin2001] or Shapley regression values [Lundberg and Lee2017], has been proposed for regression using the game-theoretic concept of Shapley values [Shapley1953, Lundberg and Lee2017]. The main difference between Shapley values and Granger-causality is that feature importance in Shapley values is defined as the marginal contribution towards the model output whereas Granger-causality defines importance in terms of the marginal contribution towards the reduction in prediction error. This subtle change in definition improves computational and memory scalability from factorial to linear in the number of features as we only have to train one additional auxiliary model per feature rather than one for every possible subset of features [Lipovetsky and Conklin2001, Lundberg and Lee2017].

(d) Method CPU(s) AME(=) 3 SHAP 982 LIME 2063 Figure 2: Determining important features on MNIST. (a) The attention map

shows which pixels were assigned the most importance. We masked the most important pixels to change the prediction to the target digit (more samples in Appendix B.2). AMEs were (b) comparable to SHAP in the change in log odds (d) at significantly lower runtime when masking over

random images. (c) Lower MGEs correlated with better estimates when comparing AMEs with different levels of test set MGE.


To compare AMEs to state-of-the-art methods for importance estimation, we performed experiments on an established benchmark for importance estimation and two real-world tasks. Our goal was to answer the following questions:

  • [noitemsep]

  • How do AMEs compare to state-of-the-art feature importance estimation methods for neural networks in terms of estimation accuracy and computational performance?

  • Does jointly optimising AMEs for predictive performance and accurate estimation of feature importance have an adverse impact on predictive performance?

  • Does a lower test-set MGE correlate with a better expected estimation accuracy on unseen data?

  • How do varying choices of impact predictive performance and attribution accuracy?

  • Are the associations identified by AMEs and other methods consistent with those reported by domain experts?

Important Features in Handwritten Digits

We performed the MNIST benchmark proposed by [Shrikumar et al.2017] to compare AMEs to LIME [Ribeiro, Singh, and Guestrin2016b] and SHAP [Lundberg and Lee2017]

, and to validate whether a lower test-set MGE on test data indicates a better estimation accuracy. Because AMEs provide a single set of importance scores per decision and not one set of importance scores for each possible output class, we adapted the benchmark to use a binary classifier that was trained to distinguish between a source and a target digit class (

). We used LIME, SHAP and multiple AMEs to determine the most important pixels in an image of the source digit. The most important pixels in this setting corresponded to those pixels which distinguish the source digit from the target digit. We masked the top of most important pixels (Figure 2a) and calculated the change in log odds for classifying across samples (Figure 2b) to quantify to what degree the feature importance estimation methods were able to identify the important pixels for distinguishing the two digit classes [Shrikumar et al.2017, Lundberg and Lee2017]. We brought LIME and SHAP to the same scale as the AME’s attention factors by applying the normalising transform  (eq. 13). We trained AME(=) and AME(=

) until convergence (100 epochs, 6 epochs early stopping patience) and stopped the training of AME(

=) and AME(=) prematurely after 10 epochs to obtain AMEs with higher test-set MGE values for comparison (Figure 2c). We applied LIME and SHAP to the AME(=) with samples. Appendix B.1 lists architectures and hyperparameters.

Drivers of Medical Prescription Demand

To gain a deeper understanding of what factors drive prescription demand, we trained machine-learning models to predict the next month’s demand for prescription items.


We used data related to prescription demand in England, United Kingdom during the time frame from January 2011 to December 2012. We used data streams split into six feature groups: (a) demand history, (b) online search interest, (c) regional weather, (d) regional demographics, (e) economic factors and (f) labor market data. Appendix C.1 contains a description of the dataset and the list of input features per expert (total number of features ). We applied a random split by practice to separate the data into training (, practices, million time series), validation (, practices, million time series) and test set (, practices, million time series). Because LIME and SHAP did not scale to the size of this test set, we used a subset of 3 practices ( time series) to perform the comparison on importance estimation speed.


We trained autoregressive integrated moving average (ARIMA) models, recurrent neural networks (RNN), feedforward neural networks (FNN), and AMEs trained with (>) and without (=) the Granger-causal objective. Each feature group was represented as an expert in the AMEs for a total of six expert networks. The AMEs trained without the Granger-causal objective served as a baseline of relying on neural attention only. ARIMA served as a baseline that did not make use of any information apart from the revenue history. We applied all feature importance estimation methods except DeepLIFT to the same AME(=). Because there, to our knowledge, currently exists no DeepLIFT propagation rule for attentive gating networks, we used the highest-performing FNN to produce the DeepLIFT explanations (architectures in Appendix C.2).

Method SMAPE (%) CPU (hr) RNN 32.79 0.250.07 FNN 32.87 0.060.02 AME(=) 33.08 0.450.14 AME(=) 33.85 0.210.08 ARIMA 34.98 527.96

Table 2: Comparison of the symmetric mean absolute percentage error (SMAPE; in %) on the test set of 1891 practices ( million time series), and the average standard deviation of CPU hours used for training and evaluation across the 35 runs.

For all neural networks, we performed a hyperparameter search with hyperparameters chosen at random from predefined ranges ( hidden layers, hidden units per layer, dropout) over 35 runs. We selected those models from the hyperparameter search that achieved the best performance. Methodologically, we optimised the neural networks’ mean squared error (MSE), batch size of 256, with an early stopping patience of 12 epochs and a learning rate of . For ARIMA, we used the iterative parameter selection algorithm from [Hyndman, Khandakar, and others2007]. To better understand the impact of , we trained AMEs with chosen on a grid in steps of . We used a neighbourhood of perturbed samples for LIME. Despite our use of a small subset for the comparison on estimation speed, we were only able to apply SHAP with perturbed samples. The expected computation time for applying SHAP with perturbed samples was 9 months of CPU time.

Pre- and Postprocessing.

Prior to fitting the models, we standardised the prescription revenue history data for each time series to the range

. We normalised all other features to have zero mean and unit variance.


We compared the predictive accuracy of the different models by computing their symmetric mean absolute percentage error (SMAPE) [Flores1986] on the test set of 1891 practices. We additionally compared the speed of the various feature importance estimation methods by measuring the computation time in CPU seconds used for evaluation and the time in CPU hours used for training.


For importance estimation, AMEs (2 CPU seconds) were faster than DeepLIFT (24 CPU seconds), LIME (10464 CPU seconds) and SHAP (729068 CPU seconds) by one, four and six orders of magnitude, respectively. In terms of predictive performance, AME(=) models performed slightly worse than the FNN and RNN (Table 2). Furthermore, AME(=) performed worse than AME(=). This indicates that there was a small performance decrease associated with both (i) the use of attentive gating networks, and (ii) optimising jointly to maximise predictive performance and feature importance estimation accuracy. We hypothesise that (ii) is caused by adverse gradient interactions [Doersch and Zisserman2017, Schwab et al.2018] between the main task and the Granger-causal objective. We also found that AMEs indeed effectively learn to match the desired Granger-causal attributions (Eq. 10) with a Pearson correlation of measured on the test set. In contrast, the AME(=) trained without the Granger-causal objective only reached a of . The training time of AME() was comparable to RNNs.

Impact of .

Increasing values of in the range of lead to an exponential improvement in MGE that was accompanied with a minor decrease in MSE (Figure 3). A good middle ground was at , where roughly 80% of the attribution accuracy gains were realised while maintaining most of the performance. The relationship between MGE and MSE was constant for values of .

Figure 3: The mean value (solid lines) and the standard deviation (shaded area) of the MSE (purple) and the MGE (grey) of AMEs trained with varying choices of across 35 runs on the test set of 1891 practices.

Discriminatory Genes Across Cancer Types

To pinpoint the genes that differentiate between several types of cancer, we analysed the feature importances in machine-learning models trained to classify gene expression data as being either breast carcinoma (BRCA), kidney renal clear cell carcinoma (KIRC), colon adenocarcinoma (COAD), lung adenocarcinoma (LUAD) and prostate adenocarcinoma (PRAD).


We used gene expression data from multiple cancer types in 801 individuals from The Cancer Genomic Atlas (TCGA) RNAseq dataset. To keep visualisations succinct, we used a subset of 100 genes as input data. We applied a stratified random split to separate the data into training (60%, 480 samples), validation (20%, 160 samples) and test set (20%, 161 samples).


We trained FNN, AME(=) and AME(=) (architectures in Appendix D.1). LIME(=) and SHAP(=) were applied to the AME(=

) and DeepLIFT to the best FNN for the same reason as in experiment 2. We also trained five random forests (RF)

[Breiman2001] with trees in a binary one-vs.-all classification setting for each cancer type. We used the Gini importance measure [Breiman2001, Genuer, Poggi, and Tuleau-Malot2010, Louppe et al.2013] derived from the RFs as a baseline that was independent of the neural networks.


For each of the 100 gene loci, we used a MLP with a single hidden layer with batch normalisation [Ioffe and Szegedy2015]

and a single neuron as expert networks in the AME models. Each expert network received the gene expression at one gene locus as its input. For the FNN baseline, we chose the matching hyperparameters and architecture (100 neurons, 1 hidden layer). We optimised the neural networks with a learning rate of 0.0001, a batch size of 8 and an early stopping patience of 12 epochs. We trained each model on 35 random initialisations.

Method Recall@10 CPU(s) AME(=) 10 3 RF 10 12 SHAP(=) 8 6119 LIME(=) 8 80 DeepLIFT 7 6 AME(=) 2 3

Table 3: Comparison of the number of gene-cancer associations that were substantiated by literature evidence in the top 10 genes by average importance (Recall@10), and the number of CPU seconds used to compute them.
Pre- and Postprocessing.

We standardised the input gene expression levels to have zero mean and unit variance. We applied the normalising transform (eq. 13) to DeepLIFT, LIME, SHAP and RF.


We compared the error rates on the test set to assess predictive performance. In order to determine whether the associations identified by the various methods are consistent with those reported by domain experts, we counted the number of gene-cancer associations that were substantiated by literature evidence in the top 10 genes by average importance on the test set (Recall@10). We performed a literature search to determine which associations have previously been reported by domain experts. Appendix D.2 contains references and details of the literature search.


We found that the AME(=) based its decisions primarily on a small number of highly predictive genes for the different types of cancer (Figure 4), and that literature evidence substantiated all of the top 10 links between respective cancer type and gene locus it reported (Table 3). AME(=) collapsed to assign an attention factor of 1 to one gene locus and 0 to all others for each cancer type - only reporting five non-zero importance scores. DeepLIFT, LIME and RF had difficulties discerning the important from the uninformative genes and assigned moderate levels of importance to many gene loci. DeepLIFT, LIME and SHAP were conflicted about which genes were relevant for which cancer with several of their top genes having high importance scores for multiple cancers. In contrast, AME(=) clearly distinguished both between cancers and important and uninformative genes. RF achieved a similar performance in terms of Recall@10 as AME(=). However, RFs can only produce an average set of importance scores for the whole training set. AMEs learn to accurately assign feature importance for individual samples, and can therefore explain every single prediction they make. On the test set, the meanstandard deviation of error rates across 35 runs of FNN, AME(=) and AME(=) were , , and , respectively.


We presented a new approach to estimating feature importance that is based on the idea of distributing the feature groups of interest among experts in a mixture of experts model. The mixture of experts uses attentive gates to assign attention factors to individual experts. We introduced a secondary Granger-causal objective that defines feature importance as the marginal contribution towards prediction accuracy to ensure that the assigned attention factors correlate with the importance of the experts’ input features. We showed that AMEs (i) compare favourably to several state-of-the-art methods in importance estimation accuracy, (ii) are significantly faster than existing methods, and (iii) discover associations that are consistent with those reported by domain experts. In addition, we found that there was a trade-off between predictive performance and accurate importance estimation when optimising jointly for both, that training with the Granger-causal objective was crucial to obtain accurate estimates, and that a lower Granger-causal error correlated with a better expected importance estimation accuracy. AMEs are a fast and accurate alternative that may be used when model-agnostic feature importance estimation methods are prohibitively expensive to compute. We believe AMEs could therefore be a first step towards translating the strong performance of neural networks in many domains into a deeper understanding of the underlying data.

(a) AME(=)
(b) AME(=)
(c) DeepLIFT
(d) LIME(=)
(e) SHAP(=)
(f) RF
Figure 4: The importance of specific genes (coloured bars) for distinguishing between multiple cancer types as measured by average assignment of attention factors . We report the average attention factors over All and over the per-cancer subsets. The grey bars spanning through the subsets highlight the 10 most discriminatory genes by average attention over All. We bolded the names of those genes whose associations are substantiated by literature evidence.


This work was funded by the Swiss National Science Foundation project No. 167302. We acknowledge the support of the NVIDIA Corporation. Contains public sector information licensed under the Open Government Licence, and data generated by the TCGA Research Network: