Mix and Match: An Optimistic Tree-Search Approach for Learning Models from Mixture Distributions

07/23/2019 ∙ by Matthew Faw, et al. ∙ The University of Texas at Austin ibm 5

We consider a co-variate shift problem where one has access to several marginally different training datasets for the same learning problem and a small validation set which possibly differs from all the individual training distributions. This co-variate shift is caused, in part, due to unobserved features in the datasets. The objective, then, is to find the best mixture distribution over the training datasets (with only observed features) such that training a learning algorithm using this mixture has the best validation performance. Our proposed algorithm, Mix&Match, combines stochastic gradient descent (SGD) with optimistic tree search and model re-use (evolving partially trained models with samples from different mixture distributions) over the space of mixtures, for this task. We prove simple regret guarantees for our algorithm with respect to recovering the optimal mixture, given a total budget of SGD evaluations. Finally, we validate our algorithm on two real-world datasets.



There are no comments yet.


page 1

page 2

page 3

page 4

Code Repositories

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The problem of covariate shift – where the distribution of the validation set has shifted and does not exactly match that of the training set – has long been appreciated as an issue of central importance for real-world problems (e.g., shimodaira2000improving , gretton2009covariate and references therein). Covariate shift is often ascribed to a changing population or underlying dynamics, bias in selection, or imperfect or noisy or missing measurements. Across these settings, a number of approaches to mitigate covariate shift attempt to re-weight the samples of the training set in order to match the target set distribution shimodaira2000improving ; zadrozny2004learning ; huang2007correcting ; gretton2009covariate . For example, huang2007correcting ; gretton2009covariate use unlabelled data in order to compute a good kernel re-weighting.

We consider a setting where covariate shift is due to unobserved variables in different populations (datasets). A motivating example is the setting of predictive health care in different regions of the world. Here, the unobserved variables may represent, for example, prevalence and expression of different conditions, genes, etc. in the makeup of the population. Another key example, and one for which we have real-world data (see Section 6), is predicting what insurance plan a customer will purchase, in a given state. The unobserved variables in this setting might include employment information (security at work), risk-level of driving, or state-specific factors such as weather or other driving-related features.

Motivated by such applications, we consider the setting where the joint distribution (of observed, unobserved variables and labels) may differ across populations, but the conditional distribution of the label (conditioned on both observed and unobserved variables) remains invariant. This motivates our point of departure from other approaches in the literature: rather than searching over possible reweightings of the

training samples, we instead search over different mixing weights of the training populations, in order to optimize performance on the validation set.

The main algorithmic contribution of this paper is Mix&Match – an algorithm that is built on SGD and a variant of optimistic tree-search (closely related to Monte Carlo Tree Search). Given a budget on the total number of SGD iterations, Mix&Match adaptively allocates this budget to different population reweightings through an iterative tree-search procedure. Importantly, Mix&Match expends a majority of the SGD iteration budget on reweightings that are "close" to the optimal reweighting mixture by using two important ideas:
(i) Parsimony in expending iterations: For a reweighting distribution that we have low confidence of being "good", Mix&Match expends only a small number of SGD iterations to train the model; doing so, however, results in biased and noisy evaluation of this model, due to early stopping in training.
(ii) Re-use of models: Rather than train a model from scratch, Mix&Match reuses and updates a partially trained model from past reweightings that are "close" to the currently chosen reweighting (effectively re-using SGD iterations from the past).

The analysis of Mix&Match requires a new concentration bound on error of SGD subject to early stopping (which we believe is of independent interest). We accomplish this by first using a coarse bound on the martingale difference in the SGD evolution to bound the SGD iterate. The SGD evolution is then reanalyzed in light of this coarse bound to derive a much finer bound on the iterate, providing tighter concentrations on SGD evolution that are especially useful in the early-stopping regime. Combining this new bound with bandit analysis of optimistic tree-search, we provide regret guarantees with respect to the optimal mixture distribution. Finally in Section 6, we empirically validate the benefits of Mix&Match with respect to a genie baseline.

2 Related Work

Transfer learning has assumed an increasingly important role, especially in settings where we are either computationally limited, or data-limited, and yet we have the opportunity to leverage significant computational and data resources yet on domains that differ slightly from the target domain raina2007self ; pan2009survey ; dai2009eigentransfer

. This has become an important paradigm in neural networks and other areas

yosinski2014transferable ; oquab2014learning ; bengio2011deep ; kornblith2018better .

An important related problem is that of covariate shift shimodaira2000improving ; zadrozny2004learning ; gretton2009covariate . The problem here is that the target distribution may be different from the training distribution. A common technique for addressing this problem is by reweighting the samples in the training set, so that the distribution better matches that of the training set. There have been a number of techniques for doing this. An important recent thread has attempted to do this by using unlabelled data huang2007correcting ; gretton2009covariate . Other approaches have considered a related problem of solving a weighted log-likelihood maximization shimodaira2000improving , or by some form of importance sampling sugiyama2007covariate ; sugiyama2008direct or bias correction zadrozny2004learning . In mohri2019agnostic , the authors study a related problem of learning from different datasets, but provide mini-max bounds in terms of an agnostically chosen test distribution.

Our work is related to, but differs from all the above. As we explain in Section 3, we share the goal of transfer learning: we have access to enough data for training, but from a family of distributions that are different than the validation distribution (from which we have only enough data to validate). Under a model of covariate shift due to unobserved variables, we show that a target goal is finding an optimal reweighting of populations rather than data points. We use optimistic tree search to address precisely this problem – something that, as far as we know, has not been undertaken.

A key part of our work is working under a computational budget, and then designing an optimistic tree-search algorithm under uncertainty. We use a single SGD iteration as the currency denomination of our budget – i.e., our computational budget requires us to minimize the number of SGD steps in total that our algorithm computes. Enabling MCTS requires a careful understanding of SGD dynamics, and the error bounds on early stopping. There have been important SGD results studying early stopping, e.g., hardt2015train ; bottou2018optimization and generally results studying error rates for various versions of SGD and recentered SGD sgdHogwild ; defazio2014saga ; roux2012stochastic

. Our work requires a new high probability bound, which we obtain in the Supplemental material, Section

B. In sgdHogwild , the authors have argued that a uniform norm bound on the stochastic gradients is not the best assumption, however the results in that paper are in expectation. In this paper, we derive our SGD high-probability bounds under the mild assumption that the SGD gradient norms are bounded only at the optimal weight .

There are several papers harvey2018tight ; rakhlin2012making which derive high probability bounds on the suffix averaged and final iterates returned by SGD for non-smooth strongly convex functions. However, both papers operate under the assumption of uniform bounds on the stochastic gradient. Although these papers do not directly report a dependence on the diameter of the space, since they both consider projected gradient descent, one could easily translate their constant dependence to a sum of a diameter dependent term and a stochastic noise term (by using the bounded gradient assumption from sgdHogwild , for example). However, as the set into which the algorithm would project is unknown to our algorithm (i.e., it would require knowing ), we cannot use projected gradient descent in our analysis. As we see in later sections, we need a high-probability SGD guarantee which characterizes the dependence on diameter of the space and noise of the stochastic gradient. It is not immediately clear how the analysis in harvey2018tight ; rakhlin2012making could be extended in this setting under the gradient bounded assumption in sgdHogwild . In Section 5, we instead develop the high probability bounds that are needed in our setting.

Optimistic tree search makes up the final important ingredient in our algorithm. These ideas have been used in a number of settings bubeck2011x ; grill2015black . Most relevant to us, is a recent extension of these ideas to a setting with biased search sen2018multi ; sen2019noisy .

3 Problem Setting and Model

Covariate shift is a fundamental problem, with diverse origins. We consider the setting where we have multiple populations from which we can learn, and then a population to which we wish to generalize, but have only limited data – enough to validate. Our point of departure from prior work is our model for covariate shift. Specifically, we consider the setting where the cause of covariate shift is the presence of unobserved variables in training and validation populations. As an example (see also the Introduction), suppose that a firm has insurance purchase data for several states (each with observed and unobserved features), and wishes to move into a new market (e.g. a new state). Small-scale market studies in the new state provides validation data for a model, but the dataset size does not suffice for actually training a model. With this motivation, the formal setting for our problem, and the relationship to covariate shift, are as follows.

Data Model: We consider a setting where we are given datasets (aka populations) for training and a (smaller) test dataset for validation, consisting of samples where is the observed feature, and the corresponding label. Traditionally, we would regard data set as governed by the distribution However, we consider the setting where these samples are projections from corresponding data sets consisting of samples where is the unobserved feature, that together with the observed feature influences the label the corresponding distribution describing the data is given by and the conditional distribution is given by

In this paper, we are motivated by problems where the underlying distribution

(which we cannot directly observe or estimate) is

invariant across the training and validation datasets. We refer to this property as the conditional shift invariance property. Note however that other distributions such as the marginal distributions of the observed features (aka the observed covariate distribution), or the relation between the observed and unobserved features can vary (aka shift) across the training and validation sets.

Our goal in this paper is to train a learning algorithm whose parameters are encoded as

which is high-dimensional. The training is performed through minimizing a loss function

, where and is drawn from a mixture distribution over the training datasets. Specifically, let be the dimensional simplex. For any , let be a mixture distribution with as the components, and being the weight of .

Why search over mixture distributions? Our approach of searching over mixture distributions over the training datasets (parameterized by the simplex ) is motivated by the following. Suppose that the true joint distribution over the validation dataset, lies in the convex hull of the corresponding distributions over the training datasets Then, from the conditional shift invariance property, it immediately follows111Note that we cannot learn simply by matching the mixture distribution over the training datasets to that of the validation set (both containing only the observed features and labels). This is because decomposes as where

is unknown and potentially differs across datasets. Thus, in a setting with unobservable features, approaches that try to directly learn the mixture weights by comparing with the test dataset (e.g., using an MMD distance or moment matching) learns the wrong mixture weights.

that for some . Further, let parameterize the model that minimizes the test cross-entropy (CE) loss:

between and the corresponding model induced label distribution Further, let be the conditional distribution over labels induced by the mixture distribution over the training datasets. The following proposition can be shown to hold using the conditional shift invariance property.

Proposition 1.

Thus, it follows that if the model class is rich enough to result in a global minimizer of cross-entropy loss with respect to the validation distribution (i.e., zero KL divergence), searching over mixture distributions over the training sets recovers the true model .

Loss Function Model: Let , be the averaged loss function w.r.t. to the distribution . In the course of this paper, we drop from whenever it is clear from context. Let denote the optimal choice of w.r.t. to the distribution i.e.,


Let be the mean test error of the weight . Finally, let us define as the test loss surface as a function of the mixture weights.

Assumptions on the Loss Function: For our theoretical results, we make the following standard assumptions on the loss function.

Assumption 1.

We have the following assumptions similar to sgdHogwild ,

  • is -smooth w.r.t. i.e., for all and .

  • is convex w.r.t. , i.e., for all and .

  • is -strongly convex w.r.t. i.e., for all and .

  • is -Lipschitz, i.e., , for all and .

Note that we make some more specialized assumptions on the loss function in addition to Assumption 1 in some of our theoretical results.

We additionally assume the following bound on the gradient of at along every sample path:

Assumption 2.

There exists a constant such that

We note that this assumption is weaker than the typical universal bound assumed in rakhlin2012making ; harvey2018tight , and is taken from sgdHogwild .

Evaluation and Feedback Model: We assume that we are given a budget of performing a total of stochastic gradient descent (SGD) steps using the loss function where is drawn from some mixture distribution over the training datasets. We are given the freedom to choose any mixture distribution , perform SGD steps using samples drawn from and starting at an initial (see Algorithm 2

for a description of the SGD algorithm). Then the resulting weight vector

can be evaluated on the validation set. We assume oracle access for the evaluation i.e., given a weight vector we directly get to observe the exact mean test error . We can do as many such evaluations, each such evaluation with a different and a different number of SGD. iterations, as the cumulative budget allows us to perform.

Remark 1 (Oracle Access).

Note that the number of evaluations performed by our algorithm is dependent on the properties of the loss surface over the simplex of the mixture weights. We assume that (the number of mixtures) is much smaller than the dimension of . Due to this dimensionality reduction, it is justified to assume that the validation set is big enough to give accurate estimate of the test error, for the small number of evaluations we perform, but it is not big enough to directly train the learning algorithm and find an optimal .

Hierarchical Partitions over the Simplex: We assume that one can construct a hierarchical partition over the simplex of mixture weights , in the form of an binary tree (potentially of infinite depth). Similar assumptions have been made in a long line of tree-search based approaches for black-box optimization munos2011optimistic ; bubeck2011x ; grill2015black ; sen2018multi ; sen2019noisy ; shang2017adaptive . The hierarchical partition is composed of cells which partition the domain into nested subsets. Here, is a depth/height parameter while is an index. For any depth , the cells denote a partitioning of the space . We use the notation as the index of a cell/node, but we overload the notation to also indicate the region (subset of ) itself. To initialize the tree at depth we place a single node whose domain is the entire simplex, i.e., . A cell is partitioned into a pair of nodes (the child nodes) at depth level whose indices are and respectively (note that this is for ease of notation; we can implement other non-binary splits as well, as we do in our empirical evaluations). A cell at is evaluated when: (i) a fixed mixture weight , is chosen and some specified number of SGD steps starting from a specified initial , using samples from the distribution are performed to obtain and (ii) is finally evaluated on the validation set, returning . We also define to be the optimal learning parameter for the weights representing the cell .

We require the following joint condition on the test loss surface (defined above) and the hierarchical partitions.

Condition 1.

There exists and such that for all cells , we have , for any . Due to smoothness assumptions, this further implies that , where .

Note that weaker versions of Condition 1 are required for tree-search based black-box optimization methods grill2015black ; sen2018multi ; shang2017adaptive . However, our guarantees are a combination of SGD analysis over the space as well as black-box optimization over the space and therefore requires the above condition. Finally, we note that we provably justify the condition above in Theorem 1.

We also require the standard near-optimality dimension definition from the tree-search literature grill2015black ; sen2018multi .

Definition 1.

The near-optimality dimension of with respect to parameters is given by,


where is the number of cells such that .

Here, we assume that there is a unique optimal such that . The lower the near-optimality dimension, the easier is the black-box optimization problem grill2015black .

4 Algorithm

In this section, we present our main algorithm Mix&Match (Algorithm 1). Our algorithm progressively goes down the hierarchical partition tree to select promising mixture distribution nodes over the simplex. During each tree iteration, a leaf node is selected according to the optimistic criterion in Algorithm 1. Then the two children and are evaluated using Algorithm 2. The evaluation of a node involves training the learning weight using SGD samples drawn from the node’s mixture distribution. The algorithm allocates a budget of for expanding any node at height , where is a carefully chosen increasing function motivated by the theory in Corollary 1. Thus the tree-search uses less SGD iterations at lower depths and uses more budget on nodes that are deeper in the tree where more accuracy is required. Moreover, a crucial feature of the algorithm is that the starting iterate when a child node is being evaluated, is set as the ending iterate of its parent i.e., . We show in Theorem 2 and Corollary 1, that this model re-use leads to a more judicious use of the total SGD iteration budget.

1:Real numbers , , hierarchical partition , SGD step budget
2:Expand the root node using Algorithm 2 and form two leaf nodes .
3:Cost (Number of SGD steps used):
4:while  do
5:     Select the leaf with minimum .
6:     Expand this node; add to the children of by querying them using Algorithm 2.
7:     .
8:end while
9:Let be the height of
10:Let . Return and .
Algorithm 1 Mix&Match : Tree-Search over the mixtures of training datasets
1:Node with end iterate , ,
2:for  do Iterate over new child node indices
3:     Let and .
4:     for  (see Corollary 1do
5:          for .
6:     end for
7:     Obtain test error and set
8:end for
Algorithm 2 ExpandNode : Optimize over the current mixture and evaluate

5 Theoretical Results

In this section, we present our main theoretical results. Our first result shows that the optimal weights with respect to the two distributions and are close, if the mixture weights and are close. This is a justification for Condition 1.

Theorem 1.

If the loss function is -Lipschitz, -strongly convex and -smooth in terms of , for all , then we have

The above theorem is under stronger conditions than Assumption 1 and in fact follows the assumptions and proof techniques similar to Theorem 3.9 in hardt2015train . Theorem 1 implies that, if the partitions are such that for any cell at height , for all , where , then we have that , for some . It is reasonable to assume that the size of the hierarchical partitions over the simplex decrease geometrically in terms of -norm diameter, as we go down the tree.

We next seek to understand how to allocate our SGD budget as we explore new nodes in the search tree. To begin, let us consider how the solution to the optimization problem on at some node with corresponding mixture could be used to solve the optimization problem at one of its children with mixture . Under Condition 1, and assuming the distance between any at node height and corresponding child with mixture decays geometrically as a function of , the distance between the corresponding optimal models and also decays geometrically as a function of height in the tree. This leads us to hope that, if we were to obtain a good enough estimate to the problem at the parent node, and used that final iterate as the starting point for solving the optimization problem at the child node, we might only have to pay a constant number of SGD steps in order to find a solution sufficiently close to , instead of a geometrically increasing (with tree height) number of SGD steps.

To further investigate this intuition and develop a budget allocation strategy, we need to understand how the error on the final iterate of SGD scales with the noise of the stochastic gradient and the initial distance from the optimal . We thus derive a high-probability bound on the error of the last iterate of SGD that captures the dependence on the diameter of the space and the noise of the stochastic gradient. This result, and in particular its dependence on the initial distance from the optimal solution, is crucial in allowing Mix&Match to carefully allocate our total SGD budget (denoted as ) as we roll out the search tree.

Theorem 2.

Consider a sequence of samples drawn from a distribution . Let . We are interested in analyzing the iterates of SGD starting at a initial point with and evolving as . Here, is the optimum of . If and satisfy Assumptions 1 and 2 with for all and taking , we have

provided we have,

Here, we have the diameter-dependent term

and, the noise-dependent term

provided . The various parameters used are: ; ; and , and is a large enough constant which depends on the noise floor and a uniform bound on the distance between each iterate and the optimal .

Theorem 2 is a general high probability bound on SGD iterates without assuming a universal (i.e., over all ) global bound on the stochastic noise gradient as usually done in the literature bottou2018optimization ; bubeck15 ; duchi2010composite . The concentration results in Theorem 2 are under similar assumptions to the recent work in sgdHogwild , but in that work only expected bounds are obtained. The bound precisely captures the dependence on the initial diameter and the noise floor . This is key in our next theorem as detailed in the next corollary. The proof of Theorem 2 is given in Appendix B. We use Theorem 2 in order to understand how to allocate the SGD budget at a particular node.

Corollary 1.

Consider a tree node with mixture weights and optimal learning parameter . Assuming we start at a initial point such that and take SGD steps using the child node distribution where,


with , for in the expressions in Theorem 2, then w.p. at least we have

The above corollary is a direct consequence of Theorem 2 and the fact that by a triangle inequality. Thus, if we start from a good ending weight of a parent node at height and proceed with SGD with the child’s distribution (re-use of parent model), we need iterations to converge to an accuracy of the order of , which is what is needed for our tree-search to succeed. Thus, when the tree is not that deep i.e., the dominates in the above expression we only need for a node expansion at height , thus being parsimonious with our SGD iteration budget.

Going back to the discussion preceding Theorem 2, we note that the expression given in Corollary 1 does not achieve a height-independent budget schedule due to the terms which appear due to the noise of the stochastic gradients. Indeed, if only the first term in with diameter squared dependence were present in the numerator, then the step size schedule would indeed be independent of height. However, the introduction of the stochastic gradient introduces the second term in and the term.

Now we are at a position to present our final bound that characterizes the performance of Algorithm 1 as Theorem 3. In the deterministic black-box optimization literature munos2011optimistic ; sen2018multi , the quantity of interest is generally simple regret which is the difference in function value between the point returned by an algorithm (given a budget) and the optimal point in the search-space. Theorem 3 provides a similar simple regret bound on , where is the mixture weight vector returned by the algorithm given a total SGD steps budget of and is the optimal mixture.

Theorem 3.

Let be the smallest number such that

With probability at least , the tree in Algorithm 1 grows to a height of at least and returns a mixture weight s.t,


Theorem 3 shows that given a total budget of SGD steps, the tree-search recovers a mixture , such that we are only away from the optimal test error, if we perform optimization using that mixture. The parameter depends on the number of steps needed for a node expansion at different heights and crucially makes use of the fact that the starting iterate for each new node can be borrowed from the parent’s last iterate. The tree search also progressively allocates more samples to deeper nodes, as we get closer to the optimum. Similar simple regret scalings have been recently shown in the context of deterministic multi-fidelity black-box optimization sen2018multi .

6 Empirical Results

We evaluate the effectiveness of Algorithm 1 on three real-world datasets. All experiments were run in python:3.7.3 Docker containers (see https://hub.docker.com/_/python) managed by Google Kubernetes Engine running on Google Cloud Platform on n1-standard-4 instances. The code used to create the testing infrastructure can be found at https://github.com/matthewfaw/mixnmatch-infrastructure, and the code used to run experiments can be found at https://github.com/matthewfaw/mixnmatch.

For the simulations considered below, we divide the data into training, validation, and testing datasets. Hyperparameter tuning is performed using the Katib framework (


) using the validation error as the objective, and results reported in the figures below is the test error. In all figures displayed below, each data point is the average of 10 experiments, and the error bars displayed are 1 standard deviation. Note that while all error bars are displayed for all experiments, some error bars are too small to see in the plots.

6.1 Allstate Purchase Prediction Challenge


Figure 1: Test accuracy of Allstate experiment with mixture of two states


Figure 2: Test accuracy of Allstate experiment with mixture of three states
State Total Size % Train % Validate % Test % Discarded
FL 14605 100 0 0 0
CT 2836 70 7.5 22.5 0
OH 6664 0 0 2.25 97.75
Table 1: The proportions of data from each state used in training, validation, and testing for Figure 1
State Total Size % Train % Validate % Test % Discarded
FL 14605 99.34 0.16 0.5 0
CT 2836 70 7.5 22.5 0
OH 6664 2.25 0.75 2.25 94.75
Table 2: The proportions of data from each state used in training, validation, and testing for Figure 2

We begin our experimental evaluation by evaluating the efficacy of Algorithm 1 under a co-variate shift between the training and validation/test data. We consider an Allstate insurance dataset kaggle with entries from customers across different states in the US. Provided in this dataset are features about the customers (e.g. car age, state of residence, whether the customer is a homeowner) and the insurance plan the customer decides to purchase. This insurance plan is divided into 7 coverage options, and each option has either 2, 3, or 4 possible ordinal values. For simplicity, we consider predicting only one of these coverage options (given as G in the dataset, taking four ordinal values). Note that the original Kaggle dataset is provided in a time-series format, with entries corresponding to customers’ intermediate coverage plan selections, as well as their final selections. We collapse all entries corresponding to a customer into a single entry, and add a few summary statistics about intermediate selections in each collapsed entry.

We construct two experiments from this dataset. In the first (see Figure 1), we consider only customers from Florida (FL) and Connecticut (CT) as training data. Validation data consists solely of data from CT, and testing data has most (approximately 80%) of data from CT, with an additional approximately 150 entries from Ohio (OH) added as a random seed. In the second (see Figure 2), we consider again customers from FL, CT, and OH, except now there are a small number of entries from FL in the validation and testing sets, and also a small equal proportions of OH data in validation and testing sets, as well as a small number of entries added to the training set. The proportions of each state’s dataset used in these experiments are given in shown in Tables 1 and 2, respectively.

In this section, we refer to as the vector whose th entry corresponds to the proportion of the validation set corresponding to region ’s data.

We first consider the uniform sampling algorithm, which spends its entire SGD budget sampling from each available data source with equal probability, as a simple baseline.

At the other extreme, we consider a Genie algorithm that has access to This genie algorithm draws samples for SGD iteration from a mixture distribution that matches the validation data distribution, and thus provides the best-case scenario to compare against.

In addition, we consider algorithms which sample exclusively from individual datasets. These are labeled in the figures as OnlyX (where X is replaced by one of the states used in the experiment).

During hyperparameter tuning, we consider several variants of the Mix&Match algorithm. In particular, we consider several simple budget allocation functions: one which allocates the same budget to each node, one which allocates budget as a linear function of node height, and one which allocates budget as the square root of the node height. In addition, we also consider running Mix&Match with one of the above budget functions for half of the total SGD budget, and then spending the second half of the SGD budget sampling according to the mixture returned by the best node from Mix&Match . Finally, we consider two different simplex partitioning strategies: one being the Delaunay partitioning strategy which produces children at each split node (where here corresponds to the number of states in the training set), and the other being a random coordinate halving simplex partitioning strategy, which produces two children at each split node. For both experiments in this section, Mix&Match uses a constant budget function with the random coordinate halving simplex partitioning strategy, and split Mix&Match ’s budget so that the best returned mixture from tree search is used for half of the total SGD budget.

We compare the performance of these algorithms using a fully connected neural network model with 1 hidden layer and classification accuracy (number of correctly classified samples normalized by the total number of classified samples). For training this neural network model, we use a batch size of 25 samples to approximate the gradient at each SGD step for the first experiment, and a batch size of 50 for the second experiment (determined by hyperparameter tuning), and

Mix&Match runs each evaluated node using a constant 500 samples.

Recall however that as the tree grows, there is an effective increase in the number of iterations because of model re-use. Specifically, each child node initializes SGD with the weight from its parent, thus, effectively the number of iterations for the model at a node is linear in the depth of the tree.

In Figures 1 and 2, we have plotted the test accuracy for these algorithms for a range of total number of samples (denoted as SGD Iteration budget in the figures). Each data point is averaged over 10 separate runs, and error bars represent one standard deviation.

In both experiments, we observe that Mix&Match outperforms uniform sampling and approaches the performance of the Genie algorithm and the OnlyCT algorithm, despite having access to a large portion of the FL dataset, which exhibits poor test accuracy. Additionally, we observe from the first experiment that Mix&Match is able to perform well relative to the baselines despite having a small amount of data in the test set from a new state which was not present in either training or validation sets.

6.2 Wine Ratings


Figure 3: Predicting wine prices in Chile using data from France, Italy, and Spain


Figure 4: Predicting wine prices in Chile using data from US, France, Italy, and Spain
Country Total Size % Train % Validate % Test
France 17776 100 0 0
Italy 16914 100 0 0
Spain 6573 100 0 0
Chile 4416 0 50 50
Table 3: The proportions of data from each state used in training, validation, and testing for Figure 3
Country Total Size % Train % Validate % Test
US 54265 100 0 0
France 17776 100 0 0
Italy 16914 100 0 0
Spain 6573 100 0 0
Chile 4416 0 50 50
Table 4: The proportions of data from each state used in training, validation, and testing for Figure 4

In this section, we consider the effectiveness of using Algorithm 1 to make predictions on a new dataset by training on similar data from other regions. For this experiment, we use another Kaggle dataset wine

, in which we are provided binary labels indicating the presence of particular tasting notes of the wine, as well as a point score of the wine and the price quartile of the wine, for a number of wine-producing countries. The objective is to predict the price of the wine, given these labels.

Similarly to moew , we use a four-layer neural network with sigmoid activations in the inner layers, and use the Mean Squared Error loss function.

We use the same algorithms for comparison as considered in Section 6.1, except Mix&Match uses the Delaunay partitioning strategy, the entire SGD budget is used for Mix&Match (so the budget is not split in half, as was done in the Section 6.1), and the Genie algorithm is not run, as a mixture of countries which make up Chile is not a priori known. As before, OnlyX (where X now represents a country from the training set) is the algorithm which uses its entire SGD budget sampling only from the data from country X.

In both experiments (shown in Figures 3 and 4), we use a batch size of 100 to approximate each stochastic gradient, and Mix&Match runs 10 SGD steps at each evaluated node, independent of the height of the search tree. We plot each result averaged over 10 runs, and show error bars of one standard deviation.

The objective in the first experiment (Figure 3) is to train a model to predict the price of wine from Chile, given only data from France, Italy, and Spain. The results seem to indicate that Spanish wine and Chilean wine may share similar qualities, as Mix&Match produces a model which has test error comparable to test error when training a model using only data from Spain, while models trained only on data from Italian or French wine perform more poorly.

The objective of the second experiment (Figure 4) is the same as the objective in the first, except we may now use data from the US also. Interestingly, while the data on Spanish wine still appears to be the best dataset to use individually to train a model, adding the data from US wine allows Mix&Match to have better average performance than models trained on any of the four datasets individually.

In both experiments, we observe that Mix&Match outperforms uniform sampling as well as the algorithms which sample only from US, France, or Italy.


  • [1] Allstate. Allstate Purchase Prediction Challenge, 2014. https://www.kaggle.com/c/allstate-purchase-prediction-challenge.
  • [2] Dara Bahri. wine ratings, 2018. https://www.kaggle.com/dbahri/wine-ratings.
  • [3] Yoshua Bengio. Deep learning of representations for unsupervised and transfer learning. In Proceedings of the 2011 International Conference on Unsupervised and Transfer Learning workshop-Volume 27, pages 17–37. JMLR. org, 2011.
  • [4] Léon Bottou, Frank E Curtis, and Jorge Nocedal.

    Optimization methods for large-scale machine learning.

    Siam Review, 60(2):223–311, 2018.
  • [5] Sébastien Bubeck. Convex optimization: Algorithms and complexity. Foundations and Trends® in Machine Learning, 8(3-4):231–357, 2015.
  • [6] Sébastien Bubeck, Rémi Munos, Gilles Stoltz, and Csaba Szepesvári. X-armed bandits. Journal of Machine Learning Research, 12(May):1655–1695, 2011.
  • [7] Wenyuan Dai, Ou Jin, Gui-Rong Xue, Qiang Yang, and Yong Yu. Eigentransfer: a unified framework for transfer learning. In Proceedings of the 26th Annual International Conference on Machine Learning, pages 193–200. ACM, 2009.
  • [8] Aaron Defazio, Francis Bach, and Simon Lacoste-Julien. Saga: A fast incremental gradient method with support for non-strongly convex composite objectives. In Advances in neural information processing systems, pages 1646–1654, 2014.
  • [9] John C Duchi, Shai Shalev-Shwartz, Yoram Singer, and Ambuj Tewari. Composite objective mirror descent. In COLT, pages 14–26, 2010.
  • [10] Arthur Gretton, Alex Smola, Jiayuan Huang, Marcel Schmittfull, Karsten Borgwardt, and Bernhard Schölkopf. Covariate shift by kernel mean matching. Dataset shift in machine learning, 3(4):5, 2009.
  • [11] Jean-Bastien Grill, Michal Valko, and Rémi Munos. Black-box optimization of noisy functions with unknown smoothness. In Advances in Neural Information Processing Systems, pages 667–675, 2015.
  • [12] Moritz Hardt, Benjamin Recht, and Yoram Singer. Train faster, generalize better: Stability of stochastic gradient descent. arXiv preprint arXiv:1509.01240, 2015.
  • [13] Nicholas JA Harvey, Christopher Liaw, Yaniv Plan, and Sikander Randhawa. Tight analyses for non-smooth stochastic gradient descent. arXiv preprint arXiv:1812.05217, 2018.
  • [14] Jiayuan Huang, Arthur Gretton, Karsten Borgwardt, Bernhard Schölkopf, and Alex J Smola. Correcting sample selection bias by unlabeled data. In Advances in neural information processing systems, pages 601–608, 2007.
  • [15] Sham M Kakade and Ambuj Tewari. On the generalization ability of online strongly convex programming algorithms. In Advances in Neural Information Processing Systems, pages 801–808, 2009.
  • [16] Simon Kornblith, Jonathon Shlens, and Quoc V Le. Do better imagenet models transfer better? arXiv preprint arXiv:1805.08974, 2018.
  • [17] Mehryar Mohri, Gary Sivek, and Ananda Theertha Suresh. Agnostic federated learning. arXiv preprint arXiv:1902.00146, 2019.
  • [18] Rémi Munos. Optimistic optimization of a deterministic function without the knowledge of its smoothness. In Advances in neural information processing systems, pages 783–791, 2011.
  • [19] Lam M. Nguyen, Phuong Ha Nguyen, Marten van Dijk, Peter Richtárik, Katya Scheinberg, and Martin Takáč. SGD and Hogwild! Convergence Without the Bounded Gradients Assumption. arXiv e-prints, page arXiv:1802.03801, Feb 2018.
  • [20] Maxime Oquab, Leon Bottou, Ivan Laptev, and Josef Sivic.

    Learning and transferring mid-level image representations using convolutional neural networks.


    Proceedings of the IEEE conference on computer vision and pattern recognition

    , pages 1717–1724, 2014.
  • [21] Sinno Jialin Pan and Qiang Yang. A survey on transfer learning. IEEE Transactions on knowledge and data engineering, 22(10):1345–1359, 2009.
  • [22] Rajat Raina, Alexis Battle, Honglak Lee, Benjamin Packer, and Andrew Y Ng. Self-taught learning: transfer learning from unlabeled data. In Proceedings of the 24th international conference on Machine learning, pages 759–766. ACM, 2007.
  • [23] Alexander Rakhlin, Ohad Shamir, and Karthik Sridharan. Making gradient descent optimal for strongly convex stochastic optimization. 2012.
  • [24] Nicolas L Roux, Mark Schmidt, and Francis R Bach. A stochastic gradient method with an exponential convergence _rate for finite training sets. In Advances in neural information processing systems, pages 2663–2671, 2012.
  • [25] Rajat Sen, Kirthevasan Kandasamy, and Sanjay Shakkottai. Multi-fidelity black-box optimization with hierarchical partitions. In International Conference on Machine Learning, pages 4545–4554, 2018.
  • [26] Rajat Sen, Kirthevasan Kandasamy, and Sanjay Shakkottai. Noisy blackbox optimization using multi-fidelity queries: A tree search approach. In

    The 22nd International Conference on Artificial Intelligence and Statistics

    , pages 2096–2105, 2019.
  • [27] Xuedong Shang, Emilie Kaufmann, and Michal Valko. Adaptive black-box optimization got easier: Hct only needs local smoothness. In

    European Workshop on Reinforcement Learning

    , 2017.
  • [28] Hidetoshi Shimodaira. Improving predictive inference under covariate shift by weighting the log-likelihood function. Journal of statistical planning and inference, 90(2):227–244, 2000.
  • [29] Masashi Sugiyama, Matthias Krauledat, and Klaus-Robert Müller. Covariate shift adaptation by importance weighted cross validation. Journal of Machine Learning Research, 8(May):985–1005, 2007.
  • [30] Masashi Sugiyama, Taiji Suzuki, Shinichi Nakajima, Hisashi Kashima, Paul von Bünau, and Motoaki Kawanabe. Direct importance estimation for covariate shift adaptation. Annals of the Institute of Statistical Mathematics, 60(4):699–746, 2008.
  • [31] Jason Yosinski, Jeff Clune, Yoshua Bengio, and Hod Lipson. How transferable are features in deep neural networks? In Advances in neural information processing systems, pages 3320–3328, 2014.
  • [32] Bianca Zadrozny. Learning and evaluating classifiers under sample selection bias. In Proceedings of the twenty-first international conference on Machine learning, page 114. ACM, 2004.
  • [33] Sen Zhao, Mahdi Milani Fard, Harikrishna Narasimhan, and Maya Gupta. Metric-optimized example weights. arXiv preprint arXiv:1805.10582, 2018.

Appendix A Smoothness with Respect to

In this section we prove Theorem 1. The analysis is an interesting application of an idea from Theorem 3.9 in [12]. The key technique is to create a total variational coupling between and . Then using this coupling we prove that SGD iterates from the two distributions cannot be too far apart in expectation. Therefore, because the two sets of iterates converge to their respective optimal solutions, we can conclude that the optimal weights and are close.

Lemma 1.

Under conditions of Theorem 1, let and

be the random variables representing the weights after performing

steps of online SGD using the data distributions represented by the mixtures and respectively, starting from the same initial weight . Then we have the following bound,


We closely follow the proof of Theorem 3.9 in [12]. Let denote the SGD operator while processing the -th example from and denote the SGD operator while processing the -th example from . Let be two random variables whose joint distribution follows the variational coupling between and . Thus the marginals of and are and respectively, while . At each time and are drawn. If , then we draw a data sample from and set . Otherwise, we draw from and from independently. is the SGD operator using at time .

Therefore, following the analysis in [12], if , then . On the hand, , for , where we will use the decreasing step-size .

Thus we have the following expression for :

By taking expectation on both sides, we obtain:

Assuming that , we get the following result from the recursion,

Let . Note that the above result also implies that,

Proof of Theorem 1.

First, note that by definition is not a random variable i.e it is the optimal weight with respect to the distribution corresponding to . On the other hand, and are random variables, where the randomness is coming from the randomness in SGD sampling. By the triangle inequality, we have the following:

The expectation in the middle of the r.h.s. is bounded as in Eq. (5). We can use Corollary 6 in [15] to bound each of the two other terms on the r.h.s. as,


where is a global diameter bound. We can now choose large enough to satisfy the bound in Theorem 1. ∎

Appendix B New High-Probability bounds on SGD

Lemma 2.

Suppose is strongly convex. Further, let be a smooth function for every . Consider the stochastic gradient iteration: where is sampled randomly from a distribution . Let . Let , where Then, iterates satisfy the following inequality:


where the last line follows from strong convexity (see equation 4.12 in [4]). Now, subtracting from both sides and rearranging, we find:


Applying the sample version of Lemma of ([19]) which uses the fact that for functions and such is -strongly convex while is -smooth for every random realization (Assumptions and as in [19]). Here . .

Observe that . Since is -strongly convex and -Lipschitz and is -smooth for every , we have:


where the last line follows from the sample-path, centered version of Lemma 2 of [19], which we state and prove here for completeness:

Lemma 3 (Adapted version of Lemma 2 from [19]).

Under Assumptions 1 and 2, for any random realization of , the following bound holds almost surely:


Using the inequality derived in [19]:


we may derive the following bound: