How does it work?
Before using JAX, it is critical to understand its functioning: JAX architecture is at the core of its efficiency and flexibility, but also the cause of a number of constraints.
Map
Here is a schematic of JAX’s functioning:
Tracing
Tracing happens during the first call of a function. Tracer objects are wrapped around each argument and record all operations performed on them, creating a Jaxpr (JAX expression). It is this intermediate representation—rather than the Python code—that JAX then uses.
The tracer objects used to create the Jaxpr contain information about the shape and dtype of the initial Python arguments, but not their values. This means that new inputs with the same shape and dtype will use the cached compiled program directly, skipping the Python code entirely. Inputs with new shape and/or dtype will trigger tracing again (so the Python function gets executed again).
Function side-effects are not recorded by the tracers, which means that they are not part of the Jaxprs. They will be executed once (during tracing), but are thereafter absent from the cached compiled program.
Functions which use values outside of their arguments (e.g. values from the global environment) will not update the cache if such values change.
For these reasons, only functionally pure functions (functions without side effects and which do not rely on values outside their arguments) should be used with JAX.
Transformations
JAX is essentially a functional programming framework. Transformations are higher-order functions transforming Jaxprs.
Transformations are composable and include:
jax.grad()
: creates a function that evaluates the gradient of the input function,jax.vmap()
: implementation of automatic vectorization,jax.pmap()
: implementation of data parallelism across processing units,
and finally, once other necessary transformations have been performed:
jax.jit()
: just-in-time compilation for the XLA.
XLA
The XLA (Accelerated Linear Algebra) compiler takes JIT-compiled JAX programs and optimizes them for the available hardware (CPUs, GPUs, or TPUs).