Pushing optimizations further

Author

Marie-Hélène Burle

JAX feels lower level than other libraries (more constraints, more performance). This can be pushed further for additional speedups (but with additional code complexity).

The lax API

jax.numpy is a high-level NumPy-like API wrapped around jax.lax. jax.lax is a more efficient lower-level API itself wrapped around XLA. It is more powerful, but even stricter and requires many more lines of code.

Pallas: extension to write GPU and TPU kernels

With the success of Triton, JAX built the Pallas extension that allows JAX users to write GPU kernels.

It also allows to write kernels for the TPU with moisaic.

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo triton Triton GPU GPU triton->GPU mosaic Mosaic TPU TPU mosaic->TPU transform Vectorization Parallelization   Differentiation   py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->triton hlo->mosaic