Fast, Accurate, and Simple Models for Tabular Data via Augmented Distillation

06/25/2020 ∙ by Rasool Fakoor, et al. ∙ Amazon University of Pennsylvania 0

Automated machine learning (AutoML) can produce complex model ensembles by stacking, bagging, and boosting many individual models like trees, deep networks, and nearest neighbor estimators. While highly accurate, the resulting predictors are large, slow, and opaque as compared to their constituents. To improve the deployment of AutoML on tabular data, we propose FAST-DAD to distill arbitrarily complex ensemble predictors into individual models like boosted trees, random forests, and deep networks. At the heart of our approach is a data augmentation strategy based on Gibbs sampling from a self-attention pseudolikelihood estimator. Across 30 datasets spanning regression and binary/multiclass classification tasks, FAST-DAD distillation produces significantly better individual models than one obtains through standard training on the original data. Our individual distilled models are over 10x faster and more accurate than ensemble predictors produced by AutoML tools like H2O/AutoSklearn.



There are no comments yet.


page 5

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Figure 1: Normalized test accuracy vs. speed of individual models and AutoML ensembles, averaged over all 30 datasets. TEACHER denotes the performance of AutoGluon; H2O and autosklearn represent the respective AutoML tools. GIB-1 indicates the results of FAST-DAD after 1 round of Gibbs sampling. BASE denotes the student model fit on original data. GIB-1/BASE dots represent the model Selected (out of the 4 types) based on validation accuracy for each dataset.

Modern AutoML tools provide good out-of-the-box accuracy on diverse datasets. This is often achieved through extensive model ensembling (Erickson et al., 2020; Feurer et al., 2019; Cortes et al., 2017). While the resultant predictors may generalize well, they can be large, slow, opaque, and expensive to deploy. Fig. 1 shows that the most accurate predictors can be 10,000 times slower than their constituent models.

Model distillation (Bucilua et al., 2006; Hinton et al., 2015) offers a way to compress the knowledge learnt by these complex models into simpler predictors with reduced inference-time and memory-usage that are also less opaque and easier to modify and debug. In distillation, we train a simpler model (the student) to output similar predictions as those of a more complex model (the teacher). Here we use AutoML to create the most accurate possible teacher, typically an ensemble of many individual models via stacking, bagging, boosting, and weighted combinations (Dietterich, 2000). Unfortunately, distillation typically comes with a sharp drop in accuracy. Our paper mitigates this drop via FAST-DAD, a technique to produce Fast-and-accurate models via Distillation with Augmented Data. We apply FAST-DAD to large stack-ensemble predictors from AutoGluon (Erickson et al., 2020) to produce individual models that are over 10,000 faster than AutoGluon and over 10 faster, yet still more accurate, than ensemble predictors produced by H2O-AutoML (Pandey, 2019) and AutoSklearn (Feurer et al., 2019).

Motivation. A key issue in distillation is that the quality of the student is largely determined by the amount of available training data. While standard distillation confers smoothing benefits (where the teacher may provide higher-quality prediction targets to the student (Hinton et al., 2015; Tang et al., 2020)

), it incurs a student-teacher statistical approximation-error of similar magnitude as when training directly on original labeled dataset. By increasing the amount of data available for distillation, one can improve the student’s approximation of the teacher and hence the student’s accuracy on test data (assuming that the teacher achieves superior generalization error than fitting the student model directly to the original data). The extra data need not be labeled; one may use the teacher to label it. This enables the use of density estimation techniques to learn the distribution of the training data and draw samples of unlabeled data. In fact, we need not even learn the full joint distribution but simply learn how to

draw approximate samples

from it. We show that the statistical error in these new samples can be traded off against the reduction in variance from fitting the student to a larger dataset. Our resultant student models are almost as accurate as the teacher while being far more efficient/lightweight.

The contributions of this paper are as follows:

  1. [noitemsep,topsep=0pt,parsep=0pt,partopsep=0pt]

  2. We present model-agnostic distillation that works across many types of teachers and students and different supervised learning problems (binary and multiclass classification, regression). This is in contrast to problem and architecture-specific distillation techniques

    (Bucilua et al., 2006; Hinton et al., 2015; Vidal et al., 2020; Cho and Hariharan, 2019).

  3. We introduce a maximum pseudolikelihood model for tabular data that uses self-attention across covariates to simultaneously learn all of their conditional distributions.

  4. We propose a corresponding Gibbs sampler that takes advantage of these conditional estimates to efficiently augment the dataset used in distillation. Our FAST-DAD approach avoids estimating the features’ joint distribution, and enables control over sample-quality and diversity of the augmented dataset.

  5. We report the first comprehensive distillation benchmark for tabular data which studies 5 distillation strategies with 4 different types of student models applied over 30 datasets involving 3 different types of prediction tasks.

Although our techniques can be adapted to other modalities, we focus on tabular data which has been under-explored in distillation despite its ubiquity in practical applications. Compared to typical data tables, vision and language datasets have far larger sample-sizes and with easily available data; data augmentation is thus not as critical for distillation as it is in the tabular setting.

2 Related Work

While distillation and model compression are popular in deep learning, existing work focuses primarily on vision, language and speech applications. Unlike the tabular settings we consider here, this prior work studies situations where: (a) unlabeled data is plentiful; (b) there are many more training examples than in typical data tables; (c) both teacher and student are neural networks; (d) the task is multiclass classification

(Ba and Caruana, 2014; Hinton et al., 2015; Urban et al., 2017; Cho and Hariharan, 2019; Mirzadeh et al., 2019; Yang et al., 2020).

For tabular data, Breiman and Shang (1996)

considered distilling models into single decision trees, but this often unacceptably harms accuracy. Recently,

Vidal et al. (2020) showed how to distill tree ensembles into a single tree without sacrificing accuracy, but their approach is restricted to tree student/teacher models. Like us, Bucilua et al. (2006) considered distillation of large ensembles of heterogeneous models, pioneering the use of data augmentation in this process. Their work only considered binary classification problems with a neural network student model; multiclass classification is handled in a one-vs-all fashion which produces less-efficient students that maintain a model for every class. Liu et al. (2020) suggest generative-adversarial networks can be used to produce better augmented data, but only conduct a small-scale distillation study with random forests.

3 From Function Approximation to Distillation

We first formalize distillation to quantify the role of the auxiliary data in this process. Consider a dataset where are observations of some features sampled from distribution , and are their labels sampled from distribution . The teacher is some function learned e.g. via AutoML that achieves good generalization error:

where loss measures the error in individual predictions. Our goal is to find a model from a restricted class of functions such that is smaller than the generalization error of another model from this class produced via empirical risk minimization.

Approximation. Distillation seeks some student that is “close” to the teacher . If over

and if the loss function

is Lipschitz continuous in , then will be nearly as accurate as the teacher (). Finding such a may however be impossible. For instance, a Fourier approximation of a step function will never converge uniformly but only pointwise. This is known as the Gibbs phenomenon (Wilbraham, 1848). Fortunately, -convergence is not required: we only require convergence with regard to some distance function averaged over . Here is determined by the task-specific loss . For instance,

-loss can be used for regression and the KL divergence between class-probability estimates from

may be used in classification. Our goal during distillation is thus to minimize


This is traditionally handled by minimizing its empirical counterpart (Hinton et al., 2015):


Rates of Convergence. Since it is only an empirical average, minimizing over

will give rise to an approximation error that can bounded, e.g. by uniform convergence bounds from statistical learning theory

(Vapnik, 1998) as . Here denotes the complexity of the function class and is the number of observations used for distillation. Note that we effectively pay twice for the statistical error due to sampling from . Once to learn and again while distilling from using the same samples.

There are a number of mechanisms to reduce the second error. If we had access to more unlabeled data, say with drawn from , we could reduce the statistical error due to distillation significantly (see Fig. S2). While we usually cannot draw from for tabular data due to a lack of additional unlabeled examples (unlike say for images/text), we might be able to draw from a related distribution which is sufficiently close. In this case we can obtain a uniform convergence bound: [Surrogate Approximation] Assume that the complexity of the function class is bounded under and for all and . Then there exists a constant such that with probability at least we have


Here are samples from and is chosen, e.g. to minimize .


This follows directly from Hölder’s inequality when applied to . Next we apply uniform convergence bounds to the difference between . Using VC bounds (Vapnik, 1998) proves the claim. ∎

The inequality Eq. 3 suggests a number of strategies when designing algorithms for distillation. Whenever and are similar in terms of the bias being small, we want to draw as much data as we can from to make the uniform convergence term vanish. However if is some sort of estimate, a nontrivial difference between and will usually exist in practice. In this case, we may trade off the variance reduction offered by extra augmented samples and the corresponding bias by drawing these samples from an intermediate distribution that lies in between the training data and the biased .

4 FAST-DAD Distillation via Augmented Data

The augmentation distribution in Lemma 3 could be naively produced by applying density estimation to the data , and then sampling from the learnt density. Unfortunately, multivariate density estimation and generative modeling are at least as difficult as the supervised learning problems AutoML aims to solve (Sugiyama et al., 2012). It is however much easier to estimate , the univariate conditional of the feature given all the other features in datum . This suggests the following strategy which forms the crux of FAST-DAD:

  1. [noitemsep,topsep=0pt,parsep=0pt,partopsep=0pt]

  2. For all features : estimate conditional distribution using the training data.

  3. Use all training data as initializations for a Gibbs sampler (Geman and Geman, 1984). That is, use each to generate an MCMC chain via:  .

  4. Use the samples from all chains as additional data for distillation.

We next describe these steps in detail but first let us see why this strategy can generate good augmented data. If our conditional probability estimates are accurate, the Gibbs sampler is guaranteed to converge to samples drawn from regardless of the initialization (Roberts and Smith, 1994). In particular, initializing the sampler with data ensures that it doesn’t need time to ‘burn-in’; it starts immediately with samples from the correct distribution. Even if is inaccurate (inevitable for small ), the sample will not deviate too far from after a small number of Gibbs sampling steps (low bias), whereas using with an inaccurate would produce disparate samples.

4.1 Maximum Pseudolikelihood Estimation via Self-Attention

A cumbersome aspect of the strategy outlined above is the need to model many conditional distributions for different . This would traditionally require many separate models. Here we instead propose a single self-attention architecture (Vaswani et al., 2017) with parameters that is trained to simultaneously estimate all conditionals via a pseudolikelihood objective (Besag, 1977):


For many models, maximum pseudolikelihood estimation produces asymptotically consistent parameter estimates, and often is more computationally tractable than optimizing the likelihood (Besag, 1977). Our model takes as input and simultaneously estimates the conditional distributions for all features using a self-attention-based encoder. As in Transformers, each encoder layer consists of a multi-head self-attention mechanism and a feature-wise feedforward block (Vaswani et al., 2017). Self-attention helps this model gather relevant information from needed for modeling .

Each conditional is parametrized as a mixture of Gaussians , where depend on and are output by topmost layer of our encoder after processing . Categorical features are numerically represented using dequantization (Uria et al., 2013). To condition on in a mini-batch (with randomly selected per mini-batch), we mask the values of to omit all information about the corresponding feature value (as in (Devlin et al., 2019)) and also mask all self-attention weights for input dimension

; this amounts to performing stochastic gradient descent on the objective in 

Eq. 4 across both samples and their individual features. We thus have an efficient way to compute any of these conditional distributions with one forward pass of the model. While this work utilizes self-attention, our proposed method can work with any efficient estimator of for .

Relation to other architectures. Our approach can be seen as an extension of the mixture density network (Bishop, 1994), which can model arbitrary conditional distributions, but not all conditionals simultaneously as enabled by our use of masked self-attention with the pseudolikelihood objective. It is also similar to TraDE (Fakoor et al., 2020): however, their auto-regressive model requires imposing an arbitrary ordering of the features. Since self-attention is permutation-invariant (Lee et al., 2018), our pseudolikelihood model is desirably insensitive to the order in which features happen to be recorded as table columns. Our use of masked self-attention shares many similarities with BERT (Devlin et al., 2019)

, where the goal is typically representation learning or text generation

(Wang et al., 2019). In contrast, our method is designed for data that lives in tables. We need to estimate the conditionals very precisely as they are used to sample continuous values; this is typically not necessary for text models.

4.2 Gibbs Sampling from the Learnt Conditionals

We adopt the following procedure to draw Gibbs samples to augment our training data: The sampler is initialized at some training example and a random ordering of the features is selected (with different orderings used for different Gibbs chains started from different training examples). We cycle through the features and in each step replace the value of one feature in , say , using its conditional distribution given all the other variables, i.e. . After every feature has been resampled, we say one round of Gibbs sampling is complete, and proceed onto the next round by randomly selecting a new feature-order to follow in subsequent Gibbs sampling steps.

A practical challenge in Gibbs sampling is that a poor choice of initialization may require many burn-in steps to produce reasonable samples. Suppose for the following discussion that our pseudolikelihood estimator and its learnt conditionals are accurate. We can use a strategy inspired by Contrastive Divergence

(Hinton, 2002) and initialize the sampler at and take a few (often only one) Gibbs sampling steps. This strategy is effective; we need not wait for the sampler to burn in because it is initialized at (or close to) the true distribution itself. This is seen in Fig. 2 where we compare samples from the true distribution and Gibbs samples (taken with respect to conditional estimates from our self-attention network) starting from an arbitrary initialization vs. initialized at .

Figure 2: Initialization of the Gibbs sampler. From left to right: original training data, samples obtained from one round of Gibbs sampling with random initialization after fitting the self-attention network, samples obtained after multiple rounds of Gibbs sampling (10 for the spiral, 100 for the checkerboard density) with random initialization, and samples obtained from one Gibbs sampling round when initializing via . The densities were generated from examples in Nash and Durkan (2019).

For distillation, we expect this sampling strategy to produce better augmented data. The number of Gibbs sampling steps provides fine-grained control over the sample fidelity and diversity of the resulting dataset used in distillation. Recall that the student will be trained over in practice. When our estimates of are accurate, it is desirable to produce only after a large number of Gibbs steps, as the bias in Eq. 3 will remain low and we would like to ensure the are more statistically independent from . With worse estimates of , it is better to produce after only a few Gibbs steps to ensure lower bias in Eq. 3, but the lack of burn-in implies are not independent of and may thus be less useful to the student during distillation. We dig deeper into this phenomenon (for the special case of ) in the following theorem. [Refinement of Lemma 1] Under the assumptions of Lemma 1, suppose the student minimizes where are samples drawn after steps of the Gibbs sampler initialized at . Then there exist constants such that with probability :


is the total-variation norm between and (the distribution of Gibbs samples after steps), where denotes the steady-state distribution of the Gibbs sampler. The proof (in Appendix D) is based on multi-task generalization bounds (Baxter, 2000) and MCMC mixing rates (Wang et al., 2014). Since as , we should use Gibbs samples from a smaller number of steps when is inaccurate (e.g. if our pseudolikelihood estimator is fit to limited data).

4.3 Training the Student with Augmented Data

While previous distillation works focused only on particular tasks (Bucilua et al., 2006; Hinton et al., 2015), we consider the range of regression and classification tasks. Our overall approach is the same for each problem type:

  1. [noitemsep,topsep=0pt,parsep=0pt,partopsep=0pt]

  2. Generate a set of augmented samples .

  3. Feed the samples as inputs to the teacher model to obtain predictions , which are the predicted class probabilities in classification (rather than hard class labels), and predicted scalar values in regression.

  4. Train each student model on the augmented dataset .

In the final step, our student model is fit to a combination of both true labels from the data as well as as augmented labels from the teacher, where is of different form than in classification (predicted probabilities rather than predicted classes). For binary classification tasks, we employ the Brier score (Brier, 1950) as our loss function for all students, treating both the probabilities assigned to the positive class by the teacher and the observed labels as continuous regression targets for the student model. The same strategy was employed by Bucilua et al. (2006) and it slightly outperformed our alternative multiclass-strategy in our binary classification experiments. We handle multiclass classification in a manner specific to different types of models, avoiding cumbersome students that maintain a separate model for each class (c.f. one-vs-all). Neural network students are trained using the cross-entropy loss which can be applied to soft labels as well. Random forest students can utilize multi-output decision trees (Segal and Xiao, 2011)

and thus be trained as native multi-output regressors against targets which are one-hot-encoded class labels in the real data and teacher-predicted probability vectors in the augmented data. Boosted tree models are similarly used to predict vectors with one dimension per class, which are then passed through a softmax transformation; the cross entropy loss is minimized via gradient boosting in this case.

5 Experiments

Data. We evaluate various methods on datasets (Table S2) spanning regression tasks from the UCI ML Repository and binary/multi classification tasks from OpenML, which are included in popular deep learning and AutoML benchmarks (Mukhoti et al., 2018; Lakshminarayanan et al., 2017; Jain et al., 2020; Gijsbers et al., 2019; Truong et al., 2019; Erickson et al., 2020). To facilitate comparisons on a meaningful scale across datasets, we evaluate methods on the provided test data based on either their accuracy in classification, or percentage of variation explained (= ) in regression. The training data are split into training/validation folds (90-10), and only the training fold is used for augmentation (validation data keep their original labels for use in model/hyper-parameter selection and early-stopping).


. We adopt AutoGluon as our teacher as this system has demonstrated higher accuracy than most other AutoML frameworks and human data science teams

(Erickson et al., 2020). AutoGluon is fit to each training dataset for up to 4 hours with the auto_stack

option which boosts accuracy via extensive model ensembling (all other arguments left at their defaults). The most accurate ensembles produced by AutoGluon often contain over 100 individual models trained via a combination of multi-layer stacking with repeated 10-fold bagging and the use of multiple hyperparameter values

(Erickson et al., 2020). Each model trained by AutoGluon is one of: (1) Neural Network (NN), (2) CatBoost, (3) LightGBM, (4) Random Forest (RF), (5) Extremely Randomized Trees, and (6) K-Nearest Neighbors.

We adopt the most accurate AutoGluon ensemble (on the validation data) as the teacher model. We use models of types (1)-(4) as students, since these are more efficient than the others and thus more appropriate for distillation. These are also some of the most popular types of models among today’s data scientists (Bansal, 2018). We consider how well each individual type of model performs under different training strategies, as well as the overall performance achieved with each strategy after a model selection step in which the best individual model on the validation data (among all 4 types) is used for prediction on the test data. This Selected model reflects how machine learning is operationalized in practice. All candidate student models (as well as the BASE models) of each type share the same hyper-parameters and are expected to have similar size and inference latency.

5.1 Distillation Strategies

Figure 3: Fig. 2(a) Normalized metrics evaluated on samples from various Gibbs rounds (averaged across 3 datasets). Sample Fidelity measures how well a random forest discriminator can distinguish between (held out) real and Gibbs-sampled data. Diffusion

is the average Euclidean distance between each Gibbs sample and the datum from which its Markov chain was initialized.

Discrepancy is the Maximum Mean Discrepancy (Gretton et al., 2012) between the Gibbs samples and the training data; it measures both how well the samples approximate as well as how distinct they are from data . Distillation Performance is the test accuracy of student models trained on the augmented data (averaged over our 4 model types). The diversity of the overall dataset used for distillation grows with increased discrepancy/diffusion, while this overall dataset more closely resembles the underlying data-generating distribution with increased sample fidelity (lower bias). The discriminator’s accuracy ranges between for these datasets. Fig. 2(b)
Percentage points improvement over the BASE model produced by each distillation method

for different model types (change in: accuracy for classification, explained variation for regression). As the improvements contain outliers/skewness, we show the median change across all datasets (dots) and the corresponding interquartile range (lines).

We compare our FAST-DAD Gibbs-augmented distillation technique with the following methods.

TEACHER: The model ensemble produced by AutoGluon fit to the training data. This is adopted as the teacher in all distillation strategies we consider.

BASE: Individual base models fit to the original training data.

KNOW: Knowledge distillation proposed by Hinton et al. (2015), in which we train each student model on the original training data, but with labels replaced by predicted probabilities from the teacher which are smoothed and nudged toward the original training labels (no data augmentation).

MUNGE: The technique proposed by Bucilua et al. (2006) to produce augmented data for model distillation, where the augmented samples are intended to resemble the underlying feature distribution. MUNGE augmentation may be viewed as a few steps of Gibbs sampling, where is estimated by first finding near neighbors of in the training data and subsequently sampling from the (smoothed) empirical distribution of their feature (Owen, 1987).

HUNGE: To see how useful the teacher’s learned label-distributions are to the student, we apply a hard variant of MUNGE. Here we produce MUNGE-augmented samples that receive hard class predictions from the teacher as their labels rather than the teacher’s predicted probabilities that are otherwise the targets in all other distillation strategies (equivalent to MUNGE for regression).

GAN: The technique proposed by Xu et al. (2019) for augmenting tabular data using conditional deep generative adversarial networks (GANs); this performs better than other GANs (Liu et al., 2020). Like our model, this GAN is trained on the training set and then used to generate augmented samples for the student model, whose labels are the predicted probabilities output by the teacher. Unlike our Gibbs sampling strategy, it is difficult to control how similar samples from the GAN should be to the training data.

We also run our Gibbs sampling data augmentation strategy generating samples after various numbers of Gibbs sampling rounds (for example, GIB-5 indicates 5 rounds were used to produce the augmented data). Under each augmentation-based strategies, we add synthetic datapoints to the training set for the student, where the number of original training samples (up to at most ).

5.2 Analysis of the Gibbs Sampler

To study the behavior of our Gibbs sampling procedure, we evaluate it on a number of different criteria (see Fig. 2(a) caption). Fig. 2(a) depicts how the distillation dataset’s overall diversity increases with additional rounds of Gibbs sampling. Fortuitously, we do not require a large number of Gibbs sampling rounds to obtain the best distillation performance and can thus efficiently generate augmented data. Running the Gibbs sampling for longer is ill-advised as its stationary distribution appears to less closely approximate than intermediate samples from a partially burned-in chain; this is likely due to the fact that we have limited data to fit the self-attention network.

Strategy Rank Accuracy Rank Accuracy Rank Accuracy
BASE 6.888 88.63 - 5.791 82.85 - 7.777 80.80 -
HUNGE 5.0 88.99 0.092 5.541 83.57 0.108 7.666 81.04 0.350
KNOW 6.555 88.49 0.712 5.25 83.89 0.072 5.555 81.39 0.275
GAN 6.666 88.65 0.450 6.708 83.17 0.250 6.055 82.26 0.069
MUNGE 5.444 88.88 0.209 5.083 83.72 0.126 4.333 82.80 0.007
GIB-1 3.777 89.35 0.025 3.708 84.21 0.051 3.277 82.88 0.005
GIB-5 3.333 89.25 0.004 5.375 84.04 0.098 3.388 82.76 0.010
GIB-10 4.777 89.09 0.044 4.958 83.74 0.087 4.222 82.64 0.010
TEACHER 2.555 90.10 0.036 2.583 84.40 0.019 2.722 83.84 0.018
Regression Problems Binary Classification Multiclass Classification
Table 1: Average ranks/performance achieved by the Selected model under each training strategy across the datasets from each prediction task. Performance is test accuracy for classification or percentage of variation explained for regression, and we list -values for the one-sided test of whether each strategy BASE.

5.3 Performance of Distilled Models

Table 1 and Fig. 2(b) demonstrate that our Gibbs augmentation strategy produces far better resulting models than any of the other strategies. Table S3 shows the only datasets where Gibbs augmentation fails to produce better models than the BASE training strategy are those where the teacher ensemble fails to outperform the best individual BASE model (so little can be gained from distillation period). As expected according to Hinton et al. (2015): KNOW helps in classification but not regression, and HUNGE fares worse than MUNGE on multiclass problems where its augmented hard class-labels fail to provide students with the teacher’s dark knowledge. As previously observed (Bucilua et al., 2006), MUNGE greatly improves the performance of neural networks, but provides less benefits for the other model types than augmentation via our Gibbs sampler. Overparameterized deep networks tend to benefit from distillation more than the tree models in our experiments (although for numerous datasets distilled tree models are still Selected as the best model to predict with). While neural nets trained in the standard fashion are usually less accurate than trees for tabular data, FAST-DAD can boost their performance above that of trees, a goal other research has struggled to reach (Biau et al., 2019; Saberian et al., 2019; Popov et al., 2019; Ke et al., 2019a, b).

Figs. S1, 4 and 1 depict the (normalized/raw) accuracy and inference-latency of our distilled models (under the GIB-1 strategy which is superior to others), compared with both the teacher (AutoGluon ensemble), as well as ensembles produced by H2O-AutoML (Pandey, 2019) and AutoSklearn (Feurer et al., 2019), two popular AutoML frameworks that have been shown to outperform other AutoML tools (Truong et al., 2019; Guyon et al., 2019). On average, the Selected individual model under standard training (BASE) would be outperformed by these AutoML ensembles, but surprisingly, our distillation approach produces Selected individual models that are both more accurate and over 10 more efficient than H2O and AutoSklearn. In multiclass classification, our distillation approach also confers significant accuracy gains over standard training. The resulting individual Selected models come close to matching the accuracy of H2O/AutoSklearn while offering much lower latency, but gains may be limited since the AutoGluon teacher appears only marginally more accurate than H2O/AutoSklearn in these multiclass problems.

Figure 4: Raw test accuracy vs. speed of individual models and AutoML ensembles, averaged over binary classification datasets. TEACHER denotes the performance of AutoGluon; H2O and autosklearn represent the respective AutoML tools. GIB-1 indicates the results of FAST-DAD after 1 round of Gibbs sampling. BASE denotes the student model fit on the original data. GIB-1/BASE dots represent the Selected model.

6 Discussion

Our goal in this paper is to build small, fast models that can bootstrap off large, ensemble-based AutoML predictors via model distillation to perform better than they would if directly fit to the original data. The key challenge arises when the data to train the student are limited. We propose to estimate the conditional distributions of all features via maximum pseudolikelihood with masked self-attention, and Gibbs sampling techniques to sample from this model for augmenting the data available to the student. Our strategy neatly suits the application because it: (i) avoids multivariate density estimation (pseudolikelihood only involves univariate conditionals), (ii) does not require separate models for each conditional (the self-attention model simultaneously computes all conditionals), (iii) is far more efficient than usual MCMC methods (by initializing the Gibbs sampler at the training data), (iv) allows control over the quality and diversity of the resulting augmented dataset (we can select samples from specific Gibbs rounds unlike from, say, a GAN). We used the high-accuracy ensembles produced by AutoGluon’s AutoML to improve standard boosted trees, random forests, and neural networks via distillation.


  • J. Ba and R. Caruana (2014) Do deep nets really need to be deep?. In Advances in Neural Information Processing Systems, Z. Ghahramani, M. Welling, C. Cortes, N. D. Lawrence, and K. Q. Weinberger (Eds.), pp. 2654–2662. Cited by: §2.
  • S. Bansal (2018) Data science trends on kaggle. Note: Cited by: §5.
  • J. Baxter (2000) A model of inductive bias learning.

    Journal of artificial intelligence research

    12, pp. 149–198.
    Cited by: §4.2.
  • J. Besag (1977) Efficiency of pseudolikelihood estimation for simple gaussian fields. Biometrika, pp. 616–618. Cited by: §4.1, §4.1.
  • G. Biau, E. Scornet, and J. Welbl (2019) Neural random forests. Sankhya A 81 (2), pp. 347–386. Cited by: §5.3.
  • C. M. Bishop (1994) Mixture density networks. Neural Computing Research Group Report, Aston University. Cited by: §4.1.
  • L. Breiman and N. Shang (1996) Born again trees. Note: Cited by: §2.
  • G. W. Brier (1950) Verification of forecasts expressed in terms of probability. Monthly weather review 78 (1), pp. 1–3. Cited by: §4.3.
  • C. Bucilua, R. Caruana, and A. Niculescu-Mizil (2006) Model compression. In Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 535–541. Cited by: item 1, §1, §2, §4.3, §5.1, §5.3.
  • J. H. Cho and B. Hariharan (2019) On the efficacy of knowledge distillation. In

    Proceedings of the IEEE International Conference on Computer Vision

    pp. 4794–4802. Cited by: item 1, §2.
  • C. Cortes, X. Gonzalvo, V. Kuznetsov, M. Mohri, and S. Yang (2017) AdaNet: adaptive structural learning of artificial neural networks. In Proceedings of the 34th International Conference on Machine Learning, Vol. 70, pp. 874–883. Cited by: §1.
  • J. Devlin, M. Chang, K. Lee, and K. Toutanova (2019) BERT: pre-training of deep bidirectional transformers for language understanding. In NAACL HLT, Cited by: §4.1, §4.1.
  • T. G. Dietterich (2000) Ensemble methods in machine learning. In

    International Workshop on Multiple Classifier Systems

    pp. 1–15. Cited by: §1.
  • N. Erickson, J. Mueller, A. Shirkov, H. Zhang, P. Larroy, M. Li, and A. Smola (2020) AutoGluon-Tabular: robust and accurate AutoML for structured data. arXiv preprint arXiv:2003.06505. Cited by: §1, §1, §5, §5.
  • R. Fakoor, P. Chaudhari, J. Mueller, and A. J. Smola (2020) TraDE: transformers for density estimation. arXiv preprint arXiv:2004.02441. Cited by: §4.1.
  • M. Feurer, A. Klein, K. Eggensperger, J. T. Springenberg, M. Blum, and F. Hutter (2019) Auto-sklearn: efficient and robust automated machine learning. In Automated Machine Learning, pp. 113–134. Cited by: §1, §1, §5.3.
  • S. Geman and D. Geman (1984) Stochastic relaxation, gibbs distributions, and the bayesian restoration of images. IEEE Transactions on pattern analysis and machine intelligence (6), pp. 721–741. Cited by: item 2.
  • P. Gijsbers, E. LeDell, J. Thomas, S. Poirier, B. Bischl, and J. Vanschoren (2019)

    An open source AutoML benchmark

    In ICML Workshop on Automated Machine Learning, Cited by: Table S2, §5.
  • A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. Smola (2012) A kernel two-sample test. Journal of Machine Learning Research 13 (Mar), pp. 723–773. Cited by: Figure 3.
  • I. Guyon, L. Sun-Hosoya, M. Boullé, H. J. Escalante, S. Escalera, Z. Liu, D. Jajetic, B. Ray, M. Saeed, M. Sebag, et al. (2019) Analysis of the AutoML challenge series 2015–2018. In Automated Machine Learning, Springer series on Challenges in Machine Learning, pp. 177–219. Cited by: §5.3.
  • G. E. Hinton (2002) Training products of experts by minimizing contrastive divergence. Neural Computation 14 (8), pp. 1771–1800. Cited by: §4.2.
  • G. Hinton, O. Vinyals, and J. Dean (2015) Distilling the knowledge in a neural network. NIPS Deep Learning and Representation Learning Workshop. Cited by: item 1, §1, §1, §2, §3, §4.3, §5.1, §5.3.
  • S. Jain, G. Liu, J. Mueller, and D. Gifford (2020) Maximizing overall diversity for improved uncertainty estimates in deep ensembles. In AAAI, Cited by: §5.
  • G. Ke, Z. Xu, J. Zhang, J. Bian, and T. Liu (2019a) DeepGBM: a deep learning framework distilled by gbdt for online prediction tasks. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp. 384–394. Cited by: §5.3.
  • G. Ke, J. Zhang, J. B. Zhenhui Xu, and T. Liu (2019b) TabNN: a universal neural network solution for tabular data. External Links: Link Cited by: §5.3.
  • D. Kingma and J. Ba (2015) Adam: a method for stochastic optimization. In International Conference for Learning Representations, Cited by: §A.1.
  • B. Lakshminarayanan, A. Pritzel, and C. Blundell (2017) Simple and scalable predictive uncertainty estimation using deep ensembles. In Advances in neural information processing systems, pp. 6402–6413. Cited by: §5.
  • J. Lee, Y. Lee, J. Kim, A. R. Kosiorek, S. Choi, and Y. W. Teh (2018) Set transformer: a framework for attention-based permutation-invariant neural networks. arXiv preprint arXiv:1810.00825. Cited by: §4.1.
  • R. Liu, N. Fusi, and L. Mackey (2020) Teacher-student compression with generative adversarial networks. arXiv preprint arXiv:1812.02271. Cited by: §2, §5.1.
  • S. Mirzadeh, M. Farajtabar, A. Li, and H. Ghasemzadeh (2019) Improved knowledge distillation via teacher assistant: bridging the gap between student and teacher. arXiv preprint arXiv:1902.03393. Cited by: §2.
  • J. Mukhoti, P. Stenetorp, and Y. Gal (2018) On the importance of strong baselines in bayesian deep learning. In Bayesian Deep Learning NeurIPS 2019 Workshop, Cited by: §5.
  • C. Nash and C. Durkan (2019) Autoregressive energy machines. arXiv preprint arXiv:1904.05626. Cited by: Figure 2.
  • A. Owen (1987) Nonparametric conditional estimation. OSTI.GOV Technical Report. Cited by: §5.1.
  • P. Pandey (2019) A Deep Dive into H2O’s AutoML. External Links: Link Cited by: §1, §5.3.
  • S. Popov, S. Morozov, and A. Babenko (2019) Neural oblivious decision ensembles for deep learning on tabular data. arXiv preprint arXiv:1909.06312. Cited by: §5.3.
  • G. O. Roberts and A. F. Smith (1994) Simple conditions for the convergence of the gibbs sampler and metropolis-hastings algorithms. Stochastic processes and their applications 49 (2), pp. 207–216. Cited by: §4.
  • M. Saberian, P. Delgado, and Y. Raimond (2019) Gradient boosted decision tree neural network. arXiv preprint arXiv:1910.09340. Cited by: §5.3.
  • M. Segal and Y. Xiao (2011) Multivariate random forests. Wiley Interdisciplinary Reviews: Data Mining and Knowledge Discovery 1 (1), pp. 80–87. Cited by: §4.3.
  • M. Sugiyama, T. Suzuki, and T. Kanamori (2012) Density ratio estimation in machine learning. Cambridge University Press. Cited by: §4.
  • J. Tang, R. Shivanna, Z. Zhao, D. Lin, A. Singh, E. H. Chi, and S. Jain (2020) Understanding and improving knowledge distillation. arXiv preprint arXiv:2002.03532. Cited by: §1.
  • A. Truong, A. Walters, J. Goodsitt, K. Hines, B. Bruss, and R. Farivar (2019) Towards automated machine learning: evaluation and comparison of automl approaches and tools. arXiv preprint arXiv:1908.05557. Cited by: §5.3, §5.
  • G. Urban, K. J. Geras, S. E. Kahou, O. Aslan, S. Wang, R. Caruana, A. Mohamed, M. Philipose, and M. Richardson (2017) Do deep convolutional nets really need to be deep and convolutional?. In International Conference on Learning Representations, Cited by: §2.
  • B. Uria, I. Murray, and H. Larochelle (2013) RNADE: the real-valued neural autoregressive density-estimator. In Advances in Neural Information Processing Systems, Cited by: §A.1, §4.1.
  • V. Vapnik (1998) Statistical learning theory. John Wiley & Sons. Cited by: §3, §3.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in Neural Information Processing Systems, Cited by: §A.1, §A.1, §4.1, §4.1.
  • T. Vidal, T. Pacheco, and M. Schiffer (2020) Born-again tree ensembles. arXiv preprint arXiv:2003.11132. Cited by: item 1, §2.
  • A. Wang, K. Cho, and C. A. G. Scholar (2019) BERT has a mouth, and it must speak: bert as a markov random field language model. In NAACL HLT, Cited by: §4.1.
  • N. Wang, L. Wu, et al. (2014) Convergence rate and concentration inequalities for gibbs sampling in high dimension. Bernoulli 20 (4), pp. 1698–1716. Cited by: §4.2.
  • H. Wilbraham (1848) On a certain periodic function. The Cambridge and Dublin Mathematical Journal 3, pp. 198–201. Cited by: §3.
  • L. Xu, M. Skoularidou, A. Cuesta-Infante, and K. Veeramachaneni (2019) Modeling tabular data using conditional gan. In Advances in Neural Information Processing Systems, Cited by: §5.1.
  • J. Yang, B. Martinez, A. Bulat, and G. Tzimiropoulos (2020) Knowledge distillation via adaptive instance normalization. arXiv preprint arXiv:2003.04289. Cited by: §2.

Appendix A Methods Details

We not only adopt the AutoGluon111 predictor as our teacher for distillation, but our experiments also use the AutoGluon implementation of each individual model type (NN, RF, LightGBM, CatBoost) for our student/BASE predictors222Although we selected AutoGluon as the AutoML tool for this paper’s experiments, we emphasize that none of our distillation methodology is specific to AutoGluon teachers/students.. Here we consider the same data preprocessing and hyperparameters as AutoGluon uses by default, which have been demonstrated to be highly performant erickson2020autogluon.

Unlike the RF/LightGBM/CatBoost models which are implemented in popular third party packages, the NN model is implemented directly in AutoGluon, and offers numerous advantages for tabular data over standard feedforward architectures erickson2020autogluon. This network uses a separate embedding layer for each categorical feature which helps the network separately learn about each of these variables before their representations are blended together by fully-connected layers guo2016entity,fastai. The network employs skip connections for improved gradient flow, with both shallow and deep paths connecting the input to the output cheng2016wide.

Note that all of our student classifiers produce valid predicted probabilities: our neural network student employs a sigmoid output layer to constrain its outputs to in binary classification, and the random forest multiclass students never output negative values (these models do not extrapolate) so we can simply re-normalize their output vectors to have unit-sum.

a.1 Architecture of our Pseudolikelihood Self-Attention Model

The input layer of our self-attention network applies a linear embedding operation followed by positional encoding. Each internal layer of the network is a Transformer block, which includes two sub-blocks: a multi-head self-attention mechanism and a position-wise fully connected feedforward block (Vaswani et al., 2017)

. Each of these sub-blocks is wrapped with layer normalization and a residual connection. Here different positions correspond to different features (columns of the table). The output layer of this network produces a mixture of multivariate Gaussians with diagonal covariance, where the final position-wise feedforward block outputs for each feature

both the mean/variance of each Gaussian component () as well as the mixing components (). In order to make sure that all input features are on a similar scale, all features are rescaled to mean-zero unit-variance before being fed into our network (and we apply the inverse transform after Gibbs sampling).

Positional encoding is essential for the model to know which value was taken by which feature. For example, without positional encoding: v.s.  would lead to similar self-attention input for the third feature without positional encoding. Thus the representations of our model would suffer, as would its estimated conditional distributions. Here we employ the same sin/cos positional encodings used by Vaswani et al. vaswani2017attention, treating the table column-index of each feature analogously to word positions in a sentence.

Tabular data can contain both numerical and categorical features. In order to have a simple, unified model that can deal with both feature types, we represent categorical features numerically using dequantization  (Uria et al., 2013). This involves adding uniform noise an the ordered integer encoding of the categories to make these features look numerical to our network. The noise can be inverted via rounding to ensure that discrete categories are produced by our Gibbs sampler (i.e. re-quantization). Dequantization has been successfully employed in a number of deep architectures that otherwise operate on continuous data hoogeboom2020learning,theis2016note, ma2019macow, and allows us to avoid having to employ heterogeneous output layers and unwieldy one-hot enodings.

Table S1 shows our network’s hyper-parameters that are used for the experiments in this paper. It is worth noting that we did not conduct any hyper-parameter search to find the best-performing architectures and models. Instead, we simply utilize two different networks: Small and Large. The Small network is used whenever the training dataset has less than examples and we otherwise use the Large network. Their only differences are in batch sizes and the width of their hidden layers, all other details such as training procedure, regularization, evaluation protocol, etc. are the same. We utilize two different models in order to avoid overfitting small datasets, and the Small network can also be more efficiently trained. We use Adam to optimize the parameters of our network (Kingma and Ba, 2015).

Small Large
Gaussian mixture components 100 100
Number of layers 4 4
Multi-head attention heads 8 8
Hidden unit size 32 128
Mini-batch size 16 256
Dropout 0.1 0.1
Learning rate 3E-4 3E-4
Weight decay 1E-6 1E-6
Gradient clipping norm 5 5
Table S1: Hyper-parameters of our self-attention models.
Dataset Type Sample Size # Columns # Classes
amazon binary 32769 9 -
australian binary 690 14 -
miniboone binary 130064 50 -
adult binary 48842 14 -
blood binary 748 4 -
credit-g binary 1000 20 -
higgs binary 98050 28 -
jasmine binary 2984 144 -
nomao binary 34465 118 -
numerai28.6 binary 96320 21 -
phoneme binary 5404 5 -
sylvine binary 5124 20 -
covertype multiclass 581012 54 7
helena multiclass 65196 27 100
jannis multiclass 83733 54 4
volkert multiclass 58310 180 10
connect-4 multiclass 67557 42 3
jungle-chess multiclass 44819 6 3
mfeat-factors multiclass 2000 216 10
segment multiclass 2310 19 7
vehicle multiclass 846 18 4
boston regression 506 13 -
concrete regression 1030 8 -
energy regression 768 8 -
kin8nm regression 8192 8 -
naval regression 11934 16 -
power regression 9568 4 -
protein regression 45730 9 -
wine regression 1599 11 -
yacht regression 308 6 -
Table S2: Summary of 30 datasets considered in this work, listing the: type of prediction problem, size of the data table, and number of classes for multiclass classification problems. The regression data (along with provided train/test splits) were downloaded from: The classification data (with provided train/test splits) were downloaded from: We initially considered additional classification datasets from Gijsbers et al. (2019), but decided to not to include those for which: it was trivial to get near 100% accuracy for many model types (so a teacher is unnecessary), the data are dominated by missing values, the original data are extremely high-dimensional (), or the original data did not come from a table (e.g. Fashion-MNIST).

Appendix B Experiment Details

We implemented knowledge distillation (KNOW) with classification targets modified as suggested in hinton2015distilling. As suggested by bucilua2006compress, the distance metric in MUNGE is taken to be the Euclidean distance between (rescaled) numerical features and the Hamming distance between categorical features. Over all datasets, we performed a grid search over MUNGE’s user-specified parameters: the feature-resampling probability and local variance parameter , in order to maximize validation accuracy of the student over . For the conditional tabular GAN, we used the original implementation available at:

On each dataset, we trained AutoGluon for up to 4 hours, and specified the same time-limit for H2O-AutoML and AutoSklearn. When running H2O and AutoSklearn on the 30 datasets, each AutoML tool failed to produce predictions on 2 datasets, and we simply recorded the accuracy/latency achieved by the other tool in this case (such failures are common in AutoML benchmarking, c.f. gijsbers2019open,erickson2020autogluon). Each AutoML tool was run with all default arguments, except for AutoGluon: we additionally set the argument auto_stack = True which instructs the system to maximize accuracy at all costs via extensive stack ensembling. We used the same type of AWS EC2 instance (m5.2xlarge) for each predictor to ensure fair comparison of inference times (each tool was run on separate EC2 instance with no other running processes).

For evaluating our Gibbs samples, we computed the Maximum Mean Discrepancy with the mixture-kernel li2017mmd, with bandwidths = . Our procedure to measure sample fidelity involved the following steps: First we trained our FAST-DAD-Net to maximize pseudolikelihood over data in the training fold. Next we applied Gibbs sampling to generate synthetic samples (initializing the Markov chains at the training data as previously described). Subsequently we assembled a balanced dataset of real (held-out) data from our validation fold which received label and fake data comprised of Gibbs samples which received label . A random forest was trained on this dataset, and then its accuracy evaluated on another balanced dataset comprised of real data from our test fold (again with label ) and fake data comprised of a different set of Gibbs samples (labeled with ). The resulting ‘sample fidelity’ was defined as the distance between this RF accuracy and 0.5.

Appendix C Additional Results

(A) Regression



(B) Binary Classification     (C) Multiclass Classification
Figure S1: Test accuracy vs latency of individual models and AutoML ensembles, averaged over the: (A) regression datasets, (B) binary classification datasets, (C) multiclass classification datasets. The GIB-1 and BASE dots show performance of Selected model (out of the 4 types) on each dataset. Note that for binary classification: the Selected BASE models are actually worse than individual RF/LightGBM models, presumably due to overfitting of the validation set via the early-stopping criterion in NN/CatBoost. This issue appears to be mitigated by distillation with augmented data. In multiclass classification, the distilled LightGBM models exhibit worse latency than their BASE counterparts because distillation uses additional data and soft (probabilistic) labels as targets, such that the underlying function to learn becomes more complex. Thus, the depth of its trees grows since LightGBM does not limit it by default. The BASE/distilled latency could easily be matched by restrictively setting the LightGBM depth/leaf-size hyperparameters to ensure equal-sized trees in these two variants.
boston 91.84 90.11 90.25 91.54 92.02 92.38 93.21 92.62 92.09
concrete 92.20 92.66 92.07 92.31 92.33 92.39 92.83 92.56 92.82
energy 99.86 99.85 99.91 99.92 99.87 99.93 99.92 99.92 99.93
kin8nm 93.36 93.58 94.10 93.82 94.08 93.96 94.14 94.10 93.99
naval 99.74 99.75 99.81 99.78 99.49 99.70 99.68 99.71 99.97
power 96.62 96.97 96.60 96.86 96.07 96.61 96.62 96.51 97.15
protein 68.34 67.37 69.95 68.14 67.64 69.96 69.07 68.01 74.33
wine 56.44 56.53 57.38 58.61 56.42 59.27 57.80 58.49 60.74
yacht 99.24 99.55 99.88 99.92 99.93 99.94 99.94 99.90 99.87
amazon 94.84 94.81 94.72 94.87 94.69 94.90 94.81 94.81 94.96
australian 86.95 88.40 85.50 85.50 85.50 88.40 86.95 85.50 86.95
miniboone 94.50 94.71 94.77 94.40 94.44 94.86 94.44 94.64 94.88
adult 87.06 87.26 87.36 87.49 86.81 86.36 86.67 86.73 87.59
blood 73.33 77.33 77.33 77.33 76.0 76.0 78.66 77.33 76.0
credit-g 71.0 78.0 78.0 76.0 75.0 80.0 80.0 77.0 79.0
higgs 72.14 73.14 73.53 72.83 72.48 73.89 73.36 73.18 73.83
jasmine 82.94 81.93 80.26 81.93 81.93 82.27 81.27 81.93 82.60
nomao 97.30 96.98 96.72 97.15 96.98 96.77 96.83 96.86 98.20
numerai28.6 51.12 50.36 51.78 50.78 51.23 52.05 51.11 51.59 51.10
phoneme 89.46 90.38 90.57 90.20 90.57 91.49 90.38 90.75 92.42
sylvine 93.56 93.37 94.15 94.34 92.39 93.56 93.95 94.54 95.32
covertype 95.90 96.99 97.00 92.84 96.39 96.19 96.48 96.06 97.66
helena 38.29 40.70 40.26 39.44 39.43 40.50 39.84 40.50 40.75
jannis 70.69 72.23 72.13 70.69 71.75 72.43 72.28 71.91 73.07
volkert 69.62 71.34 72.18 69.28 70.60 71.42 70.70 70.45 74.46
connect-4 84.87 86.10 86.27 84.35 85.90 86.19 86.44 86.38 86.04
jungle-chess 87.59 91.78 92.92 89.53 96.07 93.37 94.42 93.81 99.55
mfeat-factors 98.0 97.0 97.5 97.5 98.0 98.5 98.5 98.5 98.0
segment 98.70 98.70 98.70 98.70 98.70 99.13 99.13 99.13 99.13
vehicle 83.52 77.64 88.23 87.05 83.52 88.23 87.05 87.05 85.88
Table S3: Raw test accuracy (or percent of variation explained for regression) under various training/distillation strategies of the Selected best individual model (across all 4 model types) chosen based on validation performance. The final column shows the performance of the ensemble-predictor produced by AutoGluon (used as teacher in distillation). Datasets are colored by task: regression (black), binary classification (blue), multiclass classification (red).
      (A) Augmentation w Real Data       (B) Augmentation w GIB-1
Figure S2: Distillation performance when augmented data are: (A) additional real data points from the true underlying distribution, (B) synthetic examples obtained from 1 round of our Gibbs sampling procedure. Here we report average normalized test accuracy (

) over the 3 largest regression datasets, with corresponding standard errors indicated by vertical lines (our normalization rescales

by the teacher’s on each dataset). To obtain additional real data points for augmentation, we did the following: only 20% of the original training set was adopted as the training data (accuracies obtained from this training data shown at 0 on x-axis). The rest of the 80% held-out data was treated as unlabeled and used as augmented data for distillation (in increasing multiples of the training sample-size ), following the same distillation procedure described in the main text. The GIB-1 results are obtained by applying our FAST-DAD distillation procedure with only the 20% training data (the 80% held-out data are entirely ignored, so our self-attention pseudolikelihood model is fit to relatively little data). The AutoGluon teacher is also only fit to the same 20% of the training data. Panel (A) empirically validates Lemma 3, showing that distillation becomes much more powerful with additional unlabeled data from the true feature distribution. Distillation gains produced by augmenting with Gibbs samples do not match the performance of augmenting with real data, suggesting superior generative models may further reduce this gap.

Appendix D Proof of Fig. 2

Here we discuss our refinement of Lemma 3 that formally describes how the number of steps of Gibbs sampling affects the distillation of the student. Lemma 3

suggests that if we learn a probability distribution

using the data , we might be able to reduce the variance term in the VC-bound at the cost of a bias. We now characterize the situation when the Gibbs sampler with a steady-state distribution is initialized at samples from , namely the original training dataset , and is run for steps. Intuitively, if is large, the sampler provides data that is diverse from which leads to stronger variance reduction. However it is also true that the samples are not drawn from and therefore the teacher suffers a covariate shift on these samples which leads to poor fitting of the student . This suggests there should be a sweet spot: the number of Gibbs sampling steps should lead to variance reduction but should not be so large as to cause a large covariate-shift/bias. We capture this phenomenon in the following theorem. For simplicity, we only consider the special case where . Our proof can be generalized to but the details of the underlying symmetrization argument are more intricate (see comments in the proof). We stick to this special case to elucidate the main point. The full theorem statement is repeated here for completeness.

[Refinement of Lemma 1] Under the assumptions of Lemma 1, suppose that the student is chosen to minimize where are samples drawn after running the Gibbs sampler initialized at samples from for -steps. Then there exist constants and such that with probability at least we have


The quantity is the total-variation distance between the true data distribution and the distribution of the sampler’s iterates after steps, denoted by . The steady-state distribution of the Gibbs sampler is denoted by .


Let be the steady-state distribution of the Gibbs sampler with a linear operator that denotes the one-step transition kernel. Under general conditions robert2013monte, the distribution of the iterates of the sampler converges to this steady-state distribution as , i.e.,

from any initial distribution . Explicit rates are available for this convergence wang2014convergence: there exist constants and such that


where denotes the total-variation norm; the set is the Borel -algebra of the domain

. These rates are sharp for some parametric models diaconis2010gibbs. We use the following shorthand to denote,

, the density obtained after applying the one-step transition kernel times.

Suppose that the Gibbs sampler initialized at runs for steps and we then sample a dataset of samples from the resultant distribution :

The samples in are correlated with those already in . The student is fit to this dataset where the samples are not independent (we don’t have ) or identically distributed ( and ). Characterizing generalization performance is difficult for this scenario and requires strong assumptions, c.f. dagan2019learning, but we can we make the following helpful simplification.

The number of Gibbs steps is large enough for the samples in and to be statistically independent. Note that this does not imply that the samples are identically distributed, they still come from distributions and respectively. Since is the product of the number of rounds of Gibbs sampling and the dimensionality of the data (), achieving approximate independence does not necessarily require a large number of Gibbs rounds.

We now employ a bound by Jonathan Baxter baxter2000model that studies the generalization performance of a model when it sees data from a mixture of two different, possibly correlated, distributions, and . This is a uniform-convergence bound and follows via a two-step symmetrization argument where the second step involves separate permutations of the samples in datasets and . The same technique as that of baxter2000model also works if we draw more data from the new distribution than the original dataset , i.e., if . However the details are intricate and we stick to this special case to elucidate the main point.

For all functions , in particular for , the following holds with probability at least :


where is a constant. The quantity is the -net covering number of the hypothesis class under a given metric baxter2000model. According to Baxter’s result, for our case with two tasks, and , we are interested in computing the covering number for and the metric between two functions in as

with the labels given by the teacher. Our hypothesis class on the two tasks is the Cartesian product of the hypothesis class . Haussler’s theorem haussler1990probably gives an upper bound on the covering number in terms of the VC-dimension


where is the VC-dimension vapnik2015uniform of , is a constant, and is also called the metric entropy.

Observe that the left-hand side in Appendix D can be written as


Let us define

We next analyze the metric where we note that again .

Similarly we also have

In other words, the distance on the Cartesian space can be upper bounded by the distance between on the original space up to an additive term that increases with the number of steps of Gibbs sampling.

Next observe that we have an upper bound on the metric entropy


if the two datasets and are iid. If the datasets are not iid, using the calculation for above, computing the size of the -net for is effectively the same as changing on the right hand side of Eq. 11 to

Plugging the previous two expressions into Eq. 9 implies


The approximation above is valid if we additionally assume . We have thus shown that there exists a constant such that with probability at least :


The inequality follows because

since . Recall is the number of Gibbs sampling steps, is a constant, and . ∎

We provide some additional comments on this result. Note that as , so it increases with the number of Gibbs sampling steps . We can draw a large number of samples from to reduce the second term in the bound. Using a large is both computationally inefficient and may also cause a bias given by the additive term of (third term), if the stationary distribution of our Gibbs sampler poorly approximates . As our pseudolikelihood model is fit to limited data in practice, it is thus better to draw a large number of samples from earlier steps, i.e. using only a few steps of Gibbs sampling from each training datum instead of running a long chain. Among all that produce samples which are approximately independent of the original training data, we would like to use the smallest.

The experiments in our paper empirically show that, on an average over many datasets, running the Gibbs sampler for 1–5 rounds (one round involves performing a Gibbs step for every conditional in the pseudolikelihood) works better than running it for longer. Note that if we employ fewer steps than even a single round of Gibbs sampling, the augmented data will be highly dependent on the training data as some features will not have been resampled, thus diminishing the effective sample size of the student’s distillation dataset. It is also readily seen from the above bound that if the Gibbs sampler is initialized at a distribution other than , we would need a large number of steps before the bias term is adequately small.

abbrvnat distill