JAX: a framework for high-performance array computing and differentiation

An introductory course to noshadow

JAX is an open source Python library for high-performance array computing and flexible automatic differentiation.

High-performance computing is achieved by asynchronous dispatch, just-in-time compilation, the XLA compiler for linear algebra, and full compatibility with accelerators (GPUs and TPUs).

Automatic differentiation uses Autograd and works with complex control flows (conditions, recursions), second and third-order derivatives, forward and reverse modes. This makes JAX ideal for machine learning and neural network libraries such as Flax are built on it.

This course will teach you the basics of JAX and is a prerequisite for the following machine learning course on Flax.

You do not need to install anything on your machine for this course as we will provide access to a temporary remote cluster.


Start course ➤