Libraries built on JAX

Author

Marie-Hélène Burle

JAX is an efficient and flexible framework for array operations and program transformations (including automatic differentiation) built to run on accelerators. Its goal is not to develop specialized applications, but to focus on these chore tasks.

While it is possible to use JAX directly in applications (e.g. to build a NN from scratch), it makes sense to use specialized libraries that are built on top of JAX, make use of its characteristics, and provide convenience functions for specialized tasks.

The list of libraries built on JAX keeps growing, but here are a few of the currently important ones.

The entire JAX environment is in active development so you might want to refer to the JAX website and the Awesome JAX project for up-to-date lists.

Neural networks

Flax is an NN library initially developed by Google Brain and now by Google DeepMind. It is the deep learning library officially recommended by the JAX developers. This is the library that we will use in this course.

Equinox is another DL library, relying on models as pytrees. While its syntax is a lot more user-friendly and familiar to PyTorch users, it has limitations.

Keras can now use JAX as a backend.

It is worth noting that PyTorch is attempting to incorporate JAX’s ideas with a new library under development, functorch.

Haiku was the initial library developed by Google DeepMind. While it is still maintained, development has been stopped in favour of Flax and it is thus not advisable to get started with it unless you are already using it.

Optax is a gradient manipulation and optimization library developed by Google DeepMind.

Bayesian statistics

NumPyro and PyMC are probabilistic programming languages.

BlackJAX is a library of samples.

For a basic and high-level introduction, you can have a look at our webinar on Bayesian inference in JAX.

Probabilistic state space models

Dynamax provides state and parameter estimation for, among others:

  • hidden markov models,
  • linear gaussian state space models,
  • nonlinear gaussian state space models,
  • generalized gaussian state space models.