Torchbearer: A Model Fitting Library for PyTorch

09/10/2018 ∙ by Ethan Harris, et al. ∙ University of Southampton 0

We introduce torchbearer, a model fitting library for pytorch aimed at researchers working on deep learning or differentiable programming. The torchbearer library provides a high level metric and callback API that can be used for a wide range of applications. We also include a series of built in callbacks that can be used for: model persistence, learning rate decay, logging, data visualization and more. The extensive documentation includes an example library for deep learning and dynamic programming problems and can be found at The code is licensed under the MIT License and available at



There are no comments yet.


page 1

page 2

page 3

page 4

Code Repositories


torchbearer: A model fitting library for PyTorch

view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The meteoric rise of deep learning will leave behind a host of frameworks that support hardware accelerated tensor processing and automatic differentiation. We speculate that over time a more general characterization, differentiable programming, will take its place. This involves optimizing the parameters of some differentiable program through gradient descent in a process known as fitting. One library which has seen increasing use over recent years is

pytorch (Paszke et al., 2017), which excels, in part, because of the ease with which one can construct models that perform non-standard tensor operations. This makes pytorch especially useful for research, where any aspect of a model definition may need to be altered or extended. However, as pytorch is specifically focused on tensor processing and automatic differentiation, it lacks the high level model fitting API of other frameworks such as keras (Chollet et al., 2015), a library for training deep models with tensorflow. Furthermore, such libraries rarely support the more general case of differentiable programming.

We introduce torchbearer, a pytorch library that aids research by streamlining the model fitting process whilst preserving transparency and generality. At the core of torchbearer is a model fitting API where every aspect of the fitting process can be easily modified. We also provide a powerful metric API which enables the gathering of rolling statistics and aggregates. Finally, torchbearer provides a host of callbacks that perform advanced functions, such as using tensorboardx to log visualizations to tensorboard (included with tensorflow (Abadi et al., 2016)) or visdom.

2 Design

The torchbearer library is written in Python (van Rossum, 1995) using pytorch, torchvision and tqdm and depends on numpy (Oliphant, 2006), scikit-learn (Pedregosa et al., 2011) and tensorboardx for some features. The key abstractions in torchbearer are trials, callbacks and metrics. There are three key principles which have motivated the design:

  • Flexibility: The library supports both deep learning and differentiable programming.

  • Functionality: The library includes functions that make complex behavior available.

  • Transparency: The API is simple and clear.

By virtue of these design principles, torchbearer differs from other, similar libraries such as ignite or tnt. For example, neither library offers a wide range of built in callbacks for complex functions. Furthermore, both ignite and tnt use an event driven API for the model fitting which makes code less transparent and human readable.

def l2_weight_decay(state):
  for W in state[MODEL].parameters():
    state[LOSS] += W.norm(2)
Listing 1: Example callback
class CategoricalAccuracy
Listing 2: Example metric definition

2.1 Trial API

The Trial class defines a pytorch model fitting interface built around…). There are also predict(…) and evaluate(…) methods which can be used for model inference or evaluating saved models. The Trial class also provides a state_dict method which returns a dictionary containing: the model parameters, the optimizer state and the callback states which can be saved and later reloaded using load_state_dict.

2.2 Callback API

The callback API defines classes that can be used to execute functions at different points in the fitting process. The mutable state dictionary is an integral part of torchbearer that is given to each callback and contains intermediate variables required by the Trial. This allows callbacks to alter the nature of the fitting process dynamically. Callbacks can be implemented as a decorated function through the decorator API, as shown in Listing 1.

2.3 Metric API

The metric API uses a tree that enables data flow from one metric to a set of children. This allows for aggregates such as a running mean or standard deviation to be computed. Assembling these data structures can be difficult, and as such,

torchbearer includes a decorator API that simplifies construction. For example, in Listing 2 producing the standard deviation and running mean of an accuracy metric. The default_for_key(…) decorator enables the metric to be referenced with a string in the Trial definition.

3 Example: Generative Adversarial Networks (GANs)

A GAN (Goodfellow et al., 2014)

is a type of model that aims to learn a high quality estimation of an input data distribution. GANs comprise of two networks which are trained simultaneously but with opposing goals, the ‘generator’ and the ‘discriminator’. To encourage the generator to produce samples that appear genuine, we use an adversarial loss. This is minimized when the discriminator predicts that generated samples are genuine. For the discriminator, we use a loss that maximizes the probability of correctly identifying real and fake samples. Implementing this requires forward passes from the generator and discriminator and separate backward passes for the loss of each network. Due to this complexity, training GANs is often challenging in typical frameworks since it requires a very flexible training loop, such as that of the

torchbearer Trial. In this section we will train a standard GAN with torchbearer to demonstrate its effectiveness.

class GAN(torch.nn.Module):
  def forward(real_imgs, state):
    z = (random sample of prior noise distribution)
    state[GEN_IMGS] = generator(z)
    state[DISC_GEN] = discriminator(state[GEN_IMGS])
    state[DISC_GEN_DET] = discriminator(state[GEN_IMGS].detach())
    state[DISC_REAL] = discriminator(real_imgs)
Listing 3: GAN forward pass
def gan_loss(state):
    fake_loss = adversarial_loss(state[DISC_GEN_DET], fake)
    real_loss = adversarial_loss(state[DISC_REAL], valid)
    state[G_LOSS] = adversarial_loss(state[DISC_GEN], valid)
    state[D_LOSS] = (real_loss + fake_loss) / 2
    return state[G_LOSS] + state[D_LOSS]
Listing 4: Loss computation

To begin, the forward passes of both networks are combined in Listing 3 with intermediate tensors stored in state. When torchbearer is not passed a criterion, the base loss is automatically set to zero so that callbacks can add to it. We will use the add_to_loss callback shown in Listing 4 to combine the losses. The state keys are generated using torchbearer.state_key(key) to prevent collisions. Note that the call to the underlying model does not automatically include the state dictionary, we will set the pass_state flag in Trial(…) to achieve this. Having created a data loader and optimizer using pytorch, we train the model with just two lines in listing 5.

trial = Trial(GAN(), optimizer, metrics=[”loss”], callbacks=[gan_loss], pass_state=True)
Listing 5: Training

4 Project Management

The torchbearer library is licensed under the MIT License and the most recent release (0.2.1) is referenced on the Python Package Index (PyPi). The code is hosted on GitHub, see for release information. The repository is actively monitored and we encourage users to raise issues requesting fixes or new functionality and to open pull requests for anything they have implemented.

4.1 Continuous Integration

To support usability, the library must be as stable as possible. For this reason we use continuous integration with Travis CI which tests all pull requests before they can be merged with the master branch. We also use Codacy to perform automated code reviews which ensure that new code follows the PEP8 standard. In this way we ensure that the master copy of torchbearer is always correctly styled and passes the tests.

4.2 Example Library

An effective way to improve usability is to provide examples of using the library for a range of problems. As such, we have added an example library which includes detailed examples showing how to use torchbearer for various deep learning and differentiable programming models including: GANs (Goodfellow et al., 2014), Variational Auto-Encoders (Kingma and Welling, 2013) and Support Vector Machines (Cortes and Vapnik, 1995).

5 Conclusion

To summarize, torchbearer is a library simplifies the process of fitting deep learning and differentiable programming models in pytorch. This is done without reducing transparency so that it is still useful for the purpose of research. Key features of torchbearer include a comprehensive set of built in callbacks (such as logging, weight decay and model check-pointing) and a powerful metric API. The torchbearer library has a strong and growing community on GitHub and we are committed to improving it wherever possible.


  • Abadi et al. (2016) Martín Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Geoffrey Irving, Michael Isard, et al. Tensorflow: a system for large-scale machine learning. In OSDI, volume 16, pages 265–283, 2016.
  • Chollet et al. (2015) François Chollet et al. Keras., 2015.
  • Cortes and Vapnik (1995) Corinna Cortes and Vladimir Vapnik.

    Support-vector networks.

    Machine learning, 20(3):273–297, 1995.
  • Goodfellow et al. (2014) Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Advances in neural information processing systems, pages 2672–2680, 2014.
  • Kingma and Welling (2013) Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
  • Oliphant (2006) Travis E Oliphant. A guide to NumPy, volume 1. Trelgol Publishing USA, 2006.
  • Paszke et al. (2017) Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in pytorch, 2017.
  • Pedregosa et al. (2011) Fabian Pedregosa, Gaël Varoquaux, Alexandre Gramfort, Vincent Michel, Bertrand Thirion, Olivier Grisel, Mathieu Blondel, Peter Prettenhofer, Ron Weiss, Vincent Dubourg, et al. Scikit-learn: Machine learning in python. Journal of machine learning research, 12(Oct):2825–2830, 2011.
  • van Rossum (1995) Guido van Rossum. Python tutorial, Technical Report CS-R9526. Technical report, Centrum voor Wiskunde en Informatica (CWI), Amsterdam, May 1995.