AttentionDeepMIL
Implementation of Attention-based Deep Multiple Instance Learning in PyTorch
view repo
Multiple instance learning (MIL) is a variation of supervised learning where a single class label is assigned to a bag of instances. In this paper, we state the MIL problem as learning the Bernoulli distribution of the bag label where the bag label probability is fully parameterized by neural networks. Furthermore, we propose a neural network-based permutation-invariant aggregation operator that corresponds to the attention mechanism. Notably, an application of the proposed attention-based operator provides insight into the contribution of each instance to the bag label. We show empirically that our approach achieves comparable performance to the best MIL methods on benchmark MIL datasets and it outperforms other methods on a MNIST-based MIL dataset and two real-life histopathology datasets without sacrificing interpretability.
READ FULL TEXT VIEW PDF
Multiple Instance Learning is a form of weakly supervised learning in wh...
read it
Multiple Instance Learning (MIL) involves predicting a single label for ...
read it
Multi-instance learning (MIL) is a form of weakly supervised learning wh...
read it
Recently neural networks and multiple instance learning are both attract...
read it
In this work, we analyze if it is possible to distinguish between differ...
read it
In this paper, we propose 2D-Attention (2DA), a generic attention formul...
read it
Histopathology image analysis can be considered as a Multiple instance
l...
read it
Implementation of Attention-based Deep Multiple Instance Learning in PyTorch
In typical machine learning problems like image classification it is assumed that an image clearly represents a category (a class). However, in many real-life applications multiple instances are observed and only a general statement of the category is given. This scenario is called
multiple instance learning (MIL) (Dietterich et al., 1997; Maron & Lozano-Pérez, 1998) or, learning from weakly annotated data (Oquab et al., 2014). The problem of weakly annotated data is especially apparent in medical imaging (Quellec et al., 2017) (e.g., computational pathology, mammography or CT lung screening) where an image is typically described by a single label (benign/malignant) or a Region Of Interest (ROI) is roughly given.MIL deals with a bag of instances for which a single class label is assigned. Hence, the main goal of MIL is to learn a model that predicts a bag label, e.g., a medical diagnosis. An additional challenge is to discover key instances (Liu et al., 2012), i.e., the instances that trigger the bag label. In the medical domain the latter task is of great interest because of legal issues^{1}^{1}1According to the European Union General Data Protection Regulation (taking effect 2018), a user should have the right to obtain an explanation of the decision reached. and its usefulness in clinical practice. In order to solve the primary task of a bag classification different methods are proposed, such as utilizing similarities among bags (Cheplygina et al., 2015b)
, embedding instances to a compact low-dimensional representation that is further fed to a bag-level classifier
(Andrews et al., 2003; Chen et al., 2006), and combining responses of an instance-level classifier (Ramon & De Raedt, 2000; Raykar et al., 2008; Zhang et al., 2006). Only the last approach is capable of providing interpretable results. However, it was shown that the instance level accuracy of such methods is low (Kandemir & Hamprecht, 2015) and in general there is a disagreement among MIL methods at the instance level (Cheplygina et al., 2015a). These issues call into question the usability of current MIL models for interpreting the final decision.In this paper, we propose a new method that aims at incorporating interpretability to the MIL approach and increasing its flexibility. We formulate the MIL model using the Bernoulli distribution for the bag label and train it by optimizing the log-likelihood function. We show that the application of the Fundamental Theorem of Symmetric Functions provides a general procedure for modeling the bag label probability (the bag score function) that consists of three steps: (i) a transformation of instances to a low-dimensional embedding, (ii) a permutation-invariant (symmetric) aggregation function, and (iii) a final transformation to the bag probability. We propose to parameterize all transformations using neural networks (i.e., a combination of convolutional and fully-connected layers), which increases the flexibility of the approach and allows to train the model in an end-to-end manner by optimizing an unconstrained objective function. Last but not least, we propose to replace widely-used permutation-invariant operators such as the maximum operator and the mean operator by a trainable weighted average where weights are given by a two-layered neural network. The two-layered neural network corresponds to the attention mechanism (Bahdanau et al., 2014; Raffel & Ellis, 2015). Notably, the attention weights allow us to find key instances, which could be further used to highlight possible ROIs. In the experiments we show that our model is on a par with the best classical MIL methods on common benchmark MIL datasets, and that it outperforms other methods on a MNIST-based MIL problem as well as two real-life histopathology image datasets. Moreover, in the image datasets we provide empirical evidence that our model can indicate key instances.
Problem formulation In the classical (binary) supervised learning problem one aims at finding a model that predicts a value of a target variable, , for a given instance, . In the case of the MIL problem, however, instead of a single instance there is a bag of instances, , that exhibit neither dependency nor ordering among each other. We assume that could vary for different bags. There is also a single binary label associated with the bag. Furthermore, we assume that individual labels exist for the instances within a bag, i.e., and , for , however, there is no access to those labels and they remain unknown during training. We can re-write the assumptions of the MIL problem in the following form:
(1) |
These assumptions imply that a MIL model must be permutation-invariant. Further, the two statements could be re-formulated in a compact form using the maximum operator:
(2) |
Learning a model that tries to optimize an objective based on the maximum over instance labels would be problematic at least for two reasons. First, all gradient-based learning methods would encounter issues with vanishing gradients. Second, this formulation is suitable only when an instance-level classifier is used.
In order to make the learning problem easier, we propose to train a MIL model by optimizing the log-likelihood function where the bag label is distributed according to the Bernoulli distribution with the parameter , i.e., the probability of given the bag of instances .
MIL approaches In the MIL setting the bag probability must be permutation-invariant since we assume neither ordering nor dependency of instances within a bag. Therefore, the MIL problem can be considered in terms of a specific form of the Fundamental Theorem of Symmetric Functions with monomials given by the following theorem (Zaheer et al., 2017):
A scoring function for a set of instances , , is a symmetric function (i.e., permutation-invariant to the elements in , if and only if it can be decomposed in the following form:
(3) |
where and are suitable transformations.
This theorem provides a general strategy for modeling the bag probability using the decomposition given in (3). A similar decomposition with instead of sum is given by the following theorem (Qi et al., 2017):
For any , a Hausdorff continuous symmetric function can be arbitrarily approximated by a function in the form , where
is the element-wise vector maximum operator and
and are continuous functions, that is:(4) |
The difference between Theorems 1 and 2 is that the former is a universal decomposition while the latter provides an arbitrary approximation. Nonetheless, they both formulate a general three-step approach for classifying a bag of instances: (i) a transformation of instances using the function , (ii) a combination of transformed instances using a symmetric (permutation-invariant) function , (iii) a transformation of combined instances transformed by using a function . Finally, the expressiveness of the score function relies on the choice of classes of functions for and .
In the MIL problem formulation the score function in both theorems is the probability and the permutation-invariant function is referred to as the MIL pooling. The choice of functions , and determines a specific approach to modeling the label probability. For a given MIL operator there are two main MIL approaches:
The instance-level approach: The transformation is an instance-level classifier that returns scores for each instance. Then individual scores are aggregated by MIL pooling to obtain . The function is the identity function.
The embedding-level approach: The function maps instances to a low-dimensional embedding. MIL pooling is used to obtain a bag representation that is independent of the number of instances in the bag. The bag representation is further processed by a bag-level classifier to provide .
It is advocated in (Wang et al., 2016) that the latter approach is preferable in terms of the bag level classification performance. Since the individual labels are unknown, there is a threat that the instance-level classifier might be trained insufficiently and it introduces additional error to the final prediction. The embedding-level approach determines a joint representation of a bag and therefore it does not introduce additional bias to the bag-level classifier. On the other hand, the instance-level approach provides a score that can be used to find key instances i.e., the instances that trigger the bag label. Liu et al. (2012) were able to show that a model that is successfully detecting key instances is more likely to achieve better bag label predictions. We will show how to modify the embedding-level approach to be interpretable by using a new MIL pooling.
In classical MIL problems it is assumed that instances are represented by features that do not require further processing, i.e.,
is the identity. However, for some tasks like image or text analysis additional steps of feature extraction are necessary. Additionally, Theorem
1 and 2 indicate that for a flexible enough class of functions we can model any permutation-invariant score function. Therefore, we consider a class of transformations that are parameterized by neural networks with parameters that transform the -th instance into a low-dimensional embedding, , where such that for the instance-based approach and for the embedding-based approach.Eventually, the parameter is determined by a transformation . In the instance-based approach the transformation is simply the identity, while in the embedding-based approach it could be also parameterized by a neural network with parameters . The former approach is depicted in Figure 6(a) and the latter in Figure 6(b) in the Appendix.
The idea of parameterizing all transformations using neural networks is very appealing because the whole approach can be arbitrarily flexible and it can be trained end-to-end by backpropagation. The only restriction is that the MIL pooling must be differentiable.
The formulation of the MIL problem requires the MIL pooling to be permutation-invariant. As shown in Theorem 1 and 2, there are two MIL pooling operators that ensure the score function (i.e., the bag probability) to be a symmetric function, namely, the maximum operator:
(5) |
and the operator:^{2}^{2}2Notice that the weight can be seen as a part of the function.
(6) |
In fact, other operators could be used such as, the convex maximum operator (i.e., log-sum-exp) (Ramon & De Raedt, 2000), Integrated Segmentation and Recognition (Keeler et al., 1991), noisy-or (Maron & Lozano-Pérez, 1998) and noisy-and (Kraus et al., 2016). These MIL pooling operators could replace in Theorem 2 and proofs would follow in a similar manner (see Supplementary in (Qi et al., 2017) for a detailed proof for the maximum operator). All of these operators are differentiable, hence, they could be easily used as a MIL pooling layer in a deep neural network architecture.
All MIL pooling operators mentioned in the previous section have a clear disadvantage, namely, they are pre-defined and non-trainable. For instance, the -operator could be a good choice in the instance-based approach but it might be inappropriate for the embedding-based approach. Similarly, the operator is definitely a bad MIL pooling to aggregate instance scores, although, it could succeed in calculating the bag representation. Therefore, a flexible and adaptive MIL pooling could potentially achieve better results by adjusting to a task and data. Ideally, such MIL pooling should also be interpretable, a trait that is missing in all operators mentioned in Section 2.3.
Attention mechanism We propose to use a weighted average of instances (low-dimensional embeddings) where weights are determined by a neural network. Additionally, the weights must sum to to be invariant to the size of a bag. The weighted average fulfills the requirements of the Theorem 1 where the weights together with the embeddings are part of the function. Let be a bag of embeddings, then we propose the following MIL pooling:
(7) |
where:
(8) |
where and are parameters. Moreover, we utilize the hyperbolic tangent element-wise non-linearity to include both negative and positive values for proper gradient flow. The proposed construction allows to discover (dis)similarities among instances.
Interestingly, the proposed MIL pooling corresponds to a version of the attention mechanism (Lin et al., 2017; Raffel & Ellis, 2015). The main difference is that typically in the attention mechanism all instances are sequentially dependent while here we assume that all instances are independent. Therefore, a naturally arising question is whether the attention mechanism could work without sequential dependencies among instances, and if it will not learn the operator. We will address this issue in the experiments.
Gated attention mechanism Furthermore, we notice that the non-linearity could be inefficient to learn complex relations. Our concern follows from the fact that is approximately linear for , ], which could limit the final expressiveness of learned relations among instances. Therefore, we propose to additionally use the gating mechanism (Dauphin et al., 2016) together with non-linearity that yields:
(9) |
where are parameters, is an element-wise multiplication and is the sigmoid non-linearity. The gating mechanism introduces a learnable non-linearity that potentially removes the troublesome linearity in .
Flexibility In principle, the proposed attention-based MIL pooling allows to assign different weights to instances within a bag and hence the final representation of the bag could be highly informative for the bag-level classifier. In other words, it should be able to find key instances. Moreover, application of the attention-based MIL pooling together with the transformations and parameterized by neural networks makes the whole model fully differentiable and adaptive. These two facts make the proposed MIL pooling a potentially very flexible operator that could model an arbitrary permutation-invariant score function. The proposed attention mechanism together with a deep MIL model is depicted in Figure 6(c) in the Appendix.
Interpretability Ideally, in the case of a positive label (), high attention weights should be assigned to instances that are likely to have label (key instances). Namely, the attention mechanism allows to easily interpret the provided decision in terms of instance-level labels. In fact, the attention network does not provide scores as the instance-based classifier does but it can be considered as a proxy to that. The attention-based MIL pooling bridges the instance-level approach and the embedding-level approach.
From the practical point of view, e.g., in the computational pathology, it is desirable to provide ROIs together with the final diagnosis to a doctor. Therefore, the attention mechanism is potentially of great interest in practical applications.
MIL pooling Typically, MIL approaches utilize either the pooling or the pooling, while the latter is mostly used (Feng & Zhou, 2017; Pinheiro & Collobert, 2015; Zhu et al., 2017). Both operators are non-trainable which potentially limits their applicability. There are MIL pooling operators that contain global adaptive parameters, such as noisy-and (Kraus et al., 2016), however, their flexibility is restricted. We propose a fully trainable MIL pooling that adapts to new instances.
MIL with neural networks In the classical work on MIL it is assumed that instances are represented by precomputed features and there is very little need to apply additional feature extraction. Nevertheless, recent work on utilizing fully-connected neural networks in MIL shows that it could still be beneficial (Wang et al., 2016)
. Similarly, in computer vision the idea of MIL combined with deep learning significantly improves final accuracy
(Oquab et al., 2014). In this paper, we follow this line of research since it allows to apply a flexible class of transformations that can be trained end-to-end by backpropagation.MIL and attention The attention mechanism is widely used in deep learning for image captioning (Xu et al., 2015) or text analysis (Bahdanau et al., 2014; Lin et al., 2017). In the context of the MIL problem it has rarely been used and only in a very limited form. In (Pappas & Popescu-Belis, 2014)
an attention-based MIL was proposed but attention weights were trained as parameters of an auxiliary linear regression model. This idea was further expanded and the linear regression model was replaced by a one-layer neural network with single output
(Pappas & Popescu-Belis, 2017). The attention-based MIL operator was used very recently in (Qi et al., 2017), however, the attention was calculated using the dot product and it performed worse than the operator. Here, we propose to use a two-layered neural network to learn the MIL operator and we show that it outperforms commonly used MIL pooling operators.MIL for medical imaging The MIL seems to perfectly fit medical imaging where processing a whole image consisting of billions of pixels is computationally infeasible. Moreover, in the medical domain it is very difficult to obtain pixel-level annotations, that drastically reduces number of available data. Therefore, it is tempting to divide a medical image into smaller patches that could be further considered as a bag with a single label (Quellec et al., 2017). This idea attracts a great interest in the computational histopathology where patches could correspond to cells that are believed to indicate malignant changes (Sirinukunwattana et al., 2016). Different MIL approaches were used for histopathology data, such as, Gaussian processes (Kandemir et al., 2014, 2016) or a two-stage approach with neural networks and EM algorithm to determine instance classes (Hou et al., 2016). Other applications of MIL methods in medical imaging are mammography (nodule) classification (Zhu et al., 2017) and microscopy cell detection (Kraus et al., 2016). In this paper, we show that the proposed attention-based deep MIL approach can be used not only to provide the final diagnosis but also to indicate ROIs in a histopathology slide.
In the experiments we aim at evaluating the proposed approach: a MIL model parameterized with neural networks and a (gated) attention-based pooling layer (’Attention’ and ’Gated-Attention’). We evaluate our approach on a number of different MIL datasets: five MIL benchmark datasets (Musk1, Musk2, Fox, Tiger, Elephant), an MNIST-based image dataset (MNIST-bags) and two real-life histopathology datasets (Breast Cancer, Colon Cancer). We want to verify two research questions in the experiments: (i) whether our approach achieves the best performance or is comparable to the best performing method, (ii) if our method can provide interpretable results by using the attention weights that indicate key instances or ROIs.
In order to obtain a fair comparison we use a common evaluation methodology, i.e., 10-fold-cross-validation, and five repetitions per experiment. In the case of MNIST-bags we use a fixed division into training and test set. In order to create test bags we solely sampled images from the MNIST test set. During training we only used images from the MNIST training set. For all experiments we use modified versions of models that have shown high classification performance on the individual datasets (Wang et al., 2016; LeCun et al., 1998; Sirinukunwattana et al., 2016). The MIL pooling layers are either located before the last layer of the model (the embedded-based approach) or after last layer of the model (the instance-based approach). If an attention-based MIL pooling layer is used the number of parameters in V was determined using a validation set. We tested the following dimensions (): 64, 128 and 256. The different dimensions only resulted in minor changes of the model’s performance. For layers using the gated attention mechanism V and U have the same number of parameters. Finally, all layers were initialized according to Glorot & Bengio (2010) and biases were set to zero.
We compare our approach to various MIL methods on MIL benchmark datasets. On the image datasets our method is compared with instance-level and embedding-level neural networks and commonly used MIL pooling layers ( and ). In the following, we are using ’Instance+max/mean’ and ’Embedding+max/mean’ to indicate networks that are build from convolutional layers and fully-connected layers. In contrast to networks purely build from fully-connected layers, referred to as ’mi-Net’ and ’MI-Net’ (Wang et al., 2016).
On MNIST-bags we include a SVM-based MIL model, called (MI-SVM). We do not present results of MI-SVM
on the histopathology datasets since we could not train (including hyperparameter search and five times 10-fold-cross-validation procedure) the model in a reasonable amount of time.
^{3}^{3}3Learning a single MI-SVM took approximately one week due to the large number of patches.In order to compare the bag level performance we use the following metrics: the classification accuracy, precision, recall, F-score, and the area under the receiver operating characteristic curve (AUC).
Details In the first experiment we aim at verifying whether our approach can compete with the best MIL methods on historically important benchmark datasets. Since all five datasets contain precomputed features and only a small number of instances and bags, neural networks are most likely not well suited. First we predict drug activity (Musk1 and Musk2). A molecule has the desired drug effect if and only if one or more of its conformations bind to the target binding site. Since molecules can adopt multiple shapes, a bag is made up of shapes belonging to the same molecule (Dietterich et al., 1997). The three remaining datasets, Elephant, Fox and Tiger, contain features extracted from images. Each bag consists of a set of segments of an image. For each category, positive bags are images that contain the animal of interest, and negative bags are images that contain other animals (Andrews et al., 2003). For detailed information on the number of bags, instances and features in each dataset see Section 6.3 in the Appendix.
In our experiments we use the same architecture, optimizer and hyperparameters as in the MI-Net model (Wang et al., 2016).
Method | Musk1 | Musk2 | Fox | Tiger | Elephant |
---|---|---|---|---|---|
mi-SVM [1] | 0.874N/A | 0.836N/A | 0.582N/A | 0.784N/A | 0.822N/A |
MI-SVM [1] | 0.779N/A | 0.843N/A | 0.578N/A | 0.840N/A | 0.843N/A |
MI-Kernel [2] | 0.8800.031 | 0.8930.015 | 0.6030.028 | 0.8420.010 | 0.8430.016 |
EM-DD [3] | 0.8490.044 | 0.8690.048 | 0.6090.045 | 0.7300.043 | 0.7710.043 |
mi-Graph [4] | 0.8890.033 | 0.9030.039 | 0.6200.044 | 0.8600.037 | 0.8690.035 |
miVLAD [5] | 0.8710.043 | 0.8720.042 | 0.6200.044 | 0.8110.039 | 0.8500.036 |
miFV [5] | 0.9090.040 | 0.8840.042 | 0.6210.049 | 0.8130.037 | 0.8520.036 |
mi-Net [6] | 0.8890.039 | 0.8580.049 | 0.6130.035 | 0.8240.034 | 0.8580.037 |
MI-Net [6] | 0.8870.041 | 0.8590.046 | 0.6220.038 | 0.8300.032 | 0.8620.034 |
MI-Net with DS [6] | 0.8940.042 | 0.8740.043 | 0.6300.037 | 0.8450.039 | 0.8720.032 |
MI-Net with RC [6] | 0.8980.043 | 0.8730.044 | 0.6190.047 | 0.8360.037 | 0.8570.040 |
Attention | 0.8920.040 | 0.8580.048 | 0.6150.043 | 0.8390.022 | 0.8680.022 |
Gated-Attention | 0.9000.050 | 0.8630.042 | 0.6030.029 | 0.8450.018 | 0.8570.027 |
a standard error of a mean) is reported. [1]
(Andrews et al., 2003), [2] (Gärtner et al., 2002), [3] (Zhang & Goldman, 2002) [4] (Zhou et al., 2009) [5] (Wei et al., 2017) [6] (Wang et al., 2016)Results and discussion The results of the experiment are presented in Table 1. Our approaches (Attention and Gated-Attention) are comparable with the best performing classical MIL methods (notice the standard error of the mean).
Details The main disadvantage of the classical MIL benchmark datasets is that instances are represented by precomputed features. In order to consider a more challenging scenario, we propose to investigate a dataset that is created using the well-known MNIST image dataset. A bag is made up of a random number of
grayscale images taken from the MNIST dataset. The number of images in a bag is Gaussian-distributed and the closest integer value is taken. A bag is given a positive label if it contains one or more images with the label ’9’. We chose ’9’ since it can be easily mistaken with ’7’ or ’4’. We investigate the influence of the number of bags in the training set as well as the average number of instances per bag on the prediction performance. During evaluation we use a fixed number of
test bags. For all experiments a LeNet5 model is used (LeCun et al., 1998), see Table 8 and 9 in the Appendix. The models are trained with the Adam optimization algorithm (Kingma & Ba, 2014). We keep the default parameters for and , see Table 10 in the Appendix. In addition, we compare our method with a SVM-based MIL method (MI-SVM) (Andrews et al., 2003) that uses a Gaussian kernel on raw pixel features^{4}^{4}4We use code provided with (Doran & Ray, 2014): https://github.com/garydoranjr/misvm.In the experiments we use different numbers of the mean bag size, namely, , and
, and the variance
, respectively. Moreover, we use varying numbers of training bags, i.e., . These different settings allow us to verify how different number of training bags and different number of instances influence MIL models. We compare instance-based and embedding-based approaches parameterized with a neural network (LeNet5) with andMIL pooling. We use AUC as the evaluation metric.
Results and discussion The results of AUC for the mean bag sizes equal to , and are presented in Figure 1, 2 and 3, respectively, and detailed results are given in the Appendix. The findings of the experiment are the following: First, the proposed attention-based deep MIL approach performs much better than other methods in the small sample size regime. Moreover, when there is a small effective size of the training set that corresponds to 50-150 bags for around 10 instances per bag (see Figure 1) or 50-100 bags in the case of on average 50 instances in a bag (see Figure 2), our method still achieves significantly higher AUC than all other methods. Second, we notice that our approach is more flexible and obtained better results than the SVM-based approach in all cases except large effective sample sizes (see Figure 3). Third, the embedding-based models performed better than the instance-based models. However, for a sufficient number of training images (number of training bags and training instances per bag) all models achieve very similar results. Fourth, the operator performs significantly worse than the operator. However, the embedding-based model with the operator converged eventually to the best value but always later than the one with . See Section 6.4 in the Appendix for details.
The results of this experiment indicate that for a small-sample size regime our approach is preferable to others. Since attention serves as a gradient update filter during backpropagation (Wang et al., 2017), instances with higher weights will contribute more to learning the encoder network of instances. This is especially important since medical imaging problems contain only a small number of cases. In general, the more instances are in a bag the easier the MIL task becomes, since the MIL assumption states that every instance in a negative bag is negative. For example, a negative bag of size 100 from the MNIST-bags dataset will include about 11 negative examples per class.
Finally, we present an exemplary result of the attention mechanism in Figure 4. In this example a bag consists of images. For each digit the corresponding attention weight is given by the trained network. The bag is properly predicted as positive and all nines are correctly highlighted. Hence, the attention mechanism works as expected. More examples are given in the Appendix.
Details An automatic detection of cancerous regions in hematoxylin and eosin (HE) stained whole-slide images is a task with high clinical relevance. Current supervised approaches utilize pixel-level annotations (Litjens et al., 2017). However, data preparation requires large amount of time from pathologists which highly interferes with their daily routines. Hence, a successful solution working with weak labels would hold a great promise to reduce the workload of the pathologists. In the following, we perform two experiments on classifying weakly-labeled real-life histopathology images of the breast cancer dataset (Breast cancer) (Gelasca et al., 2008) and the colon cancer dataset (Colon cancer) (Sirinukunwattana et al., 2016).
Breast cancer consists of 58 weakly labeled H&E images. An image is labeled malignant if it contains breast cancer cells, otherwise it is benign. We divide every image into patches. This results in 672 patches per bag. A patch is discarded if it contains 75 or more of white pixels.
Colon cancer comprises 100 H&E images. The images originate from a variety of tissue appearance from both normal and malignant regions. For every image the majority of nuclei of each cell were marked. In total there are 22,444 nuclei with associated class label, i.e. epithelial, inflammatory, fibroblast, and miscellaneous. A bag is composed of patches. Furthermore, a bag is given a positive label if it contains one or more nuclei from the epithelial class. Tagging epithelial cells is highly relevant from a clinical point of view, since colon cancer originates from epithelial cells (Ricci-Vitiani et al., 2007).
For both datasets we use the model proposed in (Sirinukunwattana et al., 2016) for the transformation . All models are trained with the Adam optimization algorithm (Kingma & Ba, 2014). Due to the limited amount of data samples in both datasets we performed data augmentation to prevent overfitting. See the Appendix for further details.
Method | Accuracy | Precision | Recall | F-score | AUC |
---|---|---|---|---|---|
Instance+max | 0.6140.020 | 0.5850.03 | 0.4770.087 | 0.5060.054 | 0.6120.026 |
Instance+mean | 0.6720.026 | 0.6720.034 | 0.5150.056 | 0.5770.049 | 0.7190.019 |
Embedding+max | 0.6070.015 | 0.5580.013 | 0.5460.070 | 0.5430.042 | 0.6500.013 |
Embedding+mean | 0.7410.023 | 0.7410.023 | 0.6540.054 | 0.6890.034 | 0.7960.012 |
Attention | 0.7450.018 | 0.7180.021 | 0.7150.046 | 0.7120.025 | 0.7750.016 |
Gated-Attention | 0.7550.016 | 0.7280.016 | 0.7310.042 | 0.7250.023 | 0.7990.020 |
Method | Accuracy | Precision | Recall | F-score | AUC |
---|---|---|---|---|---|
Instance+max | 0.842 0.021 | 0.866 0.017 | 0.816 0.031 | 0.839 0.023 | 0.914 0.010 |
Instance+mean | 0.772 0.012 | 0.821 0.011 | 0.710 0.031 | 0.759 0.017 | 0.866 0.008 |
Embedding+max | 0.824 0.015 | 0.884 0.014 | 0.753 0.020 | 0.813 0.017 | 0.918 0.010 |
Embedding+mean | 0.860 0.014 | 0.911 0.011 | 0.804 0.027 | 0.853 0.016 | 0.940 0.010 |
Attention | 0.904 0.011 | 0.953 0.014 | 0.855 0.017 | 0.901 0.011 | 0.968 0.009 |
Gated-Attention | 0.898 0.020 | 0.944 0.016 | 0.851 0.035 | 0.893 0.022 | 0.968 0.010 |
Results and discussion We present results in Table 2 and 3 for Breast cancer and Colon Cancer, respectively. First, we notice that the obtained results confirm our findings in MNIST-bags experiment that our approach outperforms all other methods. A trend that is especially visible in the small-sample size regime of the Mnist-bags. Surprisingly, the embedding-based method with the pooling failed almost completely on Breast cancer but in general this dataset is difficult due to high variability of slides and small number of cases. The proposed method is not only most accurate but it also received the highest recall. High recall is especially important in the medical domain since false negatives could lead to severe consequences including patient fatality. We also notice that the gated-attention mechanism performs better than the plain attention mechanism on Breast cancer while these two behave similarly on Colon cancer.
Eventually, we present the usefulness of the attention mechanism in providing ROIs. In Figure 5 we show a histopathology image divided into patches containing (mostly) single cells. We create a heatmap by multiplying patches by its corresponding attention weight. Although only image-level annotations are used during training, there is a substantial matching between the heatmap in Figure 5(d) and the ground truth in Figure 5(c). Additionally, we notice that the instance-based classifier tends to select only a small subset of positive patches (see Figure 10(e) in Appendix) that confirms low instance accuracy of the instance-based approach discussed in (Kandemir & Hamprecht, 2015). For more examples please see the Appendix.
The obtained results again confirm that the proposed approach attains high predictive performance and allows to properly highlight ROIs. Moreover, the attention weights can be used to create a reliable heatmap.
In this paper, we proposed a flexible and interpretable MIL approach that is fully parameterized by neural networks. We outlined the usefulness of deep learning for modeling a permutation-invariant bag score function in terms of the Fundamental Theorem of Symmetric Functions. Moreover, we presented a trainable MIL pooling based on the (gated) attention mechanism. We showed empirically on five MIL datasets, one image corpora and two real-life histopathology datasets that our method is on a par with the best performing methods or performs the best in terms of different evaluation metrics. Additionally, we showed that our approach provides an interpretation of the decision by presenting ROIs, which is extremely important in many practical applications.
We strongly believe that the presented line of research is worth pursuing further. Here we focused on a binary MIL problem, however, the multi-class MIL is more interesting and challenging (Feng & Zhou, 2017). Moreover, in some applications it is worth to consider repulsion points (Scott et al., 2005), i.e., instances for which a bag is always negative, or assume dependencies among instances within a bag (Zhou et al., 2009). We leave investigating these issues for future research.
The authors are very grateful to Rianne van den Berg for insightful remarks and discussions.
Maximilian Ilse was funded by the Nederlandse Organisatie voor Wetenschappelijk Onderzoek (Grant ”DLMedIa: Deep Learning for Medical Image Analysis”).
Jakub Tomczak was funded by the European Commission within the Marie Skłodowska-Curie Individual Fellowship (Grant No. 702666, ”Deep learning and Bayesian inference for medical imaging”).
Patch-based convolutional neural network for whole slide tissue image classification.
In CVPR, pp. 2424–2433, 2016.Explaining the stars: Weighted multiple-instance learning for aspect-based sentiment analysis.
In EMNLP, pp. 455–466, 2014.Bayesian multiple instance learning: automatic feature selection and inductive transfer.
In ICML, pp. 808–815, 2008.In Figure 6 we present three deep MIL approaches discussed in the paper.
The implementation of our methods is available online at https://github.com/AMLab-Amsterdam/AttentionDeepMIL. All experiments were run on NVIDIA TITAN X Pascal with a batch size of 1 (= 1 bag) for all datasets.
In Table 1 a general description of the five benchmark MIL datasets used in the experiments is given. In Tables 5 and 6
we present architectures of the embedding-based and the instance-based models, respectively. We denote a fully-connected layer by ’fc’ and the number of output hidden units is provided after a dash. The ReLU non-linearity was used. In Table
7 the details of the optimization (learning) procedure are given. We provide values of hyperparameters determined by the model selection procedure for which the highest validation performance was achieved.Dataset | of bags | of instances | of features |
---|---|---|---|
Musk1 | 92 | 476 | 166 |
Musk2 | 102 | 6598 | 166 |
Tiger | 200 | 1220 | 230 |
Fox | 200 | 1302 | 230 |
Elephant | 200 | 1391 | 230 |
Layer | Type |
---|---|
1 | fc-256 + |
2 | dropout |
3 | fc-128 + |
4 | dropout |
5 | fc-64 + |
6 | dropout |
7 | mil-/mil-/mil-attention-64 |
8 | fc-1 + |
Layer | Type |
---|---|
1 | fc-256 + |
2 | dropout |
3 | fc-128 + |
4 | dropout |
5 | fc-64 + |
6 | dropout |
7 | fc-1 + |
8 | mil-/mil- |
Experiment | Optimizer | Momentum | Learning rate | Weight decay | Epochs | Stopping criteria |
---|---|---|---|---|---|---|
Musk1 | SGD | 0.9 | 0.0005 | 0.005 | 100 | lowest validation error and loss |
Musk2 | SGD | 0.9 | 0.0005 | 0.03 | 100 | lowest validation error and loss |
Tiger | SGD | 0.9 | 0.0001 | 0.01 | 100 | lowest validation error and loss |
Fox | SGD | 0.9 | 0.0005 | 0.005 | 100 | lowest validation error and loss |
Elephant | SGD | 0.9 | 0.0001 | 0.005 | 100 | lowest validation error and loss |
In Tables 8 and 9 we present architectures of the embedding-based and the instance-based models for Mnist-bags
, respectively. We denote a convolutional layer by ’conv’, in brackets we provide kernel size, stride and padding, and the number of kernels is provided after a dash. The convolutional max-pooling layer is denoted by ’maxpool’ and the pooling size is given in brackets. The ReLU non-linearity was used. In Table
10 the details of the optimization (learning) procedure for deep MIL approach are given. The details of the SVM are given in Table 11. We provide values of hyperparameters determined by the model selection procedure for which the highest validation performance was achieved.Layer | Type |
---|---|
1 | conv(5,1,0)-20 + |
2 | maxpool(2,2) |
3 | conv(5,1,0)-50 + |
4 | maxpool(2,2) |
5 | fc-500 + |
6 | mil-/mil-/mil-attention-128 |
7 | fc-1 + |
Layer | Type |
---|---|
1 | conv(5,1,0)-20 + |
2 | maxpool(2,2) |
3 | conv(5,1,0)-50 + |
4 | maxpool(2,2) |
5 | fc-500 + |
6 | fc-1 + |
7 | mil-/mil- |
Experiment | Optimizer | Learning rate | Weight decay | Epochs | Stopping criteria | |
---|---|---|---|---|---|---|
All | Adam | 0.9, 0.999 | 0.0005 | 0.0001 | 200 | lowest validation error+loss |
Model | Features | Kernel | Max iterations | ||
---|---|---|---|---|---|
MI-SVM | Raw pixel values | RBF | 5 | 0.0005 | 200 |
In Tables 12, 13 and 14 we present the test AUC value for 10, 50 and 100 instances on average per a bag, respectively.
In Figure 7 a negative bag is presented. In Figure 8 a positive bag with a single ’9’ is given. In Figure 9 a positive bag with multiple ’9’s is presented. In all figures attention weights are provided and in the case of positive bags a red rectangle highlights positive instances.
of training bags | 50 | 100 | 150 | 200 | 300 | 400 | 500 |
---|---|---|---|---|---|---|---|
Instance+max | 0.553 0.053 | 0.745 0.100 | 0.960 0.004 | 0.979 0.001 | 0.984 0.001 | 0.986 0.001 | 0.986 0.001 |
Instance+mean | 0.663 0.014 | 0.676 0.012 | 0.694 0.010 | 0.694 0.017 | 0.709 0.020 | 0.693 0.023 | 0.712 0.018 |
MI-SVM | 0.697 0.054 | 0.851 0.009 | 0.862 0.008 | 0.898 0.014 | 0.926 0.004 | 0.942 0.002 | 0.948 0.002 |
Embedded+max | 0.713 0.016 | 0.914 0.011 | 0.954 0.005 | 0.968 0.001 | 0.980 0.001 | 0.981 0.003 | 0.986 0.002 |
Embedded+mean | 0.695 0.026 | 0.841 0.027 | 0.926 0.004 | 0.953 0.004 | 0.974 0.002 | 0.980 0.001 | 0.984 0.002 |
Attention | 0.768 0.054 | 0.948 0.007 | 0.949 0.006 | 0.970 0.003 | 0.980 0.000 | 0.982 0.001 | 0.986 0.001 |
Gated Attention | 0.753 0.054 | 0.916 0.013 | 0.955 0.003 | 0.974 0.002 | 0.980 0.004 | 0.983 0.002 | 0.987 0.001 |
of training bags | 50 | 100 | 150 | 200 | 300 | 400 | 500 |
---|---|---|---|---|---|---|---|
Instance+max | 0.576 0.059 | 0.715 0.096 | 0.937 0.045 | 0.992 0.002 | 0.994 0.001 | 0.997 0.001 | 0.997 0.001 |
Instance+mean | 0.737 0.014 | 0.744 0.029 | 0.824 0.012 | 0.813 0.030 | 0.722 0.021 | 0.728 0.017 | 0.798 0.011 |
MI-SVM | 0.824 0.067 | 0.946 0.004 | 0.959 0.002 | 0.967 0.002 | 0.975 0.001 | 0.976 0.001 | 0.979 0.001 |
Embedded+max | 0.872 0.039 | 0.984 0.005 | 0.992 0.001 | 0.996 0.001 | 0.996 0.001 | 0.997 0.001 | 0.997 0.001 |
Embedded+mean | 0.841 0.013 | 0.906 0.046 | 0.983 0.005 | 0.992 0.001 | 0.996 0.001 | 0.997 0.001 | 0.997 0.001 |
Attention | 0.967 0.010 | 0.982 0.003 | 0.990 0.002 | 0.993 0.002 | 0.989 0.003 | 0.994 0.001 | 0.995 0.001 |
Gated Attention | 0.920 0.042 | 0.977 0.006 | 0.993 0.003 | 0.991 0.002 | 0.994 0.002 | 0.995 0.001 | 0.996 0.001 |
of training bags | 50 | 100 | 150 | 200 | 300 | 400 | 500 |
---|---|---|---|---|---|---|---|
Instance+max | 0.543 0.054 | 0.804 0.107 | 0.899 0.086 | 0.999 0.000 | 1.000 0.000 | 1.000 0.000 | 1.000 0.000 |
Instance+mean | 0.842 0.023 | 0.855 0.025 | 0.824 0.014 | 0.896 0.037 | 0.859 0.029 | 0.899 0.012 | 0.868 0.016 |
MI-SVM | 0.871 0.060 | 0.991 0.002 | 0.994 0.002 | 0.996 0.001 | 0.997 0.001 | 0.998 0.001 | 0.998 0.001 |
Embedded+max | 0.977 0.009 | 0.999 0.001 | 1.000 0.000 | 1.000 0.000 | 1.000 0.000 | 1.000 0.000 | 1.000 0.000 |
Embedded+mean | 0.959 0.010 | 0.990 0.003 | 0.998 0.001 | 0.900 0.089 | 1.000 0.000 | 1.000 0.000 | 1.000 0.000 |
Attention | 0.996 0.001 | 0.998 0.001 | 0.999 0.000 | 0.998 0.001 | 1.000 0.000 | 1.000 0.000 | 1.000 0.000 |
Gated Attention | 0.998 0.001 | 0.999 0.000 | 0.998 0.001 | 0.998 0.001 | 0.999 0.000 | 1.000 0.000 | 1.000 0.000 |
We randomly adjust the amount of HE by decomposing the RGB color of the tissue into the HE color space (Ruifrok & Johnston, 2001), followed by multiplying the magnitude of H
E for a pixel by two i.i.d. Gaussian random variables with expectation equal to one. We randomly rotate and mirror every patch. Lastly, we perform color normalization on every patch.
In Tables 15 and 16 we present architectures of the embedding-based and the instance-based models for histopathology datasets, respectively. In Table 17 the details of the optimization (learning) procedure for deep MIL approach are given. We provide values of hyperparameters determined by the model selection procedure for which the highest validation performance was achieved.
Layer | Type |
---|---|
1 | conv(4,1,0)-36 + |
2 | maxpool(2,2) |
3 | conv(3,1,0)-48 + |
4 | maxpool(2,2) |
5 | fc-512 + |
6 | dropout |
7 | fc-512 + |
8 | dropout |
9 | mil-/mil-/mil-attention-128 |
10 | fc-1 + |
Layer | Type |
---|---|
1 | conv(4,1,0)-36 + |
2 | maxpool(2,2) |
3 | conv(3,1,0)-48 + |
4 | maxpool(2,2) |
5 | fc-512 + |
6 | dropout |
7 | fc-512 + |
8 | dropout |
9 | fc-1 + |
10 | mil-/mil- |
Experiment | Optimizer | Learning rate | Weight decay | Epochs | Stopping criteria | |
---|---|---|---|---|---|---|
All | Adam | 0.9, 0.999 | 0.0001 | 0.0005 | 100 | lowest validation error+loss |
Comments
There are no comments yet.