Equinox: neural networks in JAX via callable PyTrees and filtered transformations

10/30/2021
by   Patrick Kidger, et al.
0

JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is based around pure functions and functional programming. PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions, such as neural networks. That this seems like a fundamental difference means current libraries for building parameterised functions in JAX have either rejected the OO approach entirely (Stax) or have introduced OO-to-functional transformations, multiple new abstractions, and been limited in the extent to which they integrate with JAX (Flax, Haiku, Objax). Either way this OO/functional difference has been a source of tension. Here, we introduce `Equinox', a small neural network library showing how a PyTorch-like class-based approach may be admitted without sacrificing JAX-like functional programming. We provide two main ideas. One: parameterised functions are themselves represented as `PyTrees', which means that the parameterisation of a function is transparent to the JAX framework. Two: we filter a PyTree to isolate just those components that should be treated when transforming (`jit', `grad' or `vmap'-ing) a higher-order function of a parameterised function – such as a loss function applied to a model. Overall Equinox resolves the above tension without introducing any new programmatic abstractions: only PyTrees and transformations, just as with regular JAX. Equinox is available at <https://github.com/patrick-kidger/equinox>.

READ FULL TEXT

Please sign up or login with your details

Forgot password? Click here to reset