An introduction to JAX

High-performance array computing and differentiation with noshadow

JAX is an open source Python library for high-performance array computing and flexible automatic differentiation.

High-performance computing is achieved by asynchronous dispatch, just-in-time compilation, the XLA compiler for linear algebra, and full compatibility with accelerators (GPUs and TPUs).

Automatic differentiation uses Autograd and works with complex control flows (conditions, recursions), second and third-order derivatives, forward and reverse modes. This makes JAX ideal for machine learning and neural network libraries such as Flax are built on it.


Start course ➤