Libraries on top of JAX


Marie-Hélène Burle

JAX is an efficient and flexible framework for function transformations (including automatic differentiation) built to run on accelerators. Its goal is not to develop specialized applications, but to focus on the chore of the language.

While it is possible to use JAX directly in application (e.g. to build a NN from scratch), it makes sense to use specialized libraries that are built on top of JAX, that make use of its characteristics, and provide functions for specialized applications.

Here are a few important such libraries.

Neural networks

Flax is an NN library initially developed by Google Brain and now by Google DeepMind. It is the deep learning library officially recommended by the JAX developers.

Haiku was the initial library developed by Google DeepMind. While it is still maintained, development has been stopped in favour of Flax.

Equinox is another DL library, relying on models as pytrees. While its syntax is a lot more user-friendly and familiar to PyTorch users, it has limitations.

It is worth noting that PyTorch is attempting to incorporate JAX’s ideas with a new library under development, functorch.

Optax is a gradient manipulation and optimization library developed by Google DeepMind.

Probabilistic state space models

Dynamax provides state and parameter estimation for, among others:

  • hidden markov models,
  • linear gaussian state space models,
  • nonlinear gaussian state space models,
  • generalized gaussian state space models.