Equinox: neural networks in JAX via callable PyTrees and filtered transformations

10/30/2021
by   Patrick Kidger, et al.
0

JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is based around pure functions and functional programming. PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions, such as neural networks. That this seems like a fundamental difference means current libraries for building parameterised functions in JAX have either rejected the OO approach entirely (Stax) or have introduced OO-to-functional transformations, multiple new abstractions, and been limited in the extent to which they integrate with JAX (Flax, Haiku, Objax). Either way this OO/functional difference has been a source of tension. Here, we introduce `Equinox', a small neural network library showing how a PyTorch-like class-based approach may be admitted without sacrificing JAX-like functional programming. We provide two main ideas. One: parameterised functions are themselves represented as `PyTrees', which means that the parameterisation of a function is transparent to the JAX framework. Two: we filter a PyTree to isolate just those components that should be treated when transforming (`jit', `grad' or `vmap'-ing) a higher-order function of a parameterised function – such as a loss function applied to a model. Overall Equinox resolves the above tension without introducing any new programmatic abstractions: only PyTrees and transformations, just as with regular JAX. Equinox is available at <https://github.com/patrick-kidger/equinox>.

READ FULL TEXT

page 1

page 2

page 3

page 4

research
11/06/2019

Neural Network Processing Neural Networks: An efficient way to learn higher order functions

Functions are rich in meaning and can be interpreted in a variety of way...
research
12/29/2020

TensorX: Extensible API for Neural Network Model Design and Deployment

TensorX is a Python library for prototyping, design, and deployment of c...
research
07/09/2022

Subclasses of Class Function used to Implement Transformations of Statistical Models

A library of software for inductive inference guided by the Minimum Mess...
research
11/13/2020

diagNNose: A Library for Neural Activation Analysis

In this paper we introduce diagNNose, an open source library for analysi...
research
10/03/2019

Pure and Spurious Critical Points: a Geometric Study of Linear Networks

The critical locus of the loss function of a neural network is determine...
research
06/19/2021

Deep Learning for Functional Data Analysis with Adaptive Basis Layers

Despite their widespread success, the application of deep neural network...
research
03/11/2021

Classical (Co)Recursion: Programming

Our aim here is to illustrate how the benefits of structural corecursion...

Please sign up or login with your details

Forgot password? Click here to reset