What is JAX?
Library for Python developed by Google.
Key data structure: Array.
Composition, transformation, and differentiation of numerical programs.
Compilation for CPUs, GPUs, and TPUs.
NumPy-like and lower-level APIs.
Requires strict functional programming.
Why JAX?
Fast
Default data type suited for deep learning.
Like PyTorch , uses float32 as default. This level of precision is suitable for deep learning and increases efficiency (by contrast, NumPy defaults to float64).
JIT compilation .
The same code can run on CPUs or on accelerators (GPUs and TPUs ).
XLA (Accelerated Linear Algebra) optimization.
Asynchronous dispatch.
Vectorization, data parallelism, and sharding.
All levels of shared and distributed memory parallelism are supported.
Great AD
01
Autodiff method
1
Static graph
and XLA
02
Framework
2
Dynamic graph
1->2
a
TensorFlow
4
Dynamic graph
and XLA
2->4
b
PyTorch
5
Pseudo-dynamic
and XLA
4->5
d
TensorFlow2
e
JAX
03
Advantage
7
Mostly
optimized AD
8
Convenient
9
Convenient
10
Convenient and
mostly optimized AD
04
Disadvantage
A
Manual writing of IR
B
Limited AD optimization
D
Disappointing speed
E
Pure functions
Close to the math
Considering the function f:
f = lambda x: x** 3 + 2 * x** 2 - 3 * x + 8
We can create a new function dfdx that computes the gradient of f w.r.t. x:
from jax import grad
dfdx = grad(f)
dfdx returns the derivatives:
4.0
Forward and reverse modes
reverse-mode vector-Jacobian products: jax.vjp
forward-mode Jacobian-vector products: jax.jvp
Higher-order differentiation
With a single variable, the grad function calls can be nested:
d2fdx = grad(dfdx) # function to compute 2nd order derivatives
d3fdx = grad(d2fdx) # function to compute 3rd order derivatives
...
With several variables, you have to use the functions:
jax.jacfwd for forward-mode,
jax.jacrev for reverse-mode.
How does it work?
tracer
Tracing
jaxpr
Jaxprs
(JAX expressions)
intermediate
representation
(IR)
tracer->jaxpr
jit
Transformation
hlo
High-level
optimized (HLO)
program
jit->hlo
xla
Accelerated
Linear Algebra
(XLA)
CPU
CPU
xla->CPU
GPU
GPU
xla->GPU
TPU
TPU
xla->TPU
transform
Transformations
py
Pure Python
functions
py->tracer
jaxpr->jit
jaxpr->transform
hlo->xla
tracer
Tracing
jaxpr
Jaxprs
(JAX expressions)
intermediate
representation
(IR)
tracer->jaxpr
jit
Just-in-time
(JIT)
compilation
hlo
High-level
optimized (HLO)
program
jit->hlo
xla
Accelerated
Linear Algebra
(XLA)
CPU
CPU
xla->CPU
GPU
GPU
xla->GPU
TPU
TPU
xla->TPU
transform
Vectorization
Parallelization
Differentiation
py
Pure Python
functions
py->tracer
jaxpr->jit
jaxpr->transform
hlo->xla
tracer
Tracing
jaxpr
Jaxprs
(JAX expressions)
intermediate
representation
(IR)
tracer->jaxpr
jit
jax.jit
hlo
High-level
optimized (HLO)
program
jit->hlo
xla
Accelerated
Linear Algebra
(XLA)
CPU
CPU
xla->CPU
GPU
GPU
xla->GPU
TPU
TPU
xla->TPU
transform
jax.vmap
jax.pmap
jax.grad
py
Pure Python
functions
py->tracer
jaxpr->jit
jaxpr->transform
hlo->xla
Not a deep learning library
jx
JAX
dl
Deep learning
jx->dl
op
Optimizers
jx->op
pp
Probabilistic
programming
jx->pp
pm
Probabilistic
modeling
jx->pm
ll
LLMs
ll->jx
so
Solvers
so->jx
ph
Physics
simulations
ph->jx
Ideal for DL
JAX is a Python sublanguage ideal for deep learning.
jx
JAX
dl
Deep learning
jx->dl
op
Optimizers
jx->op
pp
Probabilistic
programming
jx->pp
pm
Probabilistic
modeling
jx->pm
ll
LLMs
ll->jx
so
Solvers
so->jx
ph
Physics
simulations
ph->jx
JAX for deep learning
Deep learning libraries
jx
JAX
dl
Deep learning
jx->dl
op
Optimizers
jx->op
fl
Flax
dl->fl
eq
Equinox
dl->eq
ke
Keras
dl->ke
oa
Optax
op->oa
oi
Optimix
op->oi
This course
jx
JAX
dl
Deep learning
jx->dl
op
Optimizers
jx->op
fl
Flax
dl->fl
eq
Equinox
dl->eq
ke
Keras
dl->ke
oa
Optax
op->oa
oi
Optimix
op->oi
Modular approach
Data loaders
load
Load data
pt
torchdata
pt->load
tfds
tfds
tfds->load
dt
datasets
dt->load
Core deep learning library
load
Load data
proc
Process data
load->proc
tv
torchvision
nn
Define architecture
proc->nn
pt
torchdata
pt->load
tfds
tfds
tfds->load
dt
datasets
dt->load
gr
grain
gr->proc
tv->proc
fl
flax
fl->nn
Optimizer and loss functions
load
Load data
proc
Process data
load->proc
tv
torchvision
nn
Define architecture
proc->nn
opt
Hyperparameters
nn->opt
pt
torchdata
pt->load
tfds
tfds
tfds->load
dt
datasets
dt->load
gr
grain
gr->proc
tv->proc
fl
flax
fl->nn
oa
optax
oa->opt
Train
load
Load data
proc
Process data
load->proc
tv
torchvision
nn
Define architecture
proc->nn
opt
Hyperparameters
nn->opt
train
Train
opt->train
pt
torchdata
pt->load
tfds
tfds
tfds->load
dt
datasets
dt->load
gr
grain
gr->proc
tv->proc
fl1
flax
fl1->nn
fl2
flax
fl2->train
oa
optax
oa->opt
jx
jax
jx->fl2
Checkpointing
load
Load data
proc
Process data
load->proc
tv
torchvision
nn
Define architecture
proc->nn
opt
Hyperparameters
nn->opt
train
Train
opt->train
cp
Checkpoint
train->cp
pt
torchdata
pt->load
tfds
tfds
tfds->load
dt
datasets
dt->load
gr
grain
gr->proc
tv->proc
fl1
flax
fl1->nn
fl2
flax
oa
optax
oa->opt
jx
jax
jx->fl2
ob
orbax
ob->cp
Transfer learning
load
Load data
proc
Process data
load->proc
tv
torchvision
nn
Define architecture
proc->nn
pretr
Pre-trained model
opt
Hyperparameters
nn->opt
pretr->nn
train
Train
opt->train
cp
Checkpoint
train->cp
pt
torchdata
pt->load
tfds
tfds
tfds->load
dt
datasets
dt->load
gr
grain
gr->proc
tv->proc
tr
transformers
tr->pretr
fl1
flax
fl1->nn
fl2
flax
fl2->train
oa
optax
oa->opt
jx
jax
jx->fl2
ob
orbax
ob->cp
Installation
Installing JAX
CPU
yes
yes
yes
yes
yes
yes
NVIDIA GPU
yes
yes
no
n/a
no
experimental
Google TPU
yes
n/a
n/a
n/a
n/a
n/a
AMD GPU
yes
no
experimental
n/a
no
no
Apple GPU
n/a
no
n/a
experimental
n/a
n/a
Intel GPU
experimental
n/a
n/a
n/a
no
no
If you install packages which depend on JAX (e.g. Flax), they will by default install the CPU version of JAX. If you want to run JAX on GPUs, make sure to first install jax[cuda12].
You can install the CPU version on your machine to prototype and use a GPU version on the clusters (we have wheels).
Complementary libraries
The modular approach has the downside that several libraries are required and conflicts between dependencies can be a problem.
The meta-library jax-ai-stack makes this easier to manage (install jax[cuda12] first for GPU).
Note that for now TensorFlow and packages which depend on it (e.g. TFDS, grain) are still stuck at Python 3.12, so you can’t use a newer Python version if you want to use some of them.
On your machine (and your machine only ), a great tool to manage Python versions and packages is uv (see our webinar ). On the clusters, you have to use module to load the Python version you want and pip to install packages.