Traditionally, in supervised ML pipelines, the data used to train a model is divided into two sets, the training set and the validation set. The former is used to train various models, while the latter is used for ranking and selecting the best performing one, i.e
., the best architecture and hyper-parameters. Eventually, a final model with the selected configuration is trained on the entire data, including the training and the validation sets. Test data is inaccessible to the training pipeline, especially for model selection. The training-validation split provides means to estimate the models’ error rate, which can be used for ranking. However, it is not helpful for selection of models that were trained on the entire data. As the optimal hyper-parameters depend on the number of training data samples and since models sharing the exact same hyper-parameters may exhibit large variance in accuracy, this may eventually lead to selecting a sub-optimal model.
Recent advances in the quality of synthetic data generation pipelines (Karras2019stylegan2; karras2020training; peng2018visda) have reduced the synthetic-to-real domain gap enough to successfully apply the generated data for training deep models (besnier2020dataset). Other works have focused on analyzing and quantifying various characteristics of the domain gap (sajjadi2018assessing; kynkaanniemi2019improved). That said, to the best of our knowledge, the specific task of model selection with synthetic data was not addressed. When using synthetic data for training a model, one’s goal is to minimize the generalization gap w.r.t. the real domain (ben2010theory). Solving for generalization in presence of a synthetic-to-real domain gap is challenging. However, for model selection, one’s goal is to use synthetic data for ranking a set of pre-trained models, while requiring rank preservation in the real domain. In this work, we introduce a sufficient condition for cross-domain rank preservation and empirically validate its value for model selection.
2 Synthetic Data for Model Selection
We follow notations similar to ben2010theory. A domain is defined as a pair consisting of a distribution , where is the sample domain and
is the probability density function, and a labeling function, where represents possible classes. We consider a particular pair of domains, where one is the original real domain, denoted by , and the second one is synthetic, denoted by , specifically tuned to mimic the real one. A hypothesis (model) is a function
. The risk or the probability that a hypothesisdisagrees with a labeling function , according to the distribution is defined as: We neglect the difference between and and use the shorthand notation, . Let denote the risk difference between two hypotheses, , measured over a probabilistic distribution , i.e. missing, .
A common approach for model selection is the holdout method, where two datasets are sampled from : the training set, and the validation set, . A model (hypothesis) is trained using empirical risk minimization on . Thereafter, the model’s risk is estimated using the validation set:
. This allows to compare different models with different hyperparameters and to select those that minimize. Other approaches such as cross-validation and bootstrap also exist (kohavi1995study). Since increasing the number of samples in the training set almost always increases the accuracy of the model, a common final step is to re-train the model, using the hyperparameters found in the previous step, on the entire dataset, . However, without a held-out dataset, it is no longer possible to compare models.
We propose to replace this two-step approach with a single step where we train a model on the entire dataset, then rather than estimating , we instead generate a new dataset via a generative model. Then we estimate the error , where the domain approximates the original domain . Although it may often be impossible to guarantee highly accurate error estimation due to the synthetic-real domain gap, below we present a corollary from Theorem A.2 (see Appendix A for statement and proof) outlining a sufficient condition for hypotheses’ error rank preservation across domains.
Given the definitions above, let denote the total variation between the two distributions, . Then,
Informally, Corollary 2.1 indicates that if the total variation between the real and the synthetic distributions is not larger than the synthetic risk difference between a pair of hypotheses, then their error ranking is preserved across the domains.
All our experiments are performed on the CIFAR10 dataset (krizhevsky2009learning). To evaluate the impact of the training set size, we use the following train-test splits: 10K-50K (Train10K), 30K-30K (Train30K), 50K-10K (Train50K). We emphasize that in all our experiments the following set of rules holds:
Only the training portion of the data is available for any training purposes.
The test portion of the images is never used for model selection and is treated as non-existent for any training purposes.
In each experiment, GANs for generating synthetic detests are trained only on the training portion of the images, e.g., for experiments with the Train10K dataset, the GANs are trained only on the 10K images.
3.1 Rank Preservation Experiments
In our first experiment we focus on several commonly used deep model architectures. For each architecture, we select a number of variants. In total we experiment with 17 distinct architectures (see Appendix E for details). For each architecture, 10 models were trained on each of the three datasets. In Figure 1, we plot the empirical test errors, , vs. the empirical synthetic errors, , measured on datasets generated by four different GAN methods: (a) StyleGAN2-10, (b) StyleGAN2-Cond, (c) WGAN-GP-10, and (d) WGAN-GP-Cond (see details in Appendix F). It can be seen that, while in general , for the StyleGAN2 based models, we are able to produce datasets that preserve the error ranking of different classification models. We measure this using Spearman’s rank correlation coefficient. For different GANs, we have measured the following ranking coefficients: 0.97 (a), 0.98 (b), -0.19 (c) and 0.14 (d). In sajjadi2018assessing the connection between total variation (
) and precision and recall for distributions (PRD) was established, and an empirical method for estimating it was suggested. We use this method to empirically validate Corollary2.1. For the GANs above, we measured the following values: 3.5% (a), 8.7% (b), 43% (c) and 34% (d). Indeed we see matching behaviors where the two models with high Spearman correlation have low total variation, and vice verse. For example, it follows from Corollary 2.1 that for GAN (a), if two hypotheses have , then their rank on real data will be preserved.
3.2 Model Selection Experiments
We consider three model selection scenarios where synthetic data can be used:
Early stopping (ES):
Given a training schedule of a single model, select an epoch from which to take the model weights from.
Random seed selection (RSS): Given the same architecture and hyper-parameters, select a model instance out of trained models where the difference between the models is the randomness of the training process, e.g., weight initialization and dataset sampling order.
Hyper-parameter search (HPS): Select a model out of a set of models trained with different hyper-parameters. Possible hyper-parameters are training parameters, e.g., learning rate, batch size or architecture parameters, e.g., number of layers and network depth.
3.2.1 Early stopping and random seed selection on standard architectures
We explore the impact of synthetic data on the ES and RSS model selection scenarios and their combination. We highlight that both of these scenarios require a held-out dataset. Therefore in the standard pipeline they cannot be used when training the model on the entire dataset. Using synthetic data enables these selection scenarios. For ES, the best synthetic epoch was selected for every training run. For RSS, per architecture, the model that performed the best on synthetic data at the last epoch was selected. For RSS + ES, per architecture, the model that performed the best on synthetic data at its best epoch was selected. We first experiment with the same standard model architectures as in Section 3.1. Figure 2 shows the results on Train10K, demonstrating that for nearly all architectures RSS improves accuracy. On the other hand, ES demonstrates only marginal impact on accuracy. This might be because the models’ accuracy hardly changes across the last epochs, where the model has already converged (see Appendix D for convergence plot examples). In Appendix C we show the results for all datasets, where RSS shows comparable or better performance.
|Baseline||ES||RSS||ES + RSS|
3.2.2 Early stopping and random seed selection on similar architectures
We next evaluate the impact of ES and RSS across multiple models with similar architectures. For each dataset we constructed 64 architectures. For generating the architectures we used randomly wired neural networks (RWNN) framework(Xie_2019_ICCV) with WS(2,0.25), resulting in 64 unique but similar architectures per dataset. Each architecture is trained 10 times on each dataset (a total of 1920 models were trained). Table 1 concludes the experiment. Since the errors of the models are of the same scale, we report the average performance. RSS has a significant impact on model selection with an average improvement over the baseline of // (corresponding to: Train10K, Train30K, Train50K). Similarly to the previous experiment ES has no significant impact.
3.2.3 Synthetic data for architecture hyper-parameter search
Next, we explore the contribution of using synthetic datasets for selecting a model out of multiple possible architectures and training instances (HPS). We consider three model selection protocols:
Selecting a random model: The naïve baseline of selecting a random model.
Standard protocol: Split the dataset into training and validation subsets. Then: (1) train each architecture times with the training subset; (2) select the architecture that on average performed the best on the validation subset; (3) Train the selected architecture on the entire dataset. This methods allows for selecting a “promising” architecture without the ability to select a specific trained model instance (architecture and weights).
Synthetic protocol: (1) Train each architecture times on the entire dataset; (2) evaluate the accuracy at each training epoch on the synthetic dataset; (3) select the model that preformed the best in step 2. This method allows for selecting a “promising” model instance.
We used the same 64 architectures per dataset, generated by the RWNN framework in Section 3.2.2. Table 2 shows that the synthetic protocol achieves, on average, better results compared to the standard protocol. Both protocols perform far better then selecting a model at random. It can also be observed that the impact of using synthetic data is greater when the training dataset is smaller. This result is aligned with Corollary 2.1, as is larger. Appendix B presents an extensive breakdown of the experiment. A key finding from the breakdown is that ranking models using synthetic data is comparable to ranking models using validation data in terms of Spearman correlation to a test set.
|Synthetic protocol||Standard protocol||Avg of all models|
In this paper we presented a comprehensive empirical study on the CIFAR10 dataset for evaluating the impact of using synthetic data for model selection. The empirical evidence suggest that evaluating trained models on synthetic data can be beneficial and outperform the standard methods for model selection that are based solely on the available real images.
Appendix A Theoretical Justification
In this section we derive the theoretical justification leading up to Corollary 2.1.
Let denote the risk difference between two hypotheses, , measured over a probability distribution
, measured over a probability distribution, i.e. missing, . Let denote the labeling function. Let and . Then,
Let and denote the risk difference between two hypotheses, , measured over the real and the synthetic probability distributions and , respectively. I.e. missing, and . Let denote the labeling function. Let denote the total variation between the two distributions in the sample subspace of in which the two hypotheses disagree with and between themselves. If , then .
Let and , denote the real and the synthetic (generated) probabilistic distributions, respectively. Let and denote the risk difference between any two hypotheses, . Then,
where denotes the total variation between the two distributions.
Yellow bars represent the 10 architectures that performed the best (average of 10 training runs) on the validation set (ranked from the best to 10th best). Orange bars represent the same architectures, trained on the entire dataset and there performance (average of 10 training runs) on the test set. Black lines represents the 95% confidence interval.(4th row) The points in the first column, “Syn”, correspond to the test errors of the 10 best performing models selected using the synthetic protocol. The rest of the columns (1st-10th) correspond to the test error results of all trained models out of the 10 best architectures selected by the standard protocol (same architectures as row 3). Horizontal lines represent average test error rates of: 10 best synthetic models (Avg synth), average of the 10 best models selected by the standard protocol (Avg standard), and the average of all 640 models (Avg all).
Appendix B Synthetic data for architecture hyper-parameter search
In this experiment we explore the contribution of synthetic data for model selection out of a pool of different architectures. Given a training set and a held-out test set (not available for model selection), we compared the three model selection protocols (selecting a random network, standard protocol, synthetic protocol). The standard protocol requires a validation set, to this end we split each of the datasets (Train50K, Train30K, Train10K) into training and validation subsets. For the synthetic protocol a GAN was trained on each dataset to produce a dataset of 100K synthetic images (see Appendix F):
Train10K: The train/val split is 7.5K/2.5K. For the synthetic data protocol, a single StyleGAN2-Cond model trained on the 10K available images was used.
Train30K: The train/val split is 22.5K/7.5K. For the synthetic data protocol, 10 StyleGAN2 models were trained, each on the 3K (per class) available images.
Train50K: The train/val split is 40K/10K. For the synthetic data protocol, 10 StyleGAN2 models were trained, each on the 5K (per class) available images.
In each experiment 64 architectures were evaluated using the different protocols. For generating similar architectures, we sampled RWNN architectures with the same parameters WS(2, 0.25) (same architectures as in 3.2.2). In both the standard and synthetic protocols we trained each architecture 10 times. For the standard protocol we train on the training subset (step 1) and for the synthetic protocol we use the entire dataset. Note that for the standard protocol, it is not possible to select a model instance out of the 10 trained instances of each architecture that was trained on the entire data (step 3). Therefore, we use the average test error of the 10 trained models of each architecture (on the entire dataset) as a data point for comparisons.
Figure 4 shows different analyses of the experimental results. From the first two rows of the figure we can infer that there is a strong correlation between the error on synthetic data, , and the error on the test set, . We evaluate this correlation using Spearman’s rank correlation coefficient, as it is appropriate for measuring rank preservation. From the Spearman correlation plot (second row) we learn that the ranking capability of the synthetic data is comparable to that of the real data validation set. This strengthens our premise that using synthetic data for model selection is appropriate. It can be seen that the correlation improves when the training set is smaller and the errors are larger. This result coincides with Corollary 2.1 where for larger gaps in synthetic error, there is a lower chance for a flip in model ranking. From the third row we can infer that the ranking of architectures might change when moving from training on a smaller training set and evaluating on a validation set to training on the entire dataset end evaluating on the test set. This implies that a potential gain in accuracy could be achieved by selecting a model out of the models that were directly trained on the entire dataset. From the last row we can learn that, on average, model selection using synthetic data improves over the standard method. Again, the impact of synthetic data increases as the training dataset size decreases. Given that a synthetic dataset is available, training the models directly on the entire dataset is simpler than training on a subset and re-training on the entire dataset.
Appendix C Additional results for early stopping and random seed selection on standard architectures
In addition to the results reported in 3.2.1, Figure 5 shows results of ES, RSS and RSS+ES on all three datasets. RSS is beneficial for model selection in most cases, however the benefits decrease as the dataset size increases.
Appendix D Standard architecture convergence in training
Figure 6 shows two examples of the train, test and synthetic data errors vs. epoch index during training on the Train50K dataset. It can be observed that although the synthetic data error does not match the test error exactly, it follows the same trend as the test error. In the last epochs of training, where the learning rate has decreased there is very little change in the model’s error. This may explain why the early stopping experiments did not demonstrate any benefits.
Appendix E Standard Architectures Description
DenseNet : (huang2017densely; huang2019convolutional) with batch size 32, initial learning rate 0.05, depth 100, block type “bottleneck”, growth rate 12, compression rate 0.5.
PyramidNet 270: (DPRN) with depth 110, block type “basic”, .
PyramidNet 84: (DPRN) with depth 110, block type “basic”, .
SE-ResNet-preact: (hu2019squeezeandexcitation) with depth 110, se reduction=16.
ResNet-preact 110: (He2016) with depth 110, block type “basic”.
ResNet-preact 164: (He2016) with depth 164, block type “bottleneck”.
ResNext 4x64d: (Xie2016) with depth 29, cardinality 4, base channels 64, batch size 32 and initial learning rate 0.025.
ResNext 8x64d: (Xie2016) with depth 29, cardinality 8, base channels 64, batch size 64 and initial learning rate 0.05.
Shake-shake 32d: (Gastaldi17ShakeShake) with depth 26, base channels 32, S-S-I model.
Shake-shake 64d: (Gastaldi17ShakeShake) with depth 26, base channels 64, S-S-I model, batch size 64, base .
Shake-shake 64d + cutout: (Gastaldi17ShakeShake) with depth 26, base channels 64, S-S-I model, batch size 64, , cosine scheduler, cutout (devries2017improved) size 16.
Wide residual network + cutout: (Zagoruyko2016WRN) with depth 28, widening factor 10, base , batch size 64, cosine scheduler, cutout (devries2017improved) size 16.
Wide residual network: (Zagoruyko2016WRN) with depth 28, widening factor 10.
ResNet 32: (He2016DeepRL) with depth 32, block type “basic”.
ResNet 44: (He2016DeepRL) with depth 44, block type “basic”.
ResNet 56: (He2016DeepRL) with depth 56, block type “basic”.
ResNet 110: (He2016DeepRL) with depth 110, block type “basic”.
Appendix F Synthetic Data Generation Details
Our method for producing synthetic datasets is based on training GANs that in turn are used to generate the desired labeled data. We consider two GAN frameworks for generating our synthetic datasets:
StyleGAN2 (Karras2019stylegan2) with non-leaking augmentation (karras2020training). This framework is our best candidate for generating high quality synthetic datasets since it is the SOTA for generating CIFAR10 images.
WGAN-GP (NIPS2017_892c3b1c). This framework generates lower quality images than StyleGAN2. We consider it as a baseline to explore how the image quality impacts the datasets models selection capabilities.
For each GAN framework we consider two variants of training the GANs to generate labeled datasets:
Training 10 GANs (StyleGAN2-10/WGAN-GP-10): For each of the 10 CIFAR10 classes, a different GAN was trained with just one class at a time (e.g., 5K images for Train50K, 3K images for Train30K and 1K images for Train10K). The generator instance with the best FID (heusel2017gans) score out of all instances obtained during training was selected to generate 10K images of its corresponding class.
Training one Conditional GAN (StyleGAN2-Cond/WGAN-GP-Cond): A single Conditional-GAN was trained, and best instance selected by FID score. Thereafter, 10K images were generated per class.
Using the above methods we constructed 8 datasets (each with 100K labeled images): three “StyleGAN2-10” datasets and three “StyleGAN2-Cond” datasets (one per CIFAR10 subset), one “WGAN-GP-10” dataset and one “WGAN-GP-Cond” dataset (for the Train50K CIFAR10 subset).
Table 3 shows the FID scores breakdown for our synthetic datasets. As expected, as the training dataset size decreases the FID score increases.