The JAX AI stack
Content from the course slides for easier browsing.
What is JAX?
JAX is a high-performance accelerator-oriented array computing library for Python developed by Google. It allows composition, JIT-compilation, transformation, and automatic differentiation of numerical programs.
It provides NumPy-like and lower-level APIs.
It also requires strict functional programming.
Why JAX?
Fast
Default data type suited for deep learning
Like PyTorch, uses
float32as default. This level of precision is suitable for deep learning and increases efficiency (by contrast, NumPy defaults tofloat64).The same code can run on CPUs or on accelerators (GPUs and TPUs)
XLA (Accelerated Linear Algebra) optimization
Asynchronous dispatch
Vectorization, data parallelism, and sharding
All levels of shared and distributed memory parallelism are supported.
Great AD
summarized from a blog post by Chris Rackauckas
Close to the math
Considering the function f:
f = lambda x: x**3 + 2*x**2 - 3*x + 8We can create a new function dfdx that computes the gradient of f w.r.t. x:
from jax import grad
dfdx = grad(f)dfdx returns the derivatives:
print(dfdx(1.))4.0
Forward and reverse modes
- reverse-mode vector-Jacobian products:
jax.vjp - forward-mode Jacobian-vector products:
jax.jvp
Higher-order differentiation
With a single variable, the grad function calls can be nested:
d2fdx = grad(dfdx) # function to compute 2nd order derivatives
d3fdx = grad(d2fdx) # function to compute 3rd order derivatives
...With several variables, you have to use the functions:
jax.jacfwdfor forward-mode,jax.jacrevfor reverse-mode.
How does it work?
JAX for AI
Not itself a DL library
A sublanguage ideal for DL
The JAX AI stack
A modular approach:
