Think Locally, Act Globally: Federated Learning with Local and Global Representations

01/06/2020 ∙ by Paul Pu Liang, et al. ∙ Carnegie Mellon University 6

Federated learning is an emerging research paradigm to train models on private data distributed over multiple devices. A key challenge involves keeping private all the data on each device and training a global model only by communicating parameters and updates. Overcoming this problem relies on the global model being sufficiently compact so that the parameters can be efficiently sent over communication channels such as wireless internet. Given the recent trend towards building deeper and larger neural networks, deploying such models in federated settings on real-world tasks is becoming increasingly difficult. To this end, we propose to augment federated learning with local representation learning on each device to learn useful and compact features from raw data. As a result, the global model can be smaller since it only operates on higher-level local representations. We show that our proposed method achieves superior or competitive results when compared to traditional federated approaches on a suite of publicly available real-world datasets spanning image recognition (MNIST, CIFAR) and multimodal learning (VQA). Our choice of local representation learning also reduces the number of parameters and updates that need to be communicated to and from the global model, thereby reducing the bottleneck in terms of communication cost. Finally, we show that our local models provide flexibility in dealing with online heterogeneous data and can be easily modified to learn fair representations that obfuscate protected attributes such as race, age, and gender, a feature crucial to preserving the privacy of on-device data.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 2

page 18

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Federated learning is an emerging research paradigm to train machine learning models on private data distributed in a potentially non-i.i.d. setting over multiple devices 

DBLP:journals/corr/McMahanMRA16. A key challenge in federated learning involves keeping private all the data on each device and training a global model only via communication of parameters and parameter updates to each device (Yang:2019:FML:3306498.3298981). This relies on the global model being sufficiently compact so that the parameters and updates can be sent efficiently over existing communication channels such as wireless networks Nilsson:2018:PEF:3286490.3286559. However, the recent demands towards building deeper and larger machine learning models (DBLP:journals/corr/abs-1810-04805; radford2019language; DBLP:journals/corr/HuangLW16a; DBLP:journals/corr/HeZRS15) poses a challenge for deploying federated learning on real-world tasks. This calls for new solutions to the traditional federated averaging frameworks DBLP:journals/corr/McMahanMRA16; DBLP:journals/corr/KonecnyMRR16; DBLP:journals/corr/abs-1902-01046; DBLP:journals/corr/abs-1812-06127. In this paper, we propose to augment traditional federated learning with local representation learning on each device. Each device is augmented with a local model which learns useful and compact representations of raw data. The single global model on the central server is then trained using federated averaging over the local representations from these devices. We call the resulting method Local Global Federated Averaging (LG-FedAvg) and show that local representation learning is beneficial for the following reasons:

1) Efficiency: having local models extract useful, lower-dimensional semantic representations means that the global model now requires a fewer number of parameters. Our choice of local representation learning reduces the number of parameters and updates that need to be communicated to and from the global model, thereby reducing the bottleneck in terms of communication cost (§3.1). Our proposed method also maintains superior or competitive results on a suite of publicly available real-world datasets spanning image recognition (MNIST, CIFAR) and multimodal learning (VQA).

2) Heterogeneity: real-world data is often heterogeneous (i.e. coming from different sources). A single mobile phone is likely to contain data across multiple modalities including images, text, videos, and audio files. In addition, a new device could contain sources of data that have never been observed before during training, such as text in another language, images of a different resolution, or audio in a different voice. Local representations allow us to process the data from new devices in different ways depending on their source modalities (DBLP:journals/corr/BaltrusaitisAM17) instead of using a single global model that might not generalize to never seen before modalities and distributions DBLP:journals/corr/abs-1902-00146; Ben-David:2006:ARD:2976456.2976474. In §3.2, we show that by training local models in an online setting (Shalev-Shwartz:2012:OLO:2185819.2185820), our model can better deal with heterogeneous data.

3) Fairness: real-world data often contains sensitive attributes. While federated learning imposes a strict constraint that the data on each local device must remain private DBLP:journals/corr/McMahanMRA16, recent work has shown that it is possible to recover biases and protected attributes from data representations without having access to the data itself (caliskan2017semantics; GargE3635; DBLP:journals/corr/abs-1812-08769; bolukbasi2016man; DBLP:journals/corr/abs-1808-06640). In light of this issue, we show that our local representations can be modified to learn fair representations that obfuscate protected attributes such as race, age, and gender, a feature crucial to preserving the privacy of on-device data (§3.3). We hope that our work will inspire future research on efficient and privacy-preserving federated learning.

Figure 1: (a) Our proposed Local Global Federated Averaging algorithm (LG-FedAvg) allows for efficient global parameter updates (smaller number of global parameters ), flexibility in design across local and global models, the ability to handle heterogeneous data, and fair

representation learning. (b) through (d) show various approaches of training local models including unsupervised, supervised, and self-supervised learning (e.g. jigsaw solving 

(DBLP:journals/corr/NorooziF16)) (e) shows adversarial training against protected attributes . Shown in blue is the global server and purple represents the local devices. represents data on device , are learned local representations via local models and auxiliary models . is the global model. Agg stands for an aggregation function over local updates to the global model (e.g. FedAvg).

2 Local Global Federated Averaging (LG-FedAvg)

The core idea of our method is to augment federated learning with local representation learning on each device before a global model operating on higher-level representations is trained on the data (now as representations rather than raw data) from all devices. An overview of LG-FedAvg is shown in Figure 1(a). We begin by defining notation before describing how local (§2.1) and global (§2.2) representation learning is performed. In §2.3 we explain how the local models can be adapted to learn fair representations and we show how to perform test-time inference over local models in §2.4.

Notation: We use uppercase letters

to denote random variables and lowercase letters

to denote their values. Upper case boldface letters

denote datasets consisting of multiple vector data points

which we represent by lowercase boldface letters. In the standard federated learning setting, we assume that we have data and their corresponding labels across nodes. denotes the number of data points on device and is the total number of data points. Intuitively, each source of data captures a different view of the global data distribution . We consider settings where the individual data points in are sampled i.i.d. with respect to as well as settings in which sampling is non i.i.d. (e.g. biased sampling with respect to the marginal implies that data is distributed unevenly with respect their labels: one device may have, in expectation, a lot more cat images, and another a lot more dog images). During training, we use parenthesized subscripts (e.g. ) to represent the training iteration .

0:    Server executes:
1:  initialize global model with weights
2:  initialize local models with weights and auxiliary model weights
3:  for each round  do
4:     
5:      (random set of clients)
6:     for each client in parallel do
7:        
8:     end for
9:                     // aggregate updates
10:  end for
11:  
11:    ClientUpdate ():                // run on client
12:   (split local data into batches of size )
13:  for

 each local epoch

from to  do
14:     for batch  do
15:        , ,  // inference steps
16:                   // update local model
17:                  // update auxiliary local model
18:                      // update local model
19:                    // update (local copy of) global model
20:     end for
21:  end for
22:  return global parameters to server
Algorithm 1 LG-FedAvg: Local Global Federated Averaging. The clients are indexed by ; is the local minibatch size, is the number of local epochs, and is the learning rate.

2.1 Local Representation Learning

For each source of data , we learn a high-level, compact representation of the data. This general framework gives the user flexibility in learning , but in general the local representation should have the following properties: 1) be low-dimensional as compared to high-dimensional raw data, 2) capture important features related to and that are useful towards the global model, and 3) not overfit to on-device data which may not align perfectly to the global data distribution.

To be more concrete, define some important features that should be captured using a good representation . Some choices of can be 1) the data itself

(unsupervised autoencoder learning), 2) the labels

(supervised learning), or 3) some manually defined labels (self-supervised learning). In Figure 1(b) through (d) we summarize the local representation learning methods from to and resulting in a trained local model on each device. Given these features, each device consists of two components: the local model with parameters , as well as the local auxiliary network with parameters . These two networks allow us to infer features and auxiliary labels

from local device data. Given a suitably chosen local loss function

over , the local model

can now be learned using (stochastic) gradient descent. The local training objective is based on optimizing parameters

and with respect to the local loss (for simplicity we choose supervised learning hence ).

(1)

In practice, we do not have to compute the summation over since we perform end-to-end training in a multitask fashion. is simply a shared intermediate representation that will be trained to work well for local tasks as well as the global model objective as we will discuss next.

2.2 Global Aggregation

The non-i.i.d. requirements of federated learning implies that simply learning the best possible local model

is still insufficient for learning a good prediction model over the true joint distribution

. Therefore, it is important to learn a global model over the data from all devices . To this end, we define a global model with parameters which will be updated using data from all devices. The key difference now is that the global model now operates on the learned local representations which are already representative of the features required for prediction. Therefore, can be a much smaller model which we will empirically show in our experiments (§3.1). Contrast this with traditional federated learning where the global model takes as input raw device data and makes a prediction . A model operating on raw data will usually require multiple layers of representation learning to achieve good performance as shown from the recent trend of using large models for language understanding (e.g. BERT DBLP:journals/corr/abs-1810-04805, GPT2 radford2019language) and visual recognition (e.g. DenseNet DBLP:journals/corr/HuangLW16a, ResNet DBLP:journals/corr/HeZRS15). This leads to very significant communication costs when transmitting global parameters to each local device and back.

In our approach, at each iteration of global model training, the server sends a copy of the global model parameters to each device which we now label as to represent the asynchronous updates made to each local copy. Each device runs their local model and the global model to obtain predicted labels. Given a suitable loss function on the label space , we can compute the overall loss of the global model on device :

(2)

again we do not have to compute since this gradient is a function of both the local and global model parameters so both can be updated in an end-to-end manner. We argue that this synchronizes the local and global models: while local models can flexibly fit the data distribution on their device, the global model acts as a regularizer to synchronize the representations from all devices: each local model cannot overfit to local data because otherwise, the global model would incur a high loss.

After the joint local and global updates, each device now returns updated global parameters back to the server which aggregates these updates using FedAvg: a weighted average over the fraction of data points in each device, . We also found that weighting the updates by the norm of the global gradient sped up convergence Alain2015VarianceRI.

The overall training procedure for LG-FedAvg is shown in Algorithm 1. Communication only happens between the global server and local devices when training the global model, which as we will show in our experiments, can be much smaller given good local representations. We have shown the simple case where the local models and the global model are updated jointly during each client update, but it is also easy to modify our algorithm for settings where pretraining local models or the global model helps in convergence, as well as to define additional losses for each local model. We show an example of such a modification in the following section where we aim to learn fair and privacy-preserving local representations via an auxiliary adversarial loss in each local model.

2.3 Fair Representation Learning

Here we detail one example of local representation learning with the goal of removing information that might be indicative of protected attributes. In this setting, suppose the data on each device is now data a triple drawn non-i.i.d. from a joint distribution (instead of as we had previously considered) where are some protected attributes in which the model should not pick up on when making a prediction from to . For example, although there exist correlations between race and income (10.2307/1054978) which could help in income prediction (NIPS2018_7613), it would be undesirable for our models to rely on these correlations since these would exacerbate racial biases especially when these models are deployed in the real world.

To learn fair local representations, we follow a similar procedure to NIPS2017_6699 which uses adversarial training to remove protected attributes (Figure 1 (e)). More formally, we aim to learn a local model such that the distribution of conditional on is invariant with respect to parameters :

(3)

for all and outputs of , thereby implying that and are independent and is a pivotal quantity with respect to NIPS2017_6699 showed that we can use adversarial networks in order to constrain model to satisfy Equation (3). is pit against an adversarial model with parameters and loss . Intuitively, the adversarial network is trained to predict the distribution of as much as possible given the local representation from . If varies with , then the corresponding correlation can be captured by adversary . On the other hand, if is indeed invariant with respect to , then adversary should perform as poorly as random choice. Therefore, we train to both minimize it’s own local loss as well as to maximize the adversarial loss . In practice, , , and are simultaneously updated by defining the following value function:

(4)

and solving for the minimax solution

(5)

and are computed using the expected value of the log likelihood through the inference networks , , and . We can optimize for Equation (5) by treating it as a coordinate descent problem (Wright:2015:CDA:2783158.2783189) and alternately solving for using gradient-based methods (details in appendix). Proposition 1 shows that this adversarial training procedure learns an optimal local model that is pivotal (invariant) with respect to under local device data distribution .

Proposition 1 (Optimality of , adapted from Proposition 1 in Nips2017_6699).

Suppose we compute losses and using the expected log likelihood through the inference networks , , and ,

(6)
(7)

Then, if there is a minimax solution for Equation (5) such that , then

is both an optimal classifier and a pivotal quantity.

The proof is adapted from NIPS2017_6699 to account for local data distributions and intermediate representations . Details are in the appendix, where we also explain adversarial training for the global model.

2.4 Inference at Test Time

Given a new , FedAvg simply passes to the trained global model for inference. However, LG-FedAvg requires inference through both local and global models. How do we know which trained local model fits best? We consider two settings: (1) Local Test where we assume we know which device the test data belongs to (e.g. training a personalized text completer from phone data). Using that particular local model works best for the best match between train and test data distribution. (2) New Test where we relax this assumption where it is possible to have an entirely new device during testing with new data sources/distributions. To combat this, we view each local model as trained on a different view of the global data distribution. We can then pass through all the trained local models and ensemble the outputs. Prior research on bagging breiman1996bagging and boosting schapire1990strength has shown that ensembling base classifiers each trained on a different view of the data works well in both theory zhou2012ensemble; kaariainen2005comparison; freund2004generalization and practice zhou2012ensemble; machova2006bagging; kim2002support. Alternatively, we can train on the new device in an online setting Shalev-Shwartz:2012:OLO:2185819.2185820; Anderson:2009:TPO:1550848: first train a new local model on device and then (optionally) fine tune the global model. We now describe these settings and their experimental performance in detail.

3 Experiments

We provide experimental results that justify our motivations for incorporating local representation learning into federated learning. We begin by showing that using local representations can efficiently reduce the number of parameters required in the global model while retaining strong performance (§3.1). Secondly, we consider settings where data from heterogeneous sources is seen in an online manner where local models help to prevent catastrophic forgetting in the global model (§3.2). Finally, we demonstrate how to learn fair representations that obfuscate private attributes (§3.3). Anonymized code is included in the supplementary and implementation details can be found in the appendix.

3.1 Model Performance and Communication Efficiency

 

Method Local Test Acc. New Test Acc. # FedAvg Rounds # LG Rounds # Params Communicated

 

FedAvg
Local only
LG-FedAvg

 

FedAvg
Local only
LG-FedAvg

 

Table 1: Comparison of federated learning methods on MNIST (top 3 rows) and CIFAR-10 (bottom 3 rows) with non-iid splits. We report accuracy under settings local test and new test as well as the total number of parameters communicated during training. Best results in bold. LG-FedAvg outperforms FedAvg under local test and achieves similar performance under new test while using around

of the total communicated parameters. Mean and standard deviation are computed over

runs.

 

Method Local Test Acc. # FedAvg Rounds # LG Rounds # Params Communicated

 

FedAvg
LG-FedAvg

 

Table 2: Comparison of FedAvg and LG-FedAvg methods on Visual Question Answering on non-i.i.d. device split setting. We report the number of rounds required to reach our goal accuracy of 40%. LG-FedAvg achieves strong performance using fewer communicated parameters.

Image Recognition on MNIST and CIFAR-10: We begin by studying properties of local and global models on the MNIST (lecun-mnisthandwrittendigit-2010) and CIFAR-10 (fasfasfasf) image recognition datasets. Particularly, we focus on a highly non-i.i.d. setting and follow the experimental design in DBLP:journals/corr/McMahanMRA16. We partition the training data by sorting the dataset by labels and dividing it into 200 shards of size 300 (MNIST) and size 250 (CIFAR-10). We then randomly assign 2 shards to 100 devices so that each device has at most examples of two classes (highly unbalanced). Similarly, we divide the test set into 200 shards of size 50 and assign 2 shards to each device. Each device has matching train and test distributions.

Figure 2: Test accuracy of FedAvg and LG-FedAvg methods on VQA dataset across rounds (dotted green line marks the goal accuracy of used in Table 2). LG-FedAvg reaches a maximum accuracy of compared to that of for FedAvg while using only of the parameters.

We consider two settings during testing: 1) Local Test where we know which device the data belongs to (i.e. new predictions on an existing device) and choose that particular trained local model. For this setting, we split each device’s data into train, validation, and test data, similar to (DBLP:journals/corr/SmithCST17). 2) New Test in which we do not know which device the data belongs to (i.e. new predictions on new devices) DBLP:journals/corr/McMahanMRA16

, so we use an ensemble approach by averaging all trained local model logits before choosing the most likely class 

(breiman1996bagging)111For ensembling, all local models have to be sent to the global server only once after training. We include this overhead when computing the total number of parameters communicated to and from the global server. For this setting, we evaluate on the CIFAR-10 test set of 10,000 examples. We choose LeNet-5 (Lecun98gradient-basedlearning) as our base model which allows us to draw comparisons between LG-FedAvg and FedAvg. We set , , , and use the two convolutional layers as our global model, which make up only of the original model’s parameters. We train LG-FedAvg with global updates until we reach a set goal accuracy (97.5% for MNIST, 57% for CIFAR-10) before training for additional rounds to jointly update local and global models.

The results in Table 1 show that LG-FedAvg gives strong performance with low communication cost on both MNIST and CIFAR. For CIFAR local test, LG-FedAvg significantly outperforms FedAvg since local models allow us to better model the local device data distribution. For new test, LG-FedAvg achieves similar performance to FedAvg while using around the number of total parameters communicated during updates to the global model. Therefore, LG-FedAvg can learn good local representations for strong global performance under test settings.

Multimodal Learning on Visual Question Answering (VQA): We perform experiments on VQA (VQA), a large-scale multimodal benchmark with M images, M questions, and M answers. We split the dataset in a non-i.i.d. manner and evaluate the accuracy under the local test setting. We use LSTM (hochreiter1997long) and ResNet-18 (DBLP:journals/corr/HeZRS15) unimodal encoders as our local models and a global model which performs early fusion (srivastava2012multimodal) of text and image features for answer prediction (details in appendix). In Table 2, we observe that LG-FedAvg reaches a goal accuracy of while requiring lower communication costs. In Figure 2, we plot the convergence of test accuracy across communication rounds. LG-FedAvg outperforms FedAvg after rounds while requiring only of the number of parameters in the large global model in FedAvg and continues to improve.

 

Method i.i.d. device data non-i.i.d. device data
Normal Rotated Normal Rotated

 

FedAvg
LG-FedAvg

 

FedAvg
FedProx
LG-FedAvg

 

Table 3: What happens when FedAvg trained on 100 devices of normal MNIST sees a device with rotated MNIST? Catastrophic forgetting, unless one fine-tunes again on training devices and incur high communication cost. LG-FedAvg relieves catastrophic forgetting by using local models to perform well on both online rotated and regular MNIST, with and without fine-tuning. Mean and standard deviation are computed over runs.

3.2 Heterogeneous Data in an Online Setting

For this experiment, we focus on an online setting to test whether LG-FedAvg can handle heterogeneous data from a new source introduced only during testing. We split the original MNIST dataset across devices in both an i.i.d. and non-i.i.d. setting. We then introduce a new device with training and test examples drawn independently from the MNIST dataset but rotated degrees. This simulates a drastic change in data distribution which may happen in federated learning settings.

We consider 2 methods: 1) FedAvg: train on the original devices using FedAvg, and when a new device comes, update the global model using FedAvg. 2) LG-FedAvg: train on the original devices using FedAvg, and when a new device comes, use LG-FedAvg to learn local representations before fine-tuning together with the global model. We hypothesize that good local models can help to “unrotate” the images from the new device to better match the data distribution seen by the global model. In all our experiments, we first train on the original devices until we reach an average goal accuracy of on the devices’ test sets. We then train for additional rounds after the new device is introduced by using the new device in addition to a fraction of the original training devices for fine-tuning: implies no fine-tuning and implies some fine-tuning. Note that implies completely retraining on all data each round, which is impractical.

We report results in Table 3 and draw the following conclusions: 1) FedAvg suffers from catastrophic forgetting (pmlr-v80-serra18a; DBLP:journals/corr/KirkpatrickPRVD16; Robins95catastrophicforgetting) without fine-tuning (), in which the global model can perform well on the new device’s rotated MNIST but completely forgets how to classify regular MNIST . Only after fine-tuning () does the performance on both regular and rotated MNIST improve, but this requires more communication over the training devices. 2) LG-FedAvg with local models relieves catastrophic forgetting. Augmenting local models indeed helps to improve online performance on rotated MNIST while allowing the global model to retain performance on regular MNIST . We believe LG-FedAvg achieves these results by learning a strong local representation which therefore requires fewer updates from the trained global model.

3.3 Learning Fair Representations

The purpose of this experiment is to examine whether local models can be trained adversarially to protect private attributes before local representations pass through the global model. We use the UCI adult dataset (Kohavi96scalingup) where the goal is to predict whether an individual makes more than K per year based on their personal attributes, such as age, education, and marital status. However, we would want our models to be invariant to the sensitive attributes of race and gender instead of picking up on correlations between {race, gender} and income that could potentially exacerbate biases. The dataset contains instances each in training and testing which we take the first for easier splitting in a federated setting. We set the number of devices to be and split the dataset in two ways. For the i.i.d. setting we uniformly sample a device for each train and test point, and for the non-i.i.d. setting we choose shards of data points each to obtain imbalanced devices.

 

Method i.i.d. device data non-i.i.d. device data
Class Acc Class AUC Adv AUC Class Acc Class AUC Adv AUC

 

FedAvg
LG-FedAvgAdv
LG-FedAvgAdv

 

Table 4: Results on enforcing independence with respect to protected attributes race and gender on income prediction. LG-FedAvgAdv uses local models with adversarial (adv) training to remove information about protected attributes, at the expense of a small drop in classifier (class) accuracy of around . Mean and standard deviation are computed over runs.

We use our method in §2.3 (adapted from NIPS2017_6699) which uses adversarial learning to remove protected attributes. Specifically, we aim to learn local representations from which a fully trained adversarial network should not be able to predict the protected attributes. We report three methods: 1) FedAvg with only a global model and global adversary both updated using FedAvg. The global model is not trained with the adversarial loss since it is simply not possible: once local device data passes through the global model, privacy is potentially violated. 2) LG-FedAvgAdv which is a local-global model without penalizing the adversarial network, and 3) LG-FedAvgAdv which implements the algorithm in §2.3 by jointly training local, global, and adversary models for a minimax equilibrium.

We report results according to the following metrics: 1) classifier binary accuracy, 2) classifier ROC AUC score, and 3) adversary ROC AUC score. The classifier metrics should be as close to as possible while the adversary should be as close to as possible. From the results in Table 4, we are able to enforce independence using LG-FedAvgAdv ( adversary AUC) with a small drop in accuracy for the global model. In order to ensure that poor adversary AUC was indeed due to fair representations instead of a poorly trained adversary classifier, we also fit a post-fit classifier on local representations to protected attributes and achieve similar close to random results.

4 Related Work

Federated Learning aims to train models in massively distributed networks DBLP:journals/corr/McMahanMRA16; DBLP:journals/corr/KonecnyMRR16 at a large scale DBLP:journals/corr/abs-1902-01046, over multiple sources of heterogeneous data (DBLP:journals/corr/abs-1812-06127), and over multiple learning objectives DBLP:journals/corr/SmithCST17. Recent methods aim to improve the efficiency of federated learning (DBLP:journals/corr/abs-1812-07210), perform learning in a one-shot setting DBLP:journals/corr/abs-1902-11175, propose realistic benchmarks (DBLP:journals/corr/abs-1812-01097), and reduce the data mismatch between local and global data distributions DBLP:journals/corr/abs-1902-00146. Distributed Learning is a related field with similarities and key differences: while both study the theory and practice involving partitioning of data and aggregation of model updates (NIPS2012_4687; DBLP:journals/corr/abs-1802-09941; Keuper:2016:DTD:3018874.3018877; pmlr-v70-suresh17a), federated learning is additionally concerned with data that is private and distributed in a non-i.i.d. fashion. Recent work has improved the communication efficiency of distributed learning by sparsifying the data and model (pmlr-v70-wang17f), developing efficient gradient-based methods (Wang2018CooperativeSA; NIPS2017_7218), and compressing the updates (DBLP:journals/corr/abs-1802-06058; DBLP:journals/corr/abs-1901-03040; 46622; DBLP:journals/corr/MahajanKSB13; Leng:2015:HDD:3045118.3045293). Representation Learning involves learning informative features from data that are useful for generative (srivastava2012multimodal; suzuki2016joint; sohn2014improved) and discriminative (liang-etal-2019-learning; DBLP:journals/corr/abs-1906-02125; DBLP:journals/corr/abs-1806-06176; DBLP:journals/corr/abs-1812-07809) tasks. A recent focus has been on learning fair representations Zemel:2013:LFR:3042817.3042973, including using adversarial training Goodfellow:2014:GAN:2969033.2969125 to learn representations that are not informative of predefined private attributes DBLP:journals/corr/abs-1904-13341; DBLP:journals/corr/abs-1901-10443; NIPS2017_6699; DBLP:journals/corr/abs-1710-04394; Resheff2018PrivacyAdversarialUR such as demographics DBLP:journals/corr/abs-1808-06640 and gender Wang2018AdversarialRO. A related line of research is differential privacy which constraints statistical databases to limit the privacy impact on individuals whose information is in the database Dwork:2006:DP:2097282.2097284; Dwork:2014:AFD:2693052.2693053

. Differential privacy has also been integrated with deep learning 

45428, distributed learning (NIPS2018_7984; NIPS2018_8069; DBLP:journals/corr/abs-1812-01484; DBLP:journals/corr/abs-1811-11124; NIPS2018_7871), and federated learning geyer2019differentially; DBLP:journals/corr/abs-1812-03224; DBLP:journals/corr/abs-1712-07557.

5 Conclusion

To conclude, this paper proposed LG-FedAvg as a general method that augments FedAvg with local representation learning on each device to learn useful and compact features from raw data. On a suite of publicly available real-world datasets spanning image recognition (MNIST, CIFAR) and multimodal learning (VQA) in a federated setting, LG-FedAvg achieves strong performance while reducing communication costs, deals with heterogeneous data in an online setting, and can be easily modified to learn fair representations that obfuscate protected attributes such as race, age, and gender, a feature crucial to preserving the privacy of on-device data. We hope that our work will inspire future research on efficient and privacy-preserving federated learning.

References

6 Fair Representation Learning

6.1 Theoretical Results

In this section we derive the theoretical results we showed on learning fair local representations on each device. The material and setting is adapted from [NIPS2017_6699]. First recall our dual objective across the local model , auxiliary model , and adversarial model :

(8)

We would like to find the minimax solution , defined as

(9)

To do so, we can iteratively solving for in an alternating fashion. In other words, initialize and repeat until convergence:

(10)
(11)
(12)

and are computed using the expected value of the log likelihood through the inference networks , , and and the optimization procedure involves using gradient descent and iteratively solving for until convergence. Suppose we define the local data distribution , then Proposition 1 (restated from main text §2.3) shows that this adversarial training procedure learns an optimal local model that is at the same time pivotal (invariant) with respect to under .

Proposition 1 (Optimality of , adapted from Proposition 1 in [Nips2017_6699]).

Suppose we compute losses and using the expected log likelihood through the inference networks , , and ,

(13)
(14)

Then, if there is a minimax solution for Equation (9) such that , then is both an optimal classifier and a pivotal quantity.

Proof.

For fixed , the adversary is optimal at

(15)

in which case for all and all , and reduces to the expected entropy of the conditional distribution of the protected variables .

This expectation corresponds to the conditional entropy of the random variables and and can be written as . Accordingly, the value function can be restated as a function depending only on and :

(16)

By our choice of the objective function we know that

(17)

which implies that we have the lower bound

(18)

where the equality holds at when:

  1. , which implies that and perfectly minimize the negative log-likelihood of under , which happens when and are the parameters of an optimal classifier from to (through an intermediate representation ). In this case, reduces to its minimum value .

  2. maximizes the conditional entropy , since from the properties of entropy.

By assumption, the lower bound is active which implies that because of the second condition. This in turn implies that and are independent variables by the properties of (conditional) entropy. Therefore, the optimal classifier is also a pivotal quantity with respect to the protected attributes under local data distribution . ∎

6.2 Adversarial Training of Global Model

Observe that in addition to the training of local models, the global models should also be trained in an adversarial manner. This is because the path of inference when training the (local copy) of the global model also involves the local representation and protected attributes (refer to Figure 3).

Figure 3: A closer look at the inference paths involved in adversarial training. The local models , and adversarial model are trained together when training for the local model, while the local model , global model , and adversarial model are trained together when jointly training the local and (local copy of the) global model. Refer to Equation (9) and Equation (20) for the dual optimization objective over local and global model and adversary parameters respectively.

This implies that when training the global model we should again optimize for the dual objective across the local model , global model , and adversarial model :

(19)

We would like to find the minimax solution , defined as

(20)

To do so, we can iteratively solving for in an alternating fashion. In other words, initialize and repeat until convergence:

(21)
(22)
(23)

In practice, we optimize for the following dual objectives over local and global models respectively:

(24)
(25)

where and

are hyperparameters that control the tradeoff between the prediction model and the adversary model.

7 Experimental Details

Here we provide all the details regarding experimental setup, dataset preprocessing, model architectures, model training, and performance evaluation. Our anonymized code is attached in the supplementary material. All experiments are conducted on a single machine with 4 GeForce GTX TITAN X GPUs.

7.1 Model Performance and Communication Efficiency

7.1.1 Mnist

In all our experiments, we train with number of local epochs and local minibatch size . Images were normalized prior to training and testing. In our experiments, we take the last two layers to form our global model, reducing the number of parameters to (). Table 5 shows the of hyperparameters used. The dataset can be found here: http://yann.lecun.com/exdb/mnist/. Our results are averaged over 10 runs. # FedAvg and LG Rounds are rounded to the nearest multiple of 5, which we use to calculate the number of parameters communicated. Standard deviations are also reported.

 

Model Parameter Value

 

FedAvg Input dim 784
Layers [512, 256, 256, 128]
Output dim 10
Loss cross entropy
Batchsize 10
Activation ReLU
Optimizer SGD
Learning rate 0.05
Momentum 0.5
Global epochs 1500

 

Local Only Input dim 784
Layers [512, 256, 256, 128]
Output dim 10
Loss cross entropy
Batchsize 10
Activation ReLU
Optimizer SGD
Learning rate 0.05
Momentum 0.5
Global epochs 500

 

LG-FedAvg, Local Input dim 784
Layers [512, 256, 256, 128]
Output dim 10
Loss cross entropy
Batchsize 10
Activation ReLU
Optimizer SGD
Learning rate 0.05
Momentum 0.5
Global epochs 500

 

LG-FedAvg, Global Layers kept 2
Input dim 256
Layers [128]
Output dim 10
Loss cross entropy
Batchsize 10
Activation ReLU
Optimizer SGD
Learning rate 0.05
Momentum 0.5
Global epochs 500

 

Table 5: Table of hyperparameters for MNIST experiments.

7.1.2 Cifar10

In all our experiments, we train with number of local epochs and local minibatch size

. Images are randomly cropped to size 32, randomly flipped horizontally with probability

, resized to , and normalized. For our model architecture, we chose Lenet-5. We use the two convolutional layers for the global model in our LG-FedAvg method to minimize the number of parameters. We therefore reduce the number of parameters to (). Table 5 shows a table of additional hyperparameters used. The dataset can be found here: https://www.cs.toronto.edu/~kriz/cifar.html. Our results are averaged over 10 runs. # FedAvg and LG Rounds are rounded to the nearest multiple of 5, which we use to calculate the number of parameters communicated. Standard deviations are also reported.

 

Model Parameter Value

 

FedAvg Loss cross entropy
Batchsize 50
Optimizer SGD
Learning rate 0.1
Momentum 0.5
Learning rate decay 0.005
Global epochs 2000

 

Local Only Loss cross entropy
Batchsize 50
Optimizer SGD
Learning rate 0.1
Momentum 0.5
Learning rate decay 0.005
Global epochs 1200

 

LG-FedAvg, Local Loss cross entropy
Batchsize 50
Optimizer SGD
Learning rate 0.1
Momentum 0.5
Learning rate decay 0.005
Global epochs 1200

 

LG-FedAvg, Global Loss cross entropy
Batchsize 50
Optimizer SGD
Learning rate 0.1
Momentum 0.5
Learning rate decay 0.005
Global epochs 1200

 

Table 6: Table of hyperparameters for CIFAR-10 experiments.

7.1.3 Vqa

We adapt the baseline model from [VQA] without norm I image channel embeddings. We also substitute the VGGNet [Simonyan14c] used in the original baseline model with a pre-trained ResNet-18 [DBLP:journals/corr/HeZRS15]. Finally we use the deep LSTM [hochreiter1997long] embedding, which is an LSTM that consists of two hidden layers. For the LG-FedAvg method, the global model uses the two final fully connected layers of the image and question channels, as well as the the additional two fully connected layers following the fusion via element-wise multiplication. The global model reduces the number of parameters to (). We use 50 devices and set number of local epochs , local minibatch size , fraction of devices sampled per round . To train and evaluate our models, we use the data from the following: https://visualqa.org/download.html. Table 7 shows a table of hyperparameters used.

 

Model Parameter Value

 

FedAvg Loss cross entropy
Batchsize 100
Optimizer SGD
Learning rate 0.01
Momentum 0.9
Learning rate decay 0.0005
Global epochs 100

 

LG-FedAvg, Local Loss cross entropy
Batchsize 100
Optimizer SGD
Learning rate 0.01
Momentum 0.9
Learning rate decay 0.0005
Global epochs 100

 

LG-FedAvg, Global Loss cross entropy
Batchsize 100
Optimizer SGD
Learning rate 0.01
Momentum 0.9
Learning rate decay 0.0005
Global epochs 100

 

Table 7: Table of hyperparameters for VQA experiments.

7.2 Heterogeneous Data in an Online Setting

Our experiments for the rotated MNIST follow the same settings and hyperparameter selection as our normal MNIST experiments (section 7.1). However, we include an additional device, which randomly samples 3000 and 500 images from the train and test sets respectively and rotates them by a fixed 90 degrees. We show some samples of the rotated MNIST images we used in Figure 4, where the top row shows the normal MNIST images used during training and the bottom row shows the rotated MNIST images on the new test device.

Figure 4: Sample MNIST images used for training (top) and their rotated counterparts used to test the impact of heterogeneous data on a trained federated model in an online setting (bottom).

7.3 Learning Fair Representations

For method 1, FedAvg, we train the global model and global adversary for outer epochs, within which the number of local epochs . For methods 2 and 3 involving local models, we begin by pre-training the local models and local adversaries for epochs before joint local and global training for 10 epochs. Table 8 shows the table of all hyperparameters used. Experiments were run 10 times with the same hyperparameters but different random seeds. We aimed to keep the local, global, and adversary models as similar as possible between the three baselines for fair comparison. Apart from the number of local and global epochs all hyperparameters were kept the same from the tutorial https://blog.godatadriven.com/fairness-in-ml and associated code https://github.com/equialgo/fairness-in-ml. The data can be found at https://archive.ics.uci.edu/ml/datasets/Adult.

 

Model Parameter Value

 

FedAvg, Global model Input dim 93
Layers [32,32,32]
Output dim 1
Loss cross entropy
Dropout 0.2
Batchsize 32
Activation ReLU
Optimizer SGD
Learning rate 0.1
Momentum 0.5
Global epochs 50

 

FedAvg, Global Adversary Input dim 32
Layers [32,32,32]
Output dim 2
Loss cross entropy
Dropout 0.2
Batchsize 32
Activation ReLU
Optimizer SGD
Learning rate 0.1
Momentum 0.5
Global epochs 50

 

LG-FedAvg - Ave, Local adversary
LG-FedAvg + Ave, Local model
Input dim 93
Layers [32,32,32]
Output dim 1
Loss cross entropy
Dropout 0.2
Batchsize 32
Activation ReLU
Optimizer SGD
Learning rate 0.1
Momentum 0.5
Local epochs 10

 

LG-FedAvg - Ave, Local adversary
LG-FedAvg + Ave, Local adversary
Input dim 93
Layers [32,32,32]
Output dim 2
Loss cross entropy
Dropout 0.2
Batchsize 32
Activation ReLU
Optimizer SGD
Learning rate 0.1
Momentum 0.5
Local epochs 10

 

LG-FedAvg - Ave, Local adversary
LG-FedAvg + Ave, Global model
Input dim 93
Layers [32,32,32]
Output dim 2
Loss cross entropy
Dropout 0.2
Batchsize 32
Activation ReLU
Optimizer SGD
Learning rate 0.1
Momentum 0.5
Global epochs 10

 

Table 8: Table of hyperparameters for experiments on learning fair representations on the UCI adult dataset.