Log In Sign Up

Rapid Structural Pruning of Neural Networks with Set-based Task-Adaptive Meta-Pruning

As deep neural networks are growing in size and being increasingly deployed to more resource-limited devices, there has been a recent surge of interest in network pruning methods, which aim to remove less important weights or activations of a given network. A common limitation of most existing pruning techniques, is that they require pre-training of the network at least once before pruning, and thus we can benefit from reduction in memory and computation only at the inference time. However, reducing the training cost of neural networks with rapid structural pruning may be beneficial either to minimize monetary cost with cloud computing or to enable on-device learning on a resource-limited device. Recently introduced random-weight pruning approaches can eliminate the needs of pretraining, but they often obtain suboptimal performance over conventional pruning techniques and also does not allow for faster training since they perform unstructured pruning. To overcome their limitations, we propose Set-based Task-Adaptive Meta Pruning (STAMP), which task-adaptively prunes a network pretrained on a large reference dataset by generating a pruning mask on it as a function of the target dataset. To ensure maximum performance improvements on the target task, we meta-learn the mask generator over different subsets of the reference dataset, such that it can generalize well to any unseen datasets within a few gradient steps of training. We validate STAMP against recent advanced pruning methods on benchmark datasets, on which it not only obtains significantly improved compression rates over the baselines at similar accuracy, but also orders of magnitude faster training speed.


page 1

page 2

page 3

page 4


Lottery Jackpots Exist in Pre-trained Models

Network pruning is an effective approach to reduce network complexity wi...

Successfully Applying the Stabilized Lottery Ticket Hypothesis to the Transformer Architecture

Sparse models require less memory for storage and enable a faster infere...

ESPN: Extremely Sparse Pruned Networks

Deep neural networks are often highly overparameterized, prohibiting the...

Training Deep Neural Networks with Joint Quantization and Pruning of Weights and Activations

Quantization and pruning are core techniques used to reduce the inferenc...

Differentiable Network Pruning for Microcontrollers

Embedded and personal IoT devices are powered by microcontroller units (...

When to Prune? A Policy towards Early Structural Pruning

Pruning enables appealing reductions in network memory footprint and tim...

Studying the Consistency and Composability of Lottery Ticket Pruning Masks

Magnitude pruning is a common, effective technique to identify sparse su...

1 Introduction

Deep learning has achieved remarkable progress over the last years on a variety of tasks, such as image classification Krizhevsky et al. (2012); Wang et al. (2017); Rawat and Wang (2017), object detection Lin et al. (2017b); Liu et al. (2020), and semantic segmentation Lin et al. (2017a); Huang et al. (2019). A key factor to the success of deep neural networks is their expressive power, which allows them to represent complex functions with high precision. Yet, such expressive power came at the cost of increased memory and computational requirement. Moreover, there is an increasing demand to deploy deep neural networks to resource-limited devices, which may not have sufficient memory and computing power to run the modern deep neural networks. Thus, many approaches have been proposed to reduce the size of the deep neural networks, such as network pruning, training the model with sparsity-inducing regularizations or prior Han et al. (2015); Yoon and Hwang (2017); Lee et al. (2019a), network distillation Hinton et al. (2014); Hui et al. (2018), and network quantization Han et al. (2016); Jung et al. (2019). Arguably the most popular approach among them is network pruning, which aims to find the optimal subnetwork that is significantly smaller than the original network either by removing its weights and activations (unstructured) or filters and channels (structured). Structured pruning is often favored over unstructured pruning since GPUs can exploit its data locality to yield actual reduction of inference time, while unstructured pruning sometimes lead to longer inference time than the full networks Wen et al. (2016).

Yet, most conventional pruning techniques have a common limitation, in that they require a network pretrained on the target dataset. With such two-stage schemes, training will inevitably take more time than training of a full network, and thus most works focus only on the efficiency at inference time. However, in many real-world scenarios, it may be desirable to obtain training-time speedups with pruning. For instance, if we have to train a large network for a large dataset on cloud, it may incur large monetary cost (Figure 1(a)). As another example, due to concerns on data privacy, we may need to train the network on resource-limited devices (Figure 1

(b)), but the device may not have enough capacity even to load the original unpruned networks on memory. Handling such diverse requirement efficiently for each end user is crucial for a success of a machine learning platform (Figure 

1). Then how can we perform pruning without pretraining on the target task?

A few recently introduced methods, such as SNIP Lee et al. (2019c) and Edge-Popup Ramanujan et al. (2019) allow to prune randomly initialized neural networks, such that after fine-tuning, the pruned network obtains performance that is only marginally worse than that of the full network. This effectively eliminates the needs of pretraining, and SNIP further reduces pruning time by performing pruning in a single forward pass. However, they are limited in that they perform unstructured pruning which will not result in meaningful speedups on GPUs, either at inference or training time. Moreover, they underperform state-of-the-art structure pruning techniques with pretraining. Thus, none of the existing works can obtain strucutrally pruned subnetworks that provide us practical speedups both at the training and inference time, with minimal accuracy loss over the full network.

To achieve this challenging goal, we first focus on that in real-world scenarios, we may have a network pretrained on a large reference dataset (Figure 1 (c)). If we could prune such a reference pretrained network to make it obtain good performance on an unseen target task, it would be highly efficient since we only need to train the model once and use it for any given tasks. However, pruning a network trained on a different dataset may yield a suboptimal subnetwork for the target task. Thus, to ensure that the pruned network obtains near-optimal subnetwork for an unseen task, we propose to meta-learn the task-adaptive pruning mask generator as a set function, such that given few samples of the target dataset, it instantly generates a task-optimal subnetwork of a pretrained reference network.

We validate our Set-based Task-Adaptive Meta Pruning (STAMP) on multiple benchmark datasets against recently proposed structural pruning and random-weight pruning baselines. The results show that our method can rapidly prune an network to obtain a network is significantly more compact than the networks with similar accuracy using baseline pruning techniques. Further, this rapid structural pruning allows our model to significantly reduce the training cost in terms of both memory, computation, and wall-clock time, with minimal accuracy loss. Such efficiency makes STAMP appealing as a cheap alternative for neural architecture search in machine learning platforms (See Figure 1). The contribution of our work is threefold:

Figure 1: Illustrations of our Set-based Task-Adaptive Meta-Pruning (STAMP): STAMP meta-learns a general strategy to rapidly perform structural pruning of a reference network, for unseen tasks. If a learner gives a small fraction of information for his/her target tasks, STAMP almost instantly provides an optimally pruned network architecture which will train faster than the full network with minimal accuracy loss.
  • We propose a novel set-based structured pruning model, which instantly generates a pruning mask for a given dataset to prune a target network.

  • We suggest a meta-learning framework to train our set-based pruning mask generator, to obtain an approximately optimal subnetwork within few gradient steps on an unseen task.

  • We validate our meta-pruning model on benchmark datasets against structured and random weight pruning baselines, and show that it obtains significantly more compact subnetworks, that require only a fraction of wall-clock time to train the network to target accuracy.

2 Related Work

Neural network pruning.

During recent decades, there has been a surge of interest on weight pruning schemes for deep neural networks to promote memory/computationally efficient models. Unstructured pruning prunes the weight of the network without consideration of its structure. Some unstructured pruning methods have been shown to obtain extremely sparse networks that match the accuracy of full network, such as iterative magnitude pruning Han et al. (2015) which repeats between training and finetuning to recover from the damage from pruning. Lottery Ticket Hypothesis (LTH) Frankle and Carbin (2019) discusses the existence of a subnetwork which matches the accuracy and training time of a full network, referred as the winning ticket, and show that they can be found with iterative magnitude pruning. SNIP Lee et al. (2019c) propose a simple pruning method which can identify a similar subnetwork without pretraining in single forward step. Though SNIP does not strictly find a winning ticket, it is highly efficient and opens possibility to further research on rapid pruning without pretraining. Edge-Popup Ramanujan et al. (2019) finds optimal subsets from random weights, without any pretraining, which is also simple. However, SNIP is faster than Edge-popup in searching a pruned network.

Although unstructured pruning methods find extremely sparse subnetworks and gets simpler, due to poor data locality, it is difficult to reduce the network inference time on general-purpose hardware. Due to this limitation, recent works Liu et al. (2017, 2019b); He et al. (2017); Guo et al. (2020); Luo et al. (2017); Zhuang et al. (2018)

target to prune groups of weights (e.g., channels or neurons), to achieve actual reduction in the model size. Such structured pruning methods are useful in a resource-limited environment with compressed architectures to practically reduce the memory requirement and the running time and at inference time. SSL 

Wen et al. (2016) introduces a structured sparsity regularization method to prune neurons using (2,1)-norm during training. CGES Yoon and Hwang (2017) propose to combine group sparsity with exclusive sparsity regularization. VIBNet Dai et al. (2018)

utilizes the variational information bottleneck principle to compress neural networks. They compel the networks to minimize the neuron redundancy across adjacent layers with binary mask vectors. Beta-Bernoulli Dropout (BBDropout) 

Lee et al. (2019a)

learns a structured dropout function sampled from the Bernoulli distribution where the probability is given from the beta distribution with learnable parameters. Further, they introduce a data dependent BBDropout which generates a pruning mask as a function of the given data instance.

Meta learning.

Meta-learning, which learns over a distribution of task, have shown its efficiency in handling unseen tasks for various tasks, such as few-shot learning and sample-efficient reinforcement learning. The most popular meta-learning methods are gradient-based approaches such as MAML 

Finn et al. (2017) and REPTILE Nichol and Schulman (2018), which aim to find an initialization that can rapidly adapt to new tasks. BASE Shaw et al. (2019) learns through MAML algorithm to rapidly search for optimal neural architecture, and thus significantly reduce the search cost over state-of-art neural architecture search (NAS) methods Liu et al. (2019a); Xie et al. (2018). BASE learns a general prior through meta learning to perform fast adaptation for unseen tasks. On the other hand, our method learns the good initialization as a function of a set, such that it can rapidly adapt to the given targe task. MetaPruning Liu et al. (2019b) trains a hypernetwork that can generate sparse weights for any possible structures(i.e. the number of channels) of a network architecture. However, the hypernetwork does not generalize across tasks and thus the method requires to train one hypernetwork for each task.

3 Rapid Structural Pruning of Neural Networks with Set-based Task-Adaptive Meta-Pruning

We introduce a novel structural pruning method for deep neural networks, Set-based Task-Adaptive Meta-Pruning (STAMP), which rapidly searches and prunes uninformative units/filters of the initial neural network trained on some other reference datasets. In Section 3.1, we define an optimization problem for deep neural networks with pruning masks. In Section 3.2, we describe our set-based structural pruning method which efficiently reduces the model size in a few gradient steps while avoiding accuracy degradation. Finally, in Section 3.3, we describe our full meta-learning framework to train the pruning mask generator that generalizes to unseen tasks.

3.1 Problem Definition

Suppose that we have a neural network , which is a function of the dataset parameterized by a set of model weights , where and is a layer. Further suppose that the network has maximum desired cost (e.g., FLOPs, Memory, the number of parameters, and training/inference time), which depends on the hardware capability and applications. By denoting the total cost of the model as , we formulate the problem of searching for a network that minimizes the task loss while satisfying the total cost as an optimization problem, as follows:


where is an arbitrary regularization term. To obtain an optimal model with the desired cost, we basically follow popular pruning strategy that adopts sparsity-inducing masking parameters for deep neural networks. We reformulate the problem as obtaining compressed weights with the corresponding binary masks at layer , , where . This will result in unstructured pruning, which will prune individual weight elements. However, we may allow the model to compress its size by structured pruning, to yield actual wall-clock time speedup in training/inference time. We focus on generating structural pruning masks where the compressed weights will be expressed as , where

. Then, the objective function is defined to minimize a following loss function:

where .

3.2 Rapid Structural Pruning with Set-encoded Representation

To obtain an optimal pruned structure for the target task, we need to exploit the knowledge of the given task. Conventional pruning schemes search for the desired subnetworks through full mini-batch training where all of the instances are trained through numerous iterations, incurring excessive training cost as the data size gets bigger. To bypass this time-consuming search, and rapidly obtain the task-adaptive pruning masks, we adopt two learnable functions: a set encoding function generates a set encoded output and a mask generative function obtains a binary mask vector m, parameterized by and , respectively. That is, at each layer , through two different functions, the model generates the task-adaptive mask vector given the dataset-level encoded representation from a set encoding function. To reduce an burden for encoding the entire dataset, we use a sampled subset from , where is the sampled batch and is the input dimensionality. To this end, we formulate the objective of our set-based task-adaptive pruning as follows:


where is a batch dimension of the set representation. Throughout the paper, we use . The illustration of the set-based task-adaptive pruning model is described in Figure 2.

Figure 2: Set-based Task-Adaptive Pruning: We sample the subset from and train the model while simultaneously optimizing set-based binary masks through a set encoding function and mask generative functions.

3.3 Meta-update for Unseen Target Tasks

Now we describe how we learn the optimal parameters and for set based pruning. The simplest approach we can use is performing gradient descent through back propagation, as our model is end-to-end differentiable. However, this only results in optimized parameters for a specific task , which will not allow us to obtain an optimized parameters for an unseen task. As stated earlier, we want to apply our method across different task domains, such that we learn the pruning mask generator and the set encoder on and apply them on . To this end, we apply gradient-based meta learning method which obtains initialization parameters that rapidly adapts to given (unseen) tasks.

Basically, we train the parameters on multiple tasks sampled from by computing inner gradient in the inner loop and combining them to update the outer loop parameters. Then, the objective of the meta-train step of STAMP is learning good initialization of in the outer loop. We sample tasks from , where each task is . From each tasks, a batch  is sampled and it is divided into mini-batches to update inner gradient in respect to the loss function and the regularization terms described in Section 3.2. We note that the whole batch excluding labels is used for encoding set representation. For updating outer loop parameters at epoch, we only use the gradients of the last mini-batch, similarly to first-order MAML Finn et al. (2017) to accelerate learning as below:


After meta learning a set of the parameters, we can adapt it to various unseen tasks by performing few steps of gradient updates, with the maximum steps of 1 epoch. Through the meta-learn procedure, we can speed-up the training time on the target task by starting with the pruned network architecture in the early stage. We describe our whole process in Algorithm 1.

While we can plug in various set encoding methods  Edwards and Storkey (2016); Zaheer et al. (2017) or pruning methods to the proposed framework, STAMP adopts a transformer module Lee et al. (2019b) for set encoding function and proposed a set-based pruning mask generator based on the Beta-Bernoulli dropout Lee et al. (2019a). The details of the set encoder and the structural mask generation function are described in the Appendix (Section B).

0:  Source Dataset , Target Dataset Learnable parameters  
1:  function STAMP 
2:     for  do
3:         for  in parallel do
4:            Sample task , and
5:            Sample batch of tasks ,  
6:            Compute with Eq. 2
7:         end for
8:         Update
9:     end for
10:  end function
11:  Meta train with function STAMP
12:  Prune step to optimize with Eq. 2
13:  Finetune the pruned architecture to minimize
Algorithm 1 Set-based Task-Adaptive Meta-Pruning (STAMP)

4 Experiments

We demonstrate the effectiveness of STAMP with two widely used network architectures, namely VGGNet-19 Zhuang et al. (2018) and ResNet-18 He et al. (2016)

, on two benchmark datasets (CIFAR-10 and SVHN). We implement the code for all the experiments in Pytorch framework, and use Titan XP GPU for measuring the wall-clock time.


We validate our STAMP against recent structured pruning methods as well as unstructured random weight pruning methods. We also report the results on the variant of STAMP that only searches for the structure and randomly reinitializes the weights (STAMP-Structure). Baselines we use for comparative study are as follows: 1) MetaPruning Liu et al. (2019b)

: Structured pruning method which learns hypernetworks to generate pruned weights at each layer, and searches for the optimal pruned structure using an evolutionary algorithm.

2) BBDropout Lee et al. (2019a): Beta-Bernoulli Dropout which performs structured pruning of the channels by sampling sparse masks on them. 3) Random Pruning: Randomly pruning of channels. We sample the random structure ( i.e. the number of channels for each layer ) under the given FLOP constraints in the same manner as in MetaPruning Liu et al. (2019b). 4) Edge-Popup Ramanujan et al. (2019): Unstructured pruning method that searches for the best performing sub-network of a network with random weights. 5) SNIP Lee et al. (2019c): One-shot unstructured pruning on random weights. We also report the results on a variant of SNIP which starts from pretrained weights (SNIP (P)). For finetuning, we follow the standard setting from Zhuang et al. Zhuang et al. (2018) and perform mini-batch SGD training for 200 epochs where the batch size is set to 128.

Networks and datasets

As for the base networks, we use a modified version of VGGNet-19 with 16 convolution layers and a single fully connected layer, and ResNet-18 with an additional convolution layer on the shortcut operation to resolve the dimensionality difference between pruned units/filters. We use VGG-19 and ResNet-18 trained on CIFAR-100 as the global reference network, and use CIFAR-10 and SVHN as the target tasks for evaluation of the pruning performance.


We meta-train our pruning mask generator on CIFAR-100 dataset. During meta-training time, we divide CIFAR-100 into 10 tasks (subsets), each of which contains 10 disjoint classes, and sampled 64 instances per class. We used total of 640 instances as the input to the set function to generate a set representation for each task. We also used the sampled instances for model training, by dividing it into 5 batches (128 instances for each). We used first-order MAML with Adam optimizer for both inner and outer parameter updates.

For more details on training of the baseline methods and meta-training for STAMP, such as learning rate scheduling, please see the Appendix (Section C).

4.1 Quantitative Evaluation

We report the results of pruning VGGNet-19 on CIFAR-10 and SVHN in Table 1, and ResNet-18 on CIFAR-10 in Table 2. We compare the accuracy as well as wall-clock training and inference time for all models at similar compression rate (Parameter Used or FLOPs).

Methods Accuracy (%) P (%) FLOPs





Full Network 93.72 0.07 100 x1.00 0.78 h 0.85 sec 1.13 $

SNIP (P) Lee et al. (2019c) 92.98 0.22 4.17 x1.00 0.83 h 0.92 sec 1.21 $
SNIP Lee et al. (2019c) 92.85 0.24 4.17 x1.00 0.83 h 0.92 sec 1.21 $

Random Pruning 92.01 0.29 32.20 x3.33 0.43 h 0.42 sec 0.62 $
CIFAR-10 MetaPruning Liu et al. (2019b) 92.12 0.47 21.84 x3.58 4.99 h 0.41 sec 7.28 $

BBDropout Lee et al. (2019a) 92.97 0.10 3.99 x3.42 2.07 h 0.43 sec 3.02 $

STAMP-Structure 92.69 0.13 4.43 x3.48 0.44 h 0.36 sec 0.64 $
STAMP 93.49 0.04 4.16 x3.56 0.44 h 0.36 sec 0.64 $

Full Network 95.99 0.07 100 x1.00 1.21 h 2.42 sec 1.76 $

SNIP (P) Lee et al. (2019c) 95.56 0.09 3.08 x1.00 1.22 h 2.45 sec 1.78 $
SNIP Lee et al. (2019c) 95.52 0.10 3.08 x1.00 1.22 h 2.45 sec 1.78 $
2-8 Random Pruning 95.56 0.12 28.95 x3.40 0.62 h 1.27 sec 0.90 $

MetaPruning Liu et al. (2019b) 95.50 0.07 22.04 x3.64 2.08 h 1.44 sec 3.03 $

BBDropout Lee et al. (2019a) 95.98 0.19 2.15 x9.67 3.05 h 0.86 sec 4.45 $

STAMP-Structure 95.39 0.15 3.08 x4.60 0.58 h 0.91 sec 0.84 $
STAMP 95.82 0.16 2.87 x5.10 0.58 h 0.91 sec 0.84 $

Table 1: Experiment results of CIFAR-10 and SVHN on VGGNet. Training Time consists of time to search for pruned network, and finetuning (200 epochs). Expense is computed by multiplying the training time by 1.46 $, which is cost of using GPU (Tesla P100) on Google Cloud. The methods are sub-divided into the full network without pruning, unstructured pruning methods, structured pruning methods and STAMP. P is a remaining parameter ratio. We run each experiments 3 times and report the mean std values.

Accuracy over memory efficiency and FLOPs.

We first compare the accuracy over the parameter usage and theoretical computation cost, FLOPs. In Table 1 and Table 2, SNIP with either random networks (SNIP) or the pretrained reference network (SNIP(P)) significantly reduce the number of activated parameters with a marginal drop of the accuracy for both CIFAR-10 and SVHN dataset. However, as the methods perform unstructured pruning, they can not reduce FLOPs which remains equal to the original full networks. On the other hand, structural pruning approaches show actual FLOPs reduction by pruning a group of weights (e.g. units/filters). Interestingly, MetaPruning, which applies a learned hypernetwork on a reference architecture and dataset to prune for the target dataset, obtains suboptimal architectures which sometimes even underperforms randomly pruned networks. This shows that the learned hypernetwork does not generalize across task domains, which is expected since it is not trained with diverse tasks. BBDropout achieves superior performance over other baselines with high model compression rate, but it requires large amount of training time to train the pruning mask generator, and thus slows down the training process over training of the full network. On the other hand, our STAMP either outperforms or achieve comparable performance to all baselines, in terms of both accuracy and compression rate. We further report the accuracy-sparsity trade-off for SNIP, BBDropout, and STAMP (Ours) in Figure 3 (a). Our method achieves better accuracy over similar compression rates, and shows marginal performance degeneration even with of the parameters remaining. Such good performance on unseen dataset is made possible by meta-learning the pruning mask generator.

Accuracy over wall-clock time for training/inference.

As described earlier, our main focus in this work is to significantly reduce the training time by obtaining a near-optimal compact deep networks for unseen targets on the fly, which is not possible with any of the existing approaches. As shown in Table 1 and Table 2, unstructured random weights pruning methods (SNIP and Edge-Popup) do not results in any speedups in training time, and sometimes increases the cost of training over the full networks (See CIFAR-10 results in Table 1). These results are consistent with the findings in Frankle and Carbin Frankle and Carbin (2019), which showed that most of the subnetworks require larget number of training iterations over the full network.

While structured pruning methods yield speedups in inference time over the full networks, MetaPruning and BBDropout need and more training time than full networks to search pruned architectures, respectively. On the contrary, STAMP instantly obtains a good subnetwork (single or less than iterations according to the pruned ratio), which trains faster than the full network. STAMP is remarkably efficient over other structural pruning baselines, achieving and speedups over MetaPruning and BBDropout, respectively, with better or comparable performance. We further report the accuracy over training time for SNIP, BBdropout, and STAMP (Ours) in Figure 3 (b) and (c). Since our philosophy is train-once, and use-everywhere, once the mask generator is meta-learned on a reference dataset, it can be applied to any number of tasks without additional cost. Thus we excluded the meta training time of STAMP(15h on VGGNet and 30h on ResNet) and MetaPruning (1.2h) per task in Table 1 and Table 2.

(a) Acc. over sparsity (CIFAR-10) (b) Acc. over time (CIFAR-10) (c) Acc. over time (SVHN)
Figure 3: (a): Accuracy over the ratio of used parameters for CIFAR-10 on VGGNet. Full denotes the accuracy of the VGGNet before pruning. (b-c): Accuracy over training time for CIFAR-10 and SVHN.

Accuracy (%) P (%) FLOPs Training Time Inference time Expense

Full Network
94.37 0.12 100 x1.00 1.08 h 1.02 sec 1.57 $
Edge-Popup Ramanujan et al. (2019) 89.50 3.46 10.00 x1.00 1.38h 2.50 sec 2.01 $

SNIP (P) Lee et al. (2019c)
93.17 0.00 10.04 x1.00 1.71 h 1.90 sec 2.49 $

SNIP Lee et al. (2019c)
93.11 0.00 10.04 x1.00 1.71 h 1.90 sec 2.49 $

Random Pruned
91.95 0.65 69.77 x3.65 0.58 h 0.58 sec 0.84 $

MetaPruning Liu et al. (2019b)
91.01 0.91 66.02 x4.09 3.80 h 0.58 sec 5.54 $

BBDropout Lee et al. (2019a)
93.47 0.14 5.94 x4.11 2.17 h 0.54 sec 3.16 $

93.63 0.08 9.07 x4.08 0.57 h 0.54 sec 0.83 $
STAMP 93.61 0.27 9.22 x4.29 0.57 h 0.54 sec 0.83 $

Table 2: Experiment results of CIFAR-10 on ResNet18. Details are same with the Table 1.
      Accuracy (%) D Size FULL SNIP BBD STAMP 50K 93.68 92.92 92.66 93.34 25K 90.69 90.21 89.53 90.86 10K 85.77 85.24 84.54 86.70 5K 79.93 78.99 77.77 82.53 1K 63.63 60.34 59.55 69.26
Figure 4: Left: Training time over the number of training instances. Middle: Accuracy over the number of training instances, Right: Accuracy over the number of instances. All experimental results are obtained on CIFAR-10 with VGG-19. We compare our STAMP against Full Network (FULL), SNIP Lee et al. (2019c), and BBDropout Lee et al. (2019a) (BBD). The sparsity is matched to prune of the parameters for STAMP and SNIP.

Data size of the target tasks.

We further examine the accuracy and time-efficiency of subnetworks obtained using different pruning methods on various problem size. We previously observed that STAMP can yield larger saving in the training and inference time as the network size gets larger (ReNet-18, Table 2). Another factor that defines the problem size is the number of instances in the unseen target dataset. We used subsets of CIFAR-10 to explore the effect of the task size to training time and accuracy in Figure 4. The full dataset consists of 50K images, which corresponds to the results reported in Table 1. We observe that, as the number of instances used for training increases, STAMP obtains even larger saving in the training time, while BBDropout incurs increasingly larger time to train. Further, as the number of instances used for training becomes smaller, STAMP obtains larger gains in accuracy, even outperforming the full network, since the network will become relatively overparameterized as the number of training data becomes smaller. As another comparison with structural pruning method with learned masks, when using only 1K data instances for training, BBDropout finds the subnetwork attaining of parameters of the full network with FLOP speedup, while STAMP prunes out of the parameters, resulting in speedup in FLOPs. This is because BBDropout learns the pruning mask on the given target task, and thus overfits when the number of training instances is small. STAMP, on the other hand, does not overfit since it mostly relies on the meta-knowledge and take only few gardient steps for the given task.

4.2 Qualitative Analysis

Pruned network structures.

(a) Layer 0-3 (b) Layer 4-15
Figure 5: We denote the indices of the remaining channels at each convolution layer of VGGNet after pruning on two different datasets, CIFAR-10 and SVHN. Both of them start from a same meta-learned parameters .

We further show the effect of task-adaptive pruning, which STAMP will find different compressed subnetwork for different tasks. In other words, the pruning ratio and the pruned channels at each layer will be different for each dataset. We visualize the remaining channels of each convolution layer of VGGNet on CIFAR-10 and SVHN in Figure 5

. Note that unlike existing channel pruning methods, such as MetaPruning, we do not enforce any constraints (hyperparameters) on how much to prune, or what layer to prune since they are automatically determined by STAMP.

5 Conclusion

We proposed a novel set-based task-adaptive structural pruning method which instantly generates a near-optimal compact network for the given task, by performing rapid structural pruning of a global reference network trained on a large dataset. This is done by meta-learning a pruning mask generator over multiple subsets of the reference dataset as a function of a dataset, such that it can generate a pruning mask on the reference network for any unseen tasks. Our model, STAMP obtains a compact network that not only obtains good performance with large reduction in the memory and computation cost at inference time, but also enables training time speedup which was not possible with previous methods. Further analysis showed that STAMP obtains larger performance gains when the target dataste is small, and prunes out the channels of the same reference network differently for each dataset. We believe that both the proposal of a framework that can obtain optimal compact network for unseen tasks, and achieving training time speedup are important contributions that enhances efficiency and practicality of pruning methods.

6 Broader Impact

Every day, a tremendous amount of computing resources are used for training deep neural networks, and searching for the optimal architecture for the given task either manually or by automatic search with neural architecture search (NAS). Our method can significantly reduce the time and energy spent for both architecture search and training.

  • Significant reduction in the architecture search cost. By instantly generating the optimal network architecture (which is a subnetwork of a reference network) for a given dataset, our method allows to greatly reduce the architecture search time for each individual task.

  • Significant reduction in the training cost. Our method largely reduces the training time required to train the network for a given task, as the obtained subnetwork will lead to actual saving in memory, computation, and wall-clock time required for training.

Such reduction in both architecture search and training time will allow monetary savings and minimize energy consumption, making deep learning more affordable to service providers and end-users that cannot manage the large cost of running the model on cloud or GPU clusters.


  • J. L. Ba, J. R. Kiros, and G. E. Hinton (2016) Layer normalization. arXiv preprint arXiv:1607.06450. Cited by: §B.1.
  • B. Dai, C. Zhu, and D. Wipf (2018) Compressing neural networks using the variational information bottleneck. Cited by: §2.
  • H. Edwards and A. Storkey (2016) Towards a neural statistician. arXiv preprint arXiv:1606.02185. Cited by: §3.3.
  • C. Finn, P. Abbeel, and S. Levine (2017) Model-agnostic meta-learning for fast adaptation of deep networks. In Proceedings of the International Conference on Machine Learning (ICML), Cited by: §2, §3.3.
  • J. Frankle and M. Carbin (2019) The lottery ticket hypothesis: finding sparse, trainable neural networks. In Proceedings of the International Conference on Learning Representations (ICLR), Cited by: §2, §4.1.
  • J. Guo, W. Ouyang, and D. Xu (2020) Channel pruning guided by classification loss and feature importance. arXiv preprint arXiv:2003.06757. Cited by: §2.
  • S. Han, H. Mao, and W. J. Dally (2016) Deep compression: compressing deep neural networks with pruning, trained quantization and huffman coding. In Proceedings of the International Conference on Learning Representations (ICLR), Cited by: §1.
  • S. Han, J. Pool, J. Tran, and W. Dally (2015) Learning both weights and connections for efficient neural network. In Advances in neural information processing systems, pp. 1135–1143. Cited by: §1, §2.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    pp. 770–778. Cited by: §4.
  • Y. He, X. Zhang, and J. Sun (2017) Channel pruning for accelerating very deep neural networks. In Proceedings of the IEEE International Conference on Computer Vision, pp. 1389–1397. Cited by: §2.
  • G. Hinton, O. Vinyals, and J. Dean (2014) Distilling the knowledge in a neural network. In NIPS Deep Learning Workshop, Cited by: §1.
  • Z. Huang, X. Wang, L. Huang, C. Huang, Y. Wei, and W. Liu (2019) Ccnet: criss-cross attention for semantic segmentation. In Proceedings of the International Conference on Computer Vision (ICCV), Cited by: §1.
  • Z. Hui, X. Wang, and X. Gao (2018)

    Fast and accurate single image super-resolution via information distillation network

    In Proceedings of the IEEE International Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §1.
  • S. Jung, C. Son, S. Lee, J. Son, J. Han, Y. Kwak, S. J. Hwang, and C. Choi (2019) Learning to quantize deep networks by optimizing quantization intervals with task loss. In Proceedings of the IEEE International Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §1.
  • A. Krizhevsky, I. Sutskever, and G. E. Hinton (2012) Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, pp. 1097–1105. Cited by: §1.
  • J. Lee, S. Kim, J. Yoon, H. B. Lee, E. Yang, and S. J. Hwang (2019a) Adaptive network sparsification via dependent variational beta-bernoulli dropout. External Links: Link Cited by: §B.2, Figure 6, §C.1, Table 3, §1, §2, §3.3, Figure 4, §4, Table 1, Table 2.
  • J. Lee, Y. Lee, J. Kim, A. R. Kosiorek, S. Choi, and Y. W. Teh (2019b) Set transformer. Proceedings of the International Conference on Machine Learning (ICML). Cited by: §B.1, §3.3.
  • N. Lee, T. Ajanthan, and P. H. Torr (2019c) Snip: single-shot network pruning based on connection sensitivity. In Proceedings of the International Conference on Learning Representations (ICLR), Cited by: §C.1, §C.1, Table 3, §1, §2, Figure 4, §4, Table 1, Table 2.
  • G. Lin, A. Milan, C. Shen, and I. Reid (2017a) Refinenet: multi-path refinement networks for high-resolution semantic segmentation. In Proceedings of the IEEE International Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §1.
  • T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan, and S. Belongie (2017b) Feature pyramid networks for object detection. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2117–2125. Cited by: §1.
  • H. Liu, K. Simonyan, and Y. Yang (2019a) Darts: differentiable architecture search. In Proceedings of the International Conference on Learning Representations (ICLR), Cited by: §2.
  • L. Liu, W. Ouyang, X. Wang, P. Fieguth, J. Chen, X. Liu, and M. Pietikäinen (2020) Deep learning for generic object detection: a survey. International journal of computer vision 128 (2), pp. 261–318. Cited by: §1.
  • Z. Liu, H. Mu, X. Zhang, Z. Guo, X. Yang, K. Cheng, and J. Sun (2019b) Metapruning: meta learning for automatic neural network channel pruning. In Proceedings of the International Conference on Computer Vision (ICCV), Cited by: §C.1, §C.1, Table 3, §2, §2, §4, Table 1, Table 2.
  • Z. Liu, J. Li, Z. Shen, G. Huang, S. Yan, and C. Zhang (2017) Learning efficient convolutional networks through network slimming. In Proceedings of the International Conference on Computer Vision (ICCV), Cited by: §C.2, §2.
  • J. Luo, J. Wu, and W. Lin (2017) Thinet: a filter level pruning method for deep neural network compression. In Proceedings of the IEEE international conference on computer vision, pp. 5058–5066. Cited by: §2.
  • E. Nalisnick and P. Smyth (2017)

    Stick-breaking variational autoencoders

    In Proceedings of the International Conference on Learning Representations (ICLR), Cited by: §B.2.
  • A. Nichol and J. Schulman (2018) Reptile: a scalable metalearning algorithm. arXiv preprint arXiv:1803.02999 2, pp. 2. Cited by: §2.
  • V. Ramanujan, M. Wortsman, A. Kembhavi, A. Farhadi, and M. Rastegari (2019) What’s hidden in a randomly weighted neural network?. In arXiv preprint arXiv:1911.13299, Cited by: Table 3, §1, §2, §4, Table 2.
  • W. Rawat and Z. Wang (2017) Deep convolutional neural networks for image classification: a comprehensive review. Neural computation 29 (9), pp. 2352–2449. Cited by: §1.
  • A. Shaw, W. Wei, W. Liu, L. Song, and B. Dai (2019) Meta architecture search. In Advances in Neural Information Processing Systems (NIPS), Cited by: §2.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin (2017) Attention is all you need. In Advances in Neural Information Processing Systems (NIPS), Cited by: §B.1.
  • F. Wang, M. Jiang, C. Qian, S. Yang, C. Li, H. Zhang, X. Wang, and X. Tang (2017) Residual attention network for image classification. In Proceedings of the IEEE International Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §1.
  • W. Wen, C. Wu, Y. Wang, Y. Chen, and H. Li (2016) Learning structured sparsity in deep neural networks. In Advances in Neural Information Processing Systems (NIPS), Cited by: §1, §2.
  • S. Xie, H. Zheng, C. Liu, and L. Lin (2018) SNAS: stochastic neural architecture search. arXiv preprint arXiv:1812.09926. Cited by: §2.
  • J. Yoon and S. J. Hwang (2017) Combined group and exclusive sparsity for deep neural networks. In Proceedings of the International Conference on Machine Learning (ICML), Cited by: §1, §2.
  • M. Zaheer, S. Kottur, S. Ravanbakhsh, B. Poczos, R. R. Salakhutdinov, and A. J. Smola (2017) Deep sets. In Advances in Neural Information Processing Systems (NIPS), Cited by: §3.3.
  • Z. Zhuang, M. Tan, B. Zhuang, J. Liu, Y. Guo, Q. Wu, J. Huang, and J. Zhu (2018) Discrimination-aware channel pruning for deep neural networks. In Advances in Neural Information Processing Systems, pp. 875–886. Cited by: §2, §4, §4.

Appendix A Appendix


The supplementary file is organized as follows: We first describe each component of our Set-based Task-Adaptive Meta-Pruning (STAMP) in detail, including the set encoding function and the set-based structural pruning method (mask generative function) in Section B. Then, in Section C, we provide the detailed experimental settings and additional results on SVHN using ResNet-18 as the backbone network.

Appendix B Structural Binary Mask Generation with a Set-encoded Representation

We now describe how we obtain the set representation with and learn structural pruning masks with the set-based mask generative function introduced in Section 3.2.

b.1 Set Encoding Function

To obtain an optimally pruned neural architecture for the target task, we need to exploit the knowledge of the given task. Conventional pruning schemes search for the desired subnetworks through full mini-batch training, which will incur excessive training cost when the data size is large. On the other hand, we rapidly and precisely obtain the desired pruned structures given the encoded reprsentations for each dataset . This procedure of obtaining a set representation with the set encoder parameterized by is given as follows:


where from is a sampled task set, is the sampled batch, is the input dimensionality, and is the batch dimension of the set representation (). We then define the set function as a stacked function of the encoder and the decoder where is an arbitrary encoder function parameterized by and is a decoder function parameterized by . The encoder encodes the sampled task set and the decoder regenerates the encoded vector to the dataset-level representation. Throughout the paper, we use as the dimension of the set representation. We adopt a transformer module [17] for set encoding, which is a learnable pooling neural network module with an attention mechanism as shown below:


where rF is a fully connected layer, AE is an attention-based block [31], and is a Pooling by Multihead Attention with seed vectors [17]. AE is a permutation equivariant block that constructs upon Multi-head attention (MH), introduced in Transformer [31]. In other words, AE encodes the set information and consists , while also includes AE to model correlation between vectors after pooling. AE is defined as below:


where Norm is layer normalization [1]. The encoder encodes the given dataset using the above module, and the decoder aggregates the encoded vector. The full encoding-decoding process can be described as follows:


In here, pooling is done by applying multihead attention on a learnable vector . We set in the experiments to obtain a single set representation vector. By stacking these attention based permutation equivaraint functions, we can obtain the set representation from the sampled task .

b.2 Mask Generative Function

We now describe the mask generation function at layer , from which we obtain the pruned model parameter . Similarly as in Lee et al. [16], we use the following sparsity-inducing beta-Bernoulli prior to generate a binary pruning mask at each layer, which follows Bernoulli distribution, Bernoulli, given the probability of parameterized beta distribution as follows:


where is the number of channels in layer . With a learnable parameter for the beta distribution, the model learns the optimal binary masks from a randomly sampled value from the beta distribution, to determine which neurons/channels should be pruned. We extend this input-independent pruning method to sample binary masks based on the set representation of the target task. This set-dependent pruning with STAMP is different from data-dependent BBDropout [16] in that the former generates a mask per dataset while the altter generates a mask per instance, which makes it difficult to globally eliminate a channel. Furthermore, rather than searching for the compressed structure by training with mini-batch SGD at each iteration, we utilize a set representation to rapidly obtain a near-optimal subnetwork within a few gradient steps. With the set representation obtained from the given dataset X, we calculate the activation for each layer , where is the function of the layer (i.e. convolution) and . We omit the layer notation for readability in the following equations. Then, we sample a structural pruning mask vector m as follows:


where and are learnable scaling and shifting factors and Pool is the average pooling for o which obtains a representative value for each channel. The clamping function is defined as with a small . Using a clamping function, the network will retain only the meaningful channels. We employ variational inference to approximate sparsity inducing posterior . The KL-divergence term for our set-based task-adaptive pruning is as follows:



is a fixed value for a variance of the shifting factor

to prevent from drifting away. The first term can be computed analytically to obtain a closed form solution [16, 26]. Also, we can easily compute the second term, in the objective function of STAMP (Equation  2) by updating it with gradient-based methods.

We can further approximate the expectation for the prediction of given dataset as follows:


Appendix C Experiments

c.1 Experimental Settings

We first describe how we meta-train STAMP and set the settings for the baselines , SNIP [18] and MetaPruning [23], for the experiments in the main paper (VGGNet and ResNet-18 on two benchmark CIFAR-10 and SVHN).

For STAMP, in function STAMP in Algorithm 1, we update with the learning rate 0.001, 0.01, and 0.001 with Adam optimizer, while decreasing the learning rate by 0.1 at 50% and 80% of the total epoch, following the settings of BBDropout [16]. For Algorithm 1, we select for VGGNet and for ResNet-18. We sampled the same number of instances per class. We further set and the size of the minibatch as . When pruning with STAMP, we use the same learning rate as the one we use in the meta training stage for VGGNet. However, for ResNet-18, we set the learning rates as to adjust the pruning rate.

For SNIP [18], in the ResNet-18 experiment, we do not prune the convolution layer to match the settings for STAMP experiments. Additionally, we modify the learning rate to 0.01, since at the learning rate of , SNIP (P) and SNIP obtained lower accuracies (88.51% and 85.26% respectively). For VGGNet, we prune the weights of 16 convolution layers. For SNIP (P) we load the pretrained weights on CIFAR-100 before pruning.

For MetaPruning [23], we used the same settings for ResNet-18 and ResNet-50 experiments. For VGGNet, we prune filters of 16 convolution layers which is the same as STAMP. At the search phase, we search for the architecture under given FLOP constraints. We set the pruning ratio at each layer between 20 % to 60 %, to satisfy the FLOP constraints, which is 40 % to 80 % for the given setting. For the rest of the experimental settings, we followed the settings in [23].

c.2 Experimental Results

We report the experimental results on SVHN with ResNet18 in Table 3, which was omitted from the main paper due to the page limit. We followed the settings of Liu et al. [24] and trained on SVHN for 20 epochs. All other settings are kept the same as the experimental setting in the previous paragraph. The results show that STAMP  has the best trade-off between the accuracy and the efficiency.

Accuracy (%) P (%) FLOPs Training Time Inference time Expense

Full Network
94.57 0.01 100 x1.00 0.16 h 3.30 sec 0.24 $
Edge-Popup [28] 92.61 0.01 5.00 x1.00 0.20 h 6.15 sec 0.29 $

SNIP (P) [18]
95.38 0.01 6.06 x1.00 0.35 h 6.64 sec 0.51 $

SNIP [18]
94.88 0.01 6.06 x1.00 0.35 h 6.64 sec 0.51 $

Random Pruned
94.39 0.23 72.17 x2.99 0.08 h 1.66 sec 0.12 $

MetaPruning [23]
94.49 0.19 70.99 x2.83 2.41 h 1.68 sec 3.51 $

BBDropout [16]
94.32 0.02 4.90 x5.25 0.31 h 1.52 sec 0.46 $

95.17 0.01 4.81 x5.47 0.11 h 1.51 sec 0.16 $
STAMP 95.41 0.01 4.81 x5.47 0.11 h 1.51 sec 0.16 $

Table 3: Experiment results of SVHN on ResNet-18. Training Time consists of time to search for the pruned network and finetuning (200 epochs). Expense is computed by multiplying the training time by 1.46 $, which is the cost of using GPU (Tesla P100) on Google Cloud. The methods are sub-divided into the full network without pruning, unstructured pruning methods, structured pruning methods, and STAMP (STAMP-Structure is a variation of STAMP, which re-initializes the pruned architecture). P is the remaining parameter ratio. We run each experiment 3 times and report the mean std values.
Methods Accuracy (%) P (%) FLOPs Training Time Full Network 94.57 100 x1.00 0.16 h BBD (=15) 94.70 11.09 x2.86 0.31 h BBD (=20) 94.30 4.86 x5.14 0.31 h BBD (=25) 94.12 2.71 x8.62 0.31 h STAMP (=1, =15) 95.44 7.83 x4.26 0.11 h STAMP (=5) 95.73 3.84 x6.39 0.19 h STAMP (=10) 95.77 1.81 x8.13 0.31 h
Figure 6: Left: Accuracy over the ratio of used parameters for SVHN on ResNet-18. Full denotes the accuracy of the ResNet-18 before pruning. Right: Exploring different (the number of epochs of pruning stage) of  STAMP compared with BBDropout [16] (BBD). is a scale factor for the regularization term.

STAMP obtains higher accuracy over BBDropout at the same compression rate as shown in Figure 6). Further, when trained for larger number of epochs, STAMP can obtain even higher accuracy and larger compression rate over BBD as shown in Figure 6, outperforming all baselines in Table 3. Although training STAMP for longer epochs yields slightly higher training time than the time required to train the full network, STAMP still trains faster than BBdropout.