Image recognition has made great strides in recent years, spawning landmark architectures such as AlexNet(Krizhevsky et al., 2012) or ResNet (He et al., 2016)
. These networks are typically designed and optimized for datasets like ImageNet(Russakovsky et al., 2015), which consist of natural images well below one megapixel.111For instance, a 256256 image corresponds to only 0.06 megapixels. In contrast, real-world applications often rely on high-resolution images that reveal detailed information about an object of interest. For example, in self-driving cars, megapixel images are beneficial to recognize distant traffic signs far in advance and react in time (Sahin, 2019). In medical imaging, a pathology diagnosis system has to process gigapixel microscope slides to recognize cancer cells, as illustrated in Fig. 1.
Training neural networks on high-resolution images is challenging and can lead to out-of-memory errors even on dedicated high-performance hardware. Although downsizing the image can fix this problem, details critical for recognition may be lost in the process (Sabottke and Spieler, 2020; Katharopoulos and Fleuret, 2019)
. Reducing the batch size is another common approach to decrease memory usage, but it does not scale to arbitrarily large inputs and may lead to instabilities in networks involving batch normalization(Lian and Liu, 2019). On the other hand, distributed learning across multiple devices increases resources but is more costly and incurs higher energy consumption (Strubell et al., 2019).
We propose Iterative Patch Selection (IPS), a simple patch-based approach that decouples the consumed memory from the input size and thus enables the efficient processing of high-resolution images without running out of memory. IPS works in two steps: First, the most salient patches of an image are identified in no-gradient mode. Then, only selected patches are aggregated to train the network. We find that the attention scores of a cross-attention based transformer link both of these steps, and have a close connection to Multiple Instance Learning (MIL).
In the experiments, we demonstrate strong performance across three very different domains and training regimes: traffic sign recognition on megapixel images, multi-task classification on synthetic megapixel MNIST digits, and using self-supervised pre-training together with our method for memory-efficient learning on the gigapixel CAMELYON16 benchmark. Furthermore, our method exhibits a significantly lower memory consumption compared to various baselines. For example, when scaling megapixel MNIST images from 1k to 10k pixels per side at a batch size of 16, we can keep peak memory usage at a constant 1.7 GB while maintaining high accuracy, in contrast to a comparable CNN, which already consumes 24.6 GB at a resolution of 2k
2k. In an ablation study, we further analyze and provide insights into the key factors driving computational efficiency in IPS. Finally, we visualize exemplary attention distributions and present an approach to obtain patch-level class probabilities in a weakly-supervised multi-label classification setting.
We regard an image as a set of patches. Each patch is embedded independently by a shared encoder network, resulting in -dimensional representations, . Given the embeddings, we select the most salient patches and aggregate the information across these patches for the classification task. Thus our method, illustrated in Fig. 2, consists of two modules: a memory-efficient iterative patch selection module that selects a fixed number of salient patches and a transformer-based patch aggregation module that combines patch embeddings to compute a global image embedding that is passed on to a classification head. Crucially, the transformer-based patch aggregation module consists of a cross-attention layer, that is used by the patch selection module in no-gradient mode in order to score patches. We discuss these in more detail next and provide pseudocode in Appendix A.
2.1 Iterative Patch Selection
Given infinite memory, one could use an attention module to score each patch and select the top patches for aggregation. However, due to limited GPU memory, one cannot compute and store all patch embeddings in memory at the same time. We instead propose to iterate over patches, at a time, and autoregressively maintain a set of top patch embeddings. In other words, say is a buffer of patch embeddings at time step and are the next patch embeddings in the autoregressive update step. We run the following for iterations:
where , and are attention scores of considered patches at iteration , based on which the selection in Top-M is made. These attention scores are obtained from the cross-attention transformer as described in Sect. 2.2. The output of IPS after iterations is a set of patches corresponding to embeddings . Note that both patch embedding and patch selection are executed in no-gradient and evaluation mode. The former entails that no gradients are computed and stored, which renders IPS runtime and memory-efficient. The latter ensures deterministic patch selection behavior when using BatchNorm and Dropout.
We introduce several data loading strategies that trade off memory and runtime efficiency during IPS. In eager loading, a batch of images is loaded onto the GPU and IPS is applied to each image in parallel—this is the fastest variant. In eager sequential loading, individual images are loaded onto the GPU and thus patches are selected for one image at a time until a batch of
patches per image is selected for training. This enables the processing of different sequence lengths without padding and reduces memory usage at the cost of a higher runtime. In contrast,lazy loading first loads a batch of images onto CPU memory. Then, only patches and corresponding embeddings pertinent to the current iteration are stored on the GPU—this decouples GPU memory usage from the image size, again at the cost of a higher runtime.
2.2 Transformer-Based Patch Aggregation
After selecting patches in the IPS stage, these patches are embedded again in gradient and training mode, and the patch aggregation module can aggregate the resulting embeddings using a convex combination:
Each attention score is the result of a function of the corresponding patch embedding : and can be learned by a neural network parameterized by . We observe that the weighted average in Eq. 2 constitutes a cross-attention layer by defining as follows:
where , , with , and being the projected dimension. Here, is a learnable query that is independent of the input. Furthermore, instead of aggregating patch embeddings directly, we aggregate a linear projection called values in typical transformer notation: , with and being the value dimension. In practice, we use multi-head cross-attention and place it in a standard transformer (Vaswani et al., 2017), optionally adding a sinusoidal position encoding (details in Appendix B).
Note that there is a close connection between the patch selection and patch aggregation modules. They both share the cross-attention layer. Whereas IPS runs in no-gradient mode and cannot train its parameters, by sharing them with the patch aggregation module, one can still learn to select patches relevant for the downstream task.
3 Related Work
The setting described in Sect. 2 constitutes a MIL problem (Maron and Lozano-Pérez, 1997), where an image is seen as a bag that contains multiple patch-level features called instances. The goal is then to predict the label of unseen bags given their instances. However, training labels are only available for bags, not for instances. The model must thus learn to attend to patches that best represent the image for a given task. Eq. 2 provides such a mechanism and is known as MIL pooling function. In particular, our cross-attention layer follows the weighted collective MIL assumption stating that all instances contribute to the bag representation, albeit to varying degrees according to the attention scores (Foulds and Frank, 2010; Amores, 2013).
A MIL approach similar to ours is DeepMIL (Ilse et al., 2018), which uses a gated attention mechanism to compute attention scores for the aggregation of patches: , with , and being parameters, and corresponds to the number of hidden nodes. In contrast, our multi-head cross-attention layer provides more capacity by computing multiple intermediate bag representations. More importantly, it is a MIL pooling function that can be naturally integrated into a standard and field-proven transformer module.
patches in a single iteration according to the maximum class logit of each patch. Selected patches are then re-embedded, classified, class logits are averaged, and an activation function (e.g., softmax) is applied.
Another line of work seeks to reduce the memory consumption for the processing of high-resolution images by first identifying regions of interest in low resolution and then sampling patches from these regions in high resolution (Katharopoulos and Fleuret, 2019; Cordonnier et al., 2021; Kong and Henao, 2022)
. However, especially in gigapixel images, the lower-resolution image may either still be too large to encode, or it may be too small to identify informative regions. In contrast, IPS relies exclusively on high-resolution patches to best estimate the attention scores.
For very large images, patch-wise self-supervised learning is a promising way to reduce memory usage(Dehaene et al., 2020; Chen et al., 2022; Li et al., 2021). After pre-training, low-dimensional features are extracted and used to train a downstream network. Although the input size decreases considerably in this way, it may still require too much memory for very long sequences. Pre-training can be used orthogonally to IPS to process sequences of any length, as demonstrated in Sect. 4.2.
Our method is related to Sparse Mixture of Experts, a conditional computation approach in which specialized subnetworks (experts) are responsible for the processing of specific input patterns (Jacobs et al., 1991). Specifically, Shazeer et al. (2017) and Riquelme et al. (2021) use a gating function to decide which expert is assigned to which part of the input. This gating function is non-differentiable but shares its output with a differentiable aggregation operation that combines the expert’s outputs. On a high level this is similar to how we share the cross-attention module between no-gradient mode IPS and the gradient mode patch aggregator. Attention Sampling (Katharopoulos and Fleuret, 2019) also reuses aggregation weights for patch sampling. IPS uses Top-M instead.
We evaluate the performance and efficiency of our method on three challenging datasets from a variety of domains and training regimes: Multi-class recognition of distant traffic signs in megapixel images, weakly-supervised classification in gigapixel whole-slide images (WSI) using self-supervised representations, and multi-task learning of inter-patch relations on a synthetic megapixel MNIST benchmark. All training and baseline hyperparameters are provided in AppendixD.
To show that IPS selects salient patches, two other patch selection methods, Random Patch Selection (RPS) and Differential Patch Selection (DPS) (Cordonnier et al., 2021), are used for comparison, both employing the same transformer as IPS. Next, to verify that the cross-attention transformer is a strong MIL pooling operator, it is compared to DeepMIL (Ilse et al., 2018) and TopMIL as introduced in Sect. 3. Finally, we compare to a standard CNN applied on the original image and a smaller sized version to assess how resolution affects performance.
Each of the main experiments is performed five times with random parameter initializations, and the average classification performance and standard deviation are reported on the test set using either accuracy or AUC score. We also report computational efficiency metrics, including maximum GPU memory usage (VRAM) and training runtime for a batch size of 16 and various values of
. The runtime is calculated for a single forward and backward pass and is averaged over all iterations of one training epoch excluding the first and last iterations. Both metrics are calculated on a single NVIDIA A100 GPU in all experiments.
4.1 Traffic Signs Recognition
We first evaluate our method on the Swedish traffic signs dataset, which consists of 747 training and 684 test images with 1.3 megapixel resolution, as in Katharopoulos and Fleuret (2019). Each image shows a speed limit sign of 50, 70, 80 km/h or no sign. This problem requires high-resolution images to distinguish visually similar traffic signs, some of which appear very small due to their distance from the camera (Fig. 5 left). We resize each image to 12001600 and extract 192 non-overlapping patches of size 100
100. Due to the small number of data points, a ResNet-18 with ImageNet-1k weights is used as encoder for all methods and then finetuned. For IPS, we use eager loading and set.
Table 1 shows that the average accuracy of the IPS transformer is consistently high regardless of the value of . For , memory usage is more than 10 less than for the CNN, and for
, training time is about half that of the CNN. As expected, RPS does not select salient patches due to the low signal-to-noise ratio in the data. DPS localizes the salient patches but performs worse than IPS in all metrics and is sensitive to the choice of. The MIL baselines perform well in general, and TopMIL comes closest to IPS in accuracy. However, TopMIL requires knowledge about the number of informative patches, since the logits are averaged—already for , the performance is slightly degraded. In addition, TopMIL is faster but less memory-efficient than IPS because in TopMIL all patches are scored at once. DeepMIL, on the other hand, lacks a patch selection mechanism and hence includes all patches in the aggregation, which increases training time and VRAM. For , DeepMIL differs from the IPS transformer only in the aggregation function, yet the transformer performs 2.2 percentage points better on average. We also report results for DeepMIL+, which has a similar capacity to the transformer for a fairer comparison. It achieves an average accuracy of 97.7%, less than 98.4% of IPS. A CNN that takes the entire image as input and has no attention mechanism performs much worse and is inefficient. Downscaling the image by a factor of 3 further degrades the CNN’s performance, as details of distant traffic signs are lost.
For a fair comparison against prior work, we used the same transformer and pre-trained model where applicable. This results in better numbers than reported by DPS, which achieves in Table 1 vs. in Cordonnier et al. (2021). Nonetheless, our IPS transformer improves on that by up to 1.4 percentage points while being more memory and runtime efficient.
4.2 Gigapixel WSI Classification
Next, we consider the CAMELYON16 dataset (Litjens et al., 2018), which consists of 270 training and 129 test WSIs of gigapixel resolution for the recognition of metastases in lymph node cells. The task is formulated as a weakly-supervised binary classification problem, i.e., only image-level labels about the presence of metastases are used. The WSIs come with multiple resolutions, but we attempt to solve this task with the highest available resolution (40 magnification). Since a large portion of the images may not contain any cells, Otsu’s method (Otsu, 1979) is first applied to filter out the background. Only patches with at least 1% of the area being foreground are included—a sensitive criterion that ensures that no abnormal cells are left out. Then, non-overlapping patches of size are extracted. This gives a total of 28.6 million patches (avg.: 71.8k, std.: 37.9k), with the largest WSI having 251.8k patches (about 20% of the size of ImageNet-1k).
Due to the large volume of patches, it is impractical to learn directly from pixels. Instead, we train BYOL (Grill et al., 2020), a state-of-the-art self-supervised learning algorithm, with a ResNet-50 encoder and extract features for all patches and WSIs (see Appendix D.6
for details). Each patch is thus represented as a 2,048-dimensional vector, which is further projected down to 512 features and finally processed by the respective aggregator. Note that even when processing patches at the low-dimensional feature level, aggregating thousands of patch features can become prohibitively expensive. For example, with DeepMIL, we can aggregate around 70k patch features before running out of memory on an A100 at a batch size of 16. Therefore, in each of the 5 runs, a random sample of up to 70k patch features is drawn for the baseline MIL methods, similar toDehaene et al. (2020). In contrast, IPS can easily process all patches. Due to the varying number of patches in each WSI, a batch cannot be built without expensive padding. We thus switch to eager sequential loading in IPS, i.e., slides are loaded sequentially and patch features per image are selected and cached for subsequent mini-batch training. In all IPS experiments, .
We report AUC scores, as in the challenge (Bejnordi et al., 2017), in Table 2. Best performance is achieved by IPS using , with an improvement of more than 3 percentage points over the runner-up baseline. We also outperform the CAMELYON16 challenge winner (Wang et al., 2016), who leveraged patch-level labels, and approach the result of the strongly-supervised state of the art (Bejnordi et al., 2017). The memory usage in IPS is 4–5 GB depending on . However, IPS has a higher runtime because all patches are taken into account and eager sequential loading is used. For comparison, we also reduce the total number of patches before the IPS stage to 70k (using the same random seeds) and observe slightly better performance than DeepMIL. DPS was not considered since the downscaled image exceeded our hardware limits. We also tested CNNs trained on downscaled images up to 4k4k pixels but they perform worse than a naive classifier.
Several recent works adopt a pre-trained feature extractor and report results for CAMELYON16 (see bottom of Table 2). Given their respective setups, most consider all patches for aggregation but make compromises that one can avoid with IPS: A batch size of 1 (Wang et al., 2022; Li et al., 2021), training with features from the third ResNet-50 block, which effectively halves the size of the embeddings (Lu et al., 2021), or using smaller resolutions (Dehaene et al., 2020; Li et al., 2021). The work by Dehaene et al. (2020) achieves a slightly better result than our IPS transformer (avg. AUC: 98.1% vs. 98.7%) by using DeepMIL on a random subset of up to patches. We do not assume that this is due to our transformer, as the DeepMIL baseline reproduces their aggregation function yet performs significantly worse. Instead, it could be due to a different pre-training algorithm (MoCo v2), pre-processing (closed-source U-Net to filter out the background), or lower magnification level (20), which reduces the total number of patches (avg. 9.8k vs. 71.8k in ours).
4.3 Inter-Patch Reasoning
A single informative patch was sufficient to solve the previous tasks. In contrast, megapixel MNIST introduced in Katharopoulos and Fleuret (2019) requires the recognition of multiple patches and their relations. The dataset consists of 5,000 training and 1,000 test images of size 1,5001,500. We extract patches of size 5050 without overlap () or with 50% overlap (). In each image, 5 MNIST digits are placed, 3 of which belong to the same class and 2 of which to other classes. In addition, 50 noisy lines are added (see Fig. 5 right). The task is then to predict the majority class. We found that this problem can be solved well by most baselines, so we extend the setup with three more complex tasks: max: detect the maximum digit, top: identify the topmost digit, multi-label: the presence/absence of all classes needs to be recognized. We frame this as a multi-task learning problem and use 4 learnable queries in the cross-attention layer, i.e. 4 task representations are learned. Similarly, 4 gated-attention layers are used for DeepMIL. All methods also utilize multiple classification heads and add a sinusoidal position encoding to the patch features. For the CNNs, a sinusoidal channel is concatenated with the input. A simplified ResNet-18 consisting of 2 residual blocks is used as the encoder for all methods. For IPS, we use eager loading and set .
The results in Table 3 show that almost all methods were able to obtain high accuracies for the tasks majority and max. However, only DeepMIL, DPS and IPS were able to make appropriate use of the positional encoding to solve task top. For the hardest task, multi-label, only the IPS transformer obtained accuracies above 90% by using overlapping patches. However, for both TopMIL and DeepMIL, we were not able to use overlapping patches due to out-of-memory errors (on a single A100). Interestingly, IPS achieves high performance throughout with using overlapping patches at a memory usage of only 2.2 GB while maintaining a competitive runtime. Furthermore, the IPS transformer also achieves high performance with patches, which is the minimum number of patches required to solve all tasks.
4.4 Ablation Study
Effect of image size
How does input size affect performance and efficiency in IPS? To address this question, we generate new megapixel MNIST datasets by scaling the image size from 1k to 10k pixels per side (400 to 40k patches) and attempt to solve the same tasks as in the main experiment. We also linearly scale the number of noise patterns from 33 to 333 to make the task increasingly difficult. Fig. 3 shows a comparison of memory usage, runtime and accuracy (task majority) for different data loading options of our IPS transformer (, , non-overlapping patches, default batch size of 16), as well as the baseline CNN.
The CNN exhibits a high memory footprint that quickly exceeds the hardware limits even when using a reduced batch size of 4. In IPS with eager loading, the VRAM is dominated by the batch of inputs and can be reduced by sequential loading. In contrast, with lazy loading the memory usage is kept constant at 1.7 GB regardless of the input size. The runtime in eager loading is less than that of the CNN and almost approaches a CNN using a batch size of 4. Lazy loading is faster than eager sequential loading, suggesting that the latter should only be used in situations where different sequence lengths need to be processed. In terms of performance, we note that the IPS transformer achieves high accuracy for up to 8k pixels and decreases only slightly starting from 9k pixels likely due to the declining signal-to-noise ratio (at 10k pixels, only 5 out of 40k patches are relevant). In comparison, the CNN performance drops rapidly starting from 3k pixels towards a naive classifier.
Effect of and
Fig. 4 shows that VRAM and runtime tend to increase with as more patches need to be processed in training mode. With increasing values of , a higher VRAM but lower runtime can be observed, since IPS runs for fewer iterations. However, there are more subtle cases. Along the diagonal in Fig. 4 (left), peak VRAM is observed in the non-IPS part and thus remains constant for different values of . For example, for , VRAM remains constant for , so setting results in faster runtime without compromising memory usage. In contrast, in the upper left triangle of the graph, the VRAM is dominated by IPS in no-gradient mode. For example, uses 5GB of VRAM for (i.e., ). Interestingly, when processing all remaining patches in a single iteration (), VRAM can be reduced by increasing up to 200, because is reduced by .
Effect of patch size
In Table 4, we investigate the effect of patch size (25–400 px per side) on the performance and efficiency of traffic sign recognition. One can notice that the accuracy is similar for all patch sizes and only slightly decreases for very small patch sizes of 25 px. As the patch size increases, the runtime for IPS tends to decrease since fewer iterations are run, while the time to embed and aggregate selected patches (non-IPS time) increases because more pixels need to be processed in training mode.
In Fig. 5, we superimpose attention scores over images, which shows that salient patches are attended to while the background is neglected. In low signal-to-noise settings, such visualizations can support humans to localize tiny regions of interest that might otherwise be missed (e.g., Fig. 1). However, attention scores alone may not be satisfactory when multiple objects of different classes are present. For example, attention scores in megapixel MNIST indicate relevant patches, but not to which classes these patches belong. We propose a simple fix to obtain patch-level class scores for classes. We add a new classification head: , where with parameters . For training, these scores compute an image-level prediction: , where is an activation function. Crucially, we stop gradients from flowing back through and such that is trained as a separate read-out head jointly with the base model. We demonstrate this for megapixel MNIST (), where the patch-level classifier is learned from the multi-label task, using only image-level labels. In Fig. 5 (right), we visualize the inferred per-patch class label, , using colors and using the alpha channel.
Reflecting on the experiments, one can observe that the IPS transformer performs well in various domains and training regimes and exhibits interesting capabilities: constant memory usage, learning with arbitrarily long sequences of varying lengths, multi-task learning, and modeling of inter-patch relations. However, there are also limitations. For example, while the GPU memory can be kept constant, runtime cannot. In section 4.4, we gave some intuition about what influence hyperparameters and have in this regard. Furthermore, performance can degrade if the signal-to-noise ratio becomes too small, as shown in our MNIST scaling experiment—in practice, this might be mitigated by pre-training. Nonetheless, we believe that our main contribution, maintaining a low memory footprint for high-resolution image processing, can foster interesting applications that would otherwise be prohibitively expensive. Examples of such applications are learning with multiple medical scans for diagnosis/prognosis, the processing of satellite images, prediction tasks in videos, graphs, point clouds or whole-genome sequences. Technically, it would be interesting to extend IPS to dense prediction tasks, such as segmentation or depth estimation.
CAMELYON16 is a public benchmark dataset consisting of histopathology slides obtained from tissue of human subjects. The dataset has been approved by an ethics committee, see Litjens et al. (2018) for more information. For the Swedish traffic signs dataset, we searched the data for humans and excluded one image where a face of a pedestrian was clearly visible. Megapixel MNIST builds on the widely used MNIST dataset, which contains digits that are ethically uncritical.
In this paper, we exclusively conduct experiments and demonstrate potential applications for human benefit, such as early traffic signs detection, which could be useful for advanced driver assistance systems, or abnormality detection in medical images, which could assist physicians. However, more care is needed before any of these systems can be deployed in real-world applications including an assessment of biases and fairness and monitoring by medical professionals where relevant. Furthermore, the attention scores are currently not usable as a detection method. Beyond the showcased applications, our method could be misused, just like any other visual object detection system, but now also on gigapixel images. We do not explore such applications in our paper and do not plan to do so in future work.
Our method enables the training of deep neural networks on arbitrary large inputs even on low-cost consumer hardware with limited VRAM, which makes AI development more accessible. All training and fine-tuning experiments on downstream tasks can be efficiently run on a single A100 GPU, with the following total training runtimes: Traffic sign recognition: 45 min (, 150 epochs), megapixel MNIST: 5h (, , , 150 epochs), CAMELYON16: 2h (, , 50 epochs).
Sect. 2, as well as the pseudocode provided in Appendix A enable the reproduction of our method. Our metrics and baselines lean on existing work and are described in Sect. 4. Necessary steps to reproduce our experiments are described in respective subsections of Sect. 4. We provide further details about all network architectures, hyperparameters for our method and the baselines, as well as pre-processing steps in Appendix D. The tools used to measure the training runtime and peak GPU memory are listed in Appendix D.7.
We would like to thank Joan Puigcerver for suggesting relevant related work, Romeo Sommerfeld for fruitful discussions on interpretability, Noel Danz for useful discussions about applications, as well as Jonathan Kreidler, Konrad Lux, Selena Braune and Florian Sold for providing and managing the compute infrastructure. This project has received funding from the German Federal Ministry for Economic Affairs and Climate Action (BMWK) in the project DAKI-FWS (reference number: 01MK21009E).
- Multiple instance classification: review, taxonomy and comparative study. Artificial intelligence 201, pp. 81–105. Cited by: §3.
- Learned queries for efficient local attention. In , pp. 10841–10852. Cited by: Appendix C.
Diagnostic assessment of deep learning algorithms for detection of lymph node metastases in women with breast cancer. Jama 318 (22), pp. 2199–2210. Cited by: §4.2, Table 2.
- Clinical-grade computational pathology using weakly supervised deep learning on whole slide images. Nature medicine 25 (8), pp. 1301–1309. Cited by: §3.
- Terabyte-scale deep multiple instance learning for classification and localization in pathology. arXiv preprint arXiv:1805.06983. Cited by: §3.
- End-to-end object detection with transformers. In European conference on computer vision, pp. 213–229. Cited by: Appendix C.
- Crossvit: cross-attention multi-scale vision transformer for image classification. In Proceedings of the IEEE/CVF international conference on computer vision, pp. 357–366. Cited by: Appendix C.
- Scaling vision transformers to gigapixel images via hierarchical self-supervised learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 16144–16155. Cited by: §3.
- Differentiable patch selection for image recognition. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2351–2360. Cited by: §D.5, §3, §4, §4.1.
- Self-supervision closes the gap between weak and strong supervision in histology. arXiv preprint arXiv:2012.03583. Cited by: §D.5, §D.6, §3, §4.2, §4.2, Table 2.
A review of multi-instance learning assumptions.
The knowledge engineering review25 (1), pp. 1–25. Cited by: §3.
- Bootstrap your own latent-a new approach to self-supervised learning. Advances in neural information processing systems 33, pp. 21271–21284. Cited by: §D.6, §4.2.
- Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §1.
- Attention-based deep multiple instance learning. In International conference on machine learning, pp. 2127–2136. Cited by: §D.5, Appendix E, §3, §4.
- Adaptive mixtures of local experts. Neural computation 3 (1), pp. 79–87. Cited by: §3.
- Perceiver io: a general architecture for structured inputs & outputs. arXiv preprint arXiv:2107.14795. Cited by: Appendix C.
- Processing megapixel images with deep attention-sampling models. In International Conference on Machine Learning, pp. 3282–3291. Cited by: §D.6, §1, §3, §3, §4.1, §4.3.
- Efficient classification of very large images with tiny objects. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2384–2394. Cited by: §3.
Imagenet classification with deep convolutional neural networks. Advances in neural information processing systems 25. Cited by: §1.
- Weakly supervised multiple instance learning histopathological tumor segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 470–479. Cited by: §3.
- Dual-stream multiple instance learning network for whole slide image classification with self-supervised contrastive learning. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 14318–14328. Cited by: §3, §4.2, Table 2.
- Revisit batch normalization: new understanding and refinement via composition optimization. In The 22nd International Conference on Artificial Intelligence and Statistics, pp. 3254–3263. Cited by: §1.
- 1399 h&e-stained sentinel lymph node sections of breast cancer patients: the camelyon dataset. GigaScience 7 (6), pp. giy065. Cited by: §4.2, Datasets.
- Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101. Cited by: §D.1.
- Data-efficient and weakly supervised computational pathology on whole-slide images. Nature biomedical engineering 5 (6), pp. 555–570. Cited by: §4.2, Table 2.
- A framework for multiple-instance learning. Advances in neural information processing systems 10. Cited by: §3.
- A threshold selection method from gray-level histograms. IEEE transactions on systems, man, and cybernetics 9 (1), pp. 62–66. Cited by: §4.2.
- Streaming convolutional neural networks for end-to-end learning with multi-megapixel images. IEEE transactions on pattern analysis and machine intelligence. Cited by: Appendix C.
- Scaling vision with sparse mixture of experts. Advances in Neural Information Processing Systems 34, pp. 8583–8595. Cited by: §3.
- ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision (IJCV) 115 (3), pp. 211–252. External Links: Cited by: §1.
- Tokenlearner: adaptive space-time tokenization for videos. Advances in Neural Information Processing Systems 34, pp. 12786–12797. Cited by: Appendix C.
- The effect of image resolution on deep learning in radiography. Radiology. Artificial intelligence 2 (1). Cited by: §1.
- Long-range, high-resolution camera optical design for assisted and autonomous driving. In Photonics, Vol. 6, pp. 73. Cited by: §1.
- Outrageously large neural networks: the sparsely-gated mixture-of-experts layer. arXiv preprint arXiv:1701.06538. Cited by: §3.
- Energy and policy considerations for deep learning in nlp. arXiv preprint arXiv:1906.02243. Cited by: §1.
- Attention is all you need. Advances in neural information processing systems 30. Cited by: Appendix C, §D.4, §2.2.
- Deep learning for identifying metastatic breast cancer. arXiv preprint arXiv:1606.05718. Cited by: §4.2, Table 2.
- Transformer-based unsupervised contrastive learning for histopathological image classification. Medical Image Analysis 81, pp. 102559. Cited by: §4.2, Table 2.
- Deep multiple instance learning for automatic detection of diabetic retinopathy in retinal images. IET Image Processing 12 (4), pp. 563–571. Cited by: §3.
Appendix A Pseudocode
Appendix B Multi-head cross-attention and transformer
In practice, we use multi-head cross-attention (MCA), which computes multiple intermediate representations corresponding to heads and then aggregates them:
with , , and . Note that each representation is computed with head-specific attention scores
. For patch selection, these scores are averaged over heads to obtain a scalar score per patch. The transformer module is composed of an MCA layer, a multi-layer perceptron (MLP) and Layernorm (LN) and is formulated as:
where is an optional
sinusoidal position encoding and the MLP consists of two layers with ReLU non-linearity.
Appendix C Extended related work
is an essential component of our method because it links the patch selection and patch aggregation modules. Therefore, we want to shed more light on its use in the literature. Cross-attention is a concept in which the queries stem from another input than the one utilized by keys and values. It has previously been employed in various transformer-based architectures and applications. In the original transformer, it is used in an autoregressive decoder for machine translation (Vaswani et al., 2017), where keys and values are computed from the encoder output, whereas queries are computed from the already translated sequence. CrossViT applies cross-attention to efficiently fuse patch encodings corresponding to varying patch sizes (Chen et al., 2021). Cross-attention can also be applied to object detection. For example, DETR defines learnable object tokens that cross-attend to patch embeddings, where the number of tokens reflects the maximum number of objects that can be detected (Carion et al., 2020). More related to our work is Query-and-Attend (Arar et al., 2022), which involves learnable queries as a substitute for local convolutions and the more computationally expensive self-attention mechanism. Another interesting architecture is Perceiver IO, which applies cross-attention to the encoder for dimensionality reduction and to the decoder for learning task-specific outputs (Jaegle et al., 2021). TokenLearner (Ryoo et al., 2021) also reduces the number of tokens in the encoder for image and video recognition tasks but without using a cross-attention module. In our work, cross-attention serves two purposes. First, we treat it as a MIL pooling function that is used in the rear part of the network to combine information from multiple patches. Second, we take advantage of the attention scores to select a fixed number of most salient patches, which makes memory consumption independent of input size and thus allows us to process arbitrarily large images.
Another memory-efficient approach is the StreamingCNN (Pinckaers et al., 2020), which employs gradient checkpointing to process large images sequentially through a CNN, which is demonstrated on images up to 8k8k pixels. Due to the use of checkpoints, multiple forward and backward passes are required. Furthermore, backprop is performed for all tiles, which further slows down the training runtime. In IPS, all patches only run once through the network in fast no-gradient mode. Only selected patches then require a second forward and backward pass.
Appendix D Experiment details
All models are trained for 150 epochs (megapixel MNIST, traffic signs) or 50 epochs (CAMELYON16) on the respective training sets, and results are reported after the last epoch on the test set without early stopping. The batch size is 16, and AdamW with weight decay of 0.1 is used as optimizer (Loshchilov and Hutter, 2017). After a linear warm-up period of 10 epochs, the learning rate is set to 0.0003 when finetuning pre-trained networks and 0.001 when training from scratch. The learning rate is then decayed by a factor of 1,000 over the course of training using a cosine schedule.
d.2 Architecture components
Table 5 provides a high-level overview of the various components used by the baselines and our method. Note that all methods use the same encoder for a fair comparison.
The encoders differ depending on the dataset and are listed in Table 6. After applying each encoder, global average pooling is used to obtain a feature vector per patch. The CNN baselines also use these encoders, but the encoders are applied to the image instead of the patches.
|CAMELYON16||ResNet-50 + projector||BYOL||512|
|Megapixel MNIST||ResNet-18 (2 blocks)||None||128|
d.4 Transformer hyperparameters
The hyperparameters of the cross-attention transformer follow default values (Vaswani et al., 2017) and are listed in Table 7. We slightly deviate from these settings for very small values of , in particular in traffic signs, and in megapixel MNIST, where attention dropout is set to 0.
d.5 Baseline hyperparameters
In DPS, we mostly follow the setup of Cordonnier et al. (2021). We downsample the original image by 3 to obtain a low-resolution image (MNIST: 500500, traffic signs: 400533). As scorer, we use the first two blocks of a ResNet-18 without pre-training for megapixel MNIST and with ImageNet-1k pre-training for traffic signs, followed by a Conv layer (kernel: 3, padding: 0) with a single output channel and max pooling (kernel: 2, stride: 2). The resulting score maps have shape 3030 for megapixel MNIST and 2432 for traffic signs. Thus, we consider patches and patches for megapixel MNIST and traffic signs, respectively. We use noise samples for perturbed top-K, employ gradient norm clipping with a maximum cutoff value of and set perturbation noise , which linearly decays to over the course of training. For megapixel MNIST and , -decay was deactivated as it resulted in degraded performance towards the end of training.
For DeepMIL, following Ilse et al. (2018), whereas for DeepMIL+. The classification heads of the MIL baselines differ slightly between datasets and are listed in Table 8. In traffic signs recognition, a single output layer is utilized following Ilse et al. (2018). More capacity is added for DeepMIL+: fc (2,048) + ReLU, fc (512) + ReLU, fc (# classes). For megapixel MNIST, we obtained better results by adding an additional layer, which could be due to the multi-task setup. For CAMELYON16, we followed the settings of Dehaene et al. (2020). In DPS and our method, we use the same setup for all datasets: fc (# classes). The patch-level classification head introduced in Sect. 4.5 uses the same structure as the MIL methods in megapixel MNIST (Table 8 middle column).
d.6 Pre-processing of CAMELYON16
Since a large fraction of a WSI may not contain tissue, we first read each image in chunks, create a histogram of pixel intensities, and apply Otsu’s method to the histogram to obtain a threshold per image. We then extract non-overlapping patches of size 256256, and retain patches with at least 1% of the area being foreground. BYOL is then trained with a ResNet-50 for 500 epochs using the training settings of Grill et al. (2020) and the augmentations listed in Table 9, which mostly follow those used in Dehaene et al. (2020). In each pre-training epoch, 1,000 patches per slide are randomly sampled and shuffled. We train on 4 A100 GPUs with a batch size of 256 using mixed precision. After pre-training, each patch is center cropped with 224 px per side and 2,048-dimensional features resulting from the ResNet-50 encoder are extracted. The features are then normalized and stored in an HDF5 file. For the downstream task, we train a projector (fc (512), BatchNorm, ReLU), and its output is passed on to the aggregation module. For the pre-processing of the megapixel MNIST and traffic sign datasets, we adapt the code of Katharopoulos and Fleuret (2019).
d.7 Computational efficiency
We measure the training runtime for a full forward and backward pass without data loading for a batch size of 16, and report the average of all but the first and last iterations of one epoch. For time recording, we utilize torch.cuda.Event, which can be used as follows to measure the elapsed time for a single iteration:
After running an experiment for one epoch, the following code can be used to measure the peak GPU memory usage:
Appendix E IPS with DeepMIL
We want to explore how IPS performs with a different aggregation module than ours. Like our cross-attention module, the gated attention mechanism used in DeepMIL (Ilse et al., 2018) also assigns attention scores to each patch. We thus train IPS jointly with DeepMIL on megapixel MNIST (multi-task) and report results for each of 5 runs in Table 10, for the majority task with , and different values of .
||Accuracy per run (sorted)|
For , one can observe that the model was able to localize the digits in only one out of five runs. However, when increasing the memory buffer to , only one run failed, and for all runs were successful. For IPS + transformer, all runs achieved over 90% accuracy, even for . The average accuracy of IPS + DeepMIL for is , which is on par with IPS+transformer for (). This shows that IPS can be used with different attention-based aggregation modules, however IPS + transformer is more robust for small values of .