Personalized medicine holds promises to improve individual health by integrating a person’s genetics, environment, and lifestyle information to determine the best approach to prevent or treat diseases (ashley2016towards). Precision medicine research has attracted considerable interest and investment during the past few decades (collins2015new). With the emergence of electronic health records (EHR) linked with biobank specimens, massive environmental data, and health surveys, we now have increasing opportunities to develop accurate personalized risks prediction models in a cost-effective way (li2020electronic).
Despite the availability of large-scale biomedical data, many demographic sub-populations are observed to be underrepresented in precision medicine research (west2017genomics; kraft2018beyond). For example, a disproportionate majority (75%) of participants in existing genomics studies are of European descent (martin2019clinical). The UK biobank, one of the largest biobanks, has more than 95% of European-ancestry (EA) participants (sudlow2015uk). It remains challenging to optimize prediction model performance for such underrepresented populations, when there is a substantial amount of heterogeneity in underlying distributions of data across populations (west2017genomics; landry2018lack; kraft2018beyond; duncan2019analysis). For some diseases, due to the differences in genetic architectures, linkage disequilibrium (LD) structures, and minor allele frequencies across ancestral populations, the performance of genetic risk prediction models in non-European populations has generally been found to be much poorer than in EA populations, most notably in African ancestry (AA) populations (duncan2019analysis)
. To advance prediction medicine, it is crucial to improve the performance of statistical and machine learning models in underrepresented populations so as not to exacerbate health disparities.
We proposed to address the lack of representation and disparities in model performance through two data integration strategies: (1) leveraging the shared knowledge from diverse populations, and (2) integrate larger bodies of data from multiple healthcare institutions. Data across multiple populations may share a certain amount of similarity that can be leveraged to improve the model performance in an underrepresented population (cai2021unified). However, conventional methods where all data are combined and used indistinctly in training and testing, cannot tailor the prediction models to work well for specific population (duncan2019analysis). To account for such heterogeneity and lack of representation, we propose to use transfer learning to transfer the shared knowledge learned from diverse populations to an underrepresented population, so that comparable model performance can be reached with much less data for training (weiss2016survey). In addition, multi-institutional data integration can improve the sample size of the underrepresented populations and the diversity of data (mccarty2011emerge). We propose to use federated learning to unlock the multi-institutional EHR/biobank data, which overcomes two main barriers of institutional data integration. One is that the individual-level information can be highly sensitive which cannot be shared across institutions (van2003data). The other is that the often-enormous size of the EHR/biobanks data makes it infeasible or inefficient to pooling all data together due to challenges in data storage, management, and computation (kushida2012strategies). Therefore, as illustrated in Figure 1.1, our goal is to develop a federated transfer framework to incorporate data from diverse populations that are stored at multiple institutions to improve the model performance in a target underrepresented population.
1.2 Related Work
Existing transfer learning methods primarily focus on settings where individual-level data can be shared. For example, CW19 studied the minimax and adaptive methods for nonparametric classification in the transfer learning setting. Bastani18 studied estimation and prediction in high-dimensional linear models and the sample size of the auxiliary study is larger than the number of covariates. li2020transfer propose a minimax optimal transfer learning algorithm in high-dimensional linear models and study the adaptation to the unknown similarity level. Li2021GGM studies transfer learning in high-dimensional Gaussian graphical models with false discovery rate control. LZCL21 studies transfer learning in high-dimensional generalized linear models (GLMs) and establishes the minimax optimality. tian2021transfer studies adaptation to the unknown similarity level in transfer learning under high-dimensional GLMs. These individual data-based methods cannot be directly extended to the federated setting due to data sharing constraints and the potential heterogeneity across sites.
On the other hand, under data sharing constraints, most federated learning methods focus on settings where the true models are the same across studies. For example, many algorithms fit a common model to data from each institution and then aggregate these local estimates through a weighted average (li2013statistical; chen2014split; lee2017communication; tian2016communication; lian2017divide; battey2018distributed; wang2019distributed, e.g). To improve efficiency, surrogate likelihood approaches have been adopted in recently proposed distributed algorithms (jordan2018communication; duan2019odal; duan2020learningb) to approximate the global likelihood. These methods cannot be easily extended to the federated transfer learning setting where both data sharing constraints and heterogeneity are present.
Recently, liu2020integrative and Xia21 proposed distributed multi-task learning approaches that account for both study heterogeneity and data privacy. They allow site-specific regression parameters that are assumed to be similar across sites in magnitude and support and perform integrative analyses based on derived summary data. However, these methods require the sample sizes for different populations to be of the same order and can perform poorly when the underlying models of some source sites are significantly different from the target sites. Different from their work, we consider a more general setting where data from multiple populations are stored in multiple sites and a more challenging setting where the sample size from each population can be highly unbalanced. We focus on the model performance of an underrepresented target population without making assumptions that model parameters across populations share the support and magnitude. Instead, our methods are robust to the cases where the underlying models for some populations differ significantly from the target population.
1.3 Contributions and main results
We propose a methodology framework that incorporates heterogeneous data from diverse populations and multiple healthcare organizations to improve model fitting and prediction in an underrepresented population. Adopting transfer learning ideas, our methods tackle an important issue in precision medicine where sample sizes from different populations can be highly unbalanced. Our theoretical analysis and numerical experiments show that our methods are more accurate compared to existing methods and are robust to the level of heterogeneity. The federated learning methods we propose only require a small number of communications across participating sites, and can achieve performance comparable to the pooled analysis where individual-level data are directly pooled together. To the best of our knowledge, this is the first work that tailors transfer learning and federated computing towards improving the performance of models in underrepresented populations. From a high level, our theoretical analysis shows that the proposed methods reduce the gap of estimation accuracy across populations, and reveals how estimation accuracy is influenced by communication budgets, privacy restrictions, and heterogeneity among populations. We demonstrate the feasibility and validity of our methods through numerical experiments and a real application to a multi-center study, in which we construct polygenic risk prediction models for Type II diabetes in AA population.
2.1 Problem set-up and notation
We build our federated transfer learning methods based on sparse high-dimensional regression models (tibshirani1996regression; bickel2009simultaneous). These models have been widely applied to precision medicine research for both association studies and risk prediction models, due to the benefits of simultaneous model estimation and variable selection, and the desirable interpretability (qian2020fast).
We assume there are subjects in total from populations. We treat the underrepresented population of interest as the target population, indexed by , while the other populations are treated as source populations, indexed by . We assume data for the subjects are stored at different sites, where due to privacy constraints, no individual-level data are allowed to be shared across sites. We consider the case where is finite but is allowed to grow as the total sample size grows to infinity.
Let be the index sets of the data from the -th population in the -th site, and denote the corresponding sample size, for and . We assume the index sets are known and do not overlap with one another, i.e., for any , and . In precision medicine research, these index sets may be obtained from indicators of minority and disadvantaged groups, such as race/ethnics, gender, and socioeconomic status. Denote , and . We are particularly interested in the challenging scenario , where the underrepresentation is severe. However, at certain sites, the relative sample compositions can be arbitrary. It is possible that some sites may not have data from certain populations, i.e., for some but not all for . We consider the high-dimensional setting where can be larger and much larger than and .
For the -th subject, we observe an outcome variable and a set of predictors including the intercept term. We assume that the target data on the -th site, , follow a GLM
with a canonical link function and a negative log-likelihood function
for some unknown parameter and uniquely determined by . Similarly, the data from the -th source population in the -th site are and they follow a GLM
with negative log-likelihood
for some unknown parameter .
Our goal is to estimate , using data from the populations from the sites. These data are heterogeneous at two levels: For data from different populations, differences may exist in terms of both the regression coefficients, which characterize conditional distribution , as well as the underlying distribution of the covariates , also known as covariate shift in some related work (guo2020inference). For data from a given population, the distribution of covariates might also be heterogeneous across sites. In addition to the heterogeneity, we consider the setting only summary-level data can be shared across sites. Thus, we assume the regression parameters to be distinct across populations and, given a specific population, the regression parameter is the same across sites.
Despite the presence of between-population heterogeneity, it is reasonable to believe that the population-specific models share some degree of similarity. For example, the genetic architectures, captured by regression coefficients, of many complex traits and diseases are found to be highly concordant across ancestral groups (lam2019comparative). It is important to characterize and leverage such similarities so that knowledge can be transferred from the source to the target population.
Under our proposed modeling framework, we characterize the similarities between the -th source population and the target based on the difference between their regression parameters, . We consider the following parameter space
where and are the upper bounds for the support size of and , respectively. Intuitively, smaller indicates a higher level of similarity, so that the source data can be more helpful for estimating in the target population. When is relatively large, incorporating data from source populations may be worse than only using data from the target population to fit the model, also known as negative transfer in the machine learning literature (weiss2016survey). With unknown and in practice, we aim to devise an adaptive estimator to avoid negative transfer under unknown levels of heterogeneity across populations.
2.2 The proposed algorithm
Throughout, for real-valued sequences , we write if for some universal constant , and if for some universal constant . We say if and . We let
denote some universal constants. For a vectorand an index set , we use to denote the subvector of corresponding to . For any vector , let be formed by setting all but the largest (in magnitude) elements of to zero. For a matrix , let and
denote the largest and smallest singular values of, and
. For a random variableand a random vector , define their sub-Gaussian norms as and .
To motivate our proposed federated transfer learning algorithm, we first consider the ideal case when site-level data can be shared. The transfer learning estimator of under the high-dimensional GLMs can be obtained via the following three-step procedure:
Step 1: Fit a regression model in each source population. For , we obtain
Step 2: Adjust for differences using target data. For , we obtain
Threshold via .
Step 3: Joint estimation using source and target data
where , , and are tuning parameters. Instead of learning directly from the target data which have limited sample size, we learn from the source populations, and use them to “jumpstart” the model fitting in the target population. More specifically, we learn the difference by offsetting each . In Step 3, we combine all the data together to jointly learn , where the estimated differences are adjusted for data from the -th source population. In contrast to existing transfer learning methods based on GLM, the above procedure has benefits in estimation accuracy and flexibility to be implemented in the federated setting. Compared to a recent work (tian2021transfer), the above procedure has a faster convergence rate, which is in fact minimax optimal under mild conditions. Moreover, our method learns independently in Step 1 and Step 2, while in other related methods (tian2021transfer; li2020transfer), a pooled analysis is conducted with data from multiple populations. In a federated setting, finding a proper initialization is challenging for such a pooled estimator due to various levels of heterogeneity. In addition, compared to tian2021transfer, the above approach has fewer assumptions on the level of heterogeneity for data from different populations.
The higher-order terms are omitted given that the initial value is sufficiently close to the true parameter. Using these surrogate losses, the sites only need to share three sets of summary statistics, , the score vector and the Hessian matrix . For , we define
The functions are the combined surrogate log-likelihood functions for the -th population based on some previous estimate and corresponding gradients obtained from the sites. We then follow similar strategies as (1)-(3) but replace the full likelihood with the surrogate losses to construct a federated transfer learning estimator for , as detailed in Algorithm 1.
We discuss strategies for the initialization of and in Section 2.3. Algorithm 1 requires iterations, where within each iteration we collect the first- and second-order derivatives calculated at each site based on the current parameter values. In practice, when iterative communication across sites is not preferred, we can choose . We show in Section 3 that additional iterations can improve the estimation accuracy. Proper choices of tuning parameters are also discussed in the sequel. In practical implementation, they can be chosen by cross-validation.
When the source models are substantially different from the target model, the learned estimator may not be better than a target only estimator, which is obtained using only the target data. We thus proposed to increase the robustness of the transfer learning by optimally combining with a target only estimator. This step can guarantee that, loosely speaking, the aggregated estimator has prediction performance comparable to the best prediction performance among all the candidate estimators (RT11; Tsybakov14; Qagg). To this end, let denote a federated target-only estimator, whose construction is detailed the Supplementary Material. This procedure can be aligned with Algorithm 1 in the implementation to reduce number of communications. With and , we perform aggregation using some additional validation data from the target population in a leading site (denoted as the -th site), which should not have any overlap with the training data used for obtaining and . In the leading site, we denote the validation data to be , with sample size for some , where is the sample size of the training data in the leading site from the target population. Define . We compute
And the proposed estimator is defined as Based on our simulation study and real data example, the size of the validation data can be relatively small compared to the training data, and cross-fitting may be used to make full use of all the data. In practice, if there are strong prior knowledge indicating that the level of heterogeneity is low across populations, the aggregation step may be skipped.
(Avoid sharing Hessian matrices) Algorithm 1 requires each site to transmit Hessian matrices to the leading site, which may not be a concern when is relatively small. When is large, we provide possible options to reduce communication cost of sharing Hessian matrices: (1) If the distributions of covariate variables are homogeneous across site for a certain population, we propose to use Algorithm 2, which only requires the first-order gradients from each site. (2) When the distributions of covariate variables are heterogeneous across sites, if it is possible to fit a density ratio model between each dataset and the leading target data, we can still use the leading target data to approximate the Hessian matrices of the other datasets, through the density ratio tilting technique proposed in duan2019heterogeneity. (3) We can leverage the sparsity structures of the population-level Hessian matrices, denoted by , to reduce the communication cost. For example, when constructing polygenic risk prediction, the existing knowledge on LD structure may infer similar block-diagonal structures of the Hessian matrices. In such cases, we can apply thresholding to the Hessian matrices and only share the resulting blocks. (4) As demonstrated in our simulation study and real data application, our algorithm with one round of iteration (T=1) already achieves comparable performance as the pooled analysis. Thus, if choosing , each site will only need to share Hessian matrices once. If more iterations are allowed, we propose an alternative algorithm where only the first-order gradients are needed in the rest of the iterations. More details are included in the supplements.
2.3 Leveraging local Hessian under design homogeneity
When the distribution of in the th population is the same across sites, we introduce a modified version of Algorithm 1 which only requires each participating site sharing only the first-order gradients. This method generalizes the surrogate likelihood approach proposed by wang2017efficient; jordan2018communication to the transfer learning framework and it enjoys communication efficiency. The idea of this algorithm is to use the local data to approximate the Hessian matrices across multiple sites. We require that the leading site (the -th site) has data from all the populations. We will use the empirical Hessian matrix obtained at the leading site to the approximate of the global Hessian in each population. For , denote
is the empirical Hessian for the -th population at based on the samples in the leading site.
Without sharing the Hessian matrices, Algorithm 2 largely reduces the communication cost. However, one limitation is that it requires the distribution in the -th population are homogeneous across sites for any fixed . Second, its reliable performance requires existence of a single site which has relatively large samples from all populations. Otherwise, the local Hessian approximation can be inaccurate and lead to large estimation errors. In practice, however, such a desirable local site may not always exist. We provide a theoretical comparison in Section 3 showing that larger might be needed in Algorithm 2 to achieve the same estimation accuracy compared to 1.
2.4 Initialization strategies
The initialization determines the sample size requirements as well as the number of iterations Algorithms 1 and 2 need to reach a convergence. With data from more than one populations, one needs to balance the sample sizes and similarities across populations.
Here we offer two initialization strategies, namely the single-site initialization and the multi-site initialization. The ideal scenario for initialization is that one site has relatively large sample sizes for all populations. In such a case, we initialize and using the single-site initialization. If we cannot find a site with enough data from all the populations, the multi-site initialization can be used.
Strategy 1: single-site initialization. Find , such that for all . In site , we initialize and by applying the global transfer learning approach introduced in equations (1)-(3). For example, the All of Us Precision Medicine Initiative aims to recruit 1 million Americans, with estimates of early recruitment showing up to 75% of participants are from underrepresented populations. Such a dataset can be treated as an initialization site or leading site.
Strategy 2: multi-site initialization. We first find , which is the site with the largest sample size from the -th population. In site , if sample size of the -th populaiton is much smaller than the total sample size, we initialize by treating the -th population as the target and other populations as the source, and apply the transfer learning approach introduced in equations (1)-(3). If the -th populaiton is the dominating population, we can simply initialize using only its own data. The same procedure applies to the initialization of . For example, when constructing polygenic prediction models, the UK biobank has around 500k EA samples but only 3k AA samples. Thus, if UK biobank is selected for initialization of the EA-population, it can be done by using only the EA samples. On the contrary, if the UK biobank is selected for initialization of the AA population, a transfer learning approach is needed to improve the accuracy of initialization by incorporating both EA and AA samples.
3 Theoretical guarantees
Let and denote the population Hessian matrices for the -th population at the -th site. Let , , denote the population Hessian matrices for the -th population across all sites. We assume the following condition for the theoretical analysis.
For , , are independent uniformly bounded with mean zero and covariance with . The covariance matrices and the Hessian matrices are all positive definite for and .
Condition 3.2 (Lipschitz condition of ).
For , the random noises are independent sub-Gaussian with mean zero. The second-order derivative is uniformly bounded and for any .
Condition 3.1 assumes uniformly bounded designs with positive definite covariance matrices. The distribution of can be different for different . This assumption is more realistic in biomedical setting than the homogeneity assumptions in, say, jordan2018communication. In fact, the heterogeneous covariates are allowed because the Hessian matrices from different sites are transmitted in Algorithm 1. In contrast, Algorithm 2, which only require transmitting the gradients across sites, would require a stricter version of Condition 3.1 as stated in Condition 3.3. On the other hand, the positive definiteness assumption is only required for the hessian matrices involved in the initialization, and the pooled Hessian matrix . Moreover, when having unbounded covariates, one may consider relaxing the uniformly bounded designs to sub-Gaussian designs. We comment that our theoretical analysis still carry through with sub-Gaussian designs but the convergence rate will be inflated with some factors of . Condition 3.2 assume some standard Lipschitz conditions which hold for linear, logistic, and multinomial models.
For the tuning parameters, we take
for some constants and . We set at the magnitude of to simplify the theoretical analysis. In fact, depends on the sample size of the initialization site. With the initialization strategy 1, a reasonable choice of is . With the initialization strategy 2, a reasonable choice of is . In practice, a common practice is to select the tuning parameters by cross-validation.
We first show the convergence rate of the pooled transfer learning estimator in the following lemma.
Lemma 3.1 (Convergence rate of the global transfer learning estimator).
The convergence rate of is minimax optimal in -norm in the parameter space given that according to LZCL21.
Lemma 3.1 demonstrates that the pooled estimator has optimal rates under mild conditions. Its convergence rate is faster than the target-only minimax rate when and . The sample size condition of Lemma 3.1 is relatively mild. First, is easily satisfied as our target population is underrepresented. The condition that suggests that it is beneficial to exclude too-small samples as source data. The condition that and requires that the similarity among different populations is sufficiently high. In practice, this assumption can be violated. In this case, Corollary 3.3 shows the aggregation step we discussed in Section 2.2 prevent negative transfer and guarantee that the estimation error is no worse than only using the target data.
Lemma 3.1 also shows the benefits of our two data integration strategies. When we integrate data across multiple sites, becomes larger, which relaxes the sparsity conditions and improves the convergence rate. On the other hand, when we incorporate data from diverse populations, the total sample size is increased which also improves the convergence rate.
3.1 Convergence rate of Algorithm 1
In this subsection, we first provide in Theorem 3.1 a general conclusion which describes how the convergence rate of Algorithm 1 relies on the initial values. We then provide the convergence rates of Algorithm 1 under initialization Strategies 1 and 2 in Corollaries 3.1 and 3.2, respectively.
Theorem 3.1 (Error contraction of Algorithm 1).
Theorem 3.1 establishes the convergence rate of under certain conditions on the initializations. As the conditions in guarantee that for all and , and converge to and in -norm, respectively. For large enough , the convergence rate of is , which is the minimax rate for estimating in . Hence, the proposed distributed estimators converge to the global minimax estimators. With proper initialization, the smallest satisfying may be very small. Detailed analysis based on the initialization strategies proposed in Section 2.4 are provided in the sequel.
Comparing Theorem 3.1 with Lemma 3.1, we see some important trade-offs in federated learning. First, the larger estimation error of with small in comparison to the pooled version is a consequence of leverage summary information rather than the individual data. Second, while the accuracy of improves as increases, the communication cost also increases. A balance between communication efficiency and estimation accuracy need to determined based on the practical constraints.
To better understand the convergence rate, we investigate the initialization strategies proposed in Section 2.4. Under the single-site strategy, we have the following conclusion.
Corollary 3.1 (Convergence rate of Algorithm 1 with single-site initialization).
Corollary 3.1 uses the result that and under the current conditions. We see that after number of iterations, has the same convergence rate as the global estimator . If and for some finite and , then only constant number of iterations are needed.
Next, we study the performance of Algorithm 1 when using multi-site initialization strategy. For simplicity, we study the case where . In other words, is initialized based on only the data from the -th population in site , i.e.,
Corollary 3.2 (Convergence rate of Algorithm 1 with multi-site initialization).
Corollary 3.2 uses the result that and in the current setting. In this case, we see that after number of iterations, has the same convergence rate as the global estimator .
With the above analyses, we can evaluate the convergence rate for the proposed estimator obtained after the aggregation step.
Through aggregation, we achieve an estimator whose estimation performance is comparable to the better performance of target-only and transfer learning .
3.2 Convergence rate of Algorithm 2
In this section, we provide theoretical guarantees for Algorithm 2, which leverages local Hessian and only transmits first-order information across sites. As we discussed before, it relies on the homogeneity assumption on the distribution of for at each given .
Condition 3.3 (Homogeneous covariates).
Assume that and are identically distributed for any and .
To simplify the theoretical result, we focus on the case . That is, only one source population is in use. The more general case, where can be any finite integer, can be analyzed similarly but the results are harder to interpret.
In the next theorem, we analyze the error contraction behavior of Algorithm 2.