TorchIO: a Python library for efficient loading, preprocessing, augmentation and patch-based sampling of medical images in deep learning

03/09/2020 ∙ by Fernando Pérez-García, et al. ∙ King's College London UCL 121

We present TorchIO, an open-source Python library for efficient loading, preprocessing, augmentation and patch-based sampling of medical images for deep learning. It follows the design of PyTorch and relies on standard medical image processing libraries such as SimpleITK or NiBabel to efficiently process large 3D images during the training of convolutional neural networks. We provide multiple generic as well as magnetic-resonance-imaging-specific operations for preprocessing and augmentation of medical images. TorchIO is an open-source project with code, comprehensive examples and extensive documentation shared at https://github.com/fepegar/torchio.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 7

page 9

page 10

Code Repositories

torchio

Tools for loading, augmenting and writing 3D medical images with PyTorch.


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Due to increases in computational power of cpu and gpu, greater availability of large datasets, and recent developments in learning algorithms, deep learning has become a ubiquitous research tool for solving problems related to image understanding and analysis. cnn have become the state of the art for different medical image tasks including segmentation (Çiçek et al., 2016), classification (Lu et al., 2018), reconstruction (Chen et al., 2018) and registration (Shan et al., 2018).

Compared to 2D RGB images typically used in computer vision, processing of medical images such as mri, us or ct presents unique challenges. These include lack of large datasets with manual annotations, often higher computational costs due to the volumetric nature of the data, and the importance of metadata related to the physical size and position of the image.

Some groups have developed and open-sourced frameworks built on top of TensorFlow 

(Abadi et al., 2016) for training of cnn with medical images (Pawlowski et al., 2017; Gibson et al., 2018). As the popularity of PyTorch (Paszke et al., 2019) increases among researchers due to its improved usability compared to TensorFlow (He, 2019), tools compatible with this framework become more and more necessary. To reduce duplication of effort among research groups, improve experimental reproducibility and encourage open science practice, we have developed TorchIO: an open-source Python library for efficient loading, preprocessing, augmentation and patch-based sampling of medical images in deep learning.

2 Motivation

2.1 Challenges

In practice, multiple challenges need to be addressed when developing and applying deep learning algorithms to medical images.

2.1.1 Metadata

In computer vision, picture elements, or pixels, which are assumed to be square, have a spatial relationship that comprises both proximity and depth according to the arrangement of objects in the scene and camera placement. In comparison, many medical images are reconstructed such that voxel location encodes a meaningful 3D spatial relationship. In simple terms, in 2D photo images, pixel correspondence is not necessarily related to spatial correspondence, while in medical imaging this correspondence can be assumed.

In 3D medical images, volume elements, or voxels, are cuboids whose size is defined by the distance between their centroids, i.e. their spacing. Medical images have an offset, which is the position in mm of the first voxel in the file with respect to some origin in the physical space. For example, in mri, the origin is usually set at the magnet isocenter. Lastly, medical images have an orientation in the physical world. These three attributes are encoded in a 9-DOF,

 affine matrix. A medical image can then be defined by a 3D tensor containing voxel data and a 2D matrix representing the spatial information. These images are often stored in the dicom or nifti formats, and commonly read and processed by medical imaging frameworks such as SimpleITK 

(Lowekamp et al., 2013) or NiBabel (Brett et al., 2020).

2.1.2 Limited training data

Deep learning methods typically require large amounts of annotated data, which are often scarce in clinical scenarios due to concerns over patient privacy, the financial and time burden collecting data as part of a clinical trial, and the need for annotations from highly-trained raters. Data augmentation techniques can be used to artificially increase the size of the training dataset by applying different transforms to each training instance while preserving relationship with annotations.

Traditional data augmentation operations applied in computer vision include geometrical transforms such as random rotation or zoom, colorspace transforms such as random channel swapping or kernel filters such as random Gaussian blur. Data augmentation is usually performed on the fly, i.e. every time an image is loaded from disk during training. Multiple libraries with support for computer vision data augmentation have appeared in the last years, such as Albumentations (Buslaev et al., 2020), Augmentor (Bloice et al., 2019), Kornia (Riba et al., 2019), or imgaug (Jung et al., 2020). PyTorch also includes some computer vision transforms, mostly implemented as Pillow wrappers (wiredfool et al., 2016). None of these libraries support reading or transformations of 3D images. Furthermore, medical images are almost always grayscale, therefore colorspace transforms are not applicable. Another example is cropping and scaling, which are common in computer vision but, if applied without care in medical images, may destroy important spatial relationships.

Some image variation is specific to medical images or even dependent on specific medical image modalities. For example, ghosting artifacts will be present in mri if the patient moves during acquisition, and metallic implants often produce streak artifacts in ct. Simulation of these artifacts can be useful when performing augmentation of medical image data.

2.1.3 Computational costs

The number of pixels in 2D images used in deep learning is rarely larger than one million. In contrast, 3D medical images often contain tens of millions of voxels. In computer vision applications, images used to train the network are grouped in batches whose size is often in the order of hundreds (Krizhevsky et al., 2012) or even thousands (Chen et al., 2020) of training instances, depending on the available gpu memory. In medical image applications, batches rarely contain more than one one (Çiçek et al., 2016) or two (Milletari et al., 2016)

samples due to their larger size compared to natural images. This reduces the utility of techniques such as batch normalization, which rely on large enough batches to appropriately model dataset variance 

(Ioffe and Szegedy, 2015)

. Moreover, larger images sizes and small batches both result in training times which are longer, hindering the experimental cycle that researchers need to experiment and tune the training hyperparameters. In cases where gpu memory is limited and the architecture is large, it is possible that not even a single volume can be processed during a single training iteration. To overcome this challenge, it is common to train using image patches randomly extracted from the volumes.

Some groups extract 2D slices from 3D volumes and perform a slice-by-slice prediction at test time (Lucena et al., 2019), aggregating the 2D inference results to generate a 3D volume. This can be seen as a specific case of patch-based training, where the size of the patches along one of the dimensions is one. Other groups extract volumetric patches for training, that are often cubes if the voxel spacing is isotropic (Kamnitsas et al., 2016; Li et al., 2017) or cuboids adapted to the anisotropic spacing of the training images (Nikolov et al., 2018).

Recent techniques such as gradient checkpointing (Chen et al., 2016), automatic mixed precision (Micikevicius et al., 2018) or reversible layers (Brügger et al., 2019) can reduce the memory burden when training with large 3D images.

2.1.4 Transfer learning and preprocessing

The literature has demonstrated mixed results of applying transfer learning from natural to medical images can generally be used successfully 

(Yang et al., 2018; Cheplygina, 2019; Raghu et al., 2019).

Images used in computer vision can generally be thought of as projections of a natural scene onto a 2D plane. Pixel values are usually encoded as an RGB triplet of bytes, i.e. they are in

. One can often pre-train a network on a large dataset of natural images such as ImageNet 

(Deng et al., 2009), which contains more than 14 million labeled images, and fine-tune on a custom dataset using statistics from the original training data to preprocess each training instance.

Preprocessing of medical images is often performed on an image-per-image basis, as opposed to using statistics from the whole dataset, and the fact that these images are often bimodal (i.e. a background and a foreground) needs to be taken into account. Medical image values can be encoded in different data types with different ranges, and the meaning of a specific value can vary between different scanners or sequences. Therefore, intensity normalization methods for medical images are often more complex than those used for natural images (Nyúl and Udupa, 1999).

2.2 Deep learning frameworks

There are currently two major generic deep learning frameworks: TensorFlow (Abadi et al., 2016) and PyTorch (Paszke et al., 2019), primarily maintained by Google and Facebook, respectively. Although TensorFlow has traditionally been the primary choice for both research and industry, PyTorch has recently seen a strong increase in popularity, especially among the research community (He, 2019).

The main reason why researchers prefer PyTorch is that it is Pythonic, i.e. its design, usage, and API follow the conventions of plain Python. Moreover, the api for tensor operations is tightly close to the one for NumPy multidimensional arrays (van der Walt et al., 2011). In contrast, to use TensorFlow, researchers need to become familiar with new design elements such as sessions, placeholders, feed dictionaries or static graphs. In PyTorch, objects are standard Python classes and variables, and a dynamic graph makes debugging intuitive and familiar. These differences are decreased with the recent release of TensorFlow 2, whose eager mode makes usage reminiscent of Python.

3 Related work

The dltk (Pawlowski et al., 2017) and NiftyNet (Gibson et al., 2018) are deep learning platforms designed explicitly for medical image processing and using the TensorFlow 1 platform. They provide implementations of some popular architectures such as U-Net (Çiçek et al., 2016), and can be used to train 3D cnn for different tasks. The last set of substantial commits on the dltk repository are from June 2018, which suggests that the code is not actively maintained. NiftyNet includes some preprocessing and augmentation operations specifically for medical images, such as histogram standardization (Nyúl et al., 2000) and random bias field augmentation (Van Leemput et al., 1999; Sudre et al., 2017), both implemented using NumPy. It is designed to be used through a high-level configuration file, making its usage slightly cumbersome for researchers who desire lower-level access to alter or augment its features.

The medicaltorch library (Perone et al., 2018) closely follows the PyTorch design, and provides some functionalities for preprocessing, augmentation and training of 3D medical images. However, it does not leverage the power of specialized medical image processing libraries such as SimpleITK (Lowekamp et al., 2013)

. For example, the random 3D rotation is performed to the volume slice by slice along a specified axis using PyTorch. If rotations around more axes are desired, computation time will increase linearly with the number of rotations. Moreover, interpolation applied at multiple resampling operations degrades the image. In contrast, multiple rotations, translations, shearings, etc. may be composed into a single affine transform and applied once in 3D, reducing resampling artifacts and computational costs. As dltk, this library has not seen much activity since 2018.

The batchgenerators library (Isensee et al., 2020) includes custom dataset and data loader classes for multithreaded loading of 3D medical images, that were implemented before data loaders were available in PyTorch. In the usage examples from GitHub, preprocessing is applied to the whole dataset before training. Then, spatial data augmentation is performed at volume level, from which one patch is extracted and intensity augmentation is performed at the patch level. In this approach, only one patch is extracted from each volume. Transforms in batchgenerators are mostly implemented using NumPy (van der Walt et al., 2011) and SciPy (Virtanen et al., 2020).

4 Methods

We developed TorchIO, a Python library that focuses on medical image processing for deep learning.

Following the PyTorch philosophy (Paszke et al., 2019), we designed TorchIO with an emphasis on simplicity and usability.

As part of simplicity and usability, we used PyTorch classes and infrastructure where appropriate. This has the added advantage of a fast learning curve for users and, we hope, reduced barrier to adapt the framework to new use cases.

Volumes can be loaded and processed in parallel using torch.utils.data.DataLoader, and image datasets inherit from torch.utils.data.Dataset (torch is the name of the PyTorch Python package). Patch samplers, which yield image patches, inherit from torch.utils.data.IterableDataset. The transforms API is very similar to torchvision.transforms module. For example, TorchIO transforms are composed using torchvision.transforms.Compose. Stochasticity in random transforms is always determined by PyTorch functions to ensure multiprocessing consistence.

TorchIO uses computer vision augmentations such as random affine transformation or random blur where appropriate, but we implemented them using medical imaging libraries (Lowekamp et al., 2013; Brett et al., 2020) that take into account the specific nature of medical images (see Section 2.1.1).

Preprocessing and sampling features are similar to the ones available in NiftyNet. This ensures a smooth transition from NiftyNet input/output system to the TorchIO environment.

In contrast with dltk or NiftyNet, we do not implement architectures, loss functions or training applications, in order to limit the scope of the library for the sake of modularity.

5 Results

TorchIO code is available on GitHub111https://github.com/fepegar/torchio. Detailed api documentation is hosted on Read The Docs222https://torchio.readthedocs.io/ and Jupyter notebook tutorials are hosted on Google Colaboratory, where users can run examples online333https://colab.research.google.com/drive/112NTL8uJXzcMw4PQbUvMQN-WHlVwQS3i. The software can be installed with a single line of code on Windows, macOS or Linux using the pip package manager: pip install torchio.

5.1 Loading

The torchio.ImagesDataset inherits from torch.utils.data.Dataset and can be used to load medical images in popular formats such as nifti, dicom, minc or nrrd using medical images libraries such as NiBabel or SimpleITK. torchio.ImagesDataset can be used as is or customized by inheriting the class. As an example, the ixi dataset444https://brain-development.org/ixi-dataset/ can be downloaded using torchio.datasets.IXI, which is a subclass of torchio.ImagesDataset. torchio.ImagesDataset can be directly used with PyTorch data loaders to leverage their multiprocessing capabilities. If a transform or a composition of transforms is passed to this class, it will be applied to the image after loading. Figure 1 illustrates a typical TorchIO pipeline.

(a) Training with whole volumes
(b) Training with patches
Figure 1: Diagram of data pipelines for training with whole volumes (top) and patches (bottom). Boxes with a red border are PyTorch classes () or inherit from PyTorch classes ().

5.2 Preprocessing

Preprocessing transforms include spatial operations applied to medical images such as resampling (e.g. to make voxel spacing isotropic for all training samples) or reorientation (e.g. so that all training samples are oriented in the same way) using NiBabel. Spatial preprocessing is important as cnn do not generally take into account meta-information related to medical images (see Section 2.1.1).

TorchIO provides intensity normalization techniques including min-max scaling or standardization, which are computed using pure PyTorch. In this context, standardization refers to correcting voxel intensity values to have zero mean and unit variance. A binary image, such as a mask representing the foreground or structures of interest, can be used to define the set of voxels to be taken into account when computing statistics for intensity normalization.

We also provide a method for mri histogram standardization (Nyúl et al., 2000; Gibson et al., 2018), computed using NumPy, which may be used to overcome the differences in intensity distributions between images acquired using different scanners or sequences.

5.3 Augmentation

TorchIO includes spatial transforms such as random flipping using PyTorch and random affine and elastic deformation transforms using SimpleITK.

Intensity augmentation transforms include random Gaussian blur using a SimpleITK filter and addition of random Gaussian noise using pure PyTorch.

TorchIO provides several mri-specific augmentation transforms related to k-space, which are described below. An MR image is usually reconstructed as the magnitude of the inverse Fourier transform of the k-space signal, which is populated with the signals generated by the sample after a radio-frequency pulse. These signals are modulated using coils that create magnetic gradients in the scanner.

Artifacts are created by using k-space transforms to perturb the Fourier space and generate corresponding intensity artifacts in image space. The forward and inverse Fourier transforms are computed using the fft algorithm (Cooley and Tukey, 1965) implemented in NumPy. Figure 2 shows examples of augmentation transforms implemented in TorchIO.

Random k-space spike artifact

Gradients applied at a very high duty cycle may produce bad data points, or spikes of noise, in k-space (Zhuo and Gullapalli, 2006). These points in k-space generate a spike artifact, also known as Herringbone, crisscross or corduroy artifact, which manifests as uniformly-separated stripes in image space, as shown in Fig. 1(i)

. This type of data augmentation has recently been used to estimate uncertainty through a heteroscedastic noise model 

(Shaw et al., 2020).

Random k-space motion artifact

The k-space is often populated line by line, and the sample in the scanner is assumed to remain static. If a patient moves during the mri acquisition, motion artifacts will appear in the reconstructed image. We implemented a method to simulate random motion artifacts (see Fig. 1(h)) that has been used for data augmentation to model uncertainty and improve segmentation (Shaw et al., 2019).

Random k-space ghosting artifact

Organs motion such as respiration or cardiac pulsation may generate ghosting artifacts along the phase-encoding direction (Zhuo and Gullapalli, 2006) (see Fig. 1(j)).

Random bias field artifact

Inhomogeneity of the static magnetic field in the mri scanner produces intensity artifacts of very low spatial frequency along the entirety of the image. These artifacts can be simulated using polynomial basis functions (Van Leemput et al., 1999; Sudre et al., 2017; Gibson et al., 2018), as shown in Fig. 1(g).

Although current domain-specific data augmentation transforms available in TorchIO are mostly related to mri, we encourage users to contribute physics-based data augmentation techniques for us or ct (Omigbodun et al., 2019).

(a) Original image and segmentation
(b) Random blur
(c) Random flip
(d) Random noise
(e) Random affine transformation
(f) Random elastic transformation
(g) Random bias field artifact
(h) Random motion artifact
(i) Random spike artifact
(j) Random ghosting artifact
Figure 2: A selection of data augmentation techniques currently available in TorchIO. Each example is presented as a pair of images composed of the transformed image and a corresponding transformed annotation. Note that all images are a 2D coronal slice of the transformed 3D volumes. The mri corresponds to the mni Colin 27 average brain (Holmes et al., 1998). Annotations were generated using an automated brain parcellation algorithm (Cardoso et al., 2015).

mni

5.4 Patch-based sampling

Memory limitations often require training and inference steps to be performed using image patches, instead of the whole volumes, as explained in Section 2.1.3. A training iteration performed on a gpu is usually faster than loading, preprocessing, augmenting and cropping a volume on a cpu. Therefore, it is beneficial to prepare (i.e. load, preprocess and augment) the volumes using multiprocessing cpu techniques and then sample multiple patches from each volume. The sampled patches are added to a buffer or queue until the next training iteration, at which point they are loaded onto the gpu. Parallel processing can be performed by passing a torchio.ImagesDataset to a PyTorch data loader (torch.utils.data.DataLoader), as explained in Section 5.1.

Following the NiftyNet design, we implemented a queueing system where samplers behave as generators that yield patches from random locations in volumes contained in the subjects dataset. At the beginning of the training pipeline, the subjects dataset may be shuffled. A loader queries the dataset, which then loads and processes volumes in parallel. A sampler fills the queue with patches extracted from the volumes, and the queue may be shuffled if specified. In the background, the data loader continues querying the dataset. The queue is then refilled with new patches when empty. When patches from all subjects have been extracted, the subjects dataset is reshuffled. See diagram Fig. 0(b) for a visual representation of the patch-based training in TorchIO.

We implemented an aggregator for patch-wise inference of large volumes, following a similar design in NiftyNet. The aggregator receives —potentially overlapping— patches that have been sampled by a grid sampler and passed into a trained network, and aggregates them into an output array with the same size as the original volume from which the patches are generated.

5.5 Command-line interface

The provided cli tool torchio-transform allows the user to experiment by applying a transform to an image file without using Python. The tool can also be used in shell scripts to preprocess and augment datasets in cases where large storage is available and on-the-fly loading needs to be faster.

6 Examples

We provide in the documentation a comprehensive example of training a custom implementation of 3D U-Net (Pérez-García, 2020) to perform brain segmentation from mri using TorchIO for image processing.

To create the dataset, we 1) segmented the brain in -weighted MR images from the IXI dataset using gif (Cardoso et al., 2015), 2) registered the images to an mni template using NiftyReg (Modat et al., 2014), and 3) resampled the images to have a size of after a Gaussian filter (SimpleITK) to reduce aliasing. We resized the images so that the example can be run in a short time, with little resources.

For each subject , an instance of torchio.Subject is created with the corresponding images. Images corresponding to each subject are two specific instances of torchio.Image: an intensity image (torchio.INTENSITY) representing the mri and a binary image (torchio.LABEL) representing the brain segmentation, where (single input channel) and (background and foreground). A Python list of subjects is created.

The preprocessing and augmentation pipelines are defined as Python lists containing instances of torchio.transforms.Transform. In this example, the operations for the training set are Rescale, RandomMotion, HistogramStandardization, RandomBiasField, ZNormalization, RandomNoise, ToCanonical, Resample, CenterCropOrPad, RandomFlip, RandomAffine and RandomElasticDeformation, all imported from torchio.transforms. For the validation set, the following subset of preprocessing transforms are used: HistogramStandardization, ZNormalization, ToCanonical, Resample and CenterCropOrPad. TorchIO determines the appropriate image types to apply each transform to, with intensity transforms applied to images with type torchio.INTENSITY and spatial transforms applied to all images.

Intensity transforms are applied by default only to the mri, whereas spatial transforms are applied to both the mri and the segmentation. Each list is composed into a single transform with PyTorch using torchvision.transforms.Compose.

The list of subjects is split into a training list and a validation list and two instances of torchio.ImagesDataset are created, using their corresponding composed transforms.

The following pipeline descriptions apply to both training and validation.

6.1 Case 1: training with whole volumes

The rest of the pipeline is simple when using whole volumes for training, as shown in Fig. 0(a). First, we instantiate a torch.utils.data.DataLoader from the training dataset, specifying the batch size and the number of cpu cores that will be used to prepare the volumes.

Then, the training loop starts. The loader starts spawning processes that query the dataset for samples, where a sample is a dictionary containing the images corresponding to a subject. The dataset loads and applies the transform to the images of the subject. Transforms such as CenterCropOrPad and Resample may modify the number of voxels to . At each iteration, the loader composes a batch with volumes into a Python dictionary. The tensors corresponding to voxel intensity values are extracted from the batch dictionary, and the training iteration is performed. Figure 3 shows an example of a batch of volumes.

Figure 3: Example of a training batch of size when training with whole volumes. Top: input MR images after preprocessing and augmentation; bottom: corresponding labels. Note that these images are 3D, but only one axial slice of each image is shown here for visualization purposes.

Inference is trivial as batches can be generated by a data loader in the same manner.

Note that training with whole volumes is rare unless large gpu are available.

6.2 Case 2: training with patches

To train with patches, the torchio.ImagesDataset is passed to a torchio.Queue, which also inherits from torch.utils.data.Dataset (see Section 5.4). The torch.utils.data.DataLoader is connected to the queue to generate batches of size . This loader does not use multiprocessing, patches are just popped from the queue.

The queue internally uses the torch.utils.data.DataLoader that is connected to the torchio.ImagesDataset to load volumes with multiprocessing. Each volume is passed to the torchio.data.ImageSampler, which inherits from torch.utils.data.IterableDataset. The sampler extracts patches of size from each volume and the patches are added to the queue until it contains patches. In the example, , and .

At each iteration, composes a batch with patches into a Python dictionary. The tensors are extracted from the batch dictionary and the training iteration is executed. Figures 0(b) and 4 show a diagram of the training pipeline and an example of a batch of patches, respectively.

Figure 4: Example of a training batch of size when training with image patches. Top: input patches of MR images after preprocessing and augmentation; bottom: corresponding labels. Note that these are 3D patches, but only one axial slice is shown here for visualization purposes.

To perform a dense prediction across a volume, the volume is loaded by the dataset instance and passed to a torchio.inference.GridSampler

. The sampler creates a uniform grid of patches that can overlap to reduce the border artifacts introduced by padded convolutions. The sampler is passed to a

torchio.utils.data.DataLoader that extracts batches of patches that are passed through the network for inference. The predicted batches and their original locations in the volume are passed to a torchio.inference.GridAggregator which builds the output volume from the batches data and location.

7 Conclusion

We present TorchIO, a new library to efficiently handle medical imaging data during training of cnn. It is designed in the style of the deep learning framework PyTorch, and provides medical imaging specific features such as image reorientation and simulation of mri artifacts for data augmentation.

In the future, we will work on extending the preprocessing and augmentation transforms to different medical imaging modalities such as ct or us. The source code, as well as examples and documentation, are made publicly available online on GitHub. We welcome feedback, feature requests and contributions to the library, either by creating issues on the GitHub repository or by emailing the authors.

Acknowledgments

We thank the NiftyNet team for their support. We also would like to thank Reuben Dorent and Romain Valabregue for their valuable insight.

This work is supported by the EPSRC-funded UCL Centre for Doctoral Training in Medical Imaging (EP/L016478/1). This publication represents in part independent research commissioned by the Wellcome Trust Health Innovation Challenge Fund (WT106882). The views expressed in this publication are those of the authors and not necessarily those of the Wellcome Trust.

References