Dataset Pruning: Reducing Training Data by Examining Generalization Influence

by   Shuo Yang, et al.

The great success of deep learning heavily relies on increasingly larger training data, which comes at a price of huge computational and infrastructural costs. This poses crucial questions that, do all training data contribute to model's performance? How much does each individual training sample or a sub-training-set affect the model's generalization, and how to construct a smallest subset from the entire training data as a proxy training set without significantly sacrificing the model's performance? To answer these, we propose dataset pruning, an optimization-based sample selection method that can (1) examine the influence of removing a particular set of training samples on model's generalization ability with theoretical guarantee, and (2) construct a smallest subset of training data that yields strictly constrained generalization gap. The empirically observed generalization gap of dataset pruning is substantially consistent with our theoretical expectations. Furthermore, the proposed method prunes 40 dataset, halves the convergence time with only 1.3 which is superior to previous score-based sample selection methods.



page 1

page 2

page 3

page 4


It was the training data pruning too!

We study the current best model (KDG) for question answering on tabular ...

Data Dropout: Optimizing Training Data for Convolutional Neural Networks

Deep learning models learn to fit training data while they are highly ex...

Repartitioning of the ComplexWebQuestions Dataset

Recently, Talmor and Berant (2018) introduced ComplexWebQuestions - a da...

Training Data Subset Selection for Regression with Controlled Generalization Error

Data subset selection from a large number of training instances has been...

Deep Learning on a Data Diet: Finding Important Examples Early in Training

The recent success of deep learning has partially been driven by trainin...

Removing the influence of a group variable in high-dimensional predictive modelling

Predictive modelling relies on the assumption that observations used for...

PUMA: Performance Unchanged Model Augmentation for Training Data Removal

Preserving the performance of a trained model while removing unique char...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The great advances in deep learning over the past decades have been powered by ever-bigger models crunching ever-bigger amounts of data. However, this success comes at a price of huge computational and infrastructural costs for network training, network inference, hyper-parameter tuning, and model architecture search. While lots of research efforts seek to reduce the network inference cost by pruning redundant parameters networkpruningblalock2020state; networkpruningliu2018rethinking; networkpruningmolchanov2019importance, scant attention has been paid to the data redundant problem, which is crucial for network training and parameter tuning efficiency.

Massive training data has become a standard paradigm for deep learning, but it also limits the success of deep models to specialized equipment and infrastructure. For example, training EfficientNet xie2020self on JFT-300M dataset sun2017revisiting once takes 12,300 TPU days and consumes around J energy. Moreover, hyper-parameter/network architecture tuning could further increase the computation and energy cost. An effective way to deal with this is to construct an informative subset out of the original large-scale training data as a proxy training dataset.

Previous literatures try to sort and select a fraction of training data according to a scalar score computed based on some criterions, such as the distance to class center welling2009herding; rebuffi2017icarl; castro2018end; belouadah2020scail, the distance to other selected examples wolf2011kcenter; sener2018active, the forgetting score toneva2018empirical, and the gradient norm paul2021deep. However, these methods are (a) heuristic and lack of theoretically guaranteed generalization, also (b) discard the influence direction, i.e., the norm of averaged

gradient vector of two

high-gradient-norm samples could be zero if the direction of these two samples’ gradient is opposite.

To go beyond these limitations, a natural ask is, how much does a combination of particular training examples contribute to the model’s generalization? However, simply evaluating the test performance drop caused by removing each possible subset is not acceptable. Because it requires to re-train the model for times given a dataset with size . Therefore, the key challenges of dataset pruning are: (1)

how to efficiently estimate the generalization influence of all possible subsets without iteratively re-training the model

, and (2) how to identify the smallest subset of the original training data with strictly constrained generalization gap.

In this work, we present an optimization-based dataset pruning method that can prune a largest subset from the entire training dataset with (a) theoretically guaranteed generalization gap and (b) consideration of the joint influence of all collected data. Specifically, we define the parameter influence of a training example as the model’s parameter change caused by omitting the example from training. The parameter influence can be linearly approximated without re-training the model by Influence Function influencefunction. Then, we formulate a constrained discrete optimization problem with an objective of maximizing the number of collected samples, and a constraint of penalizing the network parameter change caused by removing the collected subset within a given threshold . We then conduct extensive theoretical and empirical analysis to show that the generalization gap of the proposed dataset pruning can be upper bounded by the pre-defined threshold . Superior to previous score-based sample selection methods, our proposed method prunes 40% training examples in CIFAR-10 cifar10, halves the convergence time, and achieves only 1.3% test accuracy drop. Before delving into details, we summarize our contributions as below:

  • [leftmargin=15pt]

  • This paper proposes a dataset pruning method, which extends the sample selection problem from cardinality constraint to generalization constraint. Specifically, previous sample selection methods find a data subset with fixed size budget and try to improve their performance, while dataset pruning tries to identify the smallest subset that satisfies the expected generalization ability.

  • This paper proposes to leverage the influence function to approximate the network parameter change caused by omitting each individual training example. Then an optimization-based method is proposed to identify the largest subset that satisfies the expected parameter change.

  • We prove that the generalization gap caused by removing the identified subset can be up-bounded by the network parameter change that was strictly constrained during the dataset pruning procedure. The observed empirical result is substantially consistent with our theoretical expectation.

  • The experimental results on dataset pruning and neural architecture search (NAS) demonstrate that the proposed dataset pruning method is extremely effective on improving the network training and architecture search efficiency.

The rest of this paper is organized as follows. In Section. 2, we briefly review existing sample selection and dataset condensation research, where we will discuss the major differences between our proposed dataset pruning and previous methods. In Section. 3, we present the formal definition of the dataset pruning problem. In Section. 4 and Section. 5, our optimization-based dataset pruning method and its generalization bound are introduced. In Section. 6, we conduct extensive experiments to verify the validity of the theoretical result and the effectiveness of dataset pruning. Finally, in Section. 7, we conclude our paper.

2 Related Works

Dataset pruning is orthogonal to few-shot learning maml; matching; snell2017prototypical; closer; DeepEMD; yang2021free; yang2021bridging; he2021partimagenet; zhang2022graph; yang2021single. Few-shot learning aims at improving the performance given limited training data, while dataset pruning aims at reducing the training data without hurting the performance much.

Dataset pruning is closely related to the data selection methods, which try to identify the most representative training samples. Classical data selection methods agarwal2004approximating; har2004coresets; feldman2013turning; yang2021noise; wang2022reliable focus on clustering problems. Recently, more and more data selection methods have been proposed in continual learning rebuffi2017icarl; toneva2019empirical; castro2018end; aljundi2019gradient; yang2021objects

and active learning 

sener2017active to identify which example needs to be stored or labeled. The data selection methods typically rely on a pre-defined criterion to compute a scalar score for each training example, e.g. the compactness rebuffi2017icarl; castro2018end; yang2019ada, diversity sener2018active; aljundi2019gradient, and forgetfulness toneva2019empirical, then rank and select the training data according to the computed score. However, these methods are heuristic and lack of generalization guarantee, they also discard the influence interaction between the collected samples. Our proposed dataset pruning method overcomes these shortcomings.

Another line of works on reducing the training data is dataset distillation wang2018dataset; such2020generative; sucholutsky2019softlabel; bohdal2020flexible; nguyen2021dataset1; nguyen2021dataset2; cazenavette2022dataset or dataset condensation zhao2021DC; zhao2021DSA; zhao2021distribution; wang2022cafe. This series of works focus on synthesizing a small but informative dataset as an alternative to the original large dataset. In particular, Dataset Distillation wang2018dataset

directly minimizes the classification loss on the real training data of the neural networks trained on the synthetic data to optimize the synthetic samples. Later, Sucholutsky et al. 

sucholutsky2019softlabel proposed to simultaneously optimize the synthetic images and soft labels, Bohdal et al. bohdal2020flexible proposed to simplify the dataset distillation by only learning the soft labels for randomly selected real images, Such et al. such2020generative proposed to leverage the advantages of generative models to generate the synthetic data, Nguyen et al. nguyen2021dataset1; nguyen2021dataset2

proposed to formulate the dataset distillation in a kernel-ridge regression and induce the synthetic images on infinity wide neural networks. Cazenavette et al. 

cazenavette2022dataset proposed to learn the synthetic images by matching the training trajectories. Zhao et al. zhao2021DC; zhao2021DSA; zhao2021distribution proposed to learn the synthetic images by matching the gradients and features. However, due to the computational power limitation, these methods usually only synthesize an extremely small number of examples (e.g. 50 images per class) and the performance is far from satisfactory. Therefore, the performance of dataset distillation and dataset pruning is not directly comparable.

Our method is inspired by Influence Function influencefunction

in statistical machine learning. Removing a training example from the training dataset and not damage the generalization indicates that the example has a small influence on the expected test loss. Earlier works focused on studying the influence of removing training points from linear models, and later works extended this to more general models 

hampel2011robust; cook1986assessment; thomas1990assessing; chatterjee1986influential; wei1998generalized. Liu et al. liu2014efficient used influence functions to study model robustness and to fo fast cross-validation in kernel methods. Kabra et al. kabra2015understanding defined a different notion of influence that is specialized to finite hypothesis classes. Koh et al. influencefunction studied the influence of weighting an example on model’s parameters. In our works, we start by analyzing the influence of removing a selected subset on model’s parameters, then we show the generalization guarantee of the proposed method.

3 Problem Definition

Given a large-scale dataset containing training points where , is the input space and is the label space. The goal of dataset pruning is to identify a set of redundant training samples from as many as possible and remove them to reduce the training cost. The identified redundant subset, and , should have a minimal impact on the learned model, i.e. the test performances of the models learned on the training sets before and after pruning should be very close, as described below:


where is the data distribution,

is the loss function, and

and are the empirical risk minimizers on the training set before and after pruning , respectively, i.e., and . Considering the neural network is a locally smooth function rifai2012generative; goodfellow2014explaining; zhao2021DC, similar weights () imply similar mappings in a local neighborhood and thus generalization performance. Therefore, we can achieve Eq. 1 by obtaining a that is very close to (the distance between and is smaller than a given very small value ). To this end, we first define the dataset pruning problem on the perspective of model parameter change, we will later provide the theoretical evidence that the generalization gap in Eq. 1 can be upper bounded by the parameter change in Section. 5.

Definition 1 (-redundant subset.).

Given a dataset containing training points where , considering is a subset of where , . We say is an -redundant subset of if , where and , then write as .

Dataset Pruning: Given a dataset , dataset pruning aims at finding its largest -redundant subset , i.e., , so that the pruned dataset can be constructed as .

4 Method

To achieve the goal of dataset pruning, we need to evaluate the model’s parameter change caused by removing each possible subset of . However, it is impossible to re-train the model for times to obtain for a given dataset with size . In this section, we propose to efficiently approximate the without the need to re-train the model.

4.1 Parameter Influence Estimation

We start from studying the model parameter change of removing each single training sample from the training set . The change can be formally written as , where . Estimating the parameter change for each training example by re-training the model for times is also unacceptable time-consuming, because is usually on the order of tens or even hundreds of thousands.

Alternatively, the researches of Influence Function cook1977detection; cook1980characterizations; cook1986assessment; cook1982residuals; influencefunction provide us an accurate and fast estimation of the parameter change caused by weighting an example for training. Considering a training example was weighted by a small during training, the empirical risk minimizer can be written as . Assigning to is equivalent to removing the training example . Then, the influence of weighting on the parameters is given by


where is the Hessian and positive definite by assumption, , is the number of network parameters. The proof of Eq.(2) can be found in influencefunction. Then, we can linearly approximate the parameter change due to removing without retraining the model by computing . Similarly, we can approximate the parameter change caused by removing a subset by accumulating the parameter change of removing each example, .

4.2 Dataset Pruning as Discrete Optimization

Combining definition 1 and the parameter influence function (Eq. (2)), it is easy to derive that if the parameter influence of a subset satisfies , then it is an -redundant subset of . Denote , to find the largest -redundant subset that satisfies the conditions of (1) and (2) simultaneously, we formulate the generalization-guaranteed dataset pruning as a discrete optimization problem as below (a):

(a) generalization-guaranteed pruning:

subject to

(b) cardinality-guaranteed pruning:

subject to

where is a discrete variable that is needed to be optimized. For each dimension , indicates the training sample is selected into , while indicates the training sample is not pruned. After solving in Eq .3, the largest -redundant subset can be constructed as . For some scenarios when we need to specify the removed subset size and want to minimize their influence on parameter change, we provide cardinality-guaranteed dataset pruning in Eq. 4.

0:  Dataset
0:  Random initialized network ;
0:  Expected generalization drop ;
1:  ; //compute ERM on
2:  Initialize ;
3:  for  do
4:     ;//store the parameter influence of each example
5:  end for
6:  Initialize ;
7:  Solve the following problem to get :
subject to //guarantee the generalization drop
8:  Construct the largest -redundant subset ;
9:  return  Pruned dataset: ;
Algorithm 1 Generalization guaranteed dataset pruning.

5 Generalization Analysis

In this section, we theoretically formulate the generalization guarantee of the minimizer given by the pruned dataset. Our theoretical analysis suggests that the proposed dataset pruning method have a relatively tight upper bound on the expected test loss, given a small enough .

For simplicity, we first assume that we prune only one training sample (namely, is one-dimensional). Following the classical results, we may write the influence function of the test loss as


which indicates the first-order derivative of the test loss with respect to .

We may easily generalize the theoretical analysis to the case that prunes training samples, if we let be a -dimensional vector and . Then we may write the multi-dimensional form of the influence function of the test loss as


which is an matrix and indicates the number of parameters.

We define the expected test loss over the data distribution as and define the generalization gap due to dataset pruning as . By using the influence function of the test loss singh2021phenomenology, we obtain Theorem 1 which formulates the upper bound of the generalization gap.

Theorem 1 (Generalization Gap of Dataset Pruning).

Suppose that the original dataset is and the pruned dataset is . If , we have the upper bound of the generalization gap as


We express the test loss at using the first-order Taylor approximation as


where the last term is usually ignorable in practice, because is very small for popular benchmark datasets. The same approximation is also used in related papers which used influence function for generalization analysis singh2021phenomenology. According to Equation (8), we have


where the first inequality is the Cauchy–Schwarz inequality, the second inequality is based on the algorithmic guarantee that in Eq. (3), and

is a hyperparameter. Given the second-order Taylor term, finally, we obtain the upper bound of the expected loss as


The proof is complete. ∎

Theorem 1 demonstrates that, if we hope to effectively decrease the upper bound of the generalization gap due to dataset pruning, we should focus on constraining or even directly minimizing . The basic idea of the proposed optimization-based dataset pruning method exactly aims at penalizing via discrete optimization, shown in Eq. 3 and Eq. 4.

Moreover, our empirical results in the following section successfully verify the estimated generalization gap. The estimated generalization gap of the proposed optimization-based dataset pruning approximately has the order of magnitude as on CIFAR-10 and CIFAR-100, which is significantly smaller the estimated generalization gap of random dataset pruning by more than one order of magnitude.

6 Experiment

In the following paragraph, we conduct experiments to verify the validity of the theoretical results and the effectiveness of the proposed dataset pruning method.

In Section. 6.1, we introduce all experimental details. In Section. 6.2, we empirically verify the validity of the Theorem. 1. In Section. 6.3 and Section. 6.4, we compare our method with several baseline methods on dataset pruning performance and cross-architecture generalization. Finally, in Section. 6.5, we show the proposed dataset pruning method is extremely effective on improving the training efficiency.

Figure 1: We compare our proposed optimization-based dataset pruning method with several sample-selection baselines on CIFAR-10 (left) and CIFAR-100 (right). Random selects training data randomly. Herding welling2009herding selects samples that are closest to the class center. Forgetting toneva2018empirical selects training samples that are easily to be forgot during the optimization.

6.1 Experiment Setup and Implementation Details

We evaluate the effectiveness of dataset pruning methods on CIFAR10 and CIFAR100 datasets cifar10. To verify the cross-architecture generalization of the pruned dataset, we prune the dataset using ResNet50 he2016deep, then train GoogLeNet googlenet and DenseNet huang2017densely

on the pruned dataset. All hyper-parameters and experimental settings of training before and after dataset pruning were controlled to be the same. Specifically, in all experiments, we train the model for 200 epochs with a batch size of 128, a learning rate of 0.01 with cosine annealing learning rate decay strategy, SGD optimizer with the momentum of 0.9 and weight decay of 5e-4, data augmentation of random crop and random horizontal flip. In Eq. 

2, we calculate the Hessian matrix inverse by using seconder-order optimization trick agarwal2016second which can significantly boost the Hessian matrix inverse estimation time. To further improve the estimation efficiency, we only calculate parameter influence in Eq. 2 for the last linear layer. In Eq. 3 and Eq. 4, we solve the discrete optimization problem using CvxPy diamond2016cvxpy with CPLEX solver manual1987ibm. All experiments were run for five times with different random seeds.

6.2 Theoretical Analysis Verification

Figure 2: The comparison of empirically observed generalization gap and our theoretical expectation in Theorem. 1. We ignore the term of since it has much smaller magnitude with .

Our proposed optimization-based dataset pruning method tries to collect the smallest subset by constraining the parameter influence . In Theorem. 1, we demonstrate that the generalization gap of dataset pruning can be upper-bounded by . The term of can be simply ignored since it usually has a much smaller magnitude than . To verify the validity of the generalization guarantee of dataset pruning, we compare the empirically observed test loss gap before and after dataset pruning and our theoretical expectation in Fig. 2. It can be clearly observed that the actual generalization gap is highly consistent with our theoretical prediction. We can also observe that there is a strong correlation between the pruned dataset generalization and , therefore Eq. 3 effectively guarantees the generalization of dataset pruning by constraining . Compared with random pruning, our proposed optimization-based dataset pruning exhibits much smaller and better generalization.

6.3 Dataset Pruning

In the previous section, we motivated the optimization-based dataset pruning method by constraining or directly minimizing the network parameter influence of a selected data subset. The theoretical result and the empirical evidence show that constraining parameter influence can effectively bound the generalization gap of the pruned dataset. In this section, we evaluate the proposed optimization-based dataset pruning method empirically. We show that the test accuracy of a network trained on the pruned dataset is comparable to the test accuracy of the network trained on the whole dataset, and is competitive with other baseline methods.

We prune CIFAR10 and CIFAR100 cifar10 using a random initialized ResNet50 network he2016deep. We compare our proposed method with the following baselines, (a) Random pruning, which selects an expected number of training examples randomly. (b) Herding welling2009herding, which selects an expected number of training examples that are closest to the cluster center of each class. (c) Forgetting toneva2019empirical, which selects training examples that are easy to be forgot. The forgetting score of a training example is defined as the number of times during training when it’s classification prediction switches from correct to incorrect. To make our method comparable to those cardinality-based pruning baselines, we pruned datasets using the cardinality-guaranteed dataset pruning as in Eq. 4

. After dataset pruning and selecting a training subset, we obtain the test accuracy by retraining a new random initialized ResNet50 network on only the pruned dataset. For each experiment, we report the test accuracy mean and standard deviation of five individual runs with different random seeds. The experimental results are shown in Fig. 

1. In Fig. 1, our method consistently surpasses all baseline methods. The forgetting method achieves very close performance to ours when the pruning ratio is small, while the performance gap increases along with the increase of the pruning ratio. This phenomenon also happens on other baselines. This is because all these baseline methods are score-based, they prune training examples by removing the lowest-score examples without considering the influence iteractions between high-score examples and low-score examples. The influence of a combination of high-score examples may be minor, and vice versa. Our method overcomes this issue by considering the influence of a subset rather than each individual example. Therefore the performance of our method is superior especially when the pruning ratio is high.

6.4 Unseen Architecture Generalization

Figure 3: To evaluate the unseen-architecture generalization of the pruned dataset, we prune CIFAR10 dataset using ResNet50 then train a GoogLeNet and an DenseNet121 on the pruned dataset.

We conduct experiments to verify the pruned dataset can generalize well to those unknown network architectures that are inaccessible during dataset pruning. To this end, we use a ResNet50 to prune the dataset and further use the pruned dataset to train different network architectures. As shown in Fig. 3, we evaluate the CIFAR10 pruned by ResNet50 on two unknown architectures, i.e., GoogLeNet googlenet and DenseNet121 huang2017densely. The experimental results show that the pruned dataset has a good generalization on network architectures that are unknown during dataset pruning. This indicates that the pruned dataset can be used in a wide range of applications regardless of specific network architecture.

6.5 Dataset Pruning Improves the Training Efficiency.

The pruned dataset can significantly improve the training efficiency while maintaining the performance, as shown in Fig. 4. Therefore, the proposed dataset pruning benefits when one needs to train many trails on the same dataset. One such application is neural network search (NAS) zoph2018learning which aims at searching a network architecture that can achieve the best performance for a specific dataset. A potential powerful tool to accelerate the NAS is by searching architectures on the smaller pruned dataset, if the pruned dataset has the same ability of identifying the best network to the original dataset.

We construct a 720 ConvNets searching space with different depth, width, pooling, activation and normalization layers. We train all these 720 models on the whole CIFAR10 training set and four smaller proxy datasets that are constructed by random, herding welling2009herding, forgetting toneva2019empirical, and our proposed dataset pruning method. All the four proxy datasets contain only 100 images per class. We train all models with 100 epochs. The proxy dataset contains 1000 images in total, which occupies 2% storage cost than training on the whole dataset.

Figure 4: Dataset pruning significantly improves the training efficiency with minor performance scarification. When pruning 40% training examples, the convergence time is nearly halved with only 1.3% test accuracy drop. The pruned dataset can be used to tune hyper-parameters and network architectures to reduce the searching time.

Table.1 reports (a) the average test performance of the best selected architectures trained on the whole dataset, (b) the Spearmen’s rank correlation coefficient between the validation accuracies obtained by training the selected top 10 models on the proxy dataset and the whole dataset, (c) time of training 720 architectures on a Tesla V100 GPU, and (d) the memory cost. Table. 1 shows that searching 720 architectures on the whole dataset raises huge timing cost. Randomly selecting 2% dataset to perform the architecture search can decrease the searching time from 3029 minutes to 113 minutes, but the searched architecture achieves much lower performance when trained on the whole dataset, making it far from the best architecture. Compared with baselines, our proposed dataset pruning method achieves the best performance (the closest one to that of the whole dataset) and significantly reduces the searching time (3% of the whole dataset).

Random Herding Forgetting Ours Whole Dataset
Performance (%) 79.4 80.1 82.5 85.7 85.9
Correlation 0.21 0.23 0.79 0.94 1.00
Time cost (min) 113 113 113 113 3029
Storage (imgs)
Table 1: Neural architecture search on proxy-sets and whole dataset. The search space is 720 ConvNets. We do experiments on CIFAR10 with 100 images/class proxy dataset selected by random, herding, forgetting, and our proposed optimization-based dataset pruning. The network architecture selected by our pruned dataset achieves very close performance to the upper-bound.

7 Conclusion

This paper proposes a problem of dataset prunning, which aims at removing redundant training examples with minor impact on model’s performance. By theoretically examining the influence of removing a particular subset of training examples on network’s parameter, this paper proposes to model the sample selection procedure as a constrained discrete optimization problem. During the sample selection, we constrain the network parameter change while maximize the number of collected samples. The collected training examples can then be removed from the training set. The extensive theoretical and empirical studies demonstrate that the proposed optimization-based dataset pruning method is extremely effective on improving the training efficiency while maintaining the model’s generalization ability.