Accelerated array computing and flexible differentiation with

Marie-Hélène Burle

April 16, 2024


Context

What is JAX?

Library for Python developed by Google

Key data structure: Array

Composition, transformation, and differentiation of numerical programs

Compilation for CPUs, GPUs, and TPUs

NumPy-like and lower-level APIs

Requires strict functional programming

Why JAX?

01 Autodiff method 1 Static graph and XLA 02 Framework 2 Dynamic graph 1->2 a TensorFlow 4 Dynamic graph and XLA 2->4 b PyTorch 5 Pseudo-dynamic and XLA 4->5 d TensorFlow2 e JAX 03 Advantage 7 Mostly optimized AD 8 Convenient 9 Convenient 10 Convenient and mostly optimized AD 04 Disadvantage A Manual writing of IR B Limited AD optimization D Disappointing speed E Pure functions

  Summarized from a blog post by Chris Rackauckas

Getting started

Installation


Install from pip wheels:

  • Personal computer: use wheels installation commands from official site
  • Alliance clusters: python -m pip install jax --no-index

Windows: GPU support only via WSL

The NumPy API

import numpy as np

print(np.array([(1, 2, 3), (4, 5, 6)]))
[[1 2 3]
 [4 5 6]]
print(np.arange(5))
[0 1 2 3 4]
print(np.zeros(2))
[0. 0.]
print(np.linspace(0, 2, 9))
[0.   0.25 0.5  0.75 1.   1.25 1.5  1.75 2.  ]
import jax.numpy as jnp

print(jnp.array([(1, 2, 3), (4, 5, 6)]))
[[1 2 3]
 [4 5 6]]
print(jnp.arange(5))
[0 1 2 3 4]
print(jnp.zeros(2))
[0. 0.]
print(jnp.linspace(0, 2, 9))
[0.   0.25 0.5  0.75 1.   1.25 1.5  1.75 2.  ]

But JAX NumPy is not NumPy…

Different types

type(np.zeros((2, 3)))
numpy.ndarray
type(jnp.zeros((2, 3)))
jaxlib.xla_extension.ArrayImpl

Different default data types

np.zeros((2, 3)).dtype
dtype('float64')
jnp.zeros((2, 3)).dtype
dtype('float32')

Standard for DL and libraries built for accelerators
Float64 are very slow on GPUs and not supported on TPUs

Immutable arrays

a = np.arange(5)
a[0] = 9
print(a)
[9 1 2 3 4]
a = jnp.arange(5)
a[0] = 9
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable.
b = a.at[0].set(9)
print(b)
[9 1 2 3 4]

Strict input control

NumPy is easy-going:

np.sum([1.0, 2.0])  # argument is a list
np.float64(3.0)
np.sum((1.0, 2.0))  # argument is a tuple
np.float64(3.0)

To avoid inefficiencies, JAX will only accept arrays:

jnp.sum([1.0, 2.0])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'>
jnp.sum((1.0, 2.0))
TypeError: sum requires ndarray or scalar arguments, got <class 'tuple'>

Out of bounds indexing

NumPy will error if you index out of bounds:

print(np.arange(5)[10])
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[10], line 1
----> 1 print(np.arange(5)[10])

IndexError: index 10 is out of bounds for axis 0 with size 5

JAX will silently return the closest boundary:

print(jnp.arange(5)[10])
4

PRNG

Traditional pseudorandom number generators are based on nondeterministic state of OS

Slow and problematic for parallel executions

JAX relies on explicitly-set random state called a key:

from jax import random

initial_key = random.PRNGKey(18)
print(initial_key)
[ 0 18]

PRNG

Each key can only be used for one random function, but it can be split into new keys:

new_key1, new_key2 = random.split(initial_key)

initial_key can’t be used anymore now

print(new_key1)
[4197003906 1654466292]
print(new_key2)
[1685972163 1654824463]

We need to keep one key to split whenever we need and we can use the other one

PRNG

To make sure we don’t reuse a key by accident, it is best to overwrite the initial key with one of the new ones

Here are easier names:

key = random.PRNGKey(18)
key, subkey = random.split(key)

We can now use subkey to generate a random array:

x = random.normal(subkey, (3, 2))

Benchmarking

JAX uses asynchronous dispatch

Instead of waiting for a computation to complete before control returns to Python, the computation is dispatched to an accelerator and a future is created

To get proper timings, we need to make sure the future is resolved by using the block_until_ready() method

JAX functioning

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Transformations py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Vectorization Parallelization   Differentiation   py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

JIT compilation

JIT syntax

from jax import jit

key = random.PRNGKey(8)
key, subkey1, subkey2 = random.split(key, 3)

a = random.normal(subkey1, (500, 500))
b = random.normal(subkey2, (500, 500))

def sum_squared_error(a, b):
    return jnp.sum((a-b)**2)

Our function could simply be used as:

sse = sum_squared_error(a, b)

JIT syntax

Our code will run faster if we create a JIT compiled version and use that instead:

sum_squared_error_jit = jit(sum_squared_error)

sse = sum_squared_error_jit(a, b)

Alternatively, this can be written as:

sse = jit(sum_squared_error)(a, b)

Or with the @jit decorator:

@jit
def sum_squared_error(a, b):
    return jnp.sum((a - b) ** 2)

sse = sum_squared_error(a, b)

Static vs traced variables

@jit
def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

print(cond_func(1.0))
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]

JIT compilation uses tracing of the code based on shape and dtype so that the same compiled code can be reused for new values with the same characteristics

Tracer objects are not real values but abstract representation that are more general

Here, an abstract general value does not work as it wouldn’t know which branch to take

Static vs traced variables

One solution is to tell jit() to exclude the problematic arguments from tracing

with arguments positions:

def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

cond_func_jit = jit(cond_func, static_argnums=(0,))

print(cond_func_jit(2.0))
print(cond_func_jit(-2.0))
8.0
4.0

Static vs traced variables

One solution is to tell jit() to exclude the problematic arguments from tracing

with arguments names:

def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

cond_func_jit_alt = jit(cond_func, static_argnames="x")

print(cond_func_jit_alt(2.0))
print(cond_func_jit_alt(-2.0))
8.0
4.0

Control flow primitives

Another solution, is to use one of the structured control flow primitives:

from jax import lax

lax.cond(False, lambda x: x ** 2.0, lambda x: x ** 3.0, jnp.array([2.]))
Array([8.], dtype=float32)
lax.cond(True, lambda x: x ** 2.0, lambda x: x ** 3.0, jnp.array([-2.]))
Array([4.], dtype=float32)

Control flow primitives

Other control flow primitives:

  • lax.while_loop
  • lax.fori_loop
  • lax.scan

Other pseudo dynamic control flow functions:

  • lax.select (NumPy API jnp.where and jnp.select)
  • lax.switch (NumPy API jnp.piecewise)

Static vs traced operations

Similarly, you can mark problematic operations as static so that they don’t get traced during JIT compilation:

@jit
def f(x):
    return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
print(f(x))
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>]

Static vs traced operations

The problem here is that the shape of the argument to prod() depends on the value of x which is unknown at compilation time

One solution is to use the NumPy version of prod():

import numpy as np

@jit
def f(x):
    return x.reshape((np.prod(x.shape)))

print(f(x))
[1. 1. 1. 1. 1. 1.]

Functionally pure functions

Jaxprs

import jax

x = jnp.array([1., 4., 3.])
y = jnp.array([8., 1., 2.])

def f(x, y):
    return 2 * x**2 + y

jax.make_jaxpr(f)(x, y) 
{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = integer_pow[y=2] a
    d:f32[3] = mul 2.0 c
    e:f32[3] = add d b
  in (e,) }

Outputs only based on inputs

def f(x):
    return a + x

f uses the variable a from the global environment

The output does not solely depend on the inputs: not a pure function

Outputs only based on inputs

a = jnp.ones(3)
print(a)
[1. 1. 1.]
def f(x):
    return a + x

print(jit(f)(jnp.ones(3)))
[2. 2. 2.]

Things seem ok here because this is the first run (tracing)

Outputs only based on inputs

Now, let’s change the value of a to an array of zeros:

a = jnp.zeros(3)
print(a)
[0. 0. 0.]

And rerun the same code:

print(jit(f)(jnp.ones(3)))
[2. 2. 2.]

Our cached compiled program is run and we get a wrong result

Outputs only based on inputs

The new value for a will only take effect if we re-trigger tracing by changing the shape and/or dtype of x:

a = jnp.zeros(4)
print(a)
[0. 0. 0. 0.]
print(jit(f)(jnp.ones(4)))
[1. 1. 1. 1.]

Passing to f() an argument of a different shape forced retracing

No side effects

Side effects: anything beside returned output

Examples:

  • Printing to standard output
  • Reading from file/writing to file
  • Modifying a global variable

No side effects

The side effects will happen during tracing, but not on subsequent runs. You cannot rely on side effects in your code

def f(a, b):
    print("Calculating sum")
    return a + b

print(jit(f)(jnp.arange(3), jnp.arange(3)))
Calculating sum
[0 2 4]

Printing happened here because this is the first run

No side effects

Let’s rerun the function:

print(jit(f)(jnp.arange(3), jnp.arange(3)))
[0 2 4]

This time, no printing

Other transformations

Automatic differentiation

Considering the function f:

f = lambda x: x**3 + 2*x**2 - 3*x + 8

We can create a new function dfdx that computes the gradient of f w.r.t. x:

from jax import grad

dfdx = grad(f)

dfdx returns the derivatives

print(dfdx(1.))
4.0

Composing transformations

Transformations can be composed:

print(jit(grad(f))(1.))
4.0
print(grad(jit(f))(1.))
4.0

Forward and reverse modes

Other autodiff methods:

  • Reverse-mode vector-Jacobian products: jax.vjp
  • Forward-mode Jacobian-vector products: jax.jvp

Higher-order differentiation


With a single variable, the grad function calls can be nested:

d2fdx = grad(dfdx)   # function to compute 2nd order derivatives
d3fdx = grad(d2fdx)  # function to compute 3rd order derivatives
...


With several variables:

  • jax.jacfwd for forward-mode
  • jax.jacrev for reverse-mode

Pytrees

JAX has a nested container structure: pytree extremely useful for DNN

Vectorization and parallelization

Other transformations for parallel run of computations across batches of arrays:

  • Automatic vectorization with jax.vmap
  • Parallelization across devices with jax.pmap

Pushing optimizations further

Lax API

jax.numpy is a high-level NumPy-like API wrapped around jax.lax

jax.lax is a more efficient lower-level API itself wrapped around XLA

Pallas: extension to write GPU and TPU kernels

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo triton Triton GPU GPU triton->GPU mosaic Mosaic TPU TPU mosaic->TPU transform Vectorization Parallelization   Differentiation   py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->triton hlo->mosaic