FRAug: Tackling Federated Learning with Non-IID Features via Representation Augmentation

05/30/2022
by   Haokun Chen, et al.
Siemens AG
6

Federated Learning (FL) is a decentralized learning paradigm in which multiple clients collaboratively train deep learning models without centralizing their local data and hence preserve data privacy. Real-world applications usually involve a distribution shift across the datasets of the different clients, which hurts the generalization ability of the clients to unseen samples from their respective data distributions. In this work, we address the recently proposed feature shift problem where the clients have different feature distributions while the label distribution is the same. We propose Federated Representation Augmentation (FRAug) to tackle this practical and challenging problem. Our approach generates synthetic client-specific samples in the embedding space to augment the usually small client datasets. For that, we train a shared generative model to fuse the clients' knowledge, learned from different feature distributions, to synthesize client-agnostic embeddings, which are then locally transformed into client-specific embeddings by Representation Transformation Networks (RTNets). By transferring knowledge across the clients, the generated embeddings act as a regularizer for the client models and reduce overfitting to the local original datasets, hence improving generalization. Our empirical evaluation on multiple benchmark datasets demonstrates the effectiveness of the proposed method, which substantially outperforms the current state-of-the-art FL methods for non-IID features, including PartialFed and FedBN.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

02/16/2022

No One Left Behind: Inclusive Federated Learning over Heterogeneous Devices

Federated learning (FL) is an important paradigm for training global mod...
06/24/2021

Federated Noisy Client Learning

Federated learning (FL) collaboratively aggregates a shared global model...
11/16/2021

Inference-Time Personalized Federated Learning

In Federated learning (FL), multiple clients collaborate to learn a mode...
04/23/2022

Federated Geometric Monte Carlo Clustering to Counter Non-IID Datasets

Federated learning allows clients to collaboratively train models on dat...
06/03/2022

On the Generalization of Wasserstein Robust Federated Learning

In federated learning, participating clients typically possess non-i.i.d...
12/09/2021

Specificity-Preserving Federated Learning for MR Image Reconstruction

Federated learning (FL) can be used to improve data privacy and efficien...
08/18/2020

Adaptive Distillation for Decentralized Learning from Heterogeneous Clients

This paper addresses the problem of decentralized learning to achieve a ...

1 Introduction

Federated Learning (FL) is a machine learning paradigm in which a shared model is collaboratively trained using decentralized data sources. In the classic FL, i.e., FedAvg

McMahan et al. (2017), the central server obtains the model by iteratively aggregating and averaging the optimized parameters from the active clients, which does not require direct access to the clients’ local data and therefore preserves the data confidentiality.

However, traditional FL suffers from performance degradation when the data is heterogeneous across the clients (Zhao et al., 2018). Several methods (Li et al., 2020b; T Dinh et al., 2020) were developed to tackle problem settings where the clients have different label distributions. We focus on the recently proposed (Li et al., 2021b) underexplored problem of heterogeneity in the feature space, i.e., the distributions of the input features from different clients are not identical, hence the overall data is non-IID. This problem setting is of great practical value since it is present in multiple real-world scenarios where different entities have local data following different feature distributions. For instance, in the healthcare sector, different clinical centers use different scanners and data acquisition protocols, resulting in a distribution shift in the input space of the collected medical images (Dou et al., 2019). Analogously, in industrial manufacturing, the data collected by machines from different manufacturers, possibly using different sensors, in different production plants, exhibit feature distribution shift. Most importantly, although these commercial entities have the same application, e.g., diagnosing the same type of cancer or detecting the same types of anomalies, they may not be willing to share their original data to prevent competitive disadvantage. Therefore, FL for non-IID features provides the benefit of training a machine learning model collaboratively in a decentralized, data-privacy-preserving manner.

In this paper, we propose Federated Representation Augmentation (FRAug) to address the feature shift problem in FL. In FRAug, we first aggregate the knowledge of different clients in the feature space by learning a shared generator, which is optimized to produce client-agnostic feature embeddings. Then, a Representation Transformation Network (RTNet) is locally trained for each client, which transforms the client-agnostic representations into client-specific representations. Finally, we use both the client local dataset, as well as the client-specific synthetic embeddings to obtain the classification model with better generalization ability. We demonstrate the effectiveness of FRAug by empirically evaluating it on real-world FL benchmarks with non-IID features, where our method outperforms the classic FL baselines FedAvg (McMahan et al., 2017), as well as the recently proposed methods for FL with feature shift.

2 Related Work

Federated Learning (FL): Federated Averaging (FedAvg) (McMahan et al., 2017) is one of the classic FL algorithms for training machine learning models using decentralized data sources. This simple paradigm suffers from performance degradation when there exists data heterogeneity (Kairouz et al., 2021; Li et al., 2020a). Numerous studies have been conducted for label space heterogeneity, i.e., class distributions are imbalanced across different clients, by regularizing local update with proximal term (Li et al., 2020b), personalizing client models (Arivazhagan et al., 2019; Fallah et al., 2020; T Dinh et al., 2020; Li et al., 2021a), utilizing shared local data (Zhao et al., 2018; Liu et al., 2021; Gong et al., 2021), introducing additional proxy datasets (Li and Wang, 2019; Lin et al., 2020; Gong et al., 2022), or performing data-free knowledge distillation (Lopes et al., 2017) in the input space (Hao et al., 2021; Zhang and Yuan, 2021; Zhang et al., 2022) or the feature space (He et al., 2020; Zhu et al., 2021). However, there are only limited studies addressing the heterogeneity in feature space, i.e., non-IID features. Recently, (Andreux et al., 2020)

showed that Batch Normalization layers (BN)

(Ioffe and Szegedy, 2015) with local-statistics improve the robustness of the FL model to inter-center data variability and yield better out-of-domain generalization results, while FedBN (Li et al., 2021b) provided more theoretical analysis on the benefits of local BN layers for FL with feature shift. PartialFed (Sun et al., 2021) empirically found that partially initializing the client models could alleviate the effect of feature distribution shift. In this work, we tackle the problem of non-IID features in FL via a client-specific data augmentation approach performed in the embedding space. In particular, client-agnostic embeddings are initially synthesized by a shared generator that captures the knowledge from different distributions, which are then personalized by separate client-specific models. Training the local models with the resulting client-specific embeddings yields a higher generalization to unseen data from the local distribution.

Cross-Domain Learning: The problem of learning on centralized data with non-IID features, i.e., cross-domain data, has been widely studied in the context of Unsupervised Domain Adaptation (UDA) (Wilson and Cook, 2020), where a model is trained using multiple source domains and finetuned using an unlabelled target domain, and Domain Generalization (DG) (Zhou et al., 2021a), where the target domain data is not accessible during the training process of UDA. A variety of efforts have been made to tackle the problem of UDA and DG. CROSSGRAD (Shankar et al., 2018)

used adversarial gradients obtained from a domain classifier to augment the training data. L2A-OT

(Zhou et al., 2020) trained a generative model to transfer the training samples into pseudo novel domains. MixStyle (Zhou et al., 2021b)

performed feature-level augmentation by interpolating the style statistics of the output features from different network layers. While the aforementioned methods assume centralized access to all datasets from different domains, we address the problem where the datasets are decentralized and cannot be shared due to privacy concerns.

3 Methodology

3.1 Problem Statement: Federated Learning (FL) with Non-IID Features

In this work, we address an FL problem setting with non-IID features, which we describe in the following. Let be an input space, be a feature space, and be an output space. Let denote the parameters of the classification model trained in an FL setting involving one central server and clients. The model consists of two components: a feature extractor parameterized by , and a prediction head parameterized by . We assume that a dataset , containing private data, is available on each client, where denotes the number of samples in and denotes the number of classes. As discussed in (Kairouz et al., 2021), FL with non-IID data can be described by the distribution shift on local datasets: with , where

defines the joint distribution of input space

and label space on . The addressed problem setting, i.e., FL with non-IID features, covers (1) covariate shift: The marginal distribution varies across clients, while is the same, and (2) concept shift: The conditional distribution varies across clients, while is the same (Li et al., 2021b). From the perspective of cross-domain learning (Wilson and Cook, 2020; Zhou et al., 2021a) literature, local data from every client can be viewed as a separate domain. In this work, we use cross-domain FL and FL with non-IID features interchangeably.

3.2 Motivational Case Study

To motivate our representation augmentation algorithm, we present an empirical analysis to address the following research question: In a data scarce scenario, does finetuning only the prediction head using additional feature embeddings leads to performance improvement? For this, we conduct local training experiments on each client, i.e., without federated learning. First, we optimize a classification model with of the local dataset . As a result, each client has ca. 100 to 1000 data samples available for training, which is matches the experimental settings used in prior FL work (McMahan et al., 2017; Li et al., 2021b). Then, we fix the parameters of the feature extractor , i.e., , and finetune only the prediction head with of . Finally, we evaluate both classification models, i.e., the model fully trained with of the data, and its finetuned version where the prediction head is further trained with embeddings of of the examples. We note that both models use the same feature extractor that was trained with of the data.

Method OfficeHome PACS
A C P R Avg A C P S Avg
w/o finetune 35.80 45.54 67.04 61.16 52.42 82.37 86.08 92.01 87.52 87.00
w. finetune 69.96 72.31 86.04 80.50 77.20 88.46 88.61 97.08 92.21 91.59
Table 1: Performance comparison of the classification model with (w.) and without (w/o) prediction head finetuning using embeddings of additional examples.

Table 1 shows the classification performance on the OfficeHome (Venkateswara et al., 2017) and PACS (Li et al., 2017) datasets, where each client holds data from one domain. We find that optimizing only the prediction head with additional embeddings, leads to an average improvement of for OfficeHome and for PACS, a substantial performance boost. These results show that the feature extractor, trained with less data, still captures useful information when exposed to unseen image samples and give evidence of the applicability and effectiveness of data augmentation methods in the embedding space, in scarce data settings. We leverage this observation when designing the proposed method.

3.3 FRAug: Federated Representation Augmentation

To tackle FL with non-IID features, we propose Federated Representation Augmentation (FRAug). Our algorithm is built upon FedAvg (McMahan et al., 2017) which is the most widely used FL strategy. In FedAvg, the central server sends a copy of the global model to each client to initialize their local models . After training on its local dataset , the client-specific updated models are sent back to the central server where they are averaged and used as the global model. Such communication rounds are repeated until some predefined convergence criteria are met. Similarly, the training process of FRAug (Algorithm 1) can be divided into two stages: (1) The Server Update, where the central server aggregates the parameters uploaded by the clients and distributes the averaged parameters to each client, and (2) the Client Update, where each client receives the model parameters from the central server and performs local optimization. Unlike FedAvg, where only the local dataset of each client is used for training, FRAug generates additional feature embeddings to finetune the prediction head of the local classification model. Concretely, we train a shared generator and a local Representation Transformation Network (RTNet) for each client, which together produce domain-specific 111Since each client has data from a different domain, we use domain-specific and client-specific interchangeably. synthetic feature embeddings for each client to augment its local data in the embedding space. Hereby, the shared generator captures knowledge from all the clients to generate client-agnostic embeddings, which are then personalized by the local RTNet into client-specific embeddings. In the following, we provide a more detailed explanation of FRAug.

3.3.1 Server Update

At the beginning of the training, the server initializes the parameters of the classification model , as well as the shared generator . In each communication round , all clients receive the aggregated model parameters and conduct the Client Update procedure in parallel. Subsequently, the server securely aggregates the optimized model parameters from all the clients into a single model that is used in the next communication round.

3.3.2 Client Update

An overview of the client learning procedure is illustrated in Figure 1. At the beginning of the first communication round, each client locally initializes a Representation Transformation Network (RTNet) parameterized by . Subsequently, each client receives the classification model parameters and the generator parameters from the server, and conducts local update steps. Each local update comprises 2 stages: (1) Classification model optimization, and (2) Generator and RTNet optimization.

Classification Model Optimization: In this stage, the generator and the RTNet are fixed, while the classification model is updated by minimizing the loss , where

(1)

While is minimized to update the model parameter by using real training samples from , is minimized to update only the prediction head as it is computed on synthetically generated samples in the embedding space . We use cross-entropy (

) for both loss functions.

To generate domain-specific synthetic embeddings, the shared generator and local RTNet are used to generate residuals that are added to the embeddings of real examples produced by the local feature extractor . Hereby, we first generate client-agnostic embeddings

by feeding a batch of random vector

and class labels into the generator . Subsequently, are transformed by the local RTNet into client-specific residuals and added to the embeddings of real datapoints. We distinguish two types of synthetic embeddings that we generate to train the local prediction head: domain-specific synthetic embeddings and class-prototypical domain-specific synthetic embeddings . The domain-specific embeddings are generated by adding synthetic residuals to the embeddings of real examples from the current batch sampled from . On the other hand, synthetic residuals are added to class-prototypes , i.e., class-wise average embeddings of real examples, to produce .

(2)
(3)

We note that, for the generation of , the original labels of the sampled data batch are used for the residual generation, since the residuals are added to the embeddings of the examples corresponding to these labels. For , we feed the label that correspond to the class of the average embedding . The class-prototypical client-specific embeddings

are generated to stabilize the training and increase the variance of the generated embeddings.

While the residuals produced by the generator and the RTNet are random in early training iterations due to the random initialization of these models, they become more informative as training progresses. To reflect this in our algorithm, we employ the weighting coefficient that controls the impact of the residuals, and increase it following an exponential schedule throughout training.

To compute the class-wise average embedding , we use the exponential moving average (EMA) scheme, at each local iteration. In particular,

(4)

where denotes the indicator function, is the batch size of the real samples, is the class-wise average embedding computed on the previous local iteration, and is a small number added for numerical stability. By using the average embeddings of previous iterations, we enable the examples of previously sampled batches to contribute to the computation of the current average embeddings. The ratio follows an exponential ramp-up schedule as proposed in (Laine and Aila, 2016).

To allow the different client-specific models to learn feature extractors tailored to their data distribution , while still benefiting from the collaborative learning, we use local Batch Normalization layers (BN) (Ioffe and Szegedy, 2015) as introduced in (Li et al., 2021b).

Generator and RTNet Optimization: In the second stage, the classification model is fixed while the generator and the RTNet are optimized. The class-conditional generator takes a batch of random vectors and class labels to produce client-agnostic feature embeddings . are then fed into the RTNet to be adapted to the feature distribution of the corresponding client . The resulting residuals are added on the embeddings of real examples to produce the domain-specific synthetic embeddings and . The generator will be optimized by minimizing the loss , with

(5)

The minimization of the cross-entropy loss incentivizes the shared generator to produce features that are recognized by the prediction heads of all the clients. By sharing and optimizing the generator across all clients, we ensure that the synthetic embeddings produced by the generator, i.e., , capture client-agnostic semantic information. Additionally, we maximize the statistical distance (Wootters, 1981) between and the real feature embeddings . By doing so, we force not to follow any client-specific distribution, and thus enhance the variance of the augmented feature space. Here, we adopt Maximum Mean Discrepancy (MMD) (Gretton et al., 2012) as the distance metric. Subsequently, the client-agnostic embeddings are fed into the RTNet parametrized by to produce domain-specific embeddings and . is optimized by minimizing the loss , where

(6)

Here, we maximize the entropy () of the prediction head output on , to incentivize the generation of synthetic embeddings that are hard to classify for the prediction head

. To avoid generating outliers, we add a regularization term that minimizes the Maximum Mean Discrepancy (MMD) between real and synthetic embeddings. In particular, we penalize high MMD distances between

and , as well as and , for each class . and denote weighting coefficients in Eq. 5 and 6, respectively.

Figure 1: Overview of FRAug local update at client : a shared generator is learned to aggregate knowledge from multiple clients and generate client-agnostic feature embeddings , which are then fed into the local Representation Transformation Network (RTNet) to produce domain-specific feature embeddings and . Finally, the real feature embeddings , extracted by the feature extractor using local dataset , will be augmented with and in the prediction head optimization, which leads to a classification model with better generalization ability.
1 ServerUpdate
2 Randomly initialize
3 for round to  do
4      for client to in parallel do
5           ClientUpdate
6           end for
7          
8          
9           end for
10           
11           ClientUpdate
12           if  then
13                Randomly initialize
14               
15                end if
16               ,  
17                for local training iteration to  do
18                     Sample from
19                     Sample
20                     Fix and , and optimize using Equation 1
21                     Fix , and optimize and using Equation 5 and 6, respectively
22                    
23                     end for
24                    return
Algorithm 1 Training procedure of Federated Representation Augmentation (FRAug)

4 Experimental Results

4.1 Dataset Description

We perform an extensive empirical analysis 222Code will be made public upon paper acceptance. using three real-world image classification datasets with domain shift: (1) OfficeHome (Venkateswara et al., 2017), which contains 65 classes in four domains: Art (A), Clipart (C), Product (P) and Real-World (R). (2) PACS (Li et al., 2017), which includes images that belong to 7 classes from four domains Art-Painting (A), Cartoon (C), Photo (P) and Sketch (S). (3) Digits

comprises images of 10 digits from the following four datasets: MNIST (MT)

(LeCun et al., 1998), MNIST-M (MM) (Ganin and Lempitsky, 2015)

, SVHN (SV)

(Netzer et al., 2011) and USPS (UP) (Hull, 1994). Each client contains data from one of the domains, i.e., there exists feature shift across different clients. To simulate data scarcity described in Section 3.2, we assume that only ( for the Digits dataset) of the original data is available for each client, resulting in ca. 100 to 1000 data samples per client following the experimental setup in the previous work (McMahan et al., 2017; Li et al., 2021b).

4.2 Implementation Details

For the OfficeHome and PACS datasets, we use a ResNet18 (He et al., 2016)

pretrained on ImageNet

(Deng et al., 2009)

as initialization of the classification model. For Digits, we use a 6-layer Convolution Neural Network (CNN) as the backbone. We adopt a 2-layer MLP as the generator and RTNet architectures for all datasets. In the appendix, we provide further details about model architectures and training hyperparameters. All experiments are repeated with 3 different random seeds.

4.3 Results and Discussion

We compare our approach with several baseline methods, including Single, i.e., training an individual model on each client separately, All, i.e., training a single model at the central server using data aggregated from all clients, FedAvg (McMahan et al., 2017), pFedAvg, i.e., FedAvg with local model personalization, and FedProx (Li et al., 2020b). We note that All is an oracle baseline as it requires centralizing the data from the different clients, hence infringing the data-privacy requirements. Furthermore, we compare our method with the current state-of-the-art FL methods for non-IID features, i.e., FedBN (Li et al., 2021b) and PartialFed (Sun et al., 2021). We use the published code of FedBN and reimplement PartialFed since the original implementation was not made public. We conduct the same hyperparameter search for all methods and report the best results. Detailed hyperparameter search space of different methods are provided in the Appendix.

We report the accuracies achieved by the different methods on all three datasets in Table 2. We observe that FRAug outperforms all the baselines on all benchmark datasets. On OfficeHome, FRAug outperforms FedAvg and FedBN by and , respectively. On Digits, FRAug achieves a substantial improvement on average compared with all the alternative methods. Likewise, FRAug yields the highest average accuracy on PACS. We find that the performance improvement compared to the best baseline is the highest on the most challenging domains, i.e., on which all methods yield lower results than on other domains. These include MNIST-M and SVHN from Digits, as well as Clipart from OfficeHome, where FRAug achieves impressive improvements of above . Interestingly, our approach outperforms the centralized baseline All, demonstrating the effectiveness of the generated embeddings in aggregating the knowledge from different clients to enable a client-specific augmentation.


Benchmark
Single All FedAvg FedProx pFedAvg PartialFed FedBN FRAug

OfficeHome
A 35.80±0.1 56.65±0.7 57.47±0.6 55.68±0.4 52.50±0.9 48.83±0.2 57.59±0.8 57.61±0.6
C 45.54±0.8 58.81±1.6 56.74±0.9 56.88±0.5 52.09±1.1 49.96±0.2 56.52±0.3 60.03±0.5
P 67.04±0.8 71.39±0.3 73.32±0.8 73.84±0.3 71.78±0.8 72.22±0.8 73.55±1.0 74.03±0.8
R 61.16±0.7 72.63±1.3 71.25±0.3 72.15±0.9 66.28±0.4 65.82±0.6 72.40±0.9 74.58±0.4
avg 52.42±0.4 64.87±0.9 64.69±0.6 64.63±0.6 60.67±0.7 59.20±0.5 65.02±0.7 66.60±0.3

Digits
MT 96.68±0.2 97.04±0.1 96.85±0.1 96.90±0.1 96.40±0.2 97.13±0.1 97.03±0.1 97.81±0.1
MM 77.77±0.5 77.04±0.1 73.51±0.2 72.60±0.4 77.56±0.4 74.21±0.5 77.02±0.2 81.65±0.5
SV 75.55±0.3 77.96±0.5 74.49±0.2 73.01±0.5 77.50±0.1 78.10±0.5 77.59±0.1 81.24±0.3
UP 79.93±0.8 97.13±0.1 97.62±0.1 97.31±0.3 96.67±0.1 94.78±0.5 96.80±0.2 97.67±0.3
avg 82.54±0.1 87.29±0.2 85.62±0.2 84.96±0.3 87.03±0.2 86.05±0.3 87.11±0.2 89.59±0.4

PACS
A 82.37±0.6 83.17±0.2 82.72±0.4 80.17±0.4 88.05±0.8 84.85±0.2 86.60±0.5 87.34±0.5
C 86.08±0.9 86.92±0.8 84.04±1.3 82.04±0.8 86.20±0.7 87.92±0.5 87.76±1.0 88.47±0.9
P 92.01±1.1 95.95±0.8 96.05±0.5 96.74±1.0 97.89±0.5 98.24±0.4 97.95±0.4 98.64±0.6
S 87.52±0.8 88.70±0.7 89.50±0.7 88.50±1.0 88.89±0.9 90.10±0.8 90.75±0.3 90.95±0.4
avg 87.00±0.5 88.68±0.6 88.08±0.9 86.86±0.9 90.26±0.6 90.28±0.7 90.76±0.3 91.34±0.1
Table 2: Accuracy results on different benchmark datasets with feature shift.

4.4 Ablation Study

To illustrate the importance of different FRAug components, we conducted an ablation study on the OfficeHome dataset. The results are shown in Table 3. We observe that using the client-agnostic synthetic embeddings instead of the personalized versions leads to performance deterioration. This highlights the importance of the transformation into personalized client-specific embeddings. Moreover, the results reveal that both types of synthetic embeddings, i.e., and , yield a performance boost when used separately. Employing them together further improves the results, which demonstrates their complementarity. Finally, we find that increasing the impact of the generated residuals via the smoothing term throughout training, i.e., as the generator and the RTNets continue improving, is beneficial for the overall performance. We conducted the same ablation study on the PACS and Digits datasets and observed consistent result patterns. We provide these results in the Appendix.

Generator
()
RTNet
()
EMA
()
Smoothing
()
OfficeHome
A C P R avg
56.24±0.5 59.65±0.3 72.90±0.3 73.09±0.5 65.47±0.8
56.75±0.8 58.95±0.8 72.88±0.2 73.93±0.3 65.63±0.8
57.34±0.9 59.10±0.4 73.65±0.7 74.25±0.9 66.09±0.2
56.93±0.9 58.50±0.8 73.35±0.9 74.39±0.4 65.79±0.9
56.65±0.2 59.50±0.7 73.50±0.5 74.31±0.9 65.99±0.3
58.16±0.8 57.89±0.3 73.12±0.3 74.15±0.6 65.83±0.3
57.61±0.9 60.03±0.8 74.03±0.3 74.69±0.2 66.60±0.4
Table 3: Ablation study for different components of FRAug.

4.5 Analysis of Local Dataset Size

We investigate the effectiveness of our personalized representation augmentation technique for different sizes of client-specific local datasets. Hereby, we vary the number of datapoints available on each client from to of its original local dataset. Figure 2 depicts the results of this experiment. We compare FRAug with two baseline methods on OfficeHome, and conduct the experiment with 3 different seeds. We note that FRAug consistently outperforms the baselines across all dataset sizes, except one (Clipart, ). Compared to FedAvg, the improvement achieved by FRAug is stable across different dataset sizes, highlighting the suitability of representation augmentation for scenarios involving non-IID features with scarce and large amounts of data. Compared to local training (Single) without collaboration, we observe that the performance improvement yielded by federated learning methods increases as the dataset size decreases.

Figure 2: Performance analysis of FRAug with different local dataset sizes on client.

5 Conclusion and Outlook

In this work, we present a novel approach to tackle the underexplored feature shift problem in the federated learning (FL) setting, where the datasets of the different clients have non-IID features. Our method, Federated Representation Augmentation (FRAug), performs client-personalized augmentation in the embedding space to improve the generalization ability of the client-specific models without sharing data in the input space. For that, we optimize a shared generative model to merge the clients’ knowledge, learned from their different feature distributions, to synthesize client-agnostic embeddings, which are then transformed into client-specific embeddings by local Representation Transformation Networks (RTNets). Our empirical evaluation on three benchmark datasets involving feature distribution shift demonstrated the effectiveness of our method which outperformed the prior approaches, achieving new state-of-the-art results in this underexplored FL setting.

While our empirical evaluation focused on image datasets, it would be interesting to apply FRAug to other data types, such as time-series data (Liu et al., 2020) or textual data (Hard et al., 2018), since it is data-type-agnostic, i.e., it does not involve any image-specific operations. Another interesting avenue for future works would be to assess the effectiveness of FRAug against adversarial attack techniques such as gradient inversion (Yin et al., 2021; Jeon et al., 2021). Finally, it is worth studying the interaction between FRAug and more sophisticated optimization algorithms (Deng and Mahdavi, 2021; Li et al., 2020a) and communication methods (Chen et al., 2020, 2021).

References

  • M. Andreux, J. O. d. Terrail, C. Beguier, and E. W. Tramel (2020) Siloed federated learning for multi-centric histopathology datasets. In Domain Adaptation and Representation Transfer, and Distributed and Collaborative Learning, pp. 129–139. Cited by: §2.
  • M. G. Arivazhagan, V. Aggarwal, A. K. Singh, and S. Choudhary (2019) Federated learning with personalization layers. arXiv preprint arXiv:1912.00818. Cited by: §2.
  • M. Chen, H. V. Poor, W. Saad, and S. Cui (2020) Convergence time optimization for federated learning over wireless networks. IEEE Transactions on Wireless Communications 20 (4), pp. 2457–2471. Cited by: §5.
  • M. Chen, N. Shlezinger, H. V. Poor, Y. C. Eldar, and S. Cui (2021) Communication-efficient federated learning. Proceedings of the National Academy of Sciences 118 (17). Cited by: §5.
  • J. Deng, W. Dong, R. Socher, L. Li, K. Li, and L. Fei-Fei (2009) Imagenet: a large-scale hierarchical image database. In

    2009 IEEE conference on computer vision and pattern recognition

    ,
    pp. 248–255. Cited by: §4.2.
  • Y. Deng and M. Mahdavi (2021)

    Local stochastic gradient descent ascent: convergence analysis and communication efficiency

    .
    In

    International Conference on Artificial Intelligence and Statistics

    ,
    pp. 1387–1395. Cited by: §5.
  • Q. Dou, D. Coelho de Castro, K. Kamnitsas, and B. Glocker (2019) Domain generalization via model-agnostic learning of semantic features. Advances in Neural Information Processing Systems 32. Cited by: §1.
  • A. Fallah, A. Mokhtari, and A. Ozdaglar (2020) Personalized federated learning with theoretical guarantees: a model-agnostic meta-learning approach. Advances in Neural Information Processing Systems 33, pp. 3557–3568. Cited by: §2.
  • Y. Ganin and V. Lempitsky (2015)

    Unsupervised domain adaptation by backpropagation

    .
    In International conference on machine learning, pp. 1180–1189. Cited by: §4.1.
  • X. Gong, A. Sharma, S. Karanam, Z. Wu, T. Chen, D. Doermann, and A. Innanje (2021) Ensemble attention distillation for privacy-preserving federated learning. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 15076–15086. Cited by: §2.
  • X. Gong, A. Sharma, S. Karanam, Z. Wu, T. Chen, D. Doermann, and A. Innanje (2022) Preserving privacy in federated learning with ensemble cross-domain knowledge distillation. Cited by: §2.
  • A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. Smola (2012) A kernel two-sample test. The Journal of Machine Learning Research 13 (1), pp. 723–773. Cited by: §3.3.2.
  • W. Hao, M. El-Khamy, J. Lee, J. Zhang, K. J. Liang, C. Chen, and L. C. Duke (2021) Towards fair federated learning with zero-shot data augmentation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3310–3319. Cited by: §2.
  • A. Hard, K. Rao, R. Mathews, S. Ramaswamy, F. Beaufays, S. Augenstein, H. Eichner, C. Kiddon, and D. Ramage (2018) Federated learning for mobile keyboard prediction. arXiv preprint arXiv:1811.03604. Cited by: §5.
  • C. He, M. Annavaram, and S. Avestimehr (2020) Group knowledge transfer: federated learning of large cnns at the edge. Advances in Neural Information Processing Systems 33, pp. 14068–14080. Cited by: §2.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §4.2.
  • J. J. Hull (1994) A database for handwritten text recognition research. IEEE Transactions on pattern analysis and machine intelligence 16 (5), pp. 550–554. Cited by: §4.1.
  • S. Ioffe and C. Szegedy (2015) Batch normalization: accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, pp. 448–456. Cited by: §2, §3.3.2.
  • J. Jeon, K. Lee, S. Oh, J. Ok, et al. (2021) Gradient inversion with generative image prior. Advances in Neural Information Processing Systems 34, pp. 29898–29908. Cited by: §5.
  • P. Kairouz, H. B. McMahan, B. Avent, A. Bellet, M. Bennis, A. N. Bhagoji, K. Bonawitz, Z. Charles, G. Cormode, R. Cummings, et al. (2021) Advances and open problems in federated learning. Foundations and Trends® in Machine Learning 14 (1–2), pp. 1–210. Cited by: §2, §3.1.
  • S. Laine and T. Aila (2016)

    Temporal ensembling for semi-supervised learning

    .
    arXiv preprint arXiv:1610.02242. Cited by: §3.3.2.
  • Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner (1998) Gradient-based learning applied to document recognition. Proceedings of the IEEE 86 (11), pp. 2278–2324. Cited by: §4.1.
  • D. Li, Y. Yang, Y. Song, and T. M. Hospedales (2017) Deeper, broader and artier domain generalization. In Proceedings of the IEEE international conference on computer vision, pp. 5542–5550. Cited by: §3.2, §4.1.
  • D. Li and J. Wang (2019) Fedmd: heterogenous federated learning via model distillation. arXiv preprint arXiv:1910.03581. Cited by: §2.
  • T. Li, S. Hu, A. Beirami, and V. Smith (2021a) Ditto: fair and robust federated learning through personalization. In International Conference on Machine Learning, pp. 6357–6368. Cited by: §2.
  • T. Li, A. K. Sahu, A. Talwalkar, and V. Smith (2020a) Federated learning: challenges, methods, and future directions. IEEE Signal Processing Magazine 37 (3), pp. 50–60. Cited by: §2, §5.
  • T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar, and V. Smith (2020b) Federated optimization in heterogeneous networks. Proceedings of Machine Learning and Systems 2, pp. 429–450. Cited by: §1, §2, §4.3.
  • X. Li, M. Jiang, X. Zhang, M. Kamp, and Q. Dou (2021b) Fedbn: federated learning on non-iid features via local batch normalization. arXiv preprint arXiv:2102.07623. Cited by: §1, §2, §3.1, §3.2, §3.3.2, §4.1, §4.3.
  • T. Lin, L. Kong, S. U. Stich, and M. Jaggi (2020) Ensemble distillation for robust model fusion in federated learning. Advances in Neural Information Processing Systems 33, pp. 2351–2363. Cited by: §2.
  • Q. Liu, C. Chen, J. Qin, Q. Dou, and P. Heng (2021) Feddg: federated domain generalization on medical image segmentation via episodic learning in continuous frequency space. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 1013–1023. Cited by: §2.
  • Y. Liu, S. Garg, J. Nie, Y. Zhang, Z. Xiong, J. Kang, and M. S. Hossain (2020)

    Deep anomaly detection for time-series data in industrial iot: a communication-efficient on-device federated learning approach

    .
    IEEE Internet of Things Journal 8 (8), pp. 6348–6358. Cited by: §5.
  • R. G. Lopes, S. Fenu, and T. Starner (2017)

    Data-free knowledge distillation for deep neural networks

    .
    arXiv preprint arXiv:1710.07535. Cited by: §2.
  • B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas (2017) Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics, pp. 1273–1282. Cited by: §1, §1, §2, §3.2, §3.3, §4.1, §4.3.
  • Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng (2011) Reading digits in natural images with unsupervised feature learning. Cited by: §4.1.
  • S. Shankar, V. Piratla, S. Chakrabarti, S. Chaudhuri, P. Jyothi, and S. Sarawagi (2018) Generalizing across domains via cross-gradient training. arXiv preprint arXiv:1804.10745. Cited by: §2.
  • B. Sun, H. Huo, Y. Yang, and B. Bai (2021) PartialFed: cross-domain personalized federated learning via partial initialization. Advances in Neural Information Processing Systems 34. Cited by: §2, §4.3.
  • C. T Dinh, N. Tran, and J. Nguyen (2020) Personalized federated learning with moreau envelopes. Advances in Neural Information Processing Systems 33, pp. 21394–21405. Cited by: §1, §2.
  • H. Venkateswara, J. Eusebio, S. Chakraborty, and S. Panchanathan (2017) Deep hashing network for unsupervised domain adaptation. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 5018–5027. Cited by: §3.2, §4.1.
  • G. Wilson and D. J. Cook (2020) A survey of unsupervised deep domain adaptation. ACM Transactions on Intelligent Systems and Technology (TIST) 11 (5), pp. 1–46. Cited by: §2, §3.1.
  • W. K. Wootters (1981) Statistical distance and hilbert space. Physical Review D 23 (2), pp. 357. Cited by: §3.3.2.
  • H. Yin, A. Mallya, A. Vahdat, J. M. Alvarez, J. Kautz, and P. Molchanov (2021) See through gradients: image batch recovery via gradinversion. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 16337–16346. Cited by: §5.
  • L. Zhang and X. Yuan (2021) FedZKT: zero-shot knowledge transfer towards heterogeneous on-device models in federated learning. arXiv preprint arXiv:2109.03775. Cited by: §2.
  • L. Zhang, L. Shen, L. Ding, D. Tao, and L. Duan (2022) Fine-tuning global model via data-free knowledge distillation for non-iid federated learning. arXiv preprint arXiv:2203.09249. Cited by: §2.
  • Y. Zhao, M. Li, L. Lai, N. Suda, D. Civin, and V. Chandra (2018) Federated learning with non-iid data. arXiv preprint arXiv:1806.00582. Cited by: §1, §2.
  • K. Zhou, Z. Liu, Y. Qiao, T. Xiang, and C. Change Loy (2021a) Domain generalization: a survey. arXiv e-prints, pp. arXiv–2103. Cited by: §2, §3.1.
  • K. Zhou, Y. Yang, T. Hospedales, and T. Xiang (2020) Learning to generate novel domains for domain generalization. In European conference on computer vision, pp. 561–578. Cited by: §2.
  • K. Zhou, Y. Yang, Y. Qiao, and T. Xiang (2021b) Domain generalization with mixstyle. arXiv preprint arXiv:2104.02008. Cited by: §2.
  • Z. Zhu, J. Hong, and J. Zhou (2021) Data-free knowledge distillation for heterogeneous federated learning. In International Conference on Machine Learning, pp. 12878–12889. Cited by: §2.