In this paper we study the collaborative PAC learning problem recently proposed in Blum et al. BHPQ17 . In this problem we have an instance space , a label space , and an unknown target function chosen from the hypothesis class . We have players with distributions labeled by the target function . Our goal is to probably approximately correct (PAC) learn the target function for every distribution . That is, for any given parameters , we need to return a function
so that with probability, agrees with the target on instances of at least probability mass in for every player .
As a motivating example, consider a scenario of personalized medicine where a pharmaceutical company wants to obtain a prediction model for dose-response relationship of a certain drug based on the genomic profiles of individual patients. While existing machine learning methods are efficient to learn the model with good accuracy for the whole population, for fairness consideration, it is also desirable to ensure the model accuracies among demographic subgroups, e.g. defined by gender, ethnicity, age, social-economic status and etc., where each of them is associated with a label distribution.
We will be interested in the ratio between the sample complexity required by the best collaborative learning algorithm and that of the learning algorithm for a single distribution, which is called the overhead ratio. A naïve approach for collaborative learning is to allocate a uniform sample budget for each player distribution, and learn the model using all collected samples. In this method, the players do minimal collaboration with each other and it leads to an overhead for many hypothesis classes (which is particularly true for the classes with fixed VC dimension – the ones we will focus on in this paper). In this paper we aim to develop a collaborative learning algorithm with the optimal overhead ratio.
We will focus on the hypothesis class with VC dimension . For every , let be the sample complexity needed to -PAC learn the class . It is known that there exists an -PAC learning algorithm with Hanneke16 . We remark that we will use the algorithm as a blackbox, and therefore our algorithms can be easily extended to other hypothesis classes given their single-distribution learning algorithms.
Given a function and a set of samples , let be the error of on . Given a distribution over , define to be the error of on . The -PAC -player collaborative learning problem can be rephrased as follows: For player distributions and a target function , our goal is to learn a function so that . Here we allow the learning algorithm to be improper, that is, the learned function does not have to be a member of .
Blum et al. BHPQ17 showed an algorithm with sample complexity . When , this leads to an overhead ratio of (assuming , are constants). In this paper we propose an algorithm with sample complexity (thm:main-2), which gives an overhead ratio of when and for constant , matching the lower bound proved in Blum et al. BHPQ17 .
Similarly to the algorithm in Blum et al. BHPQ17 , our algorithm runs in rounds and return the plurality of the functions computed in each round as the learned function . In each round, the algorithm adaptively decides the number of samples to be taken from each player distribution, and calls to learn a function. While the algorithm in Blum et al. BHPQ17 uses a grouping idea and evenly takes samples from the distribution in each group, our algorithm adopts the multiplicative weight method. In our algorithm, each player distribution is associated with a weight which helps to direct the algorithm to distribute the sample budget among all player distributions. After each round, the weight for a player distribution increases if the function learned in the round is not accurate on the distribution, letting the algorithm pay more attention to it in the future rounds. We will first present a direct application of the multiplicative weight method which leads to a slightly worse sample complexity bound (thm:main-1), and then prove thm:main-2 with more refined algorithmic ideas.
On the lower bound side, the lower bound result in Blum et al. BHPQ17 is only for the special case when . We extend their result to every and . In particular, we show that the sample complexity for collaborative learning has to be for constant (thm:lb). Therefore, the sample complexity of our algorithm is optimal when . 111We note that this is a stronger statement than the earlier one on the “the optimal overhead ratio of for ” in several aspects. First, the showing the optimal overhead ratio only needs a minimax lower bound; while in the latter statement we claim the optimal sample complexity for every and in the range. Second, our latter statement works for a much wider parameter range for and .
Finally, we have implemented our algorithms and compared with the one in Blum et al. BHPQ17 and the naïve method on several real-world datasets. Our experimental results demonstrate the superiority of our algorithm in terms of the sample complexity.
As mentioned, collaborative PAC learning was first studied in Blum et al. BHPQ17 . Besides the problem of learning one hypothesis that is good for all players’ distributions (called the centralized collaborative learning in BHPQ17 ), the authors also studied the case in which we can use different hypotheses for different distributions (called personalized collaborative learning). For the personalized version they obtained an overhead in sample complexity. Our results show that we can obtain the same overhead for the (more difficult) centralized version. In a concurrent work nguyen2018improved , the authors showed the similar results as in our paper.
Both our algorithms and Adaboost freund1997decision use the multiplicative weights method. While Adaboost places weights on the samples in the prefixed training set, our algorithms place weights on the distributions of data points, and adaptively acquire new samples to achieve better accuracy. Another important feature of our improved algorithm is that it tolerates a few “failed rounds” in the multiplicative weights method, which requires more efforts in its analysis and is crucial to shaving the extra factor when .
Balcan et al. BBFM12 studied the problem of finding a hypothesis that approximates the target function well on the joint mixture of distributions of players. They focused on minimizing the communication between the players, and allow players to exchange not only samples but also hypothesis and other information. Daume et al. DPSV12a ; DPSV12b
studied the problem of computing linear separators in a similar distributed communication model. The communication complexity of distributed learning has also been studied for a number of other problems, including principal component analysisLBKW14 , clustering BEL13 ; GYZ17 , multi-task learning WKS16 , etc.
Another related direction of research is the multi-source domain adaption problem MMR08 , where we have distributions, and a hypothesis with error at most on each of the distributions. The task is to combine the hypotheses to a single one which has error at most on any mixture of the distribution. This problem is different from our setting in that we want to learn the “global” hypothesis from scratch instead of combine the existing ones.
2 The Basic Algorithm
In this section we propose an algorithm for collaborative learning using the multiplicative weight method. The algorithm is described in alg:main-1, using alg:test as a subroutine.
We briefly describe alg:main-1 in words. We start by giving a unit weight to each of the player. The algorithm runs in rounds, and players’ weights will change at each round. At round , we take a set of samples from the average distribution of the
players weighted by their weights. We then learn a classifierfor samples in , and test for each player whether agrees with the target function with probability mass at least on distribution . If yes then we keep the weight of the -th player; otherwise we multiply its weight by a factor of , so that will attract more attention in the future learning process. Finally, we return a classifier which takes the plurality vote222I.e. the most frequent value, where ties broken arbitrarily. of the classifiers that we have constructed. We note that we make no effort to optimize the constants in the algorithms and their theoretical analysis; while in the experiment section, we will tune the constants for better empirical performance.
The following lemma shows that Test returns, with high probability, the desired set of players where is an accurate hypothesis for its own distribution. We say a call to Test successful if its returning set has the properties described in lem:utest. The omitted proofs in this section can be found in app:basic.
With probability at least , returns a set of players that includes 1) each such that , 2) none of the such that .
Given a function and a distribution , we say that is a good candidate for if . The following lemma shows that if we have a set of functions where most of them are good candidates for , then the plurality vote of these functions also has good accuracy for .
Let be a set of functions such that more than of them are good candidates for . Let , we have that .
We let the be the event that every call of the learner and Test is successful. It is straightforward to see that
Now we are ready to prove the main theorem for alg:main-1.
alg:main-1 has the following properties.
With probability at least , it returns a function such that for all .
Its sample complexity is .
Proof. While the sample complexity is easy to verify, we focus on the proof of the first property. In particular, we show that when happens (which is with probability at least by (1)), we have for all .
For now till the end of the proof, we assume that happens.
For each round , we have that . Therefore, by Markov inequality, we have that . In other words,
Now consider the total weight , we have
By lem:utest and , we have that
Now let us focus on an arbitrary player . We will show that for at least of the rounds , we have , and this will conclude the proof of this theorem thanks to lem:maj.
Suppose the contrary: for more than of the rounds, we have . At each of such round , we have because of lem:utest and , and therefore . Therefore, we have . Together with (4), we have , which is a contradiction for .
3 The Quest for Optimality via Robust Multiplicative Weights
In this section we improve the result in Theorem 3 to get an optimal algorithm when is polynomially bounded by (see Theorem 4; the optimality will be shown in Section 4). In fact, our improved algorithm (alg:main-2 using alg:ntest as a subroutine), is almost the same as alg:main-1 (using alg:test as a subroutine). We highlight the differences as follows.
The total number of iterations at line:alg-main-1-1 of alg:main-1 is changed to .
The failure probability for the single-distribution learning algorithm at line:alg-main-1-2 of alg:main-1 is increased to a constant .
The number of times that each distribution is sampled at line:alg-test-1 of alg:test is reduced to .
Although these changes seem minor, it requires substantial technical efforts to establish thm:main-2. We describe the challenge and sketch our solution as follows.
While the 2nd and 3rd items lead to the key reduction of the sample complexity, they make it impossible to use the union bound and claim that with high probability “every call of and Test is successful” (see Inequality eq:calEprob in the analysis for alg:main-1).
To address this problem, we will make our multiplicative weight analysis robust against occasionally failed rounds so that it works when “most calls of and WeakTest are successful”.
In more details, we will first work on the total weights at the -th round, and show that conditioned on the -th round, is upper bounded by (where in contrast we had a stronger and deterministic statement in the analysis for the basic algorithm). Using Jensen’s inequality we will be able to derive that is upper bounded by
. Then, using Azuma’s inequality for supermartingale random variables, we will show that with high probability,, i.e. , which corresponds to in the basic proof. On the other hand, recall that in the basic proof we had to show that if for more than 30% of the rounds, the function is not a good candidate for a player distribution , then we have . In the analysis for the improved algorithm, because the WeakTest procedure fails with much higher probability, we need to use concentration inequalities and derive a slightly weaker statement (). Finally, we will put everything together using the same proof via contradiction argument, and prove the following theorem.
alg:main-2 has the following properties.
With probability at least , it returns a function such that for all .
Its sample complexity is .
Now we prove thm:main-2.
Similarly to lem:utest, applying prop:chernoff (but without the union bound), we have the following lemma for WeakTest.
For each player , with probability at least , the following hold, 1) if , then ; 2) if , then .
Let the indicator variable if the desired event described in lem:ntest for and time does not happen; and let otherwise. By lem:ntest, we have . By prop:chernoff, for each player , we have .
Now let be the event that for every . Via a union bound, we have that
Let the indicator variable if the learner fails at time ; and let otherwise. We have
Let be the total weights at time . For each , similarly to (3), we have
For each such that , by lem:ntest, we know that . Therefore, if we take the expectation over the randomness of WeakTest at time , we have,
When , similarly to the proof of thm:main-1, we have , and
Together with (6), we have .
Let , and by Jensen’s inequality, we have . Therefore, we have .
Now let for all . We have that is a supermartingale and for all . By prop:azuma and noticing that , we have . Let be the event that , we have that
Now we are ready to prove thm:main-2 for alg:main-2.
Proof. [of thm:main-2] While the sample complexity is easy to verify, we focus on the proof of the first property. In particular, we show that when happens (which is with probability at least by (12)), we have for all .
Let us consider an arbitrary player . We will show that when happens, for at least the times , we have , and this will conclude the proof of this theorem thanks to lem:maj.
Suppose the contrary: for more than of the times, we have . Because of , for more than of the times , we have . Therefore, we have . On the other hand, by we have . Therefore, we reach , which is a contradiction to .
4 Lower Bound
We show the following lower bound result, which matches our upper bound (thm:main-1) when and .
In collaborative PAC learning with players and a hypothesis class of VC-dimension , for any , there exists a hard input distribution on which any -learning algorithm needs samples in expectation, where the expectation is taken over the randomness used in obtaining the samples and the randomness used in drawing the input from the input distribution.
The proof of thm:lb is similar to that for the lower bound result in BHPQ17 ; however, we need to generalize the hard instance provided in BHPQ17 in two different cases. We briefly discuss the high level ideas of our generalization here, and leave the full proof to app:lb due to space constraints.
The lower bound proof in BHPQ17 (for ) performs a reduction from a simple player problem to a -player problem, such that if we can -PAC learn the -party problem using samples in total, then we can -PAC learn the single player problem using samples. Now for the case when , we need to change the single player problem used in BHPQ17 whose hypothesis class is of VC-dimension to one whose hypothesis class is of VC-dimension . For the case when , we essentially duplicate the hard instance for a -player problem times, getting a hard instance for a -player problem, and then perform the random embedding reduction from the single player problem to the -player problem. See app:lb for details.
We present in this section a set of experimental results which demonstrate the effectiveness of our proposed algorithms.
Our algorithms are based on the assumption that given a hypothesis class, we are able to compute its VC dimension and access an oracle to compute an -classifier with sample complexity . In practice, however, it is usually computationally difficult to compute the exact VC dimension for a given hypothesis class. Also, the VC dimension usually only proves to be a very loose upper bound for the sample complexity needed for an -classifier.
To address these practical difficulties, in our experiment, we treat the VC dimension
as a parameter to control the sample budget. More specifically, we will first choose a concrete model as the oracle; in our implementation, we choose the decision tree. We then set the parameterand gradually increase to determine the sample budget. For each fixed sample budget (i.e., each fixed ), we run the algorithm for times and test whether the following happens,
Here is a parameter we choose and is the classifier returned by the collaborative learning algorithm to be tested. The empirical probability in (13) is calculated over the runs. We finally report the minimum number of samples consumed by the algorithm to achieve (13).
Note that in our theoretical analysis, we did not try to optimize the constants. Instead, we tune the constants for both CenLearn and MWeights for better performance. Please find more implementation details in the appendix.
We will test the collaborative learning algorithms using the following data sets.
Magic-Even BAM04 . This data set is generated to simulate registration of high energy gamma particles in an atmospheric Cherenkov telescope. There are instances and each belongs to one of the two classes (gamma and hadron). There are attributes in each data point. We randomly partition this data set into subsets (namely, ).
Magic-1. The raw data set is the same as we have in Magic-Even. Instead of random partitioning, we partition the data set into and based on the two different classes, and make more copies of so that are identical. In our case we set .
Magic-2. This data set differs from Magic-1 in the way of constructing and : we partition the original data set into and
based on the first dimension of the feature vectors; we then make duplicates for. Here we again set .
Wine PAF09 . This data set contains physicochemical tests for white wine, and the scores of the wine range from to . There are instances and there are attributes in the feature vectors. We partition the data set into based on the first two dimensions.
Eye. This data set consists of 14 EEG values and a value indicating the eye state. There are instances in this data set. We partition it into based on the first two dimensions.
Letter WS91 . This data set has instances, each in . There are classes, each representing one of capital letters. We partition this data set into subsets based on the first dimensions of the feature vectors.
We compare our algorithms with the following two baseline algorithms,
Naive. In this algorithm we treat all distributions equally. That is, given a budget , we sample training samples from . We then train a classifier (decision tree) using those samples.
CenLearn, this is the implementation of the algorithm proposed by Blum et al. BHPQ17 .
Since our alg:main-1 and alg:main-2 are very similar, and alg:main-2 has better theoretical guarantee, we will only test alg:main-2, denoted as MWeights, in our experiments.
Experimental Results and Discussion.
The experimental results are presented in fig:dt. We test the algorithms for each data set using multiple values of the error threshold , and report the sample complexity for Naive, MWeights and CenLearn.
In fig:magic-even, we notice that Naive uses less samples than its competitors. This phenomenon is predictable because in Magic-Even, are constructed via random partitioning, which is the easiest case for Naive. Since MWeights and CenLearn need to train multiple classifiers, each classifier will get fewer training samples than Naive when the total budgets are the same.
In fig:magic-1 and fig:magic-2, are constructed in a way that are identical, and is very different from other distributions. Thus the overall distribution (i.e., ) used to train Naive is quite different from the original data set. One can observe from those two figures that MWeights still works quite well while Naive suffers.
In fig:magic-1-fig:letter, one can observe that MWeights uses fewer samples than its competitors in almost all cases, which shows the superiority of our proposed algorithm. CenLearn outperforms Naive in general. However, Naive uses slightly fewer samples than CenLearn in some cases (e.g., fig:wine). This may due to the fact that the distributions in those cases are not hard enough to show the superiority of CenLearn over Naive.
To summarize, our experimental results show that MWeights and CenLearn need fewer samples than Naive when the input distributions are sufficiently different. MWeights consistently outperforms CenLearn, which may due to the facts that MWeights has better theoretical guarantees and is more straightforward to implement.
In this paper we consider the collaborative PAC learning problem. We have proved the optimal overhead ratio and sample complexity, and conducted experimental studies to show the superior performance of our proposed algorithms.
One open question is to consider the balance of the numbers of queries made to each player, which can be measured by the ratio between the largest number of queries made to a player and the average number of queries made to the players. The proposed algorithms in this paper may attain a balance ratio of in the worst case. It will be interesting to investigate:
Whether there is an algorithm with the same sample complexity but better balance ratio?
What is the optimal trade-off between sample complexity and balance ratio?
Jiecao Chen and Qin Zhang are supported in part by NSF CCF-1525024 and IIS-1633215. Part of the work was done when Yuan Zhou was visiting the Shanghai University of Finance and Economics.
- (1) M. Balcan, A. Blum, S. Fine, and Y. Mansour. Distributed learning, communication complexity and privacy. In COLT, pages 26.1–26.22, 2012.
- (2) M. Balcan, S. Ehrlich, and Y. Liang. Distributed -means and -median clustering on general communication topologies. In NIPS, pages 1995–2003, 2013.
- (3) A. Blum, N. Haghtalab, A. D. Procaccia, and M. Qiao. Collaborative PAC learning. In NIPS, pages 2389–2398, 2017.
- (4) R. Bock, A. Chilingarian, M. Gaug, F. Hakl, T. Hengstebeck, M. Jirina, J. Klaschka, E. Kotrc, P. Savickỳ, S. Towers, et al. Methods for multidimensional event classification: a case study. as Internal Note in CERN, 2003.
- (5) P. Cortez, A. Cerdeira, F. Almeida, T. Matos, and J. Reis. Modeling wine preferences by data mining from physicochemical properties. Decision Support Systems, 47(4):547–553, 2009.
- (6) A. Ehrenfeucht, D. Haussler, M. J. Kearns, and L. G. Valiant. A general lower bound on the number of examples needed for learning. Inf. Comput., 82(3):247–261, 1989.
- (7) Y. Freund and R. E. Schapire. A decision-theoretic generalization of on-line learning and an application to boosting. Journal of computer and system sciences, 55(1):119–139, 1997.
- (8) P. W. Frey and D. J. Slate. Letter recognition using holland-style adaptive classifiers. Machine Learning, 6:161–182, 1991.
- (9) S. Guha, Y. Li, and Q. Zhang. Distributed partial clustering. In SPAA, pages 143–152, 2017.
- (10) S. Hanneke. The optimal sample complexity of pac learning. The Journal of Machine Learning Research, 17(1):1319–1333, 2016.
- (11) H. D. III, J. M. Phillips, A. Saha, and S. Venkatasubramanian. Efficient protocols for distributed classification and optimization. In ALT, pages 154–168, 2012.
- (12) H. D. III, J. M. Phillips, A. Saha, and S. Venkatasubramanian. Protocols for learning classifiers on distributed data. In AISTATS, pages 282–290, 2012.
- (13) Y. Liang, M. Balcan, V. Kanchanapally, and D. P. Woodruff. Improved distributed principal component analysis. In NIPS, pages 3113–3121, 2014.
- (14) Y. Mansour, M. Mohri, and A. Rostamizadeh. Domain adaptation with multiple sources. In NIPS, pages 1041–1048, 2008.
- (15) H. L. Nguyen and L. Zakynthinou. Improved Algorithms for Collaborative PAC Learning. arXiv preprint arXiv:1805.08356, 2018.
- (16) J. Wang, M. Kolar, and N. Srebro. Distributed multi-task learning. In AISTATS, pages 751–760, 2016.
Appendix A Concentration Bounds
Proposition 7 (Multiplicative Chernoff bound)
Let be independent random variables with values in . Let . For every , we have that
Definition 8 (Supermartingale Random Variables)
A discrete-time supermartingale is a sequence of random variables that satisfies for any time ,
Proposition 9 (Azuma’s inequality for supermartingale random variables)
Suppose is a supermartingale and almost surely. Then for all positive integers and all positive reals ,
Appendix B Omitted Proofs in sec:basic
Proof. [of lem:utest] For each such that , by prop:chernoff, we have that
Therefore, with probability at least , is included in the output of Test.
Similarly, for each such that , by prop:chernoff, we have that
Therefore, with probability at least , is not included in the output of Test.
The lemma is now proved by a union bound over at most players.
Proof. [of lem:maj] Suppose for contradiction that . Given a sample , when , we know that for more than half of the ’s, we have . Therefore, we have
On the other hand, by discussing whether is a good candidate for , we have
which contradicts (14).
Appendix C Proof of Theorem 6
Instance space .
Hypothesis class: is the collection of all binary functions on that map to .
Target function: is chosen uniformly at random from .
Player’s distribution: , and .
Lemma 10 ()
For any , any -learning algorithm on needs samples in expectation, where the expectation is taken over the randomness used in obtaining the samples and the randomness used in drawing the input from .
We prove Theorem 6 in two cases: and .
The case .
Let . We create the following hard input distribution, denoted by .
Instance space: .
Hypothesis class: is the collection of all binary functions on that map to .
Target function: is chosen uniformly at random from .
Player ’s distribution (for each ): Assigns weights to items in as follows: , and . For any other item , .
Note that the induced input distribution for the -th player is the same as for any .
If there exists an -learning algorithm that uses samples in expectation on input distribution , then there exists an -learning algorithm that uses samples in expectation on input distribution .
Proof. We construct for input distribution using for input distribution as follows.
draws an input instance from , and samples uniformly at random from .
simulates on instance