Relational data organizes factual knowledge that is valuable for a wide range of applications Ji-2020, such as question answering Seyler-2015; He-2017 and information retrieval Dalton-2014; Xiong-2015a. As most of the knowledge bases in the real world are incomplete, predicting missing information in knowledge bases, i.e., statistical relational learning (SRL) Getoor-2007 task, has attracted great attention from both academia and industry.
An effective way for solving SRL tasks is to utilize graph neural network (GNN) Scarselli-2009, such as relational graph convolutional network (RGCN) Schlichtkrull-2017. However, data are usually scattered in different companies and institutions, especially for the financial domain where data are sensitive by nature. Collecting data from these institutions is difficult or even forbidden by regulation. Privacy-preserving SRL methods that allow secure collaborative training among different participants are still less studied, thus, hinders the wider applications of graph modeling. A promising direction for such collaborative training purpose is to explore Federated Learning (FL). The first algorithm proposed by Google is the FedAVG. It jointly learns a global model with multiple data sources by only exchanging the gradients or model parameters while keeping the raw data stay locally, thus limits the possibility of information leakage McMahan-2017.
Although FL has achieved significant progress and has been widely applied, the study on combining FL and graph learning based methods remains less explored. One of the possible reasons might be the inherent heterogeneity for different graph datasets. Such graph heterogeneity exists either in the statistical sense that the number of nodes and edges could be extremely varied for graph datasets even with similar types of nodes and edges, or in a structural way where an entity in separate graph datasets might be identical but with different neighborhoods. However, current FL algorithms typically assume IID training data to perform well, the fundamental heterogeneity problem in federated graphs might seriously degrade the performance of the jointly trained model.
The recently proposed FedProx Li-2020
tackles the systems and statistical heterogeneity in federated networks by adding a penalty factor on local loss function to encourage the local model closer to the global one. However, such a method can not be applied to graph related models directly, as the local data in graph modeling tasks has a different property as in batched data, and hence models diverge more significantly. One of the most vital differences is that, the batched data can be separated into batches during the training process, while the graph data can not be split at all. Such difference leads to a high variance of local weight and difficulties in aligning the local and global models. Hence, a more sophisticated alignment method needs to be designed.
In this paper, we first review the possible reasons why local models may differ significantly in the federated setting. We find that the non-separability of graph data and the complex graph model design may aggravate local model divergence. Based on these insights, we propose a simple yet effective solution, called FedAlign, by constraining the loss function to be -Lipschitz smooth and measuring the optimal transportation (OT) distance of hidden layers. Extensive experiments are conducted on several public datasets. Results show that the proposed FedAlign outperforms the state-of-art federated learning methods, such as FedAVG and FedProx on modeling relational data.
We summarize our main contributions as follows:
1) We first propose how to build a naive federated RGCN model based on relational data, by directly using the existing FL technique.
2) We then review the problems (objective divergence and unsmoothness) of the naive solution and propose an advanced method using basis alignment and weight constraint.
3) We finally conduct experiments on three benchmark datasets and the results demonstrate the effectiveness of our proposed method.
In this section, we present the preliminaries on federated learning and graph neural network, so as to propose a modified version of graph neural network for modeling the relational data over the federated networks in the later section.
2.1 Federated Learning
Federated Learning (FL) is an emerging technique that aims to preserve privacy and boost model performance on edge devices. A typical use case of FL is to train a keyboard prediction model on the mobile phone, which predicts the next word according to the last input of a device user. The inputs on the mobile phone are highly private and the user will be reluctant to share the data to the server. However, the limited data on a single phone can hardly be enough for model training. FL was first proposed by McMahan-2017 to solve such a problem by adopting a collaboratively training paradigm without sharing local data.
Federated Averaging (FedAVG) is the most commonly used algorithm in FL McMahan-2017. In this algorithm, each client (i.e., the local device such as a mobile phone of a user) updates its local model using data collected on itself. The sever chooses clients periodically, collects the parameters and aggregates them to compute the global parameters as follows:
where is the data size in each client , and is the size of all the data used in this update round. Finally, clients replace their local parameters with the global ones.
To prevent gradients from leaking sensitive information during the federated optimization, FedAVG can be easily adapted with privacy-preserving techniques, such as differential privacy and secure multiparty communication. However, in the real-world federated applications, the edge devices are not always online and the local data typically varies from device to device. The system and statistical heterogeneity naturally exist. FedProx Li-2020 was proposed recently to especially tackle the system (caused by unreliable communication or stragglers) and statistical heterogeneity (caused by the nature of data collectors). It has achieved the state-of-art performances over many federated benchmarks. The key idea is to introduce a proximal term that limits local updates from diverging:
in (2) is the local counterpart of objective in the (ideally existing) global task. It measures the local empirical risk over possibly differing data distributions . denotes the global model and refers to the local one. FedProx can be viewed as a generalization of FedAVG when .
2.2 Graph Neural Networks
Graph neural network (GNN) has shown to be effective in modeling graph data and has achieved state-of-the-art performances on several tasks Wu-2020; Liu-2018a
. A typical GNN model consists of two parts, i.e., an embedding layer that encodes the graph into learnable vectors and hidden layers that transforms the embedding into task-specific outputs. Specifically, GNN usually defines a differentiable message-passing function on local neighborhoodsGilmer-2017, i.e., :
where denotes the hidden state of node in the -layer, with being the dimensionality of this layer. A message-passing operator chooses from the set of incoming message and is calculated using the neighborhoods of node . Results will then be accumulated and passed through a non-linear function
such as ReLU.
Modeling with Relational Data.
However, the GNN model presented above can only handle homogeneous graph, as it cannot distinguish one relation from others. Relational Graph Convolutional Network (RGCN) extents the idea of GNN to relational data by aggregating the weights of different relations in a knowledge graph Schlichtkrull-2017. The message-passing function of RGCN is defined as:
where is the set of neighbor indices of node under relation . is a normalization constant that can be learned or predefined.
The rapid growth of relations number might lead to overfitting on rare relations or to models with enormous amount of parameters. To prevent that from happening, model weights regularization is enforced on top of RGCN. Authors of RGCN proposed several methods based on the principle of parameters sharing. Instead of learning separated parameters for each relations, RGCN learns a group of shared parameters which can then be composited as weights of relations. Since the shared parameters is trained by all relations, it would less likely overfit to a specific relation.
For example, one of the methods used for such purpose is basis-decomposition that defines each weight as follows:
where each is a linear combination of basis transformation with coefficients . In such way, only the coefficients depend on the relation and therefore prevents model from overfitting and overgrowing. This method has been prove to be effective in entity classification tasks Schlichtkrull-2017.
3 Federated Relational Graph Modeling
As real-world graphs differ significantly, typical FL algorithms, such as FedAVG, cannot be directly used for GNN models without aligning models in each knowledge base. Existing graph models that adapt to the federated setting is less explored, and most of the existing works can only handle homogeneous graphs Suzumura-2019; zheng2020asfgnn. To facilitate modeling heterogeneous graphs over the federated networks, in this paper, we are the first to propose a Federated version of RGCN, i.e., Fed-RGCN.
3.1 Proposed Architecture
As discussed in the last section, to regularize the weights of hidden layers from overfitting and exponentially increasing, RGCN utilizes weight sharing method to map increasing relational weights to the predefined basis. Such a basis design naturally brings convenience for extending the RGCN to the federated setting. As shown in Fig.1, all the participants build their RGCN models with the basis of the same shape, i.e., with the same dimensions and number of layers . In each iteration, the global server chose devices, each updates the model using its own data. Then the server collects their gradients of the basis and aggregating them using a weighted average:
where is the amount of data in each base, which can be calculated in multiple ways. In this paper, we define it as the number of nodes. After gradients aggregation, each participant updates the local model with the global one and continues with the local training.
3.2 Problems of Fed-RGCN
Although the proposed Fed-RGCN can help tackle the model heterogeneity issue in federated relational data modeling, using the traditional FL optimization methods such as federated averaging on Fed-RGCN may lead to slow convergence and degraded performances. In this section, we analyze two important factors that may affect the federated model convergence.
We first introduce some definitions and assumption for analyzing the convergence of federated learning algorithms Li-2020 . Note, in the following equations denotes the local device and the current state.
(-inexact solution) Considering a function and , if there exist a such that , we call is a inexact solution of .
Since , is actually a subproblem of , whose inexactness of its solutions is bounded by if and only if is -Lipschitz continuous.
(-local dissimilarity) Denote as the global objective and the local counterpart on device, the local functions are -locally dissimilar at if We further define for .
An optimal weight that minimizes local objective can also minimize the global one if and only if and is close enough. -local dissimilarity measures such similarity. Assume following assumption holds:
(Bounded dissimilarity) For some and all the points , there exists a such that .
With those definitions and assumption, a federated algorithm is guaranteed to converge in finite iterations. We direct the reader to Li-2020 for a detailed proof. Given a local objective in the form of , since will be assign to
in every epoch and, if Assumption 3.1 holds for , it holds for as well. Hence the dissimilarity between and is bounded, and the corresponding solutions is also bounded by , indicating there exists a solution of which is close enough to the solution of .
The above-mentioned assumptions and analysis are most likely held for batched samples, e.g., texts or images, however, it can hardly be satisfied in the context of graph data modeling, due to the potential divergence and non-smoothness of the objective functions.
Intuitively, the fundamental difference between batched samples and the graph data is their separability. While the samples can be easily divided into several mini-batches in any combination, graph can only be separated in strict conditions. Considering an ideal graph contains all private graphs in each base. By aggregating local trained weights trained on each , the expectation equals the stationary solution trained on if and only if each base is separated from the ideal graph by a cut vertex, which means they can form a complete and exclusive set. Since the federated network is formed in a stochastic manner, this condition will unlikely be met.
Therefore, in the same form as batched data modeling, the local objective is defined as , where is the set of all possible sub-graphs with nodes or relations. Different from the usual setting where , can be sampled only once in each device. Such restriction makes the empirical measurement of objective used in practice as just a surrogate of the expectation with significant variance, and thus the -dissimilarity measurement. Consequently, the bounded assumption can hardly be guaranteed, neither can the convergence of federated learning algorithm.
Another problem that affects the federated algorithm convergence is the smoothness of the objective for graph data modeling. Consider the global loss function and , the -Lipschitz smoothness requires . However, known as the Lipschitz extension problem Aronsson-1967, whether a -Lipschitz continuous function applying on two graphs fulfills the Lipschitz condition depends. It has been proved that Lipschitz extension of higher-dimensional functions on graphs do not always exist Raskhodnikova-2016. Therefore, the global objective for federated graph modeling might not -Lipschitz continuous, neither its expectation on the local device . Reviewing the definition of -inexact solution, it is obvious that if and the corresponding are not -Lipschitz continuous, then there may not exist a solution that makes federated algorithm to converge.
4 Proposed Solutions
As analyzed in the last section, the challenges of applying federated algorithms on Fed-RGCN rising from the potential divergency and non-smoothness of the objective functions. In this section, we propose a federated learning algorithm, called FedAlign, that utilizes optimal transport to regularize the model divergence and a weight penalty to enforce the objective to be quasi-Lipschitz continuous.
4.1 Basis alignment
As mentioned in section 3.2, divergence between the empirical local objective and the expectation violates the Bounded dissimilarity, which makes the convergence of federated learning algorithm unguaranteed. Using to replace as local objective can alleviated such problem, since the impact of biased can be balanced by penalizing the difference between the local weights and the global ones Li-2020. However, such a solution does not work for Fed-RGCN as expected. because we only extract parts of , i.e., the basis for aggregation, which makes not guaranteed to converge towards zero, and thus a biased approximation of .
To alleviate this problem, we can view the weights of Fed-RGCN as a sample drawn by a distribution. Assuming there is a stationary solution of weight who are drawn from certain distribution and the global weight
is an unbiased estimation of, we can expect the distance between the distribution of local and global weights converge to .
Optimal Transportation (OT) distance Villani-2008 is a widely used measurement for such purpose. Intuitively, OT distance can be viewed as the minimum amount of mass needed to be transferred if we want to turning one pile, which is a distribution defined on a given metric space into other. It is also known as earth mover’s distance (EMD) Rubner-1997
in computer vision with the same analogy. Comparing with other metrics, such as Euclidean distance or Kullback-Leibler divergence, OT distance has some nice properties that make it more suitable for comparing distribution related to graph data. For example, it does not assume compared distributions to be in the same probability space, and unlike KL-divergence, OT distance is symmetric for two distributions.
Different choices of cost function leads to different OT distances. In its simplest form, the cost of a move is the distance between the two points, thus, the OT distance is identical to the definition of the Wasserstein-1 distance or namely the EMD. Formally, the EMD can be defined as follows. Given two probability vectors and , each has a dimension of and , respectively. Let be the set of positive matrices, in which the rows sum to and the columns sum to , we have:
where is the dimensional vector of ones.
For two multinomial random variablesand taking values in and , each with distribution and respectively, any matrix can then be identified with a joint probability for such that . Given a cost matrix , in which is the cost to move to . The definition of EMD will be:
can be solve via linear programing.
To lower the cost of calculating OT distance, we use Sinkhorn distance Cuturi-2013 to replace Wasserstein distance. Sinkhorn distance modifies the objective function of Wasserstein distance by adding a entropy constraint:
where and . The Sinkhorn distance can be calculated via iteratively scaling the rows and columns of . The cost of computing Sinkhorn distance is , while the complexity for calculating Wasserstein distance is at least . We direct reader to Cuturi-2013 for further reading.
Although the graph models naturally differ due to the inherent heterogeneity of graph data, our proposed Fed-RGCN only needs to aggregate the basis of . Thus, we only need to calculate the OT distance of basis from different bases. The proximal term that measures the difference between local and global weights is then formulated as the average OT distances between the basis in each layer:
where is the number of selected devices and is a hyper-parameter.
4.2 Weight Penalty
To improve the algorithm convergence, we further add a weight penalty to make the objective function quasi-Lipschitz continuous. Following the previous work Gulrajani-2017, we add a weight penalty into loss function:
where is a hyper-parameter to be tuned.
Essentially, this term penalizes the -norm of gradients larger than
. Originated from Wasserstein generative adversarial network (WGAN)Arjovsky-2017, researchers find it is necessary to constrain critic function to -Lipschitz. Further work by Gulrajani-2017 shows that applying a soft constraint, i.e., the gradient penalty (GP), is more effective than using hard weight clipping. Note, the original weight penalty is an expectation calculated using for batched samples, the term in (11) can only perform in the whole graph since we can not split the graph in the current federated setting. This may cause the penalty biased to local data, a further improvement introducing into it will be favored.
Combining the basis alignment and weight penalty, the local loss function of federated RGCN is defined as:
In Algorithm 1, we present the optimization process for the federated relational data modeling. The resulting algorithm is referred to as FedAlign. Here, and are the hyper-parameters that control the basis alignment and weight penalty. denotes devices that participate in the federated training. denotes the number of epochs trained for each local device before it sends its gradients to serve and
denotes the number of epochs for the whole training process. Note, a stochastic gradient descent (SGD) optimizer with fixed learning rateis used in our implementation, however, other optimizers such as Adam Kingma-2017 can also be used.
5 Empirical evaluation
We evaluate the proposed algorithm with Fed-RGCN on entity classification task to verify its performance. Six settings are studied in the experiments with three federated algorithms, i.e., FedAVG, FedProx, FedAlign, and their variants with weight penalty (denoted by -L).
5.1 Synthetic Datasets
We use three commonly used public datasets in the Resource Description Framework (RDF) format: AIFB, MUTAG, and BGS Ristoski-2016 to test the performance of the proposed algorithm. The dataset contains different types of entities and relations, as shown in Table 1.
A specific type of entity has been labeled to be used as the classification target. The dataset provider has split them into two sets for training and testing. The number of classes and size of both sets can be seen in Table 2.
|Classes||Train Set||Test Set|
To mimic the setting of federated knowledge bases, we split each dataset into parts in the following way. First, for nodes that are not labeled, we randomly select types (excepts MUTAG for ) and sample nodes from the complete dataset for each base. We then shuffle the labeled nodes in the training set and split it into parts, each for a client. The labeled nodes in the test set will be duplicated and stored in each client, but keep unused during the training process. Finally, we add an edge from the complete dataset into a base if it contains its source and destination nodes. The number of nodes and edges in each base are listed in 3.
|AIFB||2993.00 1737.77||7923.60 7092.13|
|MUTAG||6537.10 2634.75||4578.20 1899.34|
|BGS||6123.40 5667.58||4671.00 5930.07|
As we can see from Table 3, although we choose the same numbers of types and relations for each base, the size of entities and edges can still differ tremendously. Such a phenomenon is caused by the unbalanced distribution of entities in different types, and also the vanish of edges if its source and destination are in different bases. It indicates that, in the federated setting of relational data modeling, even with a balanced setup, the statistical heterogeneity of dataset can still be significant.
We implemented the FedAlign, FedAVG and FedProx on RGCN models to compare the algorithm performances. The RGCN model is constructed following the previous work Schlichtkrull-2017 with hidden layers and a constant number of basis . Both three federated algorithms are optimized via a SGD optimizer Bottou-1991 whose learning rate is . Note, the and are the hyper-parameters needs to be tuned.
Algorithms are mostly implemented using PytorchPaszke-2017 and DGL Wang-2019c library. Sinkhorn algorithm is implemented with Geomloss Feydy-2019. We also use Tune Liaw-2018 to grid search the hyper-parameters.
Hyper-parameters settings have significant impacts on the performance of RGCN as well as the federated algorithm. We focus on tuning four parameters: the number of basis , the learning rate , the factor of basis alignment term , and the factor of weight penalty . RGCN and optimizer related parameters, i.e., , and , will affect all three algorithms, while only affects FedProx and FedAlign that constrains the divergence between global and local weights. Surprisingly, the optimal hyper-parameters for all six settings is the same, in which , , and . While the value of and is widely used in practiceGulrajani-2017, the value of and is very different from existing literature (in which is and )Schlichtkrull-2017; Li-2020. Such difference might be caused by the difference between batched samples and relational data.
In addition, two parameters control the amount of computation, i.e., is the number of training makes over the local dataset of each client on each round and denotes the global number of epochs that aggregating all devices. Due to the limited computation resources, we set the and for all datasets.
For separated learning (SP, i.e., to train only on its own device for each participant), three FL algorithms (i.e., FedAVG, FedProx, FedAlign) and their -Lipschitz regularized variants (i.e., FedAVG-L, FedProx-L and FedAlign-L), we run 10 federated training on the separated datasets, then aggregates weights of each base into a global basis, which will then be synchronized to the local model before evaluation. Performance results are measured by the classification accuracy and shown in Table 4.
As we can see from Table 4, FedAlign outperforms other federated algorithms on all three datasets. Comparing with FedAVG and FedProx, FedAlign improves the classification accuracy by on average.
We notice that separated training outperforms most of the federated algorithms on MUTAG dataset. Interestingly, the original RGCN performs worse than the traditional methods on MUTAG and BGS datasets as well. Schlichtkrull-2017 attributing the problem to the nature of datasets. Since MUTAG is a dataset of molecular graphs and BGS of rock types with hierarchical feature, their relations can either indicate atomic bonds or merely the presence of a certain feature. Therefore, the labeled entities in them can only be connected via high-degree hub nodes, such as the name of molecular or rock that encodes a certain feature. In other words, the graph structure will most likely be star-shape, and its information are stored in attributes instead of structures. Modeling these kind of relations needs understanding of the contents in node attributes or the structure of complete graph. Comparing with methods such as RDF2Vec embeddings Ristoski-2016a and Weisfeiler-Lehman kernels (WL) Shervashidze-2011; deVries-2015, which captures such information, RGCN uses only randomized embedding and messages from neighborhoods, thus limits the performance of the model.
Such problem could be even worse for federated learning scheme. Comparing with graph connected via a more centralized way, the structure of star-shape network will more likely to be break by the distributed setting. Such situation will cause tremendous information loss. As shown in Table 3, each base in federated MUTAG contains only edges of the complete dataset, and federated BGS only . Since the size of dataset could be too small, overfitting to local structure could possibly happened.
We randomly select one training log that shown in Fig. 2. The performance is evaluate in each global epoch using aggregated global model on test set. It can be seen that, comparing with models trained on AIFB, models trained on MUTAG and BGS suffering overfitting more significantly. Since federated algorithms aggregating parameters collected from each participants, models that overfitting to local dataset will probably undermining the performance of global model.
-Lipschitz weight penalty can be viewed as an regularization upon model that prevents it from overfitting to local data as analysis in WGAN-GP Gulrajani-2017. We observed similar results in our experiments. As shown in Fig. 1(b) and 1(c), comparing with original algorithms, those with -Lipschitz penalty, i.e. FedAVG-L, FedProx-L and FedAlign-L have better performances in general. Moreover, for the MUTAG and BGS datasets, FedProx-L and FedAlign-L continuously improve after performance declines in the early stage, while FedProx and FedAlign stay stationary in most of the training stage. The performance of -Lipschitz constrained algorithm improved . Such phenomenon indicates that the models have been stuck in local optimal points.
Though the proposed algorithm with basis alignment and weight penalty outperforms FedAVG and FedProx on relational data modeling, it should notice that, all the models trained on federated bases are still underperformed by the model training on complete graph as reported by Schlichtkrull-2017). As we analyzed in Section 3.2, the problems underlying in federated data modeling is the non separability of graph data which leads to a divergence of local loss function and global counterpart, and the incomparability leads to the non-Lipschitz condition. The proposed workarounds can alleviate but hardly eliminate them. Moreover, the information loss, such as edges connected entities in separated bases can not be restored in federated setting. Both problems implies the future work of federated relational data modeling might focus on changing the non-separability and incomparability of graph data.
We analyzed the problems of existing federated modeling on relational data, and proposed FedAlign algorithm to handle them. By using OT distance to measure the divergences of basis in different models and adding -Lipschitz weight penalty to training process, the accuracy of Fed-RGCN could improve with acceptable extra computational cost. Our empirical evaluation has shown the proposed algorithm outperforms state of art methods, such as FedAVG and FedProx on SRL task. As far as we are acknowledged, this is one of the earliest attempts to handle knowledge-graph related missions via federated learning. The study of applying privacy-preserving techniques on graph data remains largely untouched. There is no widely applied methods for some important problems, such as entities alignments, link prediction, that can be performed without leaking the private information. Such situation limits the usage of relational data and requires a change. We hope our work could provide useful insight for the community and push the research forward.