Pruning provides ways to remove unnecessary structure in neural networks, thus beginning to address some of the concerns above. Different ways of identifying superfluous portions of a model result in different pruning techniques. These may vary along several axes, including: the nature of the entities to prune (connections, nodes, channels, layers, etc.), the choice of proxy for importance of each entity (weight, activation, gradient, etc.), when to compute the chosen quantity, the group of entities to pool for comparison (all units in the same layer, the whole network, etc.), when to prune (during, before, or after training), whether pruned entities are forever pruned or can be reinstated, whether to apply hard (binary masks) or soft pruning, iterative or one-shot pruning, and what to do with the network once it is pruned (finetuning, reinitializing, rewinding, etc.).
With the addition of the torch.nn.utils.prune module111Available at https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/prune.py, PyTorch (Paszke and others, 2019) users may now scan over various choices of pruning techniques as easily as any other choice of hyper-parameter and building block in their machine learning pipeline. At the same time, this module aims at empowering researchers to contribute new pruning techniques and express them through a common language.
The goal of this contribution is knowledge dissemination, among the relevant community, of available tools that can simplify and power both research and deployment in resource-constrained scenarios.
BasePruningMethod is an abstract base class that provides a skeleton for all pruning techniques and implements shared functionalities. It enables the creation of new pruning techniques by requiring the overriding of methods such as compute_mask. All pruning methods in Sec. 2.1 are derived classes that inherit from it.
The core logic for the application of pruning to a tensor within a module is contained in the the class method apply. Specifically, pruning acts by removing the specified parameter from the parameters list and replacing it with a new parameter whose name equals the original one with the string "_orig" appended to it. This stores the unpruned version of the tensor. The pruning mask generated by the pruning technique is saved as a module buffer whose name equals the original parameter’s name with the string "_mask" appended to it. Once the reparametrization is in place, an attribute with the original tensor’s name, needed by the forward method, is created as a multiplication of the original tensor and the mask by the apply_mask method. Finally, the pruning method is attached to the module via a forward_pre_hook to ensure that the multiplication op (computed on the fly each time upon calling an instance of the pruning class via __call__) is registered into the forward and backward graphs whenever the module is used to compute an output given an input (Fig. 1). The function torch.nn.utils.prune.is_pruned returns information about the presence of any pruning hook associated with a module. All relevant tensors, including the mask buffers and the original parameters used to compute the pruned tensors are stored in the model’s state_dict and can therefore be easily serialized and saved.
In case the reparametrization and hook creation fail, the related exception is raised and the module is rolled back to its state prior to the failed pruning attempt, without compromising its usability.
To remove the reparametrization depicted in Fig. 1 and make the pruning permanent, the user can call the function torch.nn.utils.prune.remove, which removes buffers, hooks, and additional attributes, and assigns the pruned tensor to the parameter with the original parameter name. Note that pruning itself is not undone or reversed by this operation.
A PruningContainer is used to store the history of pruning calls executed on a module in order to enable iterative pruning, i.e. the sequential application of pruning techniques. Each parameter in a module that is pruned more than once has an associated PruningContainer; that container has an instance attribute called _tensor_name that identifies which parameter in the module it relates to. Only pruning methods that act on the same tensor can be added to the container. The tuple attribute _pruning_methods stores the instances of pruning techniques in the order they are applied.
Each pruning method has a PRUNING_TYPE
that dictates how to combine iterative pruning masks. At the moment, this supports the following types: unstructured, structured, and global. An unstructured pruning technique disregards individual entries in a tensor that have already been pruned; a structured pruning technique disregards a row or column only if all its entries have already been pruned; a global pruning method, such atorch.nn.utils.prune.CustomFromMask in Sec. 2.1, applies the pruning technique to all entries, regardless of whether they have been pruned before. These code paths are defined in the _combine_masks inner utility function within the PruningContainer’s compute_mask method. For most other intents and purposes, a container works like any other derived class of a BasePruningMethod.
The torch.nn.utils.prune module also provides a simple and clean functional interface that allows users to interact with pruning techniques through intuitive function calls on modules’ parameters identified by name. See Sec. 3 for example usage.
2.1 Available Pruning Methods
The following child classes inherit from the BasePruningMethod:
torch.nn.utils.prune.Identity: utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones;
torch.nn.utils.prune.RandomUnstructured: prune (currently unpruned) entries in a tensor at random;
torch.nn.utils.prune.L1Unstructured: prune (currently unpruned) entries in a tensor by zeroing out the ones with the lowest absolute magnitude;
torch.nn.utils.prune.RandomStructured: prune entire (currently unpruned) rows or columns in a tensor at random;
torch.nn.utils.prune.LnStructured: prune entire (currently unpruned) rows or columns in a tensor based on their -norm (supported values of correspond to supported values for argument in torch.norm());
torch.nn.utils.prune.CustomFromMask: prune a tensor using a user-provided mask.
Their functional equivalents are:
Global pruning, in which entries are compared across multiple tensors, is enabled through torch.nn.utils.prune.global_unstructured (see Sec. 3 for example usage).
2.2 Extending the Module
The torch.nn.utils.prune module can be extended to implement custom pruning functions. This requires sub-classing the BasePruningMethod base class, and implementing the __init__ and compute_mask methods, i.e., respectively, the constructor and the instructions to compute the mask for the given tensor according to the logic of the pruning technique. The nature of the pruning technique is specified through the assignment of a PRUNING_TYPE. If none of the currently supported types fits the new pruning technique, the user will also have to add support for a new PRUNING_TYPE in the way PruningContainer handles the iterative application of pruning masks.
For example, Listing 1 demonstrates how to implement a pruning technique that prunes every other currently unpruned entry in a tensor, and how to provide a convenient functional interface for the method in just a few lines of code.
3 Example Usage
Pruning parameters in a model is as simple as invoking the desired pruning function on each parameter (identified by name) within a given neural network module. In this example, the first convolutional layer in a VGG-11 architecture (Simonyan and Zisserman, 2014) is first pruned by removing 3 individual entries using unstructured magnitude-based pruning, then pruned again by removing the bottom 50% of remaining channels along the 0 axis by -norm.
This can be easily extended to apply pruning to all layers in a network. For instance, Listing 3 shows how to prune all 2D convolutional and linear layers in a network, with different pruning fractions that depend on the layer type.
In the examples above, each candidate entity for pruning is compared in magnitude to other candidate entities within a single layer. Listing 4, instead, provides an example of how to enable the pooling together of all entities (i.e. single connections, entire units, or channels) across a network for a global magnitude comparison.
The module also supports the pruning of individual tensors by calling the prune method of any pruning class, as shown in Listing 5.
- Deep compression: compressing deep neural networks with pruning, trained quantization and huffman coding. External Links: Cited by: §1.
- Towards the systematic reporting of the energy and carbon footprints of machine learning. Cited by: §1.
- SqueezeNet: alexnet-level accuracy with 50x fewer parameters and ¡0.5mb model size. External Links: Cited by: §1.
Compression of deep convolutional neural networks for fast and low power mobile applications. External Links: Cited by: §1.
- Gradient-based learning applied to document recognition. Proceedings of the IEEE 86 (11), pp. 2278–2324. Cited by: Listing 4.
- NullaNet: training deep neural networks for reduced-memory-access inference. CoRR abs/1807.08716. External Links: Cited by: §1.
- PyTorch: an imperative style, high-performance deep learning library. In Advances in Neural Information Processing Systems 32, H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. Fox, and R. Garnett (Eds.), pp. 8024–8035. External Links: Cited by: §1.
- Histologie du système nerveux de l’homme et des vertébrés.. Vol. v. 1, Paris :Maloine,. External Links: Cited by: §1.
- Minerva: enabling low-power, highly-accurate deep neural network accelerators. In 2016 ACM/IEEE 43rd Annual International Symposium on Computer Architecture (ISCA), Vol. , pp. 267–278. External Links: Cited by: §1.
- Green ai. External Links: Cited by: §1.
- Very deep convolutional networks for large-scale image recognition. External Links: Cited by: §3.
- Brain connectivity. Scholarpedia 2 (10), pp. 4695. Note: revision #91084 External Links: Cited by: §1.
- Energy and policy considerations for deep learning in nlp. External Links: Cited by: §1.
- Embedded deep neural network processing: algorithmic and processor techniques bring deep learning to iot and edge devices. IEEE Solid-State Circuits Magazine 9 (4), pp. 55–65. External Links: Cited by: §1.
- Designing energy-efficient convolutional neural networks using energy-aware pruning. In , Cited by: §1.