Streamlining Tensor and Network Pruning in PyTorch

by   Michela Paganini, et al.
Brown University

In order to contrast the explosion in size of state-of-the-art machine learning models that can be attributed to the empirical advantages of over-parametrization, and due to the necessity of deploying fast, sustainable, and private on-device models on resource-constrained devices, the community has focused on techniques such as pruning, quantization, and distillation as central strategies for model compression. Towards the goal of facilitating the adoption of a common interface for neural network pruning in PyTorch, this contribution describes the recent addition of the PyTorch torch.nn.utils.prune module, which provides shared, open source pruning functionalities to lower the technical implementation barrier to reducing model size and capacity before, during, and/or after training. We present the module's user interface, elucidate implementation details, illustrate example usage, and suggest ways to extend the contributed functionalities to new pruning methods.


Successive Pruning for Model Compression via Rate Distortion Theory

Neural network (NN) compression has become essential to enable deploying...

What is the State of Neural Network Pruning?

Neural network pruning—the task of reducing the size of a network by rem...

KD-Lib: A PyTorch library for Knowledge Distillation, Pruning and Quantization

In recent years, the growing size of neural networks has led to a vast a...

Pruning Algorithms to Accelerate Convolutional Neural Networks for Edge Applications: A Survey

With the general trend of increasing Convolutional Neural Network (CNN) ...

Stealthy Backdoors as Compression Artifacts

In a backdoor attack on a machine learning model, an adversary produces ...

Paoding: Supervised Robustness-preserving Data-free Neural Network Pruning

When deploying pre-trained neural network models in real-world applicati...

HideNseek: Federated Lottery Ticket via Server-side Pruning and Sign Supermask

Federated learning alleviates the privacy risk in distributed learning b...

1 Introduction

State-of-the-art deep learning techniques rely on over-parametrized models that are hard to deploy. On the contrary, biological neural networks are known to use efficient sparse connectivity 

(Sporns, 2007; Ramón y Cajal, 1909). Identifying optimal techniques to compress models by reducing the number of parameters is important in order to reduce memory, battery, and hardware consumption without sacrificing accuracy, to deploy lightweight models on device in mobile, IoT, and AR/VR systems, with an eye towards download bandwidth, data consumption, and heat dissipation, and to guarantee privacy with private on-device computation (Nazemi et al., 2018; Yang et al., 2017; Iandola and others, 2016; Han et al., 2015; Reagen et al., 2016; Kim and others, 2015; Verhelst and Moons, 2017). Real-time applications that require reduced latency, meteorological models that, similar to personalization models, require frequent retraining to capture the latest trends, as well as applications with targeted deployment for custom ASICs, may also benefit from the sparsification of models for train and inference speed concerns. Furthermore, the growth in model size has contributed to making reproducing and building upon state-of-the-art techniques only accessible to few, with severe inequalities at a geographical and socioeconomic level. Environmental concerns around the cost and carbon footprint of training large-scale models have been documented in Strubell et al. (2019); Schwartz and others (2019); Henderson and others (2020).

Pruning provides ways to remove unnecessary structure in neural networks, thus beginning to address some of the concerns above. Different ways of identifying superfluous portions of a model result in different pruning techniques. These may vary along several axes, including: the nature of the entities to prune (connections, nodes, channels, layers, etc.), the choice of proxy for importance of each entity (weight, activation, gradient, etc.), when to compute the chosen quantity, the group of entities to pool for comparison (all units in the same layer, the whole network, etc.), when to prune (during, before, or after training), whether pruned entities are forever pruned or can be reinstated, whether to apply hard (binary masks) or soft pruning, iterative or one-shot pruning, and what to do with the network once it is pruned (finetuning, reinitializing, rewinding, etc.).

With the addition of the torch.nn.utils.prune module111Available at, PyTorch (Paszke and others, 2019) users may now scan over various choices of pruning techniques as easily as any other choice of hyper-parameter and building block in their machine learning pipeline. At the same time, this module aims at empowering researchers to contribute new pruning techniques and express them through a common language.

The goal of this contribution is knowledge dissemination, among the relevant community, of available tools that can simplify and power both research and deployment in resource-constrained scenarios.

2 Implementation

BasePruningMethod is an abstract base class that provides a skeleton for all pruning techniques and implements shared functionalities. It enables the creation of new pruning techniques by requiring the overriding of methods such as compute_mask. All pruning methods in Sec. 2.1 are derived classes that inherit from it.

The core logic for the application of pruning to a tensor within a module is contained in the the class method apply. Specifically, pruning acts by removing the specified parameter from the parameters list and replacing it with a new parameter whose name equals the original one with the string "_orig" appended to it. This stores the unpruned version of the tensor. The pruning mask generated by the pruning technique is saved as a module buffer whose name equals the original parameter’s name with the string "_mask" appended to it. Once the reparametrization is in place, an attribute with the original tensor’s name, needed by the forward method, is created as a multiplication of the original tensor and the mask by the apply_mask method. Finally, the pruning method is attached to the module via a forward_pre_hook to ensure that the multiplication op (computed on the fly each time upon calling an instance of the pruning class via __call__) is registered into the forward and backward graphs whenever the module is used to compute an output given an input (Fig. 1). The function torch.nn.utils.prune.is_pruned returns information about the presence of any pruning hook associated with a module. All relevant tensors, including the mask buffers and the original parameters used to compute the pruned tensors are stored in the model’s state_dict and can therefore be easily serialized and saved.

In case the reparametrization and hook creation fail, the related exception is raised and the module is rolled back to its state prior to the failed pruning attempt, without compromising its usability.

Figure 1: Reparametrization of a tensor in terms of its unpruned version and the computed mask, and its usage in the forward pass.

To remove the reparametrization depicted in Fig. 1 and make the pruning permanent, the user can call the function torch.nn.utils.prune.remove, which removes buffers, hooks, and additional attributes, and assigns the pruned tensor to the parameter with the original parameter name. Note that pruning itself is not undone or reversed by this operation.

A PruningContainer is used to store the history of pruning calls executed on a module in order to enable iterative pruning, i.e. the sequential application of pruning techniques. Each parameter in a module that is pruned more than once has an associated PruningContainer; that container has an instance attribute called _tensor_name that identifies which parameter in the module it relates to. Only pruning methods that act on the same tensor can be added to the container. The tuple attribute _pruning_methods stores the instances of pruning techniques in the order they are applied.

Each pruning method has a PRUNING_TYPE

that dictates how to combine iterative pruning masks. At the moment, this supports the following types: unstructured, structured, and global. An unstructured pruning technique disregards individual entries in a tensor that have already been pruned; a structured pruning technique disregards a row or column only if all its entries have already been pruned; a global pruning method, such a

torch.nn.utils.prune.CustomFromMask in Sec. 2.1, applies the pruning technique to all entries, regardless of whether they have been pruned before. These code paths are defined in the _combine_masks inner utility function within the PruningContainer’s compute_mask method. For most other intents and purposes, a container works like any other derived class of a BasePruningMethod.

The torch.nn.utils.prune module also provides a simple and clean functional interface that allows users to interact with pruning techniques through intuitive function calls on modules’ parameters identified by name. See Sec. 3 for example usage.

2.1 Available Pruning Methods

The following child classes inherit from the BasePruningMethod:

  • torch.nn.utils.prune.Identity: utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones;

  • torch.nn.utils.prune.RandomUnstructured: prune (currently unpruned) entries in a tensor at random;

  • torch.nn.utils.prune.L1Unstructured: prune (currently unpruned) entries in a tensor by zeroing out the ones with the lowest absolute magnitude;

  • torch.nn.utils.prune.RandomStructured: prune entire (currently unpruned) rows or columns in a tensor at random;

  • torch.nn.utils.prune.LnStructured: prune entire (currently unpruned) rows or columns in a tensor based on their -norm (supported values of correspond to supported values for argument in torch.norm());

  • torch.nn.utils.prune.CustomFromMask: prune a tensor using a user-provided mask.

Their functional equivalents are:

  • torch.nn.utils.prune.identity

  • torch.nn.utils.prune.random_unstructured

  • torch.nn.utils.prune.l1_unstructured

  • torch.nn.utils.prune.random_structured

  • torch.nn.utils.prune.ln_structured

  • torch.nn.utils.prune.custom_from_mask

Global pruning, in which entries are compared across multiple tensors, is enabled through torch.nn.utils.prune.global_unstructured (see Sec. 3 for example usage).

2.2 Extending the Module

The torch.nn.utils.prune module can be extended to implement custom pruning functions. This requires sub-classing the BasePruningMethod base class, and implementing the __init__ and compute_mask methods, i.e., respectively, the constructor and the instructions to compute the mask for the given tensor according to the logic of the pruning technique. The nature of the pruning technique is specified through the assignment of a PRUNING_TYPE. If none of the currently supported types fits the new pruning technique, the user will also have to add support for a new PRUNING_TYPE in the way PruningContainer handles the iterative application of pruning masks.

For example, Listing 1 demonstrates how to implement a pruning technique that prunes every other currently unpruned entry in a tensor, and how to provide a convenient functional interface for the method in just a few lines of code.

1import torch.nn.utils.prune as prune
3class FooBarPruningMethod(prune.BasePruningMethod):
4    """Prune every other entry in a tensor
5    """
6    PRUNING_TYPE = ’unstructured’
8    def compute_mask(self, t, default_mask):
9        mask = default_mask.clone()
10        mask.view(-1)[::2] = 0
11        return mask
13def foobar_unstructured(module, name):
14    """Prunes tensor corresponding to parameter called ‘name‘ in ‘module‘
15    by removing every other entry in the tensors.
16    Modifies module in place (and also return the modified module)
17    by:
18    1) adding a named buffer called ‘name+’_mask’‘ corresponding to the
19    binary mask applied to the parameter ‘name‘ by the pruning method.
20    The parameter ‘name‘ is replaced by its pruned version, while the
21    original (unpruned) parameter is stored in a new parameter named
22    ‘name+’_orig’‘.
23    Args:
24        module (nn.Module): module containing the tensor to prune
25        name (string): parameter name within ‘module‘ on which pruning
26                will act.
27    Returns:
28        module (nn.Module): modified (i.e. pruned) version of the input
29            module
31    Examples:
32        >>> m = nn.Linear(3, 4)
33        >>> foobar_unstructured(m, name=’bias’)
34    """
35    FooBarPruningMethod.apply(module, name)
36    return module
Listing 1: Simple extension to the torch.nn.utils.prune module that implements a custom pruning function to prune every other unpruned entry in a parameter.

3 Example Usage

Pruning parameters in a model is as simple as invoking the desired pruning function on each parameter (identified by name) within a given neural network module. In this example, the first convolutional layer in a VGG-11 architecture (Simonyan and Zisserman, 2014) is first pruned by removing 3 individual entries using unstructured magnitude-based pruning, then pruned again by removing the bottom 50% of remaining channels along the 0 axis by -norm.

1import torch
2from torch import nn
3import torch.nn.functional as F
4from torchvision.models import vgg11
6device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7model = vgg11().to(device=device)
9# Prune the first convolutional layer
10prune.l1_unstructured(model.features[0], name="weight", amount=3)
12# Iteratively prune by simply calling another pruning function on the same parameter.
13prune.ln_structured(model.features[0], name="weight", amount=0.5, n=2, dim=0)
Listing 2: Simple iterative pruning of a single parameter in a network.

This can be easily extended to apply pruning to all layers in a network. For instance, Listing 3 shows how to prune all 2D convolutional and linear layers in a network, with different pruning fractions that depend on the layer type.

1new_model = vgg11().to(device=device)
2for name, module in new_model.named_modules():
3    # Prune 20% of connections in all 2D-conv layers
4    if isinstance(module, torch.nn.Conv2d):
5        prune.l1_unstructured(module, name=’weight’, amount=0.2)
6    # Prune 40% of connections in all linear layers
7    elif isinstance(module, torch.nn.Linear):
8        prune.l1_unstructured(module, name=’weight’, amount=0.4)
Listing 3: Automated pruning of all weights in a network belonging to specific layer types.

In the examples above, each candidate entity for pruning is compared in magnitude to other candidate entities within a single layer. Listing 4, instead, provides an example of how to enable the pooling together of all entities (i.e. single connections, entire units, or channels) across a network for a global magnitude comparison.

1class LeNet(nn.Module):
2    def __init__(self):
3        super(LeNet, self).__init__()
4        # 1 input image channel, 6 output channels, 3x3 square conv kernel
5        self.conv1 = nn.Conv2d(1, 6, 3)
6        self.conv2 = nn.Conv2d(6, 16, 3)
7        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
8        self.fc2 = nn.Linear(120, 84)
9        self.fc3 = nn.Linear(84, 10)
11    def forward(self, x):
12        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
13        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
14        x = x.view(-1, int(x.nelement() / x.shape[0]))
15        x = F.relu(self.fc1(x))
16        x = F.relu(self.fc2(x))
17        x = self.fc3(x)
18        return x
20model = LeNet()
22parameters_to_prune = (
23    (model.conv1, ’weight’),
24    (model.conv2, ’weight’),
25    (model.fc1, ’weight’),
26    (model.fc2, ’weight’),
27    (model.fc3, ’weight’),
31    parameters_to_prune,
32    pruning_method=prune.L1Unstructured,
33    amount=0.2,
Listing 4: Example of how to prune the bottom 20% of connections by absolute magnitude across an entire LeNet (LeCun and others, 1998) architecture.

The module also supports the pruning of individual tensors by calling the prune method of any pruning class, as shown in Listing 5.

1t = torch.rand(2, 5)
2p = prune.RandomUnstructured(amount=0.7)
3pruned_tensor = p.prune(t)
Listing 5: Code to prune 70% of entries at random in a tensor not associated with any torch.nn.Module.


  • S. Han, H. Mao, and W. J. Dally (2015) Deep compression: compressing deep neural networks with pruning, trained quantization and huffman coding. External Links: 1510.00149 Cited by: §1.
  • P. Henderson et al. (2020) Towards the systematic reporting of the energy and carbon footprints of machine learning. Cited by: §1.
  • F. N. Iandola et al. (2016) SqueezeNet: alexnet-level accuracy with 50x fewer parameters and ¡0.5mb model size. External Links: 1602.07360 Cited by: §1.
  • Y.-D. Kim et al. (2015)

    Compression of deep convolutional neural networks for fast and low power mobile applications

    External Links: 1511.06530 Cited by: §1.
  • Y. LeCun et al. (1998) Gradient-based learning applied to document recognition. Proceedings of the IEEE 86 (11), pp. 2278–2324. Cited by: Listing 4.
  • M. Nazemi, G. Pasandi, and M. Pedram (2018) NullaNet: training deep neural networks for reduced-memory-access inference. CoRR abs/1807.08716. External Links: Link, 1807.08716 Cited by: §1.
  • A. Paszke et al. (2019) PyTorch: an imperative style, high-performance deep learning library. In Advances in Neural Information Processing Systems 32, H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. Fox, and R. Garnett (Eds.), pp. 8024–8035. External Links: Link Cited by: §1.
  • S. Ramón y Cajal (1909) Histologie du système nerveux de l’homme et des vertébrés.. Vol. v. 1, Paris :Maloine,. External Links: Link Cited by: §1.
  • B. Reagen, P. Whatmough, R. Adolf, S. Rama, H. Lee, S. K. Lee, J. M. Hernández-Lobato, G. Wei, and D. Brooks (2016) Minerva: enabling low-power, highly-accurate deep neural network accelerators. In 2016 ACM/IEEE 43rd Annual International Symposium on Computer Architecture (ISCA), Vol. , pp. 267–278. External Links: Document, ISSN 1063-6897 Cited by: §1.
  • R. Schwartz et al. (2019) Green ai. External Links: 1907.10597 Cited by: §1.
  • K. Simonyan and A. Zisserman (2014) Very deep convolutional networks for large-scale image recognition. External Links: 1409.1556 Cited by: §3.
  • O. Sporns (2007) Brain connectivity. Scholarpedia 2 (10), pp. 4695. Note: revision #91084 External Links: Document Cited by: §1.
  • E. Strubell, A. Ganesh, and A. McCallum (2019) Energy and policy considerations for deep learning in nlp. External Links: 1906.02243 Cited by: §1.
  • M. Verhelst and B. Moons (2017) Embedded deep neural network processing: algorithmic and processor techniques bring deep learning to iot and edge devices. IEEE Solid-State Circuits Magazine 9 (4), pp. 55–65. External Links: Document, ISSN 1943-0590 Cited by: §1.
  • T.-J. Yang, Y.-H. Chen, and V. Sze (2017) Designing energy-efficient convolutional neural networks using energy-aware pruning. In

    The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)

    Cited by: §1.