A brief intro to

Author

and how to use it for deep learning

Content from the intro slides for easier browsing.

What is JAX?

  • Library for Python developed by Google.
  • Key data structure: Array.
  • Composition, transformation, and differentiation of numerical programs.
  • Compilation for CPUs, GPUs, and TPUs.
  • NumPy-like and lower-level APIs.
  • Requires strict functional programming.

Why JAX?

Fast

  • Default data type suited for deep learning.

    Like PyTorch, uses float32 as default. This level of precision is suitable for deep learning and increases efficiency (by contrast, NumPy defaults to float64).

  • JIT compilation.

  • The same code can run on CPUs or on accelerators (GPUs and TPUs).

  • XLA (Accelerated Linear Algebra) optimization.

  • Asynchronous dispatch.

  • Vectorization, data parallelism, and sharding.

    All levels of shared and distributed memory parallelism are supported.

Great AD

01 Autodiff method 1 Static graph and XLA 02 Framework 2 Dynamic graph 1->2 a TensorFlow 4 Dynamic graph and XLA 2->4 b PyTorch 5 Pseudo-dynamic and XLA 4->5 d TensorFlow2 e JAX 03 Advantage 7 Mostly optimized AD 8 Convenient 9 Convenient 10 Convenient and mostly optimized AD 04 Disadvantage A Manual writing of IR B Limited AD optimization D Disappointing speed E Pure functions

Summarized from a blog post by Chris Rackauckas

Close to the math

Considering the function f:

f = lambda x: x**3 + 2*x**2 - 3*x + 8

We can create a new function dfdx that computes the gradient of f w.r.t. x:

from jax import grad

dfdx = grad(f)

dfdx returns the derivatives:

print(dfdx(1.))
4.0

Forward and reverse modes

  • reverse-mode vector-Jacobian products: jax.vjp
  • forward-mode Jacobian-vector products: jax.jvp

Higher-order differentiation

With a single variable, the grad function calls can be nested:

d2fdx = grad(dfdx)   # function to compute 2nd order derivatives
d3fdx = grad(d2fdx)  # function to compute 3rd order derivatives
...

With several variables, you have to use the functions:

  • jax.jacfwd for forward-mode,
  • jax.jacrev for reverse-mode.

How does it work?

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Transformation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Transformations py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Vectorization Parallelization   Differentiation   py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit jax.jit hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform jax.vmap jax.pmap jax.grad py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

Not a deep learning library

jx JAX dl Deep learning jx->dl op Optimizers jx->op pp Probabilistic programming jx->pp pm Probabilistic modeling jx->pm ll LLMs ll->jx so Solvers so->jx ph Physics simulations ph->jx

Ideal for DL

JAX is a Python sublanguage ideal for deep learning.

jx JAX dl Deep learning jx->dl op Optimizers jx->op pp Probabilistic programming jx->pp pm Probabilistic modeling jx->pm ll LLMs ll->jx so Solvers so->jx ph Physics simulations ph->jx

JAX for deep learning

Deep learning libraries

jx JAX dl Deep learning jx->dl op Optimizers jx->op fl Flax dl->fl eq Equinox dl->eq ke Keras dl->ke oa Optax op->oa oi Optimix op->oi

This course

jx JAX dl Deep learning jx->dl op Optimizers jx->op fl Flax dl->fl eq Equinox dl->eq ke Keras dl->ke oa Optax op->oa oi Optimix op->oi

Modular approach

Data loaders

load Load data pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load

Data transformations

load Load data proc Process data load->proc tv torchvision pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc

Core deep learning library

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc fl flax fl->nn

Optimizer and loss functions

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn opt Hyperparameters nn->opt pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc fl flax fl->nn oa optax oa->opt

Train

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn opt Hyperparameters nn->opt train Train opt->train pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc fl1 flax fl1->nn fl2 flax fl2->train oa optax oa->opt jx jax jx->fl2

Checkpointing

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn opt Hyperparameters nn->opt train Train opt->train cp Checkpoint train->cp pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc fl1 flax fl1->nn fl2 flax oa optax oa->opt jx jax jx->fl2 ob orbax ob->cp

Transfer learning

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn pretr Pre-trained model opt Hyperparameters nn->opt pretr->nn train Train opt->train cp Checkpoint train->cp pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc tr transformers tr->pretr fl1 flax fl1->nn fl2 flax fl2->train oa optax oa->opt jx jax jx->fl2 ob orbax ob->cp

Installation

Installing JAX

Linux x86_64 Linux aarch64 Mac x86_64 Mac aarch64 Windows x86_64 Windows WSL2 x86_64
CPU yes yes yes yes yes yes
NVIDIA GPU yes yes no n/a no experimental
Google TPU yes n/a n/a n/a n/a n/a
AMD GPU yes no experimental n/a no no
Apple GPU n/a no n/a experimental n/a n/a
Intel GPU experimental n/a n/a n/a no no

If you install packages which depend on JAX (e.g. Flax), they will by default install the CPU version of JAX. If you want to run JAX on GPUs, make sure to first install jax[cuda12].

You can install the CPU version on your machine to prototype and use a GPU version on the clusters (we have wheels).

Complementary libraries

The modular approach has the downside that several libraries are required and conflicts between dependencies can be a problem.

The meta-library jax-ai-stack makes this easier to manage (install jax[cuda12] first for GPU).

Note that for now TensorFlow and packages which depend on it (e.g. TFDS, grain) are still stuck at Python 3.12, so you can’t use a newer Python version if you want to use some of them.

On your machine (and your machine only), a great tool to manage Python versions and packages is uv (see our webinar). On the clusters, you have to use module to load the Python version you want and pip to install packages.