Automated medical image segmentation is essential for many clinical applications like finding new biomarkers and monitoring disease progression. The recent developments in deep neural network architectures have achieved great performance improvements in image segmentation. Manually designed networks, like U-Net [ronneberger2015unet], have been widely used in different tasks. However, the diversity of medical image segmentation tasks could be extremely high since the image characteristics & appearances can be completely distinct for different modalities and the presentation of diseases can vary considerably. This makes the direct application of even a successful network like U-Net [ronneberger2015unet] to a new task less likely to be optimal.
The neural architecture search (NAS) algorithms [zoph2016nas] have been proposed to automatically discover the optimal architectures within a search space. The NAS search space for segmentation usually contains two levels: network topology level and cell level. The network topology controls the connections among cells and decides the flow of the feature maps across different spatial scales. The cell level decides the specific operations on the feature maps. A more flexible search space has more potential to contain better performing architectures.
In terms of the search methods in finding the optimal architecture from the search space, evolutionary or reinforcement learning-based[zoph2016nas, real2019regularized] algorithms are usually time consuming. C2FNAS [yu2020c2fnas] takes 333 GPU days to search one 3D segmentation network using the evolutionary-based methods, which is too computationally expensive for common use cases. Differentiable architecture search [liu2018darts] is much more efficient and Auto-DeepLab [liu2019auto] is the first work to apply differentiable search for segmentation network topology. However, Auto-DeepLab’s differentiable formulation limits the searched network topology. As shown in Fig. 1, this formulation assumes that only one input edge would be kept for each node. Its final searched model only has a single path from input to output which limits its complexity. Our first goal is to propose a new differentiable scheme to support more complex topologies in order to find novel architectures with better performance.
Meanwhile, the differentiable architecture search suffers from the “discretization gap” problem [chen2019progressive, tian2020discretization]. The discretization of the searched optimal continuous model may produce a sub-optimal discrete final architecture and cause a large performance gap. As shown in Fig. 1
, the gap comes from two sides: 1) the searched continuous model is not binary, thus some operations/edges with small but non-zero probabilities are discarded during the discretization step; 2) the discretization algorithm has topology constraints (e.g. single-path), thus edges causing infeasible topology are not allowed even if they have large probabilities in the continuous model. Alleviating the first problem by encouraging a binarized model during search has been explored[chu2019fair, tian2020discretization, nayman2019xnas]. However, alleviating the second problem requires the search to be aware of the discretization algorithm and topology constraints. In this paper, we propose a topology loss in search stage and a topology guaranteed discretization algorithm to mitigate this problem.
In medical image analysis, especially for some longitudinal analysis tasks, high input image resolution and large patch size are usually desired to capture miniscule longitudinal changes. Thus, large GPU memory usage is a major challenge for training with large high resolution 3D images. Most NAS algorithms with computational constraints focus on latency [cai2018proxylessnas, chen2019fasterseg, li2019partial, Shaw_2019_ICCV] for real-time applications. However, real-time inference often is not a major concern compared to the problem caused by huge GPU memory usage in 3D medical image analysis. In this paper, we propose additional GPU memory constraints in the search stage to limit the GPU usage needed for retraining the searched model.
We validate our method on the Medical Segmentation Decathlon (MSD) dataset [simpson2019large] which contains 10 representative 3D medical segmentation tasks covering different anatomies and imaging modalities. We achieve state-of-the-art results while only takes 5.8 GPU days (recent C2FNAS [yu2020c2fnas] takes 333 GPU days on the same dataset). Our contributions can be summarized as:
We propose a novel Differentiable Network Topology Search scheme DiNTS, which supports more flexible topologies and joint two-level search.
We propose a topology guaranteed discretization algorithm and a discretization aware topology loss for the search stage to minimize the discretization gap.
We develop a memory usage aware search method which is able to search 3D networks with different GPU memory requirements.
We achieve the new state-of-the-art results and top ranking in the MSD challenge leaderboard while only taking 1.7% of the search time compared to the NAS-based C2FNAS [yu2020c2fnas].
2 Related Work
2.1 Medical Image Segmentation
Medical image segmentation faces some unique challenges like lacking manual labels and vast memory usage for processing 3D high resolution images. Compared to networks used in natural images like DeepLab [chen2018encoder] and PSPNet [zhao2017psp], 2D/3D UNet [ronneberger2015unet, cciccek20163d] is better at preserving fine details and memory friendly when applied to 3D images. VNet [milletari2016v] improves 3D UNet with residual blocks. UNet++ [zhou2019unet++] uses dense blocks [huang2017densely] to redesign skip connections. H-DenseUNet [li2018h] combines 2D and 3D UNet to save memory. nnUNet [isensee2019nnunet] ensembles 2D, 3D, and cascaded 3D UNet and achieves state-of-the-art results on a variety of medical image segmentation benchmarks.
2.2 Neural Architecture Search
Neural architecture search (NAS) focuses on designing network automatically. The work in NAS can be categorized into three dimensions: search space, search method and performance estimation[elsken2018neural]. The search space defines what architecture can be searched, which can be further divided into network topology level and cell level. For image classification, [liu2018darts, zoph2018learning, liu2018progressive, real2019regularized, pham2018efficient, gu2020dots] focus on searching optimal cells and apply a pre-defined network topology while [fang2020densely, xie2019exploring] perform search on the network topology. In segmentation, Auto-DeepLab [liu2019auto] uses a highly flexible search space while FasterSeg [chen2019fasterseg] proposes a low latency two level search space. Both perform a joint two-level search. In medical image segmentation, NAS-UNet [weng2019nasunet], V-NAS [zhu2019vnas] and Kim et al [kim2019scalable] search cells and apply it to a U-Net-like topology. C2FNAS [yu2020c2fnas] searches 3D network topology in a U-shaped space and then searches the operation for each cell. MS-NAS [yan2020ms] applies PC-Darts [xu2019pc] and Auto-DeepLab’s formulation to 2D medical images.
Search method and performance estimation focus on finding the optimal architecture from the search space. Evolutionary and reinforcement learning has been used in [zoph2016nas, real2019regularized] but those methods require extremely long search time. Differentiable methods [liu2018darts, liu2019auto] relax the discrete architecture into continuous representations and allow direct gradient based search. This is magnitudes faster and has been applied in various NAS works [liu2018darts, liu2019auto, xu2019pc, zhu2019vnas, yan2020ms]. However, converting the continuous representation back to the discrete architecture causes the “discretization gap”. To solve this problem, FairDARTS [chu2019fair] and Tian et al [tian2020discretization] proposed zero-one loss and entropy loss respectively to push the continuous representation close to binary. Some works [nayman2019xnas, hu2020dsnas] use temperature annealing to achieve the same goal. Another problem of the differentiable method is the large memory usage during search stage. PC-DARTS [xu2019pc] uses partial channel connections to reduce memory, while Auto-DeepLab [liu2019auto] reduces the filter number at search stage. It’s a common practice to retrain the searched model while increasing the filter number, batch size, or patch size to gain better performance. But for 3D medical image segmentation, the change of retraining scheme (e.g. transferring to a new task which requires larger input size) can still cause out-of-memory problem. Most NAS work has been focused on searching architecture with latency constraints [cai2018proxylessnas, chen2019fasterseg, li2019partial, Shaw_2019_ICCV], while only a few considered memory as a constraint. Mem-NAS [liu2020memnas] uses a growing and trimming framework to constrain the inference GPU memory but does not allow integration in a differentiable scheme.
3.1 Network Topology Search Space
Inspired by Auto-Deeplab [liu2019auto] and [li2020learning], we propose a search space with fully connected edges between adjacent resolutions (2 higher, 2 lower or the same) from adjacent layers as shown in Fig. 2. A stack of multi-resolution images are generated by down-sampling the input image by 1/2, 1/4, 1/8 along each axis. Together with the original image, we use four
3D convolutions with stride 2 to generate multi-resolution features (layer 0 in Fig.2) to the following search space. The search space has layers and each layer consists of feature nodes (green nodes) from =4 resolutions and =3-2 candidate input edges (dashed green edges). Each edge contains a cell operation, and a upsample/downsample operation (factor 2) is used before the cell if the edge is an upsample/downsample edge. A feature node is the summation of the output features from each input edge. Compared to Auto-DeepLab [liu2019auto], our search space supports searching for input image scales and complex multi-path topologies, as shown in Fig. 3. As for multi-path topology, MS-NAS [yan2020ms] discretizes and combines multiple single-path models searched from Auto-DeepLab’s framework, but the search is still unaware of the discretization thus causing the gap. [li2020learning] also supports multi-path topology, but [li2020learning] is more about feature routing in a “fully connected” network, not a NAS method.
3.2 Cell Level Search Space
We define a cell search space to be a set of basic operations where the input and output feature maps have the same spatial resolution. The cell search space in DARTS [liu2018darts] and Auto-Deeplab [liu2019auto] contains multiple blocks and the connections among those blocks can also be searched. However, the searched cells are repeated over all the cells in the network topology level. Similar to C2FNAS [yu2020c2fnas], our algorithm searches the operation of each cell independently, with one operation selected from the following: (1) skip connection (2) 3x3x3 3D convolution (3) P3D 3x3x1: 3x3x1 followed by 1x1x3 convolution (4) P3D 3x1x3: 3x1x3 followed by 1x3x1 convolution (5) P3D 1x3x3: 1x3x3 followed by 3x1x1 convolution P3D represents pseudo 3D [qiu2017p3d] and has been used in V-NAS [zhu2019vnas]
. A cell also includes ReLU activation and Instance Normalization[ulyanov2016instance] which are used before and after those operations respectively (except for skip connection). The cell operations do not include multi-scale feature aggregation operations like atrous convolution and pooling. The feature spatial changes are performed by the upsample/downsample operations in the edges searched from the topology level.
3.3 Continuous Relaxation and Discretization
We briefly recap the relaxation in DARTS [liu2018darts]. NAS tries to select one from candidate operations for each computational node. Each operation is paired with a trainable parameter where , , and the output feature , where is the input feature. Thus, the discrete operation is relaxed by the continuous representation which can be optimized using gradient descent. After optimization, with larger is more important and will be selected. However, a small (as long as ) can still make a significant difference on and following layers. Therefore, directly discarding non-zero operations will lead to the discretization gap.
Auto-DeepLab [liu2019auto] extends this idea to edge selection in network topology level. As illustrated in Fig. 1, every edge is paired with a trainable parameter (), and parameters paired with edges that pointed to the same feature node sum to one. This is based on an assumption that “one input edge for each node” because the input edges to a node are competing with each other. After discretization, a single path is kept while other edges, even with a large , are discarded. This means the feature flow in the searched continuous model has a significant gap with the feature flow in the final discrete model. The single-path topology limitation comes from the previous assumption for topology level relaxation while the gap comes from the unawareness of the discretization in the search stage, such that edges with large probabilities can be discarded due to topology.
3.3.2 Sequential Model with Super Feature Node
We propose a network topology relaxation framework which converts the multi-scale search space into a sequential space using “Super Feature Node”. For a search space with layers and resolution levels, these feature nodes in the same layer are combined as a super feature node and features flow sequentially from these super nodes as shown in Fig. 4. There are =3-2 candidate input edges to each super node and the topology search is to select an optimal set of input edges for each super node. We define a connection pattern as a set of selected edges and there are feasible candidate connection patterns. The -th connection pattern
is an indication vector of length, where , if -th edge is selected in -th pattern.
We define the input connection operation to with connection pattern as . defines ’s topology while also includes cell operations on the selected edges in . means the input/output connection patterns for are respectively. Under this formulation, the topology search becomes selecting an input connection pattern for each super node and the competition is among all connection patterns, not among edges. We associate a variable to the connection operation for every and every pattern . Denote the input features at layer 0 as , we have a sequential feature flow equation:
However, is growing exponentially with . To reduce the architecture parameters, we parameterize with a set of edge probability parameters , = in Eq. 2. For a search space with =12 layers and =4, the network topology parameter number is reduced from to . Under this formulation, the probability of connections are highly correlated. If an input edge to has low probability, all the candidate patterns to with selected will have lower probabilities.
For cell operation relaxation, we use the method in Sec. 3.3.1. Each cell on the input edge to has its own cell architecture parameters and will be optimized. Notice the in Eq. 1. contains the cell operations defined on the selected edges, and it contains relaxed cell architecture parameters . Thus we can perform gradient based search for topology and cell levels jointly.
3.3.3 Discretization with Topology Constraints
After training, the final discrete architecture is derived from the optimized continuous architecture representation (derived from ) and . represents the probability of using input connection pattern for super node . Since the network topology search space is converted into a sequential space, the easiest way for topology discretization is to select with the maximum . However, the topology may not be feasible. We define topology infeasibility as:
“a feature node has an input edge but no output edge or has an output edge but no input edge”.
The gray feature nodes in Fig. 5 indicate infeasible topology. Therefore, we cannot select and as ’s input/output connection patterns even if they have the largest probabilities. For every connection pattern , we generate a feasible set . If a super node with input pattern and output pattern is feasible (all feature nodes of the super node are topologically feasible), then . Denote the array of selected input connection pattern indexes for these super nodes as , and the topology discretization can be performed by sampling from its distribution using maximum likelihood (minimize negative log likelihood):
We build a directed graph using and as illustrated in Fig. 5. The nodes (yellow blocks) of are connection operations and the input edge cost to a node in is . The path with minimum cost from the source to the sink nodes (green nodes with gray contour) corresponds to Eq. 4, and we obtained the optimal using Dijkstra algorithm [dijkstra1959note]. For cell operations on the selected edges from , we simply use the operation with the largest .
3.4 Bridging the Discretization Gap
To minimize the gap between the continuous representation and the final discretized architecture, we add entropy losses to encourage binarization of and :
However, even if the architecture parameters and are almost binarized, there may still be a large gap due to the topology constraints in the discretization algorithm. Recall the definition of topology feasibility in Sec. 3.3.3: an activated feature node (node with at least one input edge) must have an output edge while an in-activated feature node cannot have an output edge. Each super node has feature nodes, thus there are node activation pattern. We define as the set of all node activation patterns. Each element is a indication function of length , where if the -th node of the super-node is activated. We further define two sets and representing all feasible input and output connection pattern indexes for a super node with node activation as shown in Fig. 6. We propose the following topology loss:
is the probability that the activation pattern for is , and is the probability that with pattern is feasible. By minimizing , the search stage is aware of the topology constraints and encourages all super nodes to be topologically feasible, thus reduce the gap caused by topology constraints in the discretization step.
3.5 Memory Budget Constraints
The searched model is usually retrained under different training settings like patch size, filter number, or tasks. Auto-DeepLab [liu2019auto] used larger image patch and more filters in retraining compared to the search stage. But this can cause out of memory problem for 3D images in retraining, thus we consider memory budget in architecture search. A cell’s expected memory usage is estimated by . is the memory usage of operation
(estimated by tensor size[gao2020estimating]) defined in Sec. 3.2. The expected memory usage of the searched model is:
Similar to [li2020learning], we consider the budget as the percentage of the maximum memory usage , of which all and equal to one.
We adopt the same optimization strategy as in DARTS [liu2018darts] and Auto-DeepLab [liu2019auto]. We partition the training set into train1 and train2, and optimize the network weights (e.g. convolution kernels) using on train1 and network architecture weights and using on train2 alternately. The loss for is the evenly sum of dice and cross-entropy loss [yu2020c2fnas] in segmentation, while
and are the current and total iterations for architecture optimization such that the searching is focusing more on at the starting point. We empirically scale to the same range with other losses by setting =0.001.
We conduct experiments on the MSD dataset [simpson2019large] which is a comprehensive benchmark for medical image segmentation. It contains ten segmentation tasks covering different anatomies of interest, modalities and imaging sources (institutions) and is representative for real clinical problems. Recent C2FNAS [yu2020c2fnas] reaches state-of-the-art results on MSD dataset using NAS based methods. We follow its experiment settings by searching on the MSD Pancreas dataset and deploying the searched model on all 10 MSD tasks for better comparison. All images are resampled to have a voxel resolution.
4.1 Implementation Details
Our search space has =12 layers and =4 resolution levels as shown in Fig. 2. The stem cell at scale 1 has 16 filters and we double the filter number when decreasing the spatial size by half in each axis. The search is conducted on Pancreas dataset following the same 5 fold data split (4 for training and last 1 for validation) as C2FNAS [yu2020c2fnas]. We use SGD optimizer with momentum 0.9, weight decay of 4e-5 for network weights . We train for the first one thousand (1k) warm-up and following 10k iterations without updating architecture. The architecture weights are initialized with Gaussian respectively before softmax and sigmoid. In the following 10k iterations, we jointly optimize with SGD and with Adam optimizer [kingma2014adam] (learning rate 0.008, weight decay 0). The learning rate of SGD linearly increases from 0.025 to 0.2 in the first 1k warm-up iterations, and decays with factor 0.5 at the following [8k, 16k] iterations. The search is conducted on 8 GPUs with batch size 8 (each GPU with one 969696 patch). The patches are randomly augmented with 2D rotation by [90, 180, 270] degrees in the x-y plane and flip in all three axis. The total training iterations, SGD learning rate scheduler and data pre-processing and augmentation are the same with C2FNAS [yu2020c2fnas]. After searching, the discretized model is randomly initialized and retrained with doubled filter number and doubled batch size to match C2FNAS [yu2020c2fnas]’s setting. We use the SGD optimizer with 1k warm-up and 40k training iterations and decay the learning rate by a factor of 0.5 at [8k, 16k, 24k, 32k] iterations after warm-up. The learning rate scheduler is the same with search stage in the warm-up and the first 20k iterations. The latter 20k iterations are for better convergence and match the 40k total retraining iterations used in C2FNAS [yu2020c2fnas]. The same data augmentation as C2FNAS (also the same as the search stage) is used for the Pancreas dataset for better comparison. To test the generalizability of the searched model, we retrain the model on all of the rest nine tasks. Some tasks in the MSD dataset contain very few training data so we use additional basic 2D data augmentations of random rotation, scaling and gamma correction for all nine tasks. We use patch size and stride for all ten tasks except Prostate and Hippocampus. Prostate data has very few slices (less than 40) in the z-axis, so we use patch size and stride . Hippocampus data size is too small (around ) and we use patch size and stride . Post-processing with largest connected component is also applied.
|3D UNet [cciccek20163d] (nn-UNet)||658||18||9176||-||-||-|
|Attention UNet [oktay2018attention]||1163||104||13465||-||-||-|
4.2 Pancreas Dataset Search Results
The search takes 5.8 GPU days while C2FNAS takes 333 GPU days on the same dataset (both using 8 16GB V100 GPU). We vary the memory constraints and show the search results in Fig. 7. The searched models have highly flexible topology which are searched jointly with the cell level. The 5-fold cross-validation results on Pancreas are shown in Table 1. By increasing
, the searched model is more “dense in connection” and can achieve better performance while requiring more GPU memory (estimated using PyTorch[paszke2019pytorch] functions in training described in Sec. 4.1). The marginal performance drop by decreasing to shows that we can reduce memory usage without losing too much accuracy. Although techniques like mixed-precision training [micikevicius2017mixed] can be used to further reduce memory usage, our memory aware search tries to solve this problem from NAS perspective. Compared to nnUNet [isensee2019nnunet] (represented by 3D UNet because it ensembles 2D/3D/cascaded-3D U-Net differently for each task) and C2FNAS in Table 1, our searched models have no advantage in FLOPs and Parameters which are important in mobile settings. We argue that for medical image analysis, light model and low latency are less a focus than better GPU memory usage and accuracy. Our DiNTS can optimize the usage of the available GPU and achieve better performance.
4.3 Segmentation Results on MSD
The searched model with from Pancreas is used for retraining and testing on all ten tasks of MSD dataset. Similar to the model ensemble used in nnUNet [isensee2019nnunet] and C2FNAS [yu2020c2fnas], we use a 5 fold cross validation for each task and ensemble the results using majority voting. The largest connected component post-processing in nnUNet [isensee2019nnunet] is also applied. The Dice-Sørensen (DSC) and Normalised Surface Distance (NSD) as used in the MSD challenge are reported for the test set in Table 2. nnUNet [isensee2019nnunet] uses extensive data augmentation, different hyper-parameters like patch size, batch size for each task and ensembles networks with different architectures. It focuses on hyper-parameter selection based on hand-crafted rules and is the champion of multiple medical segmentation challenges including MSD. Our method and C2FNAS [yu2020c2fnas] focus on architecture search and use consistent hyper-parameters and basic augmentations for all ten tasks. We achieved better results than C2FNAS [yu2020c2fnas] in all tasks with similar hyper-parameters while only takes 1.7% searching time. Comparing to nn-UNet [isensee2019nnunet], we achieve much better performance on challenging datasets like Pancrease, Brain and Colon, while worse on smaller datasets like Heart (10 test cases), Prostate (16 test cases) and Spleen (20 test cases). Task-specific hyper-parameters, test-time augmentation, extensive data augmentation and ensemble more models as used in nn-UNet [isensee2019nnunet] might be more effective on those small datasets than our unified DiNTS searched architecture. Overall, we achieved the best average results and top ranking in the MSD challenge leaderboard, showing that a non-UNet based topology can achieve superior performance in medical imaging.
|Kim et al [kim2019scalable]||67.40||45.75||68.26||60.47||86.65||72.03||90.28||82.99|
|Kim et al [kim2019scalable]||93.11||96.44||94.25||72.96||83.605||96.76||88.58||92.67|
|Kim et al [kim2019scalable]||63.10||62.51||90.11||88.72||89.42||97.77||97.73||97.75|
|Kim et al [kim2019scalable]||91.92||94.83||72.64||89.02||80.83||95.05||98.03||96.54|
|Kim et al [kim2019scalable]||49.32||62.21||62.34||68.63||65.485||83.22||78.43||80.825|
|Kim et al [kim2019scalable]||80.61||51.75||66.18||95.83||73.09||84.46||74.34||85.12|
4.4 Ablation Study
4.4.1 Search on Different Datasets
The models in Sec. 4.2 and Sec. 4.3 are searched from the Pancreas dataset (282 CT 3D training images). To test the generalizability of DiNTS, we perform the same search as in Sec. 4.1 on Brain (484 MRI data), Liver (131 CT data) and Lung (64 CT data) covering big, medium and small datasets. The results are shown in Table. 3 and demonstrate the good generalizability of our DiNTS.
4.4.2 Necessity of Topology Loss
As illustrated in Sec. 1, the discretization algorithm discards topologically infeasible edges (even with large probabilities), which causes a gap between feature flow in the optimized continuous model (Eq. 1) and the discrete model. Our topology loss encourages connections with large probabilities to be feasible, thus will not be discarded and causing the gap. We denote as the topology decoded by selecting connection with largest for each layer (can be infeasible). is the topology decoded by our discretization algorithm. are the indication matrices of size  representing whether an edge is selected, and . Larger represents larger gap between the feature flow before and after discretization. Fig. 8 shows the change of during search with/without topology loss under different memory constraints. With topology loss, the gap between and is reduced, and it’s more crucial for smaller where the searched architecture is more sparse and more likely to have topology infeasibility.
In this paper, we present a novel differentiable network topology search framework (DiNTS) for 3D medical image segmentation. By converting the feature nodes with varying spatial resolution into super nodes, we are able to focus on connection patterns rather than individual edges, which enables more flexible network topologies and a discretization aware search framework. Medical image segmentation challenges have been dominated by U-Net based architectures [isensee2019nnunet], even NAS-based C2FNAS is searched within a U-shaped space. DiNTS’s topology search space is highly flexible and achieves the best performance on the benchmark MSD challenge using non-UNet architectures, while only taking 1.7% search time compared to C2FNAS. Since directly converting Auto-DeepLab [liu2019auto] to the 3D version will have memory issues, we cannot fairly compare with it. For future work, we will test our proposed algorithm on 2D natural image segmentation benchmarks and explore more complex cells.