Differentiable Graph Module (DGM) Graph Convolutional Networks

by   Anees Kazi, et al.

Graph deep learning has recently emerged as a powerful ML concept allowing to generalize successful deep neural architectures to non-Euclidean structured data. Such methods have shown promising results on a broad spectrum of applications ranging from social science, biomedicine, and particle physics to computer vision, graphics, and chemistry. One of the limitations of the majority of the current graph neural network architectures is that they are often restricted to the transductive setting and rely on the assumption that the underlying graph is known and fixed. In many settings, such as those arising in medical and healthcare applications, this assumption is not necessarily true since the graph may be noisy, partially- or even completely unknown, and one is thus interested in inferring it from the data. This is especially important in inductive settings when dealing with nodes not present in the graph at training time. Furthermore, sometimes such a graph itself may convey insights that are even more important than the downstream task. In this paper, we introduce Differentiable Graph Module (DGM), a learnable function predicting the edge probability in the graph relevant for the task, that can be combined with convolutional graph neural network layers and trained in an end-to-end fashion. We provide an extensive evaluation of applications from the domains of healthcare (disease prediction), brain imaging (gender and age prediction), computer graphics (3D point cloud segmentation), and computer vision (zero-shot learning). We show that our model provides a significant improvement over baselines both in transductive and inductive settings and achieves state-of-the-art results.


page 7

page 8


SIGN: Scalable Inception Graph Neural Networks

Geometric deep learning, a novel class of machine learning algorithms ex...

A Review of Graph Neural Networks and Their Applications in Power Systems

Deep neural networks have revolutionized many machine learning tasks in ...

Learning Graph Structure from Convolutional Mixtures

Machine learning frameworks such as graph neural networks typically rely...

Factor Graph Neural Network

Most of the successful deep neural network architectures are structured,...

Overcoming Oversmoothness in Graph Convolutional Networks via Hybrid Scattering Networks

Geometric deep learning (GDL) has made great strides towards generalizin...

DiffWire: Inductive Graph Rewiring via the Lovász Bound

Graph Neural Networks (GNNs) have been shown to achieve competitive resu...

Graph-in-Graph (GiG): Learning interpretable latent graphs in non-Euclidean domain for biological and healthcare applications

Graphs are a powerful tool for representing and analyzing unstructured, ...

1 Introduction

Geometric deep learning (GDL) is a novel emerging branch of deep learning attempting to generalize deep neural networks to non-Euclidean structured data such as graphs and manifolds (Bronstein et al., 2017; Hamilton et al., 2017b; Battaglia et al., 2018). Graphs in particular, being a very general abstract descriptions of relation and interaction systems, are ubiquitous in different branches of science. Graph-based learning models have been successfully applied in social networks, link prediction (Zhang & Chen, 2018), human-object interaction (Qi et al., 2018), computer vision (Qi et al., 2017) and graphics (Monti et al., 2017; Wang et al., 2019), particle physics (Choma et al., 2018), chemistry (Duvenaud et al., 2015; Gilmer et al., ; Li et al., 2018b), medicine (Parisot et al., 2018, 2017; Mellema et al., 2019; Kazi et al., 2019b), drug repositioning (Zitnik et al., 2018), and protein science (Gainza et al., 2019), to mention a few.

Early formulations of learning on graphs date back to the seminal work of (Scarselli et al., 2008). Bruna et al. (2013

) proposed formulating convolution-like operations in the spectral domain, defined by the eigenvectors of the graph Laplacian operator, a non-Euclidean analogy of the Fourier transform. More efficient and generalizable spectral graph CNNs were developed using polynomial

(Defferrard et al., 2016; Kipf & Welling, 2016) or rational (Levie et al., 2018; Bianchi et al., 2019) spectral filters. Monti et al. (2017) proposed an analogy of ‘patches’ on graphs based on local weighting. A similar model was employed in graph attention networks (Veličković et al., 2017). The graph attention mechanism was extended in follow-up papers (Kondor, 2018; Bruna & Li, 2017; Monti et al., 2018). (Gilmer et al., ) formulated graph neural networks in terms of message passing, and (Hamilton et al., 2017a) developed efficient mechanisms for learning on large-scale graphs.

We note that most graph neural networks assume that the underlying graph is given and fixed and the graph convolution-like operations typically amount to modifying the node-wise features. Architectures like message passing neural networks (Gilmer et al., ) or primal-dual convolutions (Monti et al., 2018) allow also to update the edge features, but the graph topology is always kept the same. This often happens to be a limiting assumption. In many problems the data can be assumed to have some underlying graph structure, however, the graph itself might not be explicitly given (Liu et al., 2012). This is the case, for example, in medical and healthcare applications, where the graph may be noisy, partially- or even completely unknown, and one is thus interested in inferring it from the data. This is especially important in inductive settings where some nodes might be present in the graph at test but not training. Furthermore, sometimes the graph may be even more important than the downstream task as it conveys some interpretability of the model.

Several geometric models allowing to learn the graph have recently been studied. (Kipf et al., 2018)

proposed a variational autoencoder, in which the latent code represents the interaction graph underlying a physical system, and the reconstruction is based on graph neural networks. Wang et al. (


) proposed dynamic graph CNNs for the analysis of point clouds, where a kNN graph is constructed on the fly in the feature space of the neural network. Zhan et al. (

2018) proposed constructing multiple Laplacians and learn to weight them during optimization. Similarly, Li et al. (2018a) proposed a spectral graph convolutional method, in which residual Laplacian computed on the feature output from each layer and the input Laplacian is updated after each layer. Both the method learn the graph through Laplacians but still needs an initial graph. Huang et al. (2018) proposed another version of spectral filters that parametrize the Laplacian instead of the coefficient of the filter.

In this paper, we propose a generalized technique for learning the graph based on the output features of each layer and optimize these graphs along with the network parameters during the training. The main obstacle for including the graph construction as a part of the deep learning pipeline that, being a discrete structure, it is non-differentiable. Inspired by (Plötz & Roth, 2018)

, we propose a technique that enables the backpropagation through the graph. The main idea is to use the continuous deterministic relaxation of neighborhood selection rules such as kNN, thus allowing differentiating the output w.r.t. the edges of the graph. In order to avoid the use of a pre-fixed graph, we leverage kNN graph on the input feature representation of each node, separately for each layer

(Wang et al., 2019).

In the subsequent sections, we describe our model and extensively evaluate it on applications from the domains of healthcare (disease prediction), brain imaging (gender and age prediction), computer graphics (3D point cloud segmentation), and computer vision (zero-shot learning). Our model shows significant improvement over baselines both in transductive and inductive settings and achieves state-of-the-art results.

Figure 2: Forward propagation rule using the proposed DGM. We show the model architecture of two consecutive layers and the flow from input features to the predicted edges.

2 Method

Given a set of input nodes and associated features , our goal is to discover the underlying latent graph structure in order to enable the use of graph convolutional operators for learning classification tasks. A graph is by construction a discrete structure, where an edge linking two nodes is either present or absent. This makes non-differentiable with respect to its edge set , therefore it cannot be directly optimized with gradient-descent based optimization techniques.

To overcome this limitation, we replace the edge set with weighted adjacency , where is interpreted as the probability of . The probability is computed in a separate feature space designated as the graph representation, where is a learnable function. A graph constructed this way can than be sampled according to to be used in any graph convolutional layer for node representation learning.

In the following subsections we introduce the Differentiable Graph Module (DGM), propose a general architecture that exploits DGM for node-wise classification, and show how can be optimized in a task-driven fashion.

2.1 Differentiable Graph Module (DGM)

As shown in figure 1, DGM takes the node features and the set of edges (if available) as input, and outputs a new set of edges . We divide the operation of DGM in three parts.

Graph representation feature learning.

The learnable part of our DGM conists of a parametric function transforming input features into features used for graph representation as explained in the next paragraph. The function could be in principle any non-linear function, such as a small neural network (MLP), or a graph convolution operator if an input graph is provided.

Probabilistic graph generator.

The probabilistic graph generator part (shown in fig 1) assumes initially a fully connected graph and computes the probability


of the edge . Here is a optimized temperature parameter and is the output of . Such a continuous modeling of allows back propagation of the gradients through the neighborhood selection.

Our choice of using a Euclidean embedding in eq. 1 for defining the edge probability reduces the complexity in comparison to an architectural choice, for instance, an MLP that takes features of two nodes as input to predict their probability (Jang et al., 2019). We note that other spaces, e.g. with hyperbolic geometry (Krioukov et al., 2010), could also be used.

Graph sampling

From the estimated edge probability matrix

P, we then sample a fixed -degree graph. We make use of the Gumbel-Top- trick (Kool et al., 2019)

for sampling the unnormalized probability distribution defined in equation

1, thus making the sampling a stochastic relaxation of the k-NN rule.

Let be the unnormalized probability distribution of ingoing edges of node. We extract edges according to the first elements of where is uniform i.i.d. in the interval . (Kool et al., 2019) prove that the samples extracted this way follow the categorical distribution . We denote the new extracted set of edges by . Finally, the DGM outputs the unweighted graph .

2.2 Forward propagation rule

As shown in figure 2, at each layer , DGM is used as a block to learn the graph , which is then passed as input to the separate ’GraphConv’ operation to learn the node representations for the task at hand. We compute output features at layer as,

where and are learned parameters of some non-linear functions and . are features given as input to DGM for graph representation.

In its simplest implementation , meaning that we use same input features for both graph and node representation. We instead propose to use the concatenation of previous graph and node representation features for :


The final node features of last layer are then used to generate the predictions.

Whether not specified we use ’EdgeConv’ proposed by (Wang et al., 2019) for both and , since it is the natural choice for our fixed k-degree graph sampling:


where, is the permutation-invariant aggregation operation (chosen as in our paper) and is a non-linear function with a set of as the learnable parameters. In the first DGM layer, where no graph is available, we just set as the identity function, letting the network learn only the temperature parameter.

2.3 Graph optimization loss

The sampling scheme we adopt does not allow the gradient of any classification loss function involving just graph features


to flow through the graph prediction branch of our network. To allow its optimization we exploit tools from reinforcement learning, rewarding edges involved in a correct classification and penalizing edges that led to misclassification.

Suppose that, after a forward step, the network outputs the classification for the input features with sampled edges at layer . We define the following graph loss function:


, where is a function taking value if and otherwise, and is the ground truth label.

The previous definition intrinsically weights unevenly positive and negative samples, especially in the early stages of the training where the classification accuracy is low. This drives the network to favor a uniform low probability estimation for all the edges. To prevent this behavior we weight positive and negative samples according to the current per-class accuracy:


with being the class accuracy computed on predictions . Using a per-class accuracy rather than a global accuracy helps in dealing with uneven distribution of samples among different classes in the dataset.

Graph loss is then optimized by summing it with the classification loss (e.g. Categorical Cross-Entropy)

2.4 Multi-modal setting

Multi-modal datasets consist of two (or more) sets of features coming from different modalities. The graph can be learned from the one of the modalities and node representation from the other modality. Towards this, we provide a variant of DGM named as ’Multi-modal GDM’ (M-GDM). The only difference w.r.t. the forward propagation described above is that we train the graph learning part on separate set of features dedicated for the graph learning purpose. Thus amounts to using only in eq. 2.

2.5 Out-of-sample extension

One of the major challenges for the graph-based techniques is the out of sample extension. Specially spectral convolution based methods need a pre-defined graph. In such a setting, it is difficult to change the graph or to add the nodes or to use the filters that are learned for the input graph. In the spatial techniques, the underlying graph needs to be defined beforehand. In the case of out of sample extension, the whole graphs need to be redefined incorporating the test samples.

Different methods have been proposed to solve the out of sample extension, such as (Kipf et al., 2018) and graphSAGE (Hamilton et al., 2017a). In this paper, an embedding function is learned based on the node features, the local and global topology of the nodes. Then the unseen points are projected to the node embedding. However, this method still requires a graph. In this section, we show that our proposed method can be extended to an inductive setting.

In our method, the graph is optimized with the task at hand. From equation 3 we focus on learning the function and , hence the learnable parameter and during the training. Since the graph in our case is dynamic and generated at each layer, it is easy to generate the new graph with a dynamic number of nodes as well. In the inductive setting, the parameters and are used to learn the representations based on the previously trained filters.

3 Experiments and Results

In this section, we show the diverse nature and superiority of our method to 4 different applications. We choose 4 datasets to cover a wide variety of possible heterogeneity in the data.

3.1 Application to disease prediction

Given multi-modal features collected in the hospital, the task is to predict the disease for each patient. Here, we target Alzheimer’s disease prediction given imaging features (MRI, fMRI, PET) and non-imaging (demographics and genotypes) per patient. We pose this problem as a classification of each patient either of the 3 classes viz. Normal, Alzheimer’s and Mild Cognitive Impairment (MCI).

GCNs are being leveraged to utilize such rich multi-modal data. In such a setting, the graph is constructed on the entire population where each patient is considered as a node and the connectivity between the patient is based on the similarity in their respective non-imaging features. Imaging features are assigned to each node. Finally, the features for each node are learned from this setting and used for the classification task. For this experiment we use Tadpole (Marinescu et al., 2018) dataset which is a subset of the Alzheimer’s Disease Neuroimaging Initiative (adni.loni.usc.edu), consisting of 557 patients with 354 multi-modal features per patient. Imaging features are constituted of Magnetic Resonance Imaging, Positron Emission Tomography, cognitive tests, and CSF whereas non-imaging features are constituted of demographics (age, gender), genotype and average FDG PET value.

Method Accuracy

Linear classifier

70.22 06.32
Multi-GCN (Kazi et al., 2019a) 76.06 00.72
Spectral-GCN (Parisot et al., 2017) 81.00 06.40
InceptionGCN (Kazi et al., 2019b) 84.11 04.50
DGCNN 84.59 04.33
Proposed (MM) 90.05 03.70
Proposed (SM) 91.05 05.93
Table 1: The accuracy of classification on Tadpole dataset. We compare the proposed method with respect to the state of the art. The table proves that DGCNN is a Strong baseline to compare with in the further experiments.
node-features graph-features DGCNN M-DGM DGM
M1 M1 82.9803.35 92.5602.57 89.8803.95
M1 M2 85.6505.90 90.0503.70 91.5005.93
M1 M1+M2 84.2205.82 90.9603.59 90.5902.40
M1+M2 M1+M2 84.5904.33 86.8904.91 90.4203.87
Mean 84.3604.85 90.1203.69 90.6004.04
Table 2: The table represents the average accuracy of classification for the 10 fold cross validation in the transductive setting, for tadpole dataset. The first two columns show the feature type used for graph learning chosen between modality 1 and modality 2 corresponding to M1 and M2 respectively.
node-features graph-features DGCNN M-DGM DGM
M 1 M 1 82.9904.91 87.9403.02 88.1203.65
M 1 M 2 81.0604.80 87.5903.05 88.4804.58
M 1 M 1 +M 2 81.9506.17 86.7004.43 89.5405.69
M 1 +M 2 M 1 +M 2 84.3904.57 88.6403.63 87.2303.53
Mean 82.6005.11 87.7203.53 88.3404.36
Table 3: The table represents the average accuracy of classification for the 10 fold cross validation in the inductive setting, for tadpole dataset. 10 % of the data is kept completely unseen.

We show three sets of experiments for this dataset. As a first experiment, in table 1 we compare the proposed method with four states of the art methods. Linear classifier represents a non-graph based method where results are obtained by ridge classifier. Multi-GCN (Kazi et al., 2019a), Spectral-GCN (Parisot et al., 2017) and InceptionGCN (Kazi et al., 2019b) are spectral approaches targeting the classification task. These three methods require a pre-defined graph obtained from non-imaging modality. We also add DGCNN as a baseline, being its dynamically build graph a similar approach to our method.

We can see from the results in a table (1

) that graph-based methods perform better than the linear non-graph based method. DGCNN show better or comparable results with respect to the spectral graph-based techniques making it a strong baseline. Both our proposed models exceed the state of the art results by 7.28% for the classification task. Further, the variance of the proposed method MM-DGM is smaller than most of the state of the art methods showing the robustness of the model.

In our second experiment, shown in table 2, we vary the graph features. In this setting, we keep the node feature constant to check the sensitivity of the model towards the graph-features and compare the performance of the classification task to DGCNN. We also show the results in the inductive setting in table 3. For, this setting we keep 10% of the data completely unseen and train our model with remaining data in the regular fashion. During the inductive setting, we use the pre-trained model for the filters, while the graph in each layer is constructed over the whole population including the 10% out of sample set.

Clearly, in all the settings both our proposed models perform better than DGCNN. This means that the graph constructed by our method is better than a DGCNN (kNN selection rule). The performance of the inductive setting drops globally for all the setting on average by 2.14%. The variance of both of the proposed models is lower than DGCNN reassuring the robustness of the model.

Gender classification
Transductive 87.06 02.89 90.00 01.89 90.2202.03
Inductive 85.31 06.37 88.7103.78 89.1405.92
Table 4: The table represents the accuracy of classification for gender classification task for transductive and inductive settings for the UK Biobank data

Implementation details. As a pre-processing step, we use standard normalization for all the features and apply dimensionality reduction technique ’recursive feature elimination’ to reduce the input feature dimension for all the experiments to 30 for all the datasets. We use

, a learning rate of 0.01 reduced to 0.0001 at the intervals of 100 epochs in a piecewise constant fashion. We train each model for 300 epochs, optimizing the loss using Adam Optimizer. All the experiments are implemented in TensorFlow and performed using Titan Xp, 12GB GPU.

3.2 Application to Gender Classification and Age Prediction Task

Similar to Tadpole dataset, we test our model for two different tasks on another dataset ’UKbiobank’. Given structural, volumetric and functional features of the brain, the first task is to predict the gender and the second to predict the age for each patient. For both tasks, we use a subsample of the UK Biobank data (Miller et al., 2016). It consists of 14,503 individuals with 440 features per individual including age and gender. The features are mainly collected from brain MRI and fMRI imaging providing the structural and functional information for each patient respectively.

Keeping the implementation details similar to Tadpole, firstly, we cast gender prediction as a binary classification task. Secondly, we devise the age prediction task as a categorical classification task. For the age prediction task, we divide the individuals into four groups based on age as group 1 (50-59 years), group 2 (60-69 years), group 3 (70-79 years) and group 4 (80-89 years) making it a four classes classification problem. We report the results for both transductive and inductive settings. For the transductive setting, we split the data into 90% training and 10% testing points whereas for the inductive setting we divide the data into 10% unseen point 80% training and 10% validation set. As can be seen from the table 4 and 5, both our models perform better than DGCNN in all the settings for both tasks.

Age prediction
Transductive 58.3500.91 60.8500.91 61.5901.05
Inductive 51.8408.16 55.7706.01 53.3707.94
Table 5: The table represents the accuracy of classification for age prediction task in transductive and inductive settings for the UK Biobank data. We divide the population into 4 groups with a bin of 10 years starting from 50 to 89 years.
# Shapes DGCNN DGM
Airplane 2690 84.0 84.1
Bag 76 83.4 82.5
Cap 55 86.7 84.6
Car 898 77.8 77.9
Chair 3758 90.6 91.3
Earphone 69 74.7 79.0
Guitar 787 91.2 92.5
Knife 392 87.5 87.7
Lamp 1547 82.8 83.7
Laptop 451 95.7 96.5
Motorbike 202 66.3 66.8
Mug 184 94.9 95.1
Pistol 283 81.1 83.1
Rocket 66 63.5 62.3
Skateboard 152 74.5 77.8
Table 5271 82.6 82.2
MEAN 85.2 85.6
Table 6: Comparison of mIoU(%) score in ShapeNet part segmentation task.
Figure 3: Segmentation examples on ShapeNet dataset. Points are colored according to their predicted (top) and ground-truth (bottom) part labels.

3.3 Application to Point cloud segmentation

Point cloud part segmentation is a more challenging task from the graph optimization perspective. We are given an object represented by a set of 3D points in space with an unknown connectivity. Each object is thus a completely new set of points and there is no intersection between training and testing points.

We directly compare with (Wang et al., 2019) on the task of part segmentation of ShapeNet part dataset (Yi et al., 2016). The dataset is composed of 16881 point clouds representing 3D shapes from 16 different objects categories. Each shape’s point is annotated with one of the 50 part category labels, where most of the shapes are composed by less than 6 category parts. Following the experimental setup of (Wang et al., 2019) we sample 2048 points from each training shape with 3-dimensional features representing by the 3D position of the point. We follow the same train/validation/test split scheme as (Chang et al., 2015).

We mimic the same architecture used by (Wang et al., 2019) for this task, replacing their graph kNN sampling scheme by our DGM with a feature depth of 16. We keep the remaining of the network untouched, including the value of and training parameters. During inference, given the stochastic nature of our graph, we repeat the classification of each point for 8 times and then choose the of the cumulative soft predictions.

Figure 4: Comparison between our DGM (left) and original KNN sampling on the feature space (right) in the last two convolutional layers of the network. In DGM the colormap encodes the probability of each point to be connected to the red point. For the KNN sampling of DGCNN we plot the exponential of the opposite of the euclidean distance on feature space.

In table 6 we report the mean Intersection-over-Union (mIoU) values calculated by averaging the IoUs of all testing shapes. Our approach allows increasing performance over the original kNN sampling scheme on almost all shape classes.

In figure 4 we show the sampling probabilities of some points (red dot) on different shapes at the last two layers of the network. We can notice that the probability of connecting two points is not related to the point feature space which is used for part classification, but it rather retains some spatial information and seems to be inspecting symmetries of the shape. Some segmentation examples are shown in figure 3.

3.4 Application to zero-shot learning task

We first define the problem of zero-shot learning and provide the details of the state-of-the-art GCN based model (Kampffmeyer et al., 2019) used for this task.

The problem of Zero-Shot Learning consists on the classification of samples belonging to classes that have never been seen during training phase. The most popular approach is to train a network to predict a vector representation of a category starting from some implicit knowledge, i.e. semantic embedding. The vector representation is then mapped directly to classifiers

(Xian et al., 2018)

. Recent works showed that using also explicit relations between categories in term of knowledge graphs can help in significantly improve classification accuracy

(Kampffmeyer et al., 2019; Wang et al., 2018).

Figure 5: First two figures show an example of the 2-ring neighborhood of the ”sheep” category. Left figure corresponds to the knowledge graph while central figure is our graph, sampled considering the 5 most probable edges. On the right a plot showing the average predicted probability of edges belonging to the k-ring neighborhood (AwA2 test categories). Higher probabilities corresponding to nearest neighbors suggest that the predicted graph structure is loosely related to the knowledge graph.
Model ACC (%)
GCNZ 70.5
DGP 77.3
DGM (ours) 73.0
Table 7: Classification accuracy for unseen classes on AWA2 dataset. GCNZ and DGP results are reported from (Kampffmeyer et al., 2019).

Proposed Model for Zero-Shot Learning

We base our model on the SGCN architecture proposed in (Kampffmeyer et al., 2019), where the input knowledge graph is replaced by our DGM module.

Let be the set of input samples equipped with a dimensional feature vector. In this case, each X is the semantic embedding (i.e. word vector) associated with each category class. Each layer of the network is composed by the following convolution on graphs:


where is a LeakyRelu non linearity, are the learned weights and , with , is the non-symmetric normalization of the adjacency matrix constructed from the graph sampled with our DGM.

The Zero-Shot task loss is thus defined as:


where is the number of training classes, and are the predicted and ground-truth vector representation of the category.

Note that, even if in 8 we deal with a regression problem, it is straightforward to adapt it to deal with our graph loss defined in equation 6, considering as the predicted category for sample .

Dataset Tadpole UKbiobank ShapeNet Awa
Multi-modal yes no no no
Sample size 557 14.5k 2048 21K
feature size 354 440 3 300
number of graphs 1 1 16881 1
Table 8: The chosen datasets show wide variety of challanges in within them

Dataset and training details As in (Kampffmeyer et al., 2019), we use weights of the last fully connected layer of a ResNet-50 (He et al., 2016)

pre-trained on ImageNet 2012 dataset

(Deng et al., 2009) as our target vector representation . Input semantic features are extracted with GloVe text model (Pennington et al., 2014) trained on Wikipedia dataset. Our model consists of two graph convolution layers with hidden and output layer of dimension 2048 and 2049, paired with two DGM layers of dimension 16 for graph representation and .

We train our model on the 21K ImageNet dataset categories, where we have as input the semantic embedding for all categories, but only the first 1K have a corresponding ground-truth vector representation. The model is trained for 5000 iterations on a randomly subsampled set of 7K categories containing all the 1K of training.

Testing is performed on AWA2 dataset, composed by 37,322 images belonging to 50 different animal categories. In table 7 we report top-1 accuracy results for the test split proposed in (Wang et al., 2018) composed by images from 10 classes not present in the first 1K of ImageNet used for training. Note that, as opposed to both GCNZ (Wang et al., 2018) and DGP (Kampffmeyer et al., 2019), we do not make use of the knowledge graph. As shown in figure 5, the knowledge graph seems indeed a good graph representation for zero-shot task but even if our predicted graph shows some similarity to it, our sampling scheme fails in capturing its hierarchical structure.

4 Discussion and conclusion

4.1 Conclusion

In this paper, we tackled the challenge of graph learning in convolutional graph neural networks. We have proposed a novel Differentiable Graph Module (DGM) that predicts a probabilistic graph, allowing a discrete graph to be sampled accordingly in order to be used in any graph convolutional operator. Further, we devised a weighted loss inspired by reinforcement learning which allows the optimization over edge probabilities.

Our DGM is generic and adaptable to any graph convolution based method. We prove this by using our method to solve a wide variety of tasks starting from application in healthcare (disease prediction), brain imaging (age and gender prediction), computer graphics (3D point cloud segmentation) and computer vision (zero-shot learning), dealing with multi-modal datasets and inductive settings. Table 8 shows the wide heterogeneity captured by the choice of our datasets and tasks.

4.2 Discussion

There are some open problems with the proposed method. Computation-wise our method, even being more lightweight of a full pairwise MLP approach (Jang et al., 2019), still needs the computation of all pairwise distances, making it quadratic with respect to input nodes. Restricting the computation of probabilities in a neighborhood of the node and using tree-based algorithm could help in reducing the complexity to . Further, our choice of sampling neighbors does not consider the heterogeneity of the graph in terms of the degree distribution of nodes. Other sampling schemes (e.g. threshold-based sampling (Jang et al., 2019)) could be investigated. It would be also interesting to take into consideration previous knowledge about the graph, as for instance impose a node degree distribution or even deal with an initial input graph to be optimized for a specific task.


This research has been conducted using the UK Biobank Resource (part of Application No. 12579). We thank Dr. Ben Glocker for helping us with the UK Biobank dataset. This study was supported in part by the TUM-Imperial College incentive fund, the Freunde und Förderer der Augenklinik, München, Germany, ERC Consolidator Grant No. 724228 (LEMAN), and Royal Society Wolfson Research Merit award.