Pushing optimizations further
JAX feels lower level than other libraries (more constraints, more performance). This can be pushed further for additional speedups (but with additional code complexity).
The lax API
jax.numpy
is a high-level NumPy-like API wrapped around jax.lax
. jax.lax
is a more efficient lower-level API itself wrapped around XLA. It is more powerful, but even stricter and requires many more lines of code.
Pallas: extension to write GPU and TPU kernels
With the success of Triton, JAX built the Pallas extension that allows JAX users to write GPU kernels.
It also allows to write kernels for the TPU with moisaic.