Accelerated array computing and flexible differentiation with JAX


Marie-Hélène Burle

JAX is an open source Python library for high-performance array computing and flexible automatic differentiation.

High-performance computing is achieved by asynchronous dispatch, just-in-time compilation, the XLA compiler for linear algebra, and full compatibility with accelerators (GPUs and TPUs).

Automatic differentiation uses Autograd and works with complex control flows (conditions, recursions), second and third-order derivatives, forward and reverse modes. This makes JAX ideal for machine learning and neural network libraries such as Flax are built on it.

This webinar will give an overview of JAX’s principles and functioning.

Slides (Click and wait: this reveal.js presentation may take a little time to load.)