How does JAX work?

Author

Marie-Hélène Burle

Before using JAX, it is critical to understand its functioning: JAX architecture is at the core of its efficiency and flexibility, but also the cause of a number of constraints.

Map

Here is a schematic of JAX’s functioning:

tracer Tracing jaxpr Jaxpr (JAX expression) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Transformations py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

Tracing

Tracing happens during the first call of a function. Tracer objects are wrapped around each argument and record all operations performed on them, creating a Jaxpr (JAX expression). It is this intermediate representation—rather than the Python code—that JAX then uses.

The tracer objects used to create the Jaxpr contain information about the shape and dtype of the initial Python arguments, but not their values. This means that new inputs with the same shape and dtype will use the cached compiled program directly, skipping the Python code entirely. Inputs with new shape and/or dtype will trigger tracing again (so the Python function gets executed again).

Function side-effects are not recorded by the tracers, which means that they are not part of the Jaxprs. They will be executed once (during tracing), but are thereafter absent from the cached compiled program.

Functions which use values outside of their arguments (e.g. values from the global environment) will not update the cache if such values change.

For these reasons, only functionally pure functions (functions without side effects and which do not rely on values outside their arguments) should be used with JAX.

Transformations

JAX is essentially a functional programming framework. Transformations are higher-order functions transforming Jaxprs.

Transformations are composable and include:

  • jax.grad(): creates a function that evaluates the gradient of the input function,
  • jax.vmap(): implementation of automatic vectorization,
  • jax.pmap(): implementation of data parallelism across processing units,

and finally, once other necessary transformations have been performed:

  • jax.jit(): just-in-time compilation for the XLA.

XLA

The XLA (Accelerated Linear Algebra) compiler takes JIT-compiled JAX programs and optimizes them for the available hardware (CPUs, GPUs, or TPUs).