The Importance of Skip Connections in Biomedical Image Segmentation

08/14/2016 ∙ by Michal Drozdzal, et al. ∙ Corporation de l'ecole Polytechnique de Montreal Imagia Cybernetics Inc. 0

In this paper, we study the influence of both long and short skip connections on Fully Convolutional Networks (FCN) for biomedical image segmentation. In standard FCNs, only long skip connections are used to skip features from the contracting path to the expanding path in order to recover spatial information lost during downsampling. We extend FCNs by adding short skip connections, that are similar to the ones introduced in residual networks, in order to build very deep FCNs (of hundreds of layers). A review of the gradient flow confirms that for a very deep FCN it is beneficial to have both long and short skip connections. Finally, we show that a very deep FCN can achieve near-to-state-of-the-art results on the EM dataset without any further post-processing.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Semantic segmentation is an active area of research in medical image analysis. With the introduction of Convolutional Neural Networks (CNN), significant improvements in performance have been achieved in many standard datasets. For example, for the EM ISBI 2012 dataset 

[2], BRATS [13] or MS lesions [19], the top entries are built on CNNs [16, 4, 7, 3].

All these methods are based on Fully Convolutional Networks (FCN) [12]. While CNNs are typically realized by a contracting path built from convolutional, pooling and fully connected layers, FCN adds an expanding path built with deconvolutional or unpooling layers. The expanding path recovers spatial information by merging features skipped from the various resolution levels on the contracting path.

Variants of these skip connections are proposed in the literature. In [12], upsampled feature maps are summed with feature maps skipped from the contractive path while [16] concatenate them and add convolutions and non-linearities between each upsampling step. These skip connections have been shown to help recover the full spatial resolution at the network output, making fully convolutional methods suitable for semantic segmentation. We refer to these skip connections as long skip connections.

Recently, significant network depth has been shown to be helpful for image classification [21, 8, 9, 15]. The recent results suggest that depth can act as a regularizer [8]

. However, network depth is limited by the issue of vanishing gradients when backpropagating the signal across many layers. In

[21], this problem is addressed with additional levels of supervision, while in [8, 9] skip connections are added around non-linearities, thus creating shortcuts through which the gradient can flow uninterrupted allowing parameters to be updated deep in the network. Moreover, [20] have shown that these skip connections allow for faster convergence during training. We refer to these skip connections as short skip connections.

In this paper, we explore deep, fully convolutional networks for semantic segmentation. We expand FCN by adding short skip connections that allow us to build very deep FCNs. With this setup, we perform an analysis of short and long skip connections on a standard biomedical dataset (EM ISBI 2012 challenge data). We observe that short skip connections speed up the convergence of the learning process; moreover, we show that a very deep architecture with a relatively small number of parameters can reach near-state-of-the-art performance on this dataset. Thus, the contributions of the paper can be summarized as follows:

  • We extend Residual Networks to fully convolutional networks for semantic image segmentation (see Section 2).

  • We show that a very deep network without any post-processing achieves performance comparable to the state of the art on EM data (see Section 3.1).

  • We show that long and short skip connections are beneficial for convergence of very deep networks (see Section 3.2)

2 Residual network for semantic image segmentation

Our approach extends Residual Networks [8] to segmentation tasks by adding an expanding (upsampling) path (Figure 1). We perform spatial reduction along the contracting path (left) and expansion along the expanding path (right). As in [12] and [16], spatial information lost along the contracting path is recovered in the expanding path by skipping equal resolution features from the former to the latter. Similarly to the short skip connections in Residual Networks, we choose to sum the features on the expanding path with those skipped over the long skip connections.

We consider three types of blocks, each containing at least one convolution and activation function: bottleneck, basic block, simple block (Figure

1-1

). Each block is capable of performing batch normalization on its inputs as well as spatial downsampling at the input (marked blue; used for the contracting path) and spatial upsampling at the output (marked yellow; for the expanding path). The bottleneck and basic block are based on those introduced in

[8] which include short skip connections to skip the block input to its output with minimal modification, encouraging the path through the non-linearities to learn a residual representation of the input data. To minimize the modification of the input, we apply no transformations along the short skip connections, except when the number of filters or the spatial resolution needs to be adjusted to match the block output. We use convolutions to adjust the number of filters but for spatial adjustment we rely on simple decimation or simple repetition of rows and columns of the input so as not to increase the number of parameters. We add an optional dropout layer to all blocks along the residual path.

Figure 1: An example of residual network for image segmentation. (a) Residual Network with long skip connections built from bottleneck blocks, (b) bottleneck block, (c) basic block and (d) simple block. Blue color indicates the blocks where an downsampling is optionally performed, yellow color depicts the (optional) upsampling blocks, dashed arrow in figures (b), (c) and (d) indicates possible long skip connections. Note that all blocks (b), (c) and (d) can have a dropout layer (depicted with dashed line rectangle).

We experimented with both binary cross-entropy and dice loss functions. Let

be the output of the last network layer passed through a sigmoid non-linearity and let be the corresponding label. The binary cross-entropy is then defined as follows:

(1)

The dice loss is:

(2)

We implemented the model in Keras

[5]

using the Theano backend

[1]

and trained it using RMSprop

[22] (learning rate ) with weight decay set to . We also experimented with various levels of dropout.

3 Experiments

In this section, we test the model on electron microscopy (EM) data [2] (Section 3.1) and perform an analysis on the importance of the long and short skip connections (Section 3.2).

3.1 Segmenting EM data

EM training data consist of images ( pixels) assembled from serial section transmission electron microscopy of the Drosophila first instar larva ventral nerve cord. The test set is another set of images for which labels are not provided. Throughout the experiments, we used images for training, leaving images for validation.

During training, we augmented the input data using random flipping, sheering, rotations, and spline warping. We used the same spline warping strategy as [16]. We used full resolution () images as input without applying random cropping for data augmentation. For each training run, the model version with the best validation loss was stored and evaluated. The detailed description of the highest performing architecture used in the experiments is shown in Table 1.

Layer name block type output resolution output width repetition number
Down 1 conv 32 1
Down 2 simple block 32 1
Down 3 bottleneck 128 3
Down 4 bottleneck 256 8
Down 5 bottleneck 512 10
Across bottleneck 1024 3
Up 1 bottleneck 512 10
Up 2 bottleneck 256 8
Up 3 bottleneck 128 3
Up 4 simple block 32 1
Up 5 conv 3x3 32 1
Classifier conv 1 1
Table 1: Detailed model architecture used in the experiments. Repetition number indicates the number of times the block is repeated.

Interestingly, we found that while the predictions from models trained with cross-entropy loss were of high quality, those produced by models trained with the Dice loss appeared visually cleaner since they were almost binary (similar observations were reported in a parallel work [14].); borders that would appear fuzzy in the former (see Figure 2) would be left as gaps in the latter (Figure 2). However, we found that the border continuity can be improved for models with the Dice loss by implicit model averaging over output samples drawn at test time, using dropout [10] (Figure 2). This yields better performance on the validation and test metrics than the output of models trained with binary cross-entropy (see Table 2).

Figure 2: Qualitative results on the test set. (a) original image, (b) prediction for a model trained with binary cross-entropy, (c) prediction of the model trained with dice loss and (d) model trained with dice loss with dropout at the test time.

Two metrics used in this dataset are: Maximal foreground-restricted Rand score after thinning () and maximal foreground-restricted information theoretic score after thinning (). For a detailed description of the metrics, please refer to [2].

Method FCN post-processing average over parameters (M)
CUMedVision [4] 0.977 0.989 YES YES 6 8
Unet [16] 0.973 0.987 YES NO 7 33
IDSIA [6] 0.970 0.985 NO - - -
motif [24] 0.972 0.985 NO - - -
SCI [11] 0.971 0.982 NO - - -
optree-idsia[23] 0.970 0.985 NO - - -
PyraMiD-LSTM[18] 0.968 0.983 NO - - -
Ours () 0.969 0.986 YES NO Dropout 11
Ours () 0.957 0.980 YES NO 1 11
Table 2: Comparison to published entries for EM dataset. For full ranking of all submitted methods please refer to challenge web page: http://brainiac2.mit.edu/isbi_challenge/leaders-board-new. We note the number of parameter, the use of post-processing, and the use of model averaging only for FCNs.

Our results are comparable to other published results that establish the state of the art for the EM dataset (Table 2). Note that we did not do any post-processing of the resulting segmentations. We match the performance of UNet, for which predictions are averaged over seven rotations of the input images, while using less parameters and without sophisticated class weighting. Note that among other FCN available on the leader board, CUMedVision is using post-processing in order to boost performance.

3.2 On the importance of skip connections

The focus in the paper is to evaluate the utility of long and short skip connections for training fully convolutional networks for image segmentation. In this section, we investigate the learning behavior of the model with short and with long skip connections, paying specific attention to parameter updates at each layer of the network. We first explored variants of our best performing deep architecture (from Table 1

), using binary cross-entropy loss. Maintaining the same hyperparameters, we trained (Model 1) with long and short skip connections, (Model 2) with only short skip connections and (Model 3) with only long skip connections. Training curves are presented in Figure

3 and the final loss and accuracy values on the training and the validation data are presented in Table 3.

We note that for our deep architecture, the variant with both long and short skip connections is not only the one that performs best but also converges faster than without short skip connections. This increase in convergence speed is consistent with the literature [20]. Not surprisingly, the combination of both long and short skip connections performed better than having only one type of skip connection, both in terms of performance and convergence speed. At this depth, a network could not be trained without any skip connections. Finally, short skip connections appear to stabilize updates (note the smoothness of the validation loss plots in Figures 3 and 3 as compared to Figure 3).

Figure 3: Training and validation losses and accuracies for different network setups: (a) Model 1: long and short skip connections enabled, (b) Model 2: only short skip connections enabled and (c) Model 3: only long skip connections enabled.
Method training loss validation loss
Long and short skip connections 0.163 0.162
Only short skip connections 0.188 0.202
Only long skip connection 0.205 0.188
Table 3: Best validation loss and its corresponding training loss for each model.

We expect that layers closer to the center of the model can not be effectively updated due to the vanishing gradient problem which is alleviated by short skip connections. This identity shortcut effectively introduces shorter paths through fewer non-linearities to the deep layers of our models. We validate this empirically on a range of models of varying depth by visualizing the mean model parameter updates at each layer for each epoch (see sample results in Figure

4). To simplify the analysis and visualization, we used simple blocks instead of bottleneck blocks.

Parameter updates appear to be well distributed when short skip connections are present (Figure 4). When the short skip connections are removed, we find that for deep models, the deep parts of the network (at the center, Figure 4) get few updates, as expected. When long skip connections are retained, at least the shallow parts of the model can be updated (see both sides of Figure 4) as these connections provide shortcuts for gradient flow. Interestingly, we observed that model performance actually drops when using short skip connections in those models that are shallow enough for all layers to be well updated (eg. Figure 4). Moreover, batch normalization was observed to increase the maximal updatable depth of the network. Networks without batch normalization had diminishing updates toward the center of the network and with long skip connections were less stable, requiring a lower learning rate (eg. Figure 4).

It is also interesting to observe that the bulk of updates in all tested model variations (also visible in those shown in Figure 4) were always initially near or at the classification layer. This follows the findings of [17], where it is shown that even randomly initialized weights can confer a surprisingly large portion of a model’s performance after training only the classifier.

Figure 4: Weight updates in different network setups: (a) the best performing model with long and short skip connections enabled, (b) only long skip connections enabled with 9 repetitions of simple block, (c) only long skip connections enabled with 3 repetitions of simple block and (d) only long skip connections enabled with 7 repetitions of simple block, without batch normalization. Note that due to a reduction in the learning rate for Figure (d), the scale is different compared to Figures (a), (b) and (c).

4 Conclusions

In this paper, we studied the influence of skip connections on FCN for biomedical image segmentation. We showed that a very deep network can achieve results near the state of the art on the EM dataset without any further post-processing. We confirm that although long skip connections provide a shortcut for gradient flow in shallow layers, they do not alleviate the vanishing gradient problem in deep networks. Consequently, we apply short skip connections to FCNs and confirm that this increases convergence speed and allows training of very deep networks.

Acknowledgements

We would like to thank all the developers of Theano and Keras for providing such powerful frameworks. We gratefully acknowledge NVIDIA for GPU donation to our lab at École Polytechnique. The authors would like to thank Lisa di Jorio, Adriana Romero and Nicolas Chapados for insightful discussions. This work was partially funded by Imagia Inc., MITACS (grant number IT05356) and MEDTEQ.

References