Log In Sign Up

Iterative Patch Selection for High-Resolution Image Recognition

High-resolution images are prevalent in various applications, such as autonomous driving and computer-aided diagnosis. However, training neural networks on such images is computationally challenging and easily leads to out-of-memory errors even on modern GPUs. We propose a simple method, Iterative Patch Selection (IPS), which decouples the memory usage from the input size and thus enables the processing of arbitrarily large images under tight hardware constraints. IPS achieves this by selecting only the most salient patches, which are then aggregated into a global representation for image recognition. For both patch selection and aggregation, a cross-attention based transformer is introduced, which exhibits a close connection to Multiple Instance Learning. Our method demonstrates strong performance and has wide applicability across different domains, training regimes and image sizes while using minimal accelerator memory. For example, we are able to finetune our model on whole-slide images consisting of up to 250k patches (>16 gigapixels) with only 5 GB of GPU VRAM at a batch size of 16.


page 9

page 18

page 19

page 20


Differentiable Patch Selection for Image Recognition

Neural Networks require large amounts of memory and compute to process h...

Multi-scale GANs for Memory-efficient Generation of High Resolution Medical Images

Currently generative adversarial networks (GANs) are rarely applied to m...

Memory-efficient GAN-based Domain Translation of High Resolution 3D Medical Images

Generative adversarial networks (GANs) are currently rarely applied on 3...

Foveation for Segmentation of Ultra-High Resolution Images

Segmentation of ultra-high resolution images is challenging because of t...

Enhanced Biologically Inspired Model for Image Recognition Based on a Novel Patch Selection Method with Moment

Biologically inspired model (BIM) for image recognition is a robust comp...

Patch-Based Stochastic Attention for Image Editing

Attention mechanisms have become of crucial importance in deep learning ...

Skin Lesion Classification Using CNNs with Patch-Based Attention and Diagnosis-Guided Loss Weighting

Objective: This work addresses two key problems of skin lesion classific...

1 Introduction

Figure 1: IPS selects salient patches for high-resolution image recognition. In the top right insert, green contours represent ground truth cancerous cells and the white overlays indicate patch scores from our model. Example patches are shown in the bottom right.

Image recognition has made great strides in recent years, spawning landmark architectures such as AlexNet 

(Krizhevsky et al., 2012) or ResNet (He et al., 2016)

. These networks are typically designed and optimized for datasets like ImageNet 

(Russakovsky et al., 2015), which consist of natural images well below one megapixel.111For instance, a 256256 image corresponds to only 0.06 megapixels. In contrast, real-world applications often rely on high-resolution images that reveal detailed information about an object of interest. For example, in self-driving cars, megapixel images are beneficial to recognize distant traffic signs far in advance and react in time (Sahin, 2019). In medical imaging, a pathology diagnosis system has to process gigapixel microscope slides to recognize cancer cells, as illustrated in Fig. 1.

Training neural networks on high-resolution images is challenging and can lead to out-of-memory errors even on dedicated high-performance hardware. Although downsizing the image can fix this problem, details critical for recognition may be lost in the process (Sabottke and Spieler, 2020; Katharopoulos and Fleuret, 2019)

. Reducing the batch size is another common approach to decrease memory usage, but it does not scale to arbitrarily large inputs and may lead to instabilities in networks involving batch normalization 

(Lian and Liu, 2019). On the other hand, distributed learning across multiple devices increases resources but is more costly and incurs higher energy consumption (Strubell et al., 2019).

We propose Iterative Patch Selection (IPS), a simple patch-based approach that decouples the consumed memory from the input size and thus enables the efficient processing of high-resolution images without running out of memory. IPS works in two steps: First, the most salient patches of an image are identified in no-gradient mode. Then, only selected patches are aggregated to train the network. We find that the attention scores of a cross-attention based transformer link both of these steps, and have a close connection to Multiple Instance Learning (MIL).

In the experiments, we demonstrate strong performance across three very different domains and training regimes: traffic sign recognition on megapixel images, multi-task classification on synthetic megapixel MNIST digits, and using self-supervised pre-training together with our method for memory-efficient learning on the gigapixel CAMELYON16 benchmark. Furthermore, our method exhibits a significantly lower memory consumption compared to various baselines. For example, when scaling megapixel MNIST images from 1k to 10k pixels per side at a batch size of 16, we can keep peak memory usage at a constant 1.7 GB while maintaining high accuracy, in contrast to a comparable CNN, which already consumes 24.6 GB at a resolution of 2k

2k. In an ablation study, we further analyze and provide insights into the key factors driving computational efficiency in IPS. Finally, we visualize exemplary attention distributions and present an approach to obtain patch-level class probabilities in a weakly-supervised multi-label classification setting.

2 Methods

We regard an image as a set of patches. Each patch is embedded independently by a shared encoder network, resulting in -dimensional representations, . Given the embeddings, we select the most salient patches and aggregate the information across these patches for the classification task. Thus our method, illustrated in Fig. 2, consists of two modules: a memory-efficient iterative patch selection module that selects a fixed number of salient patches and a transformer-based patch aggregation module that combines patch embeddings to compute a global image embedding that is passed on to a classification head. Crucially, the transformer-based patch aggregation module consists of a cross-attention layer, that is used by the patch selection module in no-gradient mode in order to score patches. We discuss these in more detail next and provide pseudocode in Appendix A.

2.1 Iterative Patch Selection

Given infinite memory, one could use an attention module to score each patch and select the top patches for aggregation. However, due to limited GPU memory, one cannot compute and store all patch embeddings in memory at the same time. We instead propose to iterate over patches, at a time, and autoregressively maintain a set of top patch embeddings. In other words, say is a buffer of patch embeddings at time step and are the next patch embeddings in the autoregressive update step. We run the following for iterations:


where , and are attention scores of considered patches at iteration , based on which the selection in Top-M is made. These attention scores are obtained from the cross-attention transformer as described in Sect. 2.2. The output of IPS after iterations is a set of patches corresponding to embeddings . Note that both patch embedding and patch selection are executed in no-gradient and evaluation mode. The former entails that no gradients are computed and stored, which renders IPS runtime and memory-efficient. The latter ensures deterministic patch selection behavior when using BatchNorm and Dropout.

Figure 2: Left: Illustration of Iterative Patch Selection running for iterations with , and . Right: Cross-attention transformer module that aggregates selected patches.

Data loading

We introduce several data loading strategies that trade off memory and runtime efficiency during IPS. In eager loading, a batch of images is loaded onto the GPU and IPS is applied to each image in parallel—this is the fastest variant. In eager sequential loading, individual images are loaded onto the GPU and thus patches are selected for one image at a time until a batch of

patches per image is selected for training. This enables the processing of different sequence lengths without padding and reduces memory usage at the cost of a higher runtime. In contrast,

lazy loading first loads a batch of images onto CPU memory. Then, only patches and corresponding embeddings pertinent to the current iteration are stored on the GPU—this decouples GPU memory usage from the image size, again at the cost of a higher runtime.

2.2 Transformer-Based Patch Aggregation

After selecting patches in the IPS stage, these patches are embedded again in gradient and training mode, and the patch aggregation module can aggregate the resulting embeddings using a convex combination:


Each attention score is the result of a function of the corresponding patch embedding : and can be learned by a neural network parameterized by . We observe that the weighted average in Eq. 2 constitutes a cross-attention layer by defining as follows:


where , , with , and being the projected dimension. Here, is a learnable query that is independent of the input. Furthermore, instead of aggregating patch embeddings directly, we aggregate a linear projection called values in typical transformer notation: , with and being the value dimension. In practice, we use multi-head cross-attention and place it in a standard transformer (Vaswani et al., 2017), optionally adding a sinusoidal position encoding (details in Appendix B).

Note that there is a close connection between the patch selection and patch aggregation modules. They both share the cross-attention layer. Whereas IPS runs in no-gradient mode and cannot train its parameters, by sharing them with the patch aggregation module, one can still learn to select patches relevant for the downstream task.

3 Related Work

The setting described in Sect. 2 constitutes a MIL problem (Maron and Lozano-Pérez, 1997), where an image is seen as a bag that contains multiple patch-level features called instances. The goal is then to predict the label of unseen bags given their instances. However, training labels are only available for bags, not for instances. The model must thus learn to attend to patches that best represent the image for a given task. Eq. 2 provides such a mechanism and is known as MIL pooling function. In particular, our cross-attention layer follows the weighted collective MIL assumption stating that all instances contribute to the bag representation, albeit to varying degrees according to the attention scores (Foulds and Frank, 2010; Amores, 2013).

A MIL approach similar to ours is DeepMIL (Ilse et al., 2018), which uses a gated attention mechanism to compute attention scores for the aggregation of patches: , with , and being parameters, and corresponds to the number of hidden nodes. In contrast, our multi-head cross-attention layer provides more capacity by computing multiple intermediate bag representations. More importantly, it is a MIL pooling function that can be naturally integrated into a standard and field-proven transformer module.

Other MIL methods common for binary classification of large images classify each patch and use max-pooling or a variant thereof on the class scores for patch selection and training 

(Campanella et al., 2018; Zhou et al., 2018; Lerousseau et al., 2020). Our method is more general and can be used for multi-{task, label, class} classification while being able to attend to multiple patches. In early experiments, we found that max-pooling in the baseline leads to overfitting. Instead, we employ TopMIL (based on Campanella et al. (2019)) as a baseline, which selects the top

patches in a single iteration according to the maximum class logit of each patch. Selected patches are then re-embedded, classified, class logits are averaged, and an activation function (e.g., softmax) is applied.

Another line of work seeks to reduce the memory consumption for the processing of high-resolution images by first identifying regions of interest in low resolution and then sampling patches from these regions in high resolution (Katharopoulos and Fleuret, 2019; Cordonnier et al., 2021; Kong and Henao, 2022)

. However, especially in gigapixel images, the lower-resolution image may either still be too large to encode, or it may be too small to identify informative regions. In contrast, IPS relies exclusively on high-resolution patches to best estimate the attention scores.

For very large images, patch-wise self-supervised learning is a promising way to reduce memory usage 

(Dehaene et al., 2020; Chen et al., 2022; Li et al., 2021). After pre-training, low-dimensional features are extracted and used to train a downstream network. Although the input size decreases considerably in this way, it may still require too much memory for very long sequences. Pre-training can be used orthogonally to IPS to process sequences of any length, as demonstrated in Sect. 4.2.

Our method is related to Sparse Mixture of Experts, a conditional computation approach in which specialized subnetworks (experts) are responsible for the processing of specific input patterns (Jacobs et al., 1991). Specifically, Shazeer et al. (2017) and Riquelme et al. (2021) use a gating function to decide which expert is assigned to which part of the input. This gating function is non-differentiable but shares its output with a differentiable aggregation operation that combines the expert’s outputs. On a high level this is similar to how we share the cross-attention module between no-gradient mode IPS and the gradient mode patch aggregator. Attention Sampling (Katharopoulos and Fleuret, 2019) also reuses aggregation weights for patch sampling. IPS uses Top-M instead.

4 Experiments

We evaluate the performance and efficiency of our method on three challenging datasets from a variety of domains and training regimes: Multi-class recognition of distant traffic signs in megapixel images, weakly-supervised classification in gigapixel whole-slide images (WSI) using self-supervised representations, and multi-task learning of inter-patch relations on a synthetic megapixel MNIST benchmark. All training and baseline hyperparameters are provided in Appendix 



To show that IPS selects salient patches, two other patch selection methods, Random Patch Selection (RPS) and Differential Patch Selection (DPS) (Cordonnier et al., 2021), are used for comparison, both employing the same transformer as IPS. Next, to verify that the cross-attention transformer is a strong MIL pooling operator, it is compared to DeepMIL (Ilse et al., 2018) and TopMIL as introduced in Sect. 3. Finally, we compare to a standard CNN applied on the original image and a smaller sized version to assess how resolution affects performance.


Each of the main experiments is performed five times with random parameter initializations, and the average classification performance and standard deviation are reported on the test set using either accuracy or AUC score. We also report computational efficiency metrics, including maximum GPU memory usage (VRAM) and training runtime for a batch size of 16 and various values of

. The runtime is calculated for a single forward and backward pass and is averaged over all iterations of one training epoch excluding the first and last iterations. Both metrics are calculated on a single NVIDIA A100 GPU in all experiments.

4.1 Traffic Signs Recognition

Acc. [%] VRAM Time
IPS Transformer
RPS Transformer
DPS Transformer
Table 1: Traffic sign recognition: Our IPS transformer outperforms all baselines using 1.2 GB of VRAM while being faster than the CNN. DPS uses patches (see Appendix D.5). VRAM is reported in gigabytes and time in milliseconds.

We first evaluate our method on the Swedish traffic signs dataset, which consists of 747 training and 684 test images with 1.3 megapixel resolution, as in Katharopoulos and Fleuret (2019). Each image shows a speed limit sign of 50, 70, 80 km/h or no sign. This problem requires high-resolution images to distinguish visually similar traffic signs, some of which appear very small due to their distance from the camera (Fig. 5 left). We resize each image to 12001600 and extract 192 non-overlapping patches of size 100

100. Due to the small number of data points, a ResNet-18 with ImageNet-1k weights is used as encoder for all methods and then finetuned. For IPS, we use eager loading and set


Table 1 shows that the average accuracy of the IPS transformer is consistently high regardless of the value of . For , memory usage is more than 10 less than for the CNN, and for

, training time is about half that of the CNN. As expected, RPS does not select salient patches due to the low signal-to-noise ratio in the data. DPS localizes the salient patches but performs worse than IPS in all metrics and is sensitive to the choice of

. The MIL baselines perform well in general, and TopMIL comes closest to IPS in accuracy. However, TopMIL requires knowledge about the number of informative patches, since the logits are averaged—already for , the performance is slightly degraded. In addition, TopMIL is faster but less memory-efficient than IPS because in TopMIL all patches are scored at once. DeepMIL, on the other hand, lacks a patch selection mechanism and hence includes all patches in the aggregation, which increases training time and VRAM. For , DeepMIL differs from the IPS transformer only in the aggregation function, yet the transformer performs 2.2 percentage points better on average. We also report results for DeepMIL+, which has a similar capacity to the transformer for a fairer comparison. It achieves an average accuracy of 97.7%, less than 98.4% of IPS. A CNN that takes the entire image as input and has no attention mechanism performs much worse and is inefficient. Downscaling the image by a factor of 3 further degrades the CNN’s performance, as details of distant traffic signs are lost.

For a fair comparison against prior work, we used the same transformer and pre-trained model where applicable. This results in better numbers than reported by DPS, which achieves in Table 1 vs. in Cordonnier et al. (2021). Nonetheless, our IPS transformer improves on that by up to 1.4 percentage points while being more memory and runtime efficient.

4.2 Gigapixel WSI Classification

Next, we consider the CAMELYON16 dataset (Litjens et al., 2018), which consists of 270 training and 129 test WSIs of gigapixel resolution for the recognition of metastases in lymph node cells. The task is formulated as a weakly-supervised binary classification problem, i.e., only image-level labels about the presence of metastases are used. The WSIs come with multiple resolutions, but we attempt to solve this task with the highest available resolution (40 magnification). Since a large portion of the images may not contain any cells, Otsu’s method (Otsu, 1979) is first applied to filter out the background. Only patches with at least 1% of the area being foreground are included—a sensitive criterion that ensures that no abnormal cells are left out. Then, non-overlapping patches of size are extracted. This gives a total of 28.6 million patches (avg.: 71.8k, std.: 37.9k), with the largest WSI having 251.8k patches (about 20% of the size of ImageNet-1k).

Due to the large volume of patches, it is impractical to learn directly from pixels. Instead, we train BYOL (Grill et al., 2020), a state-of-the-art self-supervised learning algorithm, with a ResNet-50 encoder and extract features for all patches and WSIs (see Appendix D.6

for details). Each patch is thus represented as a 2,048-dimensional vector, which is further projected down to 512 features and finally processed by the respective aggregator. Note that even when processing patches at the low-dimensional feature level, aggregating thousands of patch features can become prohibitively expensive. For example, with DeepMIL, we can aggregate around 70k patch features before running out of memory on an A100 at a batch size of 16. Therefore, in each of the 5 runs, a random sample of up to 70k patch features is drawn for the baseline MIL methods, similar to 

Dehaene et al. (2020). In contrast, IPS can easily process all patches. Due to the varying number of patches in each WSI, a batch cannot be built without expensive padding. We thus switch to eager sequential loading in IPS, i.e., slides are loaded sequentially and patch features per image are selected and cached for subsequent mini-batch training. In all IPS experiments, .

max width=0.99 Total () Pretr. AUC [%] VRAM [GB] Time [ms] IPS Transformer (ours) all BYOL all BYOL all BYOL BYOL DeepMIL BYOL BYOL BYOL TopMIL BYOL BYOL BYOL DSMIL-LC (Li et al., 2021) SimCLR CLAM (Lu et al., 2021) CPC CLAM-SB (Wang et al., 2022) NR NR SRCL DeepMIL (Dehaene et al., 2020) MoCo v2 Challenge Winner (Wang et al., 2016) None Strong superv. (Bejnordi et al., 2017) None

Table 2: CAMELYON16: IPS achieves performance close to the strongly-supervised state-of-the-art. : Average number of patches is reported for these methods, NR: and are unknown.

We report AUC scores, as in the challenge (Bejnordi et al., 2017), in Table 2. Best performance is achieved by IPS using , with an improvement of more than 3 percentage points over the runner-up baseline. We also outperform the CAMELYON16 challenge winner (Wang et al., 2016), who leveraged patch-level labels, and approach the result of the strongly-supervised state of the art (Bejnordi et al., 2017). The memory usage in IPS is 4–5 GB depending on . However, IPS has a higher runtime because all patches are taken into account and eager sequential loading is used. For comparison, we also reduce the total number of patches before the IPS stage to 70k (using the same random seeds) and observe slightly better performance than DeepMIL. DPS was not considered since the downscaled image exceeded our hardware limits. We also tested CNNs trained on downscaled images up to 4k4k pixels but they perform worse than a naive classifier.

Several recent works adopt a pre-trained feature extractor and report results for CAMELYON16 (see bottom of Table 2). Given their respective setups, most consider all patches for aggregation but make compromises that one can avoid with IPS: A batch size of 1 (Wang et al., 2022; Li et al., 2021), training with features from the third ResNet-50 block, which effectively halves the size of the embeddings (Lu et al., 2021), or using smaller resolutions (Dehaene et al., 2020; Li et al., 2021). The work by Dehaene et al. (2020) achieves a slightly better result than our IPS transformer (avg. AUC: 98.1% vs. 98.7%) by using DeepMIL on a random subset of up to patches. We do not assume that this is due to our transformer, as the DeepMIL baseline reproduces their aggregation function yet performs significantly worse. Instead, it could be due to a different pre-training algorithm (MoCo v2), pre-processing (closed-source U-Net to filter out the background), or lower magnification level (20), which reduces the total number of patches (avg. 9.8k vs. 71.8k in ours).

4.3 Inter-Patch Reasoning

max width=0.99 Total () Majority Max Top Multi-label VRAM [GB] Time [ms] IPS Transformer GB ms (ours) GB ms GB ms GB ms GB ms GB ms RPS Transformer GB ms GB ms DPS Transformer GB ms GB ms TopMIL GB ms GB ms GB ms DeepMIL GB ms CNN GB ms CNN 0.3x GB ms

Table 3: Results for megapixel MNIST. The IPS Transformer accurately solves all relational tasks with a minimal number of patches and can leverage overlapping patches with small memory usage.

A single informative patch was sufficient to solve the previous tasks. In contrast, megapixel MNIST introduced in Katharopoulos and Fleuret (2019) requires the recognition of multiple patches and their relations. The dataset consists of 5,000 training and 1,000 test images of size 1,5001,500. We extract patches of size 5050 without overlap () or with 50% overlap (). In each image, 5 MNIST digits are placed, 3 of which belong to the same class and 2 of which to other classes. In addition, 50 noisy lines are added (see Fig. 5 right). The task is then to predict the majority class. We found that this problem can be solved well by most baselines, so we extend the setup with three more complex tasks: max: detect the maximum digit, top: identify the topmost digit, multi-label: the presence/absence of all classes needs to be recognized. We frame this as a multi-task learning problem and use 4 learnable queries in the cross-attention layer, i.e. 4 task representations are learned. Similarly, 4 gated-attention layers are used for DeepMIL. All methods also utilize multiple classification heads and add a sinusoidal position encoding to the patch features. For the CNNs, a sinusoidal channel is concatenated with the input. A simplified ResNet-18 consisting of 2 residual blocks is used as the encoder for all methods. For IPS, we use eager loading and set .

The results in Table 3 show that almost all methods were able to obtain high accuracies for the tasks majority and max. However, only DeepMIL, DPS and IPS were able to make appropriate use of the positional encoding to solve task top. For the hardest task, multi-label, only the IPS transformer obtained accuracies above 90% by using overlapping patches. However, for both TopMIL and DeepMIL, we were not able to use overlapping patches due to out-of-memory errors (on a single A100). Interestingly, IPS achieves high performance throughout with using overlapping patches at a memory usage of only 2.2 GB while maintaining a competitive runtime. Furthermore, the IPS transformer also achieves high performance with patches, which is the minimum number of patches required to solve all tasks.

4.4 Ablation Study

Figure 3: Peak GPU memory (left), training runtime (center) and accuracy (right) as a function of image size in megapixel MNIST: With IPS and lazy loading, memory usage can be kept constant. High performance is mostly maintained compared to the baseline CNN.
Figure 4: VRAM, training runtime and number of iterations for different combinations of and .

Effect of image size

How does input size affect performance and efficiency in IPS? To address this question, we generate new megapixel MNIST datasets by scaling the image size from 1k to 10k pixels per side (400 to 40k patches) and attempt to solve the same tasks as in the main experiment. We also linearly scale the number of noise patterns from 33 to 333 to make the task increasingly difficult. Fig. 3 shows a comparison of memory usage, runtime and accuracy (task majority) for different data loading options of our IPS transformer (, , non-overlapping patches, default batch size of 16), as well as the baseline CNN.

The CNN exhibits a high memory footprint that quickly exceeds the hardware limits even when using a reduced batch size of 4. In IPS with eager loading, the VRAM is dominated by the batch of inputs and can be reduced by sequential loading. In contrast, with lazy loading the memory usage is kept constant at 1.7 GB regardless of the input size. The runtime in eager loading is less than that of the CNN and almost approaches a CNN using a batch size of 4. Lazy loading is faster than eager sequential loading, suggesting that the latter should only be used in situations where different sequence lengths need to be processed. In terms of performance, we note that the IPS transformer achieves high accuracy for up to 8k pixels and decreases only slightly starting from 9k pixels likely due to the declining signal-to-noise ratio (at 10k pixels, only 5 out of 40k patches are relevant). In comparison, the CNN performance drops rapidly starting from 3k pixels towards a naive classifier.

Effect of and

Fig. 4 shows that VRAM and runtime tend to increase with as more patches need to be processed in training mode. With increasing values of , a higher VRAM but lower runtime can be observed, since IPS runs for fewer iterations. However, there are more subtle cases. Along the diagonal in Fig. 4 (left), peak VRAM is observed in the non-IPS part and thus remains constant for different values of . For example, for , VRAM remains constant for , so setting results in faster runtime without compromising memory usage. In contrast, in the upper left triangle of the graph, the VRAM is dominated by IPS in no-gradient mode. For example, uses 5GB of VRAM for (i.e., ). Interestingly, when processing all remaining patches in a single iteration (), VRAM can be reduced by increasing up to 200, because is reduced by .

Effect of patch size

In Table 4, we investigate the effect of patch size (25–400 px per side) on the performance and efficiency of traffic sign recognition. One can notice that the accuracy is similar for all patch sizes and only slightly decreases for very small patch sizes of 25 px. As the patch size increases, the runtime for IPS tends to decrease since fewer iterations are run, while the time to embed and aggregate selected patches (non-IPS time) increases because more pixels need to be processed in training mode.

max width=0.99 Size [px] Acc. [%] VRAM [GB] IPS Time [ms] Non-IPS Time [ms] Total Time [ms]

Table 4: Effect of patch size on traffic sign recognition and efficiency (, , single run).
Figure 5: Patch scores and crops of most salient patches. Left: Traffic signs obtain high scores. Center: Most cancer cells (black contour) are attended. Right: Patch-level class mappings in MNIST.

4.5 Interpretability

In Fig. 5, we superimpose attention scores over images, which shows that salient patches are attended to while the background is neglected. In low signal-to-noise settings, such visualizations can support humans to localize tiny regions of interest that might otherwise be missed (e.g., Fig. 1). However, attention scores alone may not be satisfactory when multiple objects of different classes are present. For example, attention scores in megapixel MNIST indicate relevant patches, but not to which classes these patches belong. We propose a simple fix to obtain patch-level class scores for classes. We add a new classification head: , where with parameters . For training, these scores compute an image-level prediction: , where is an activation function. Crucially, we stop gradients from flowing back through and such that is trained as a separate read-out head jointly with the base model. We demonstrate this for megapixel MNIST (), where the patch-level classifier is learned from the multi-label task, using only image-level labels. In Fig. 5 (right), we visualize the inferred per-patch class label, , using colors and using the alpha channel.

5 Conclusion

Reflecting on the experiments, one can observe that the IPS transformer performs well in various domains and training regimes and exhibits interesting capabilities: constant memory usage, learning with arbitrarily long sequences of varying lengths, multi-task learning, and modeling of inter-patch relations. However, there are also limitations. For example, while the GPU memory can be kept constant, runtime cannot. In section 4.4, we gave some intuition about what influence hyperparameters and have in this regard. Furthermore, performance can degrade if the signal-to-noise ratio becomes too small, as shown in our MNIST scaling experiment—in practice, this might be mitigated by pre-training. Nonetheless, we believe that our main contribution, maintaining a low memory footprint for high-resolution image processing, can foster interesting applications that would otherwise be prohibitively expensive. Examples of such applications are learning with multiple medical scans for diagnosis/prognosis, the processing of satellite images, prediction tasks in videos, graphs, point clouds or whole-genome sequences. Technically, it would be interesting to extend IPS to dense prediction tasks, such as segmentation or depth estimation.

Ethics Statement


CAMELYON16 is a public benchmark dataset consisting of histopathology slides obtained from tissue of human subjects. The dataset has been approved by an ethics committee, see Litjens et al. (2018) for more information. For the Swedish traffic signs dataset, we searched the data for humans and excluded one image where a face of a pedestrian was clearly visible. Megapixel MNIST builds on the widely used MNIST dataset, which contains digits that are ethically uncritical.


In this paper, we exclusively conduct experiments and demonstrate potential applications for human benefit, such as early traffic signs detection, which could be useful for advanced driver assistance systems, or abnormality detection in medical images, which could assist physicians. However, more care is needed before any of these systems can be deployed in real-world applications including an assessment of biases and fairness and monitoring by medical professionals where relevant. Furthermore, the attention scores are currently not usable as a detection method. Beyond the showcased applications, our method could be misused, just like any other visual object detection system, but now also on gigapixel images. We do not explore such applications in our paper and do not plan to do so in future work.

Computational efficiency

Our method enables the training of deep neural networks on arbitrary large inputs even on low-cost consumer hardware with limited VRAM, which makes AI development more accessible. All training and fine-tuning experiments on downstream tasks can be efficiently run on a single A100 GPU, with the following total training runtimes: Traffic sign recognition: 45 min (, 150 epochs), megapixel MNIST: 5h (, , , 150 epochs), CAMELYON16: 2h (, , 50 epochs).

Reproducibility Statement

Sect. 2, as well as the pseudocode provided in Appendix A enable the reproduction of our method. Our metrics and baselines lean on existing work and are described in Sect. 4. Necessary steps to reproduce our experiments are described in respective subsections of Sect. 4. We provide further details about all network architectures, hyperparameters for our method and the baselines, as well as pre-processing steps in Appendix D. The tools used to measure the training runtime and peak GPU memory are listed in Appendix D.7.


We would like to thank Joan Puigcerver for suggesting relevant related work, Romeo Sommerfeld for fruitful discussions on interpretability, Noel Danz for useful discussions about applications, as well as Jonathan Kreidler, Konrad Lux, Selena Braune and Florian Sold for providing and managing the compute infrastructure. This project has received funding from the German Federal Ministry for Economic Affairs and Climate Action (BMWK) in the project DAKI-FWS (reference number: 01MK21009E).


  • J. Amores (2013) Multiple instance classification: review, taxonomy and comparative study. Artificial intelligence 201, pp. 81–105. Cited by: §3.
  • M. Arar, A. Shamir, and A. H. Bermano (2022) Learned queries for efficient local attention. In

    Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition

    pp. 10841–10852. Cited by: Appendix C.
  • B. E. Bejnordi, M. Veta, P. J. Van Diest, B. Van Ginneken, N. Karssemeijer, G. Litjens, J. A. Van Der Laak, M. Hermsen, Q. F. Manson, M. Balkenhol, et al. (2017)

    Diagnostic assessment of deep learning algorithms for detection of lymph node metastases in women with breast cancer

    Jama 318 (22), pp. 2199–2210. Cited by: §4.2, Table 2.
  • G. Campanella, M. G. Hanna, L. Geneslaw, A. Miraflor, V. Werneck Krauss Silva, K. J. Busam, E. Brogi, V. E. Reuter, D. S. Klimstra, and T. J. Fuchs (2019) Clinical-grade computational pathology using weakly supervised deep learning on whole slide images. Nature medicine 25 (8), pp. 1301–1309. Cited by: §3.
  • G. Campanella, V. W. K. Silva, and T. J. Fuchs (2018) Terabyte-scale deep multiple instance learning for classification and localization in pathology. arXiv preprint arXiv:1805.06983. Cited by: §3.
  • N. Carion, F. Massa, G. Synnaeve, N. Usunier, A. Kirillov, and S. Zagoruyko (2020) End-to-end object detection with transformers. In European conference on computer vision, pp. 213–229. Cited by: Appendix C.
  • C. R. Chen, Q. Fan, and R. Panda (2021) Crossvit: cross-attention multi-scale vision transformer for image classification. In Proceedings of the IEEE/CVF international conference on computer vision, pp. 357–366. Cited by: Appendix C.
  • R. J. Chen, C. Chen, Y. Li, T. Y. Chen, A. D. Trister, R. G. Krishnan, and F. Mahmood (2022) Scaling vision transformers to gigapixel images via hierarchical self-supervised learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 16144–16155. Cited by: §3.
  • J. Cordonnier, A. Mahendran, A. Dosovitskiy, D. Weissenborn, J. Uszkoreit, and T. Unterthiner (2021) Differentiable patch selection for image recognition. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2351–2360. Cited by: §D.5, §3, §4, §4.1.
  • O. Dehaene, A. Camara, O. Moindrot, A. de Lavergne, and P. Courtiol (2020) Self-supervision closes the gap between weak and strong supervision in histology. arXiv preprint arXiv:2012.03583. Cited by: §D.5, §D.6, §3, §4.2, §4.2, Table 2.
  • J. Foulds and E. Frank (2010) A review of multi-instance learning assumptions.

    The knowledge engineering review

    25 (1), pp. 1–25.
    Cited by: §3.
  • J. Grill, F. Strub, F. Altché, C. Tallec, P. Richemond, E. Buchatskaya, C. Doersch, B. Avila Pires, Z. Guo, M. Gheshlaghi Azar, et al. (2020) Bootstrap your own latent-a new approach to self-supervised learning. Advances in neural information processing systems 33, pp. 21271–21284. Cited by: §D.6, §4.2.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §1.
  • M. Ilse, J. Tomczak, and M. Welling (2018) Attention-based deep multiple instance learning. In International conference on machine learning, pp. 2127–2136. Cited by: §D.5, Appendix E, §3, §4.
  • R. A. Jacobs, M. I. Jordan, S. J. Nowlan, and G. E. Hinton (1991) Adaptive mixtures of local experts. Neural computation 3 (1), pp. 79–87. Cited by: §3.
  • A. Jaegle, S. Borgeaud, J. Alayrac, C. Doersch, C. Ionescu, D. Ding, S. Koppula, D. Zoran, A. Brock, E. Shelhamer, et al. (2021) Perceiver io: a general architecture for structured inputs & outputs. arXiv preprint arXiv:2107.14795. Cited by: Appendix C.
  • A. Katharopoulos and F. Fleuret (2019) Processing megapixel images with deep attention-sampling models. In International Conference on Machine Learning, pp. 3282–3291. Cited by: §D.6, §1, §3, §3, §4.1, §4.3.
  • F. Kong and R. Henao (2022) Efficient classification of very large images with tiny objects. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2384–2394. Cited by: §3.
  • A. Krizhevsky, I. Sutskever, and G. E. Hinton (2012)

    Imagenet classification with deep convolutional neural networks

    Advances in neural information processing systems 25. Cited by: §1.
  • M. Lerousseau, M. Vakalopoulou, M. Classe, J. Adam, E. Battistella, A. Carré, T. Estienne, T. Henry, E. Deutsch, and N. Paragios (2020) Weakly supervised multiple instance learning histopathological tumor segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 470–479. Cited by: §3.
  • B. Li, Y. Li, and K. W. Eliceiri (2021) Dual-stream multiple instance learning network for whole slide image classification with self-supervised contrastive learning. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 14318–14328. Cited by: §3, §4.2, Table 2.
  • X. Lian and J. Liu (2019) Revisit batch normalization: new understanding and refinement via composition optimization. In The 22nd International Conference on Artificial Intelligence and Statistics, pp. 3254–3263. Cited by: §1.
  • G. Litjens, P. Bandi, B. Ehteshami Bejnordi, O. Geessink, M. Balkenhol, P. Bult, A. Halilovic, M. Hermsen, R. van de Loo, R. Vogels, et al. (2018) 1399 h&e-stained sentinel lymph node sections of breast cancer patients: the camelyon dataset. GigaScience 7 (6), pp. giy065. Cited by: §4.2, Datasets.
  • I. Loshchilov and F. Hutter (2017) Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101. Cited by: §D.1.
  • M. Y. Lu, D. F. Williamson, T. Y. Chen, R. J. Chen, M. Barbieri, and F. Mahmood (2021) Data-efficient and weakly supervised computational pathology on whole-slide images. Nature biomedical engineering 5 (6), pp. 555–570. Cited by: §4.2, Table 2.
  • O. Maron and T. Lozano-Pérez (1997) A framework for multiple-instance learning. Advances in neural information processing systems 10. Cited by: §3.
  • N. Otsu (1979) A threshold selection method from gray-level histograms. IEEE transactions on systems, man, and cybernetics 9 (1), pp. 62–66. Cited by: §4.2.
  • H. Pinckaers, B. Van Ginneken, and G. Litjens (2020) Streaming convolutional neural networks for end-to-end learning with multi-megapixel images. IEEE transactions on pattern analysis and machine intelligence. Cited by: Appendix C.
  • C. Riquelme, J. Puigcerver, B. Mustafa, M. Neumann, R. Jenatton, A. Susano Pinto, D. Keysers, and N. Houlsby (2021) Scaling vision with sparse mixture of experts. Advances in Neural Information Processing Systems 34, pp. 8583–8595. Cited by: §3.
  • O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A. Khosla, M. Bernstein, A. C. Berg, and L. Fei-Fei (2015) ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision (IJCV) 115 (3), pp. 211–252. External Links: Document Cited by: §1.
  • M. Ryoo, A. Piergiovanni, A. Arnab, M. Dehghani, and A. Angelova (2021) Tokenlearner: adaptive space-time tokenization for videos. Advances in Neural Information Processing Systems 34, pp. 12786–12797. Cited by: Appendix C.
  • C. F. Sabottke and B. M. Spieler (2020) The effect of image resolution on deep learning in radiography. Radiology. Artificial intelligence 2 (1). Cited by: §1.
  • F. E. Sahin (2019) Long-range, high-resolution camera optical design for assisted and autonomous driving. In Photonics, Vol. 6, pp. 73. Cited by: §1.
  • N. Shazeer, A. Mirhoseini, K. Maziarz, A. Davis, Q. Le, G. Hinton, and J. Dean (2017) Outrageously large neural networks: the sparsely-gated mixture-of-experts layer. arXiv preprint arXiv:1701.06538. Cited by: §3.
  • E. Strubell, A. Ganesh, and A. McCallum (2019) Energy and policy considerations for deep learning in nlp. arXiv preprint arXiv:1906.02243. Cited by: §1.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. Advances in neural information processing systems 30. Cited by: Appendix C, §D.4, §2.2.
  • D. Wang, A. Khosla, R. Gargeya, H. Irshad, and A. H. Beck (2016) Deep learning for identifying metastatic breast cancer. arXiv preprint arXiv:1606.05718. Cited by: §4.2, Table 2.
  • X. Wang, S. Yang, J. Zhang, M. Wang, J. Zhang, W. Yang, J. Huang, and X. Han (2022) Transformer-based unsupervised contrastive learning for histopathological image classification. Medical Image Analysis 81, pp. 102559. Cited by: §4.2, Table 2.
  • L. Zhou, Y. Zhao, J. Yang, Q. Yu, and X. Xu (2018) Deep multiple instance learning for automatic detection of diabetic retinopathy in retinal images. IET Image Processing 12 (4), pp. 563–571. Cited by: §3.


Appendix A Pseudocode

# N: total num patches # M: memory size # I: iter size # D: embedding dimension # B: batch size # a: attention scores for patches, labels in loader: # eager loading patches, labels =, # set network to no-gradient and evaluation mode grad(); net.eval(); # initialize memory buffer mem emb = net.encoder(patches[:,:M]) # B, M, D # get num iterations n iter = ceil((N - M) / I)) for i in range(n iter): # get and embed next patches iter patch = get next(patches, i) iter emb = net.encoder(iter patch) # B, I, D # concatenate with memory buffer all emb = cat(mem emb, iter emb) # get attention scores a = net.transf.get scores(all emb) # B, M + I # update memory buffer mem emb, mem idx = top M(all emb, a) # select patches mem patch = gather(patches, mem idx) # set network to gradient and training mode net.grad(); net.train(); # embed and aggregate patches mem emb = net.encoder(mem patch) # B, M, D image emb = net.transf(mem emb) # B, D # predict based on image embedding pred = net.predict(image emb) # training step loss = loss fn(pred, labels) loss.backward() optimizer.step()  
Algorithm 1 Pseudocode for Iterative Patch Selection and Patch Aggregation

Appendix B Multi-head cross-attention and transformer

In practice, we use multi-head cross-attention (MCA), which computes multiple intermediate representations corresponding to heads and then aggregates them:


with , , and . Note that each representation is computed with head-specific attention scores

. For patch selection, these scores are averaged over heads to obtain a scalar score per patch. The transformer module is composed of an MCA layer, a multi-layer perceptron (MLP) and Layernorm (LN) and is formulated as:


where is an optional

sinusoidal position encoding and the MLP consists of two layers with ReLU non-linearity.

Appendix C Extended related work


is an essential component of our method because it links the patch selection and patch aggregation modules. Therefore, we want to shed more light on its use in the literature. Cross-attention is a concept in which the queries stem from another input than the one utilized by keys and values. It has previously been employed in various transformer-based architectures and applications. In the original transformer, it is used in an autoregressive decoder for machine translation (Vaswani et al., 2017), where keys and values are computed from the encoder output, whereas queries are computed from the already translated sequence. CrossViT applies cross-attention to efficiently fuse patch encodings corresponding to varying patch sizes (Chen et al., 2021). Cross-attention can also be applied to object detection. For example, DETR defines learnable object tokens that cross-attend to patch embeddings, where the number of tokens reflects the maximum number of objects that can be detected (Carion et al., 2020). More related to our work is Query-and-Attend (Arar et al., 2022), which involves learnable queries as a substitute for local convolutions and the more computationally expensive self-attention mechanism. Another interesting architecture is Perceiver IO, which applies cross-attention to the encoder for dimensionality reduction and to the decoder for learning task-specific outputs (Jaegle et al., 2021). TokenLearner (Ryoo et al., 2021) also reduces the number of tokens in the encoder for image and video recognition tasks but without using a cross-attention module. In our work, cross-attention serves two purposes. First, we treat it as a MIL pooling function that is used in the rear part of the network to combine information from multiple patches. Second, we take advantage of the attention scores to select a fixed number of most salient patches, which makes memory consumption independent of input size and thus allows us to process arbitrarily large images.


Another memory-efficient approach is the StreamingCNN (Pinckaers et al., 2020), which employs gradient checkpointing to process large images sequentially through a CNN, which is demonstrated on images up to 8k8k pixels. Due to the use of checkpoints, multiple forward and backward passes are required. Furthermore, backprop is performed for all tiles, which further slows down the training runtime. In IPS, all patches only run once through the network in fast no-gradient mode. Only selected patches then require a second forward and backward pass.

Appendix D Experiment details

d.1 Training

All models are trained for 150 epochs (megapixel MNIST, traffic signs) or 50 epochs (CAMELYON16) on the respective training sets, and results are reported after the last epoch on the test set without early stopping. The batch size is 16, and AdamW with weight decay of 0.1 is used as optimizer (Loshchilov and Hutter, 2017). After a linear warm-up period of 10 epochs, the learning rate is set to 0.0003 when finetuning pre-trained networks and 0.001 when training from scratch. The learning rate is then decayed by a factor of 1,000 over the course of training using a cosine schedule.

d.2 Architecture components

Table 5 provides a high-level overview of the various components used by the baselines and our method. Note that all methods use the same encoder for a fair comparison.

max width= # CNN DPS Transformer TopMIL DeepMIL IPS Transformer (ours) 1 Encoder Scorer + Diff. Top-M Patch selection Encoder IPS 2 Class. head Encoder Encoder Gated attention Encoder 3 Activation Cross-attention transf. Class. head Class. head Cross-attention transf. 4 Class. head Avg. pool Activation Class. head 5 Activation Activation Activation

Table 5: Module pipeline of baselines and our method

d.3 Encoders

The encoders differ depending on the dataset and are listed in Table 6. After applying each encoder, global average pooling is used to obtain a feature vector per patch. The CNN baselines also use these encoders, but the encoders are applied to the image instead of the patches.

Dataset Encoder Pre-training # Features
Traffic signs ResNet-18 ImageNet-1k 512
CAMELYON16 ResNet-50 + projector BYOL 512
Megapixel MNIST ResNet-18 (2 blocks) None 128
Table 6: Encoder architectures and pre-training for each dataset.

d.4 Transformer hyperparameters

The hyperparameters of the cross-attention transformer follow default values (Vaswani et al., 2017) and are listed in Table 7. We slightly deviate from these settings for very small values of , in particular in traffic signs, and in megapixel MNIST, where attention dropout is set to 0.

Setting Value
Heads 8
# Features
attention dropout
Table 7: Transformer hyperparameters. is the hidden layer size in the feed-forward MLP of the transformer block.

d.5 Baseline hyperparameters

In DPS, we mostly follow the setup of Cordonnier et al. (2021). We downsample the original image by 3 to obtain a low-resolution image (MNIST: 500500, traffic signs: 400533). As scorer, we use the first two blocks of a ResNet-18 without pre-training for megapixel MNIST and with ImageNet-1k pre-training for traffic signs, followed by a Conv layer (kernel: 3, padding: 0) with a single output channel and max pooling (kernel: 2, stride: 2). The resulting score maps have shape 3030 for megapixel MNIST and 2432 for traffic signs. Thus, we consider patches and patches for megapixel MNIST and traffic signs, respectively. We use noise samples for perturbed top-K, employ gradient norm clipping with a maximum cutoff value of and set perturbation noise , which linearly decays to over the course of training. For megapixel MNIST and , -decay was deactivated as it resulted in degraded performance towards the end of training.

For DeepMIL, following Ilse et al. (2018), whereas for DeepMIL+. The classification heads of the MIL baselines differ slightly between datasets and are listed in Table 8. In traffic signs recognition, a single output layer is utilized following Ilse et al. (2018). More capacity is added for DeepMIL+: fc (2,048) + ReLU, fc (512) + ReLU, fc (# classes). For megapixel MNIST, we obtained better results by adding an additional layer, which could be due to the multi-task setup. For CAMELYON16, we followed the settings of Dehaene et al. (2020). In DPS and our method, we use the same setup for all datasets: fc (# classes). The patch-level classification head introduced in Sect. 4.5 uses the same structure as the MIL methods in megapixel MNIST (Table 8 middle column).

max width= # Traffic signs Megapixel MNIST CAMELYON16 1 fc (# classes) fc (128) + ReLU fc (128) + ReLU 2 fc (# classes) fc (64) + ReLU 3 fc (# classes)

Table 8: Classification heads for each dataset for DeepMIL and TopMIL

d.6 Pre-processing of CAMELYON16

Since a large fraction of a WSI may not contain tissue, we first read each image in chunks, create a histogram of pixel intensities, and apply Otsu’s method to the histogram to obtain a threshold per image. We then extract non-overlapping patches of size 256256, and retain patches with at least 1% of the area being foreground. BYOL is then trained with a ResNet-50 for 500 epochs using the training settings of Grill et al. (2020) and the augmentations listed in Table 9, which mostly follow those used in Dehaene et al. (2020). In each pre-training epoch, 1,000 patches per slide are randomly sampled and shuffled. We train on 4 A100 GPUs with a batch size of 256 using mixed precision. After pre-training, each patch is center cropped with 224 px per side and 2,048-dimensional features resulting from the ResNet-50 encoder are extracted. The features are then normalized and stored in an HDF5 file. For the downstream task, we train a projector (fc (512), BatchNorm, ReLU), and its output is passed on to the aggregation module. For the pre-processing of the megapixel MNIST and traffic sign datasets, we adapt the code of Katharopoulos and Fleuret (2019).

Type Parameters Probability
VerticalFlip -
HorizontalFlip -
Grayscale -
Table 9: BYOL augmentations. Square brackets indicate varying probabilities in each of the 2 views.

d.7 Computational efficiency

We measure the training runtime for a full forward and backward pass without data loading for a batch size of 16, and report the average of all but the first and last iterations of one epoch. For time recording, we utilize torch.cuda.Event, which can be used as follows to measure the elapsed time for a single iteration:

1import torch
2start_event = torch.cuda.Event(enable_timing=True)
3end_event = torch.cuda.Event(enable_timing=True)
5#forward and backward pass
8time = start_event.elapsed_time(end_event)

After running an experiment for one epoch, the following code can be used to measure the peak GPU memory usage:

1mem_stats = torch.cuda.memory_stats()
2peak_gb = mem_stats["allocated_bytes.all.peak"] / 1024 ** 3

Appendix E IPS with DeepMIL

We want to explore how IPS performs with a different aggregation module than ours. Like our cross-attention module, the gated attention mechanism used in DeepMIL (Ilse et al., 2018) also assigns attention scores to each patch. We thus train IPS jointly with DeepMIL on megapixel MNIST (multi-task) and report results for each of 5 runs in Table 10, for the majority task with , and different values of .

Accuracy per run (sorted)
5 9.1 9.6 10.0 10.0 94.3
100 10.0 98.6 98.7 98.7 99.2
200 98.5 98.7 98.9 99.0 99.1
Table 10: Performance of DeepMIL when combined with IPS.

For , one can observe that the model was able to localize the digits in only one out of five runs. However, when increasing the memory buffer to , only one run failed, and for all runs were successful. For IPS + transformer, all runs achieved over 90% accuracy, even for . The average accuracy of IPS + DeepMIL for is , which is on par with IPS+transformer for (). This shows that IPS can be used with different attention-based aggregation modules, however IPS + transformer is more robust for small values of .

Appendix F Additional visualizations

Figure 6: Evolution of attention scores in IPS for . The traffic sign (80 km/h) is localized starting from epoch 20. The entropy of the attention distribution appears to decrease during training.
Figure 7: Patches with the highest attention scores (decreasing from left-to-right then top-to-bottom) out of all images from the traffic sign recognition test set.
Figure 8: Attention scores and crops of the most salient patches for example test images in the traffic sign recognition dataset.
Figure 9: Patch-level classification scores and crops of patches that obtained highest attention scores (decreasing from left-to-right then top-to-bottom per image) in megapixel MNIST ().
Figure 10: Attention maps overlaid on the input image and crops of the most salient patches for two examples of the CAMELYON16 dataset.