A brief intro to

noshadow

and how to use it for deep learning

What is JAX?

Library for Python developed by Google

Key data structure: Array

Composition, transformation, and differentiation of numerical programs

Compilation for CPUs, GPUs, and TPUs

NumPy-like and lower-level APIs

Requires strict functional programming

Why JAX?

Fast

  • Default data type suited for deep learning

    Like PyTorch, uses float32 as default. This level of precision is suitable for deep learning and increases efficiency (by contrast, NumPy defaults to float64)

  • JIT compilation

  • The same code can run on CPUs or on accelerators (GPUs and TPUs)

  • XLA (Accelerated Linear Algebra) optimization

  • Asynchronous dispatch

  • Vectorization, data parallelism, and sharding

    All levels of shared and distributed memory parallelism are supported

Great AD

01 Autodiff method 1 Static graph and XLA 02 Framework 2 Dynamic graph 1->2 a TensorFlow 4 Dynamic graph and XLA 2->4 b PyTorch 5 Pseudo-dynamic and XLA 4->5 d TensorFlow2 e JAX 03 Advantage 7 Mostly optimized AD 8 Convenient 9 Convenient 10 Convenient and mostly optimized AD 04 Disadvantage A Manual writing of IR B Limited AD optimization D Disappointing speed E Pure functions

  Summarized from a blog post by Chris Rackauckas

Close to the math

Considering the function f:

f = lambda x: x**3 + 2*x**2 - 3*x + 8

We can create a new function dfdx that computes the gradient of f w.r.t. x:

from jax import grad

dfdx = grad(f)

dfdx returns the derivatives:

print(dfdx(1.))
4.0

Forward and reverse modes

  • reverse-mode vector-Jacobian products: jax.vjp
  • forward-mode Jacobian-vector products: jax.jvp

Higher-order differentiation

With a single variable, the grad function calls can be nested:

d2fdx = grad(dfdx)   # function to compute 2nd order derivatives
d3fdx = grad(d2fdx)  # function to compute 3rd order derivatives
...

With several variables, you have to use the functions:

  • jax.jacfwd for forward-mode,
  • jax.jacrev for reverse-mode.

How does it work?

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Transformation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Transformations py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

How does it work?

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Vectorization Parallelization   Differentiation   py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

How does it work?

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit jax.jit hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform jax.vmap jax.pmap jax.grad py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

Not a deep learning library

jx JAX dl Deep learning jx->dl op Optimizers jx->op pp Probabilistic programming jx->pp pm Probabilistic modeling jx->pm ll LLMs ll->jx so Solvers so->jx ph Physics simulations ph->jx

A Python sublanguage ideal for deep learning

jx JAX dl Deep learning jx->dl op Optimizers jx->op pp Probabilistic programming jx->pp pm Probabilistic modeling jx->pm ll LLMs ll->jx so Solvers so->jx ph Physics simulations ph->jx

JAX for deep learning

Deep learning libraries

jx JAX dl Deep learning jx->dl op Optimizers jx->op fl Flax dl->fl eq Equinox dl->eq ke Keras dl->ke oa Optax op->oa oi Optimix op->oi

This course

jx JAX dl Deep learning jx->dl op Optimizers jx->op fl Flax dl->fl eq Equinox dl->eq ke Keras dl->ke oa Optax op->oa oi Optimix op->oi

Modular approach

Data loaders

load Load data pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load

Data transformations

load Load data proc Process data load->proc tv torchvision pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc

Core deep learning library

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc fl flax fl->nn

Optimizer and loss functions

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn opt Hyperparameters nn->opt pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc fl flax fl->nn oa optax oa->opt

Train

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn opt Hyperparameters nn->opt train Train opt->train pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc fl1 flax fl1->nn fl2 flax fl2->train oa optax oa->opt jx jax jx->fl2

Checkpointing

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn opt Hyperparameters nn->opt train Train opt->train cp Checkpoint train->cp pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc fl1 flax fl1->nn fl2 flax oa optax oa->opt jx jax jx->fl2 ob orbax ob->cp

Transfer learning

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn pretr Pre-trained model opt Hyperparameters nn->opt pretr->nn train Train opt->train cp Checkpoint train->cp pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc tr transformers tr->pretr fl1 flax fl1->nn fl2 flax fl2->train oa optax oa->opt jx jax jx->fl2 ob orbax ob->cp

Installation

Installing JAX

Linux x86_64 Linux aarch64 Mac x86_64 Mac aarch64 Windows x86_64 Windows WSL2 x86_64
CPU yes yes yes yes yes yes
NVIDIA GPU yes yes no n/a no experimental
Google TPU yes n/a n/a n/a n/a n/a
AMD GPU yes no experimental n/a no no
Apple GPU n/a no n/a experimental n/a n/a
Intel GPU experimental n/a n/a n/a no no

From JAX documentation

Installing JAX

If you install packages which depend on JAX (e.g. Flax), they will by default install the CPU version of JAX. If you want to run JAX on GPUs, make sure to first install jax[cuda12]

You can install the CPU version on your machine to prototype and use a GPU version on the clusters (we have wheels)

Installing complementary libraries

The modular approach has the downside that several libraries are required and conflicts between dependencies can be a problem

The meta-library jax-ai-stack makes this easier to manage (install jax[cuda12] first for GPU)

Note that for now TensorFlow and packages which depend on it (e.g. TFDS, grain) are still stuck at Python 3.12, so you can’t use a newer Python version if you want to use some of them

On your machine (and your machine only), a great tool to manage Python versions and packages is uv. (Webinar coming soon). On the clusters, you have to use module to load the Python version you want and pip to install packages

 Back to the course

A brief intro to and how to use it for deep learning

  1. Slides

  2. Tools

  3. Close
  • A brief intro to
  • What is JAX?
  • Why JAX?
  • Fast
  • Great AD
  • Close to the math
  • Forward and reverse modes
  • Higher-order differentiation
  • How does it work?
  • How does it work?
  • How does it work?
  • Not a deep learning library
  • A Python sublanguage ideal for deep learning
  • JAX for deep learning
  • Deep learning libraries
  • This course
  • Modular approach
  • Data loaders
  • Data transformations
  • Core deep learning library
  • Optimizer and loss functions
  • Train
  • Checkpointing
  • Transfer learning
  • Installation
  • Installing JAX
  • Installing JAX
  • Installing complementary libraries
  • f Fullscreen
  • s Speaker View
  • o Slide Overview
  • e PDF Export Mode
  • r Scroll View Mode
  • ? Keyboard Help