torchbearer: A model fitting library for PyTorch
We introduce torchbearer, a model fitting library for pytorch aimed at researchers working on deep learning or differentiable programming. The torchbearer library provides a high level metric and callback API that can be used for a wide range of applications. We also include a series of built in callbacks that can be used for: model persistence, learning rate decay, logging, data visualization and more. The extensive documentation includes an example library for deep learning and dynamic programming problems and can be found at http://torchbearer.readthedocs.io. The code is licensed under the MIT License and available at https://github.com/ecs-vlc/torchbearer.READ FULL TEXT VIEW PDF
Deep learning hyper-parameter optimization is a tough task. Finding an
The literature on structured prediction for NLP describes a rich collect...
In spite of showing unreasonable effectiveness in modalities like Text a...
A major trend in academia and data science is the rapid adoption of Baye...
We demonstrate a library for the integration of domain knowledge in deep...
This example of Clifford algebras calculations uses GiNaC
We present Kaolin, a PyTorch library aiming to accelerate 3D deep learni...
torchbearer: A model fitting library for PyTorch
The meteoric rise of deep learning will leave behind a host of frameworks that support hardware accelerated tensor processing and automatic differentiation. We speculate that over time a more general characterization, differentiable programming, will take its place. This involves optimizing the parameters of some differentiable program through gradient descent in a process known as fitting. One library which has seen increasing use over recent years ispytorch (Paszke et al., 2017), which excels, in part, because of the ease with which one can construct models that perform non-standard tensor operations. This makes pytorch especially useful for research, where any aspect of a model definition may need to be altered or extended. However, as pytorch is specifically focused on tensor processing and automatic differentiation, it lacks the high level model fitting API of other frameworks such as keras (Chollet et al., 2015), a library for training deep models with tensorflow. Furthermore, such libraries rarely support the more general case of differentiable programming.
We introduce torchbearer, a pytorch library that aids research by streamlining the model fitting process whilst preserving transparency and generality. At the core of torchbearer is a model fitting API where every aspect of the fitting process can be easily modified. We also provide a powerful metric API which enables the gathering of rolling statistics and aggregates. Finally, torchbearer provides a host of callbacks that perform advanced functions, such as using tensorboardx to log visualizations to tensorboard (included with tensorflow (Abadi et al., 2016)) or visdom.
The torchbearer library is written in Python (van Rossum, 1995) using pytorch, torchvision and tqdm and depends on numpy (Oliphant, 2006), scikit-learn (Pedregosa et al., 2011) and tensorboardx for some features. The key abstractions in torchbearer are trials, callbacks and metrics. There are three key principles which have motivated the design:
Flexibility: The library supports both deep learning and differentiable programming.
Functionality: The library includes functions that make complex behavior available.
Transparency: The API is simple and clear.
By virtue of these design principles, torchbearer differs from other, similar libraries such as ignite or tnt. For example, neither library offers a wide range of built in callbacks for complex functions. Furthermore, both ignite and tnt use an event driven API for the model fitting which makes code less transparent and human readable.
The Trial class defines a pytorch model fitting interface built around Trial.run(…). There are also predict(…) and evaluate(…) methods which can be used for model inference or evaluating saved models. The Trial class also provides a state_dict method which returns a dictionary containing: the model parameters, the optimizer state and the callback states which can be saved and later reloaded using load_state_dict.
The callback API defines classes that can be used to execute functions at different points in the fitting process. The mutable state dictionary is an integral part of torchbearer that is given to each callback and contains intermediate variables required by the Trial. This allows callbacks to alter the nature of the fitting process dynamically. Callbacks can be implemented as a decorated function through the decorator API, as shown in Listing 1.
The metric API uses a tree that enables data flow from one metric to a set of children. This allows for aggregates such as a running mean or standard deviation to be computed. Assembling these data structures can be difficult, and as such,torchbearer includes a decorator API that simplifies construction. For example, in Listing 2 producing the standard deviation and running mean of an accuracy metric. The default_for_key(…) decorator enables the metric to be referenced with a string in the Trial definition.
A GAN (Goodfellow et al., 2014)
is a type of model that aims to learn a high quality estimation of an input data distribution. GANs comprise of two networks which are trained simultaneously but with opposing goals, the ‘generator’ and the ‘discriminator’. To encourage the generator to produce samples that appear genuine, we use an adversarial loss. This is minimized when the discriminator predicts that generated samples are genuine. For the discriminator, we use a loss that maximizes the probability of correctly identifying real and fake samples. Implementing this requires forward passes from the generator and discriminator and separate backward passes for the loss of each network. Due to this complexity, training GANs is often challenging in typical frameworks since it requires a very flexible training loop, such as that of thetorchbearer Trial. In this section we will train a standard GAN with torchbearer to demonstrate its effectiveness.
To begin, the forward passes of both networks are combined in Listing 3 with intermediate tensors stored in state. When torchbearer is not passed a criterion, the base loss is automatically set to zero so that callbacks can add to it. We will use the add_to_loss callback shown in Listing 4 to combine the losses. The state keys are generated using torchbearer.state_key(key) to prevent collisions. Note that the call to the underlying model does not automatically include the state dictionary, we will set the pass_state flag in Trial(…) to achieve this. Having created a data loader and optimizer using pytorch, we train the model with just two lines in listing 5.
The torchbearer library is licensed under the MIT License and the most recent release (0.2.1) is referenced on the Python Package Index (PyPi). The code is hosted on GitHub, see CHANGELOG.md for release information. The repository is actively monitored and we encourage users to raise issues requesting fixes or new functionality and to open pull requests for anything they have implemented.
To support usability, the library must be as stable as possible. For this reason we use continuous integration with Travis CI which tests all pull requests before they can be merged with the master branch. We also use Codacy to perform automated code reviews which ensure that new code follows the PEP8 standard. In this way we ensure that the master copy of torchbearer is always correctly styled and passes the tests.
An effective way to improve usability is to provide examples of using the library for a range of problems. As such, we have added an example library which includes detailed examples showing how to use torchbearer for various deep learning and differentiable programming models including: GANs (Goodfellow et al., 2014), Variational Auto-Encoders (Kingma and Welling, 2013) and Support Vector Machines (Cortes and Vapnik, 1995).
To summarize, torchbearer is a library simplifies the process of fitting deep learning and differentiable programming models in pytorch. This is done without reducing transparency so that it is still useful for the purpose of research. Key features of torchbearer include a comprehensive set of built in callbacks (such as logging, weight decay and model check-pointing) and a powerful metric API. The torchbearer library has a strong and growing community on GitHub and we are committed to improving it wherever possible.
Support-vector networks.Machine learning, 20(3):273–297, 1995.