Shape analysis of anatomical structures is of core importance for many tasks in medical imaging, not only as a regularization prior for segmentation tasks, but also as a powerful tool to assess differences between subjects and populations. A fundamental question when operating on shapes is to find a suitable numerical representation for a given task. Hence, many different types of parameterizations have been proposed in the past including point distribution models [Cootes1995], spectral signatures [Wachinger2015], spherical harmonics [gerardin2009multidimensional], medial representations [Gorczowski2007], and diffeomorphisms [miller2014diffeomorphometry]. Even though these representations have proven their utility for the analysis of shapes in the medical domain, they might not be optimal for a particular task.
In recent years, deep networks have had ample success for many medical imaging tasks by learning complex, hierarchical feature representations from images. These representations have proven to outperform hand-crafted features in a variety of medical imaging applications [Litjens2017]. One of the main reasons for the success of these methods is the use of convolutional layers, which take advantage of the shift-invariance properties of images [Bronstein2017]. However, the use of deep networks in medical shape analysis is still largely unexplored; mainly because typical shape representations such as point clouds and meshes do not possess an underlying Euclidean or grid-like structure.
In this work, we propose an alternative approach to perform supervised learning on medical shape data. Our method is based on PointNet [Qi2017], a deep neural network architecture, which operates directly on a point cloud and predicts a label in an end-to-end fashion. Point clouds present a raw and simple parameterization that avoids complexities involved with meshes and that is trivial to obtain given a segmented surface. The network does not require the alignment of point clouds, as a spatial transformer network maps the data to a canonical space before further processing. PointNet has been proposed for object classification, where the category of a single shape is predicted. For many medical applications however, not just a single anatomical structure is important for the prediction but a simultaneous view of multiple structures is required for a more comprehensive analysis of a subject’s anatomy. Hence, we propose the Multi-Structure PointNet (MSPNet), which is able to simultaneously predict a label given the shape of multiple structures. We evaluate MSPNet in two neuroimaging applications, neurodegenerative disease prediction and age regression.
1.1 Related Work
Several shape representations have previously been used for supervised learning tasks. Spherical harmonics for approximating the hippocampal shape have been proposed in [gerardin2009multidimensional]. Shape information has been derived from thickness measurements of the hippocampus from a medial representation [costafreda2011automated]. Statistical shape models to detect hippocampal shape changes were proposed by [shen2012detecting]. Multi-resolution shape features with non-Euclidean wavelets were employed for the analysis of cortical thickness [Kim2014107]. The use of medial axis shape representations was used to compare the brain morphology of autistic and normal populations [Gorczowski2007]. Recently, shape representation based on spectral signatures have been introduced to perform age regression and disease prediction [Wachinger2015, wachinger2016domain].
All the mentioned approaches rely on computing pre-defined shape features. Alternatively, a variational auto-encoder was proposed to automatically extract features from 3D surfaces, which can in turn be used in a classification task [Shakeri2016]. However different to our approach, this is not an end-to-end learning since the variational encoder is not directly linked to the classification task. Consequently, the learned features capture overall variation but are not directly optimized for the given task. In addition, this approach relies on computing point correspondences between meshes and focuses on a single structure, while we simultaneously model multiple structures.
We propose a method for multiple structure shape analysis that is divided into two main stages: the extraction of point clouds representing the anatomy of different structures from medical images (section 2.1), and a Multi-Structure PointNet (MSPNet) (section 2.2). Figure 1 illustrates the architecture of MSPNet, which is based on PointNet [Qi2017], and extends on it to allow the simultaneous processing of multiple structures.
2.1 Point Cloud Extraction
We extract point clouds from MRI T1-weighted images of the brain. We process the images with the FreeSurfer pipeline [Fischl2012] and obtain segmentations of multiple neuroanatomical regions. From the resulting segmentations, point clouds are created by uniformly sampling the boundary of each brain structure. After this process, the anatomy of a subject is represented by a collection of point clouds , where each point cloud represents a structure. A point cloud is defined as a set of points
, where each point is a vector of Cartesian coordinates.
2.2 MSPNet Architecture
We aim at finding a network architecture corresponding to a function , mapping a collection of shapes described by to a prediction . An overview of the network is shown in figure 1
. MSPNet consists of multiple branches, where each branch processes the point cloud of one structure independently. This ensures that an optimal feature representation is learned per structure. At the end, the features of all branches are merged to perform a joint prediction. Each branch can be divided into the following stages: 1) point cloud alignment using a transformation network, 2) feature extraction, 3) feature alignment with a second transformation net, 4) dropout and 5) prediction. The first three stages of the architecture of each branch resemble that of a single PointNet architecture. The last two stages are particular to MSPNet.
Point Transformation Network: In contrast to previous approaches in deep medical shape analysis [Shakeri2016], MSPNet does not require point correspondences across shapes, i.e., the i- points of two shapes, and
, respectively, do not need to represent the same anatomical position. We obtain the invariance to rigid transformations in MSPNet by (i) augmenting the training dataset by applying a random rigid transformation to each shape during training time and by (ii) introducing a transformation network (T-Net). This network estimates atransformation matrix, which is applied to the input as a first step. One can think of the T-Net as a transformation into a canonical space to roughly align point clouds before any processing is done. The T-Net is shown in figure 2
and is composed of a multilayer perceptron (MLP), a max pooling operator and two fully connected layers.
Feature Extractions: The transformed points are fed into a MLP with shared weights among points. This MLP layer can be thought of as the feature extraction stage of the network. At this stage of the network, each point has access to the position of all the remaining points of the point cloud, and therefore as the output of the network, we obtain a -dimensional feature vector for each point (in our case = 64). Although each point is assigned a single feature vector, in practice each feature vector point contains a global signature of the input point cloud.
Feature Transformation: A second T-Net is applied to the computed features. This network has the same properties as the first transformation network, but its output corresponds to a transformation matrix. This transformation matrix has a much higher dimension than the previous spatial transformation, which makes the optimization more challenging. To facilitate the optimization of this larger feature transformation matrix
, we constrain it to be close to an orthogonal matrix, similar to [Qi2017]. The regularization term ensures a more stable convergence of the network. After the points are transformed they are fed to a MLP layer.
Dropout and prediction: Up to this point, the architecture of each branch mirrors that of the PointNet. However the final dropout and prediction stage is particular to MSPNet. In PointNet, the last stage corresponds to a max-pooling layer performed across
points, so that the output is a vector with size corresponding to the feature dimensionality. Instead of performing max-pooling, which leads to a strong shrinkage in feature space, we propose to keep the localized information per point. This leads to an increase in the network capacity, which may lend itself to overfitting. Hence, we introduce a dropout layer (keep probability = 0.3) for regularization. The main advantage of the new design is that more localized information is retained in the network, which we hypothesize may boost the predictive power of our network. Finally, the individual features from each branch are concatenated and fed into a last MLP to perform prediction. Batch normalization is used for all MLP layers and ReLU activations are used. The last MLP perceptron counts with intermediate dropout layers with 0.7 keep probabilities as in PointNet. To facilitate the exposition, we assumed that each structure per branch is described by the same number ofpoints, but in practice each structure can be represented by point clouds of different dimensions.
We evaluate the performance of MSPNet in two supervised learning tasks, classification and regression. For the classification task, we aim at using shape descriptors to discriminate between healthy controls (HC), and patients diagnosed with mild cognitive impairment (MCI) or Alzheimer’s disease (AD). For the regression task, we perform age estimation of a subject based on shape information. In all our experiments, we compare to the standard PointNet architecture and spectral shape descriptors in BrainPrint [Wachinger2015], which achieved high performance in a competition for Alzheimer’s disease classification [wachinger2016domain]. For PointNet, the multi-structure input corresponds to a concatenation of the point clouds of all structures. We use image data from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database (adni.loni.usc.edu) [Jack2008]. We work with a total of 7,974 images (2,423 HC, 978 AD, and 4,625 MCI).
3.1 AD and MCI Classification on Shape Data
For this experiment, we perform classification based on the shape of the left and right hippocampus and the left and right lateral ventricles, due to their key importance in Alzheimer’s disease [thompson2004mapping]. Each structure is represented by a Pointcloud of 512 points. For our experiments the dataset is split in a training, validation and test set (75%,15%,15%). Splitting is done on a per subject basis, to guarantee that the same subject does not appear in different sets. Table 1 reports the results of the classification experiment, where we report average classification precision, recall and F1-score. In both classification scenarios, PointNet shows a higher accuracy than BrainPrint, illustrating the potential of learning feature representations. Further, MSPNet showed the best performance, highlighting the benefit of individual feature learning in each branch of the network.
3.2 Age Prediction on Shape Data
For the age estimation task, we perform two different evaluations. In the first one, we perform age estimation only on the healthy controls of the ADNI database. For the second evaluation, we also include patients diagnosed with MCI and AD. The evaluations are done again on the same brain structures used for the classification task. The results of these two experiments are summarized in the mean absolute error plots of figure 3. For the experiment on HC MSPNet significantly outperformed BrainPrint (p-value ) and PointNet (p-value 0.03). In the experiment on all subjects both PointNet and MSPNet presented comparable performance, both outperforming BrainPrint (p-value 0.01).
3.3 Visualizing Point Importance
Of key importance for making predictions with shapes is the ability to visualize the part of the anatomy that is driving the decision. This holds in particular in the clinical context. In MSPNet, we introduce a simple yet effective method to visualize the importance that each point has in the prediction. Our visualization is inspired by the commonly used occlusion method [Grun2016], which consists of occluding parts of a test image and observing differences in the network response. We apply a similar concept to visualize the response of MSPNet. In our case, we assess the importance of each point in the classification task by occluding this point (making the point coordinates equal to 0) together with its nearest neighbors. Then the occluded point cloud is passed through the network and the response of the output ReLU is compared to that obtained when the full point cloud is evaluated. The difference between these responses can then be assigned as the importance of this particular point. In figure 4, we can observe the result of using this visualization technique for one of the AD test subjects in the HC-AD classification experiment. If a point tends towards the red side of the scale, it indicates that by occluding this particular point, the network increases the activation of the AD class. This means that the region around this point is used by the network to predict AD. The exact opposite is true for points on the blue side of the scale. White points indicate that the network response was not largely affected by occluding this point. In the particular case of the example in figure 4, the decision of the network to give this subject a AD label is mainly driven by the left hippocampus.
We introduced MSPNet, a deep neural network for shape analysis on multiple brain structures.
To the best of our knowledge, this is the first time that a neural network for shape analysis on point clouds is proposed in medical applications. We have shown that our method is able to achieve high accuracy in both classification and regression tasks, when compared to shape descriptors based on spectral signatures. This performance is achieved without relying on point correspondences or meshes. MSPNet learns feature representations from multiple structures simultaneously. Finally, we illustrated point-wise importance for the prediction by adapting the occlusion method.
Acknowledgments. This work was supported in part by SAP SE and the Bavarian State Ministry of Education, Science and the Arts in the framework of the Centre Digitisation.Bavaria (ZD.B).