April 16, 2024

# Context

## What is JAX?

Library for Python developed by Google

Key data structure: Array

Composition, transformation, and differentiation of numerical programs

Compilation for CPUs, GPUs, and TPUs

NumPy-like and lower-level APIs

Requires strict functional programming

## Why JAX?

Summarized from a blog post by Chris Rackauckas

# Getting started

## Installation

Install from pip wheels:

• Personal computer: use wheels installation commands from official site
• Alliance clusters: `python -m pip install jax --no-index`

Windows: GPU support only via WSL

## The NumPy API

``````import numpy as np

print(np.array([(1, 2, 3), (4, 5, 6)]))``````
``````[[1 2 3]
[4 5 6]]``````
``print(np.arange(5))``
``[0 1 2 3 4]``
``print(np.zeros(2))``
``[0. 0.]``
``print(np.linspace(0, 2, 9))``
``[0.   0.25 0.5  0.75 1.   1.25 1.5  1.75 2.  ]``
``````import jax.numpy as jnp

print(jnp.array([(1, 2, 3), (4, 5, 6)]))``````
``````[[1 2 3]
[4 5 6]]``````
``print(jnp.arange(5))``
``[0 1 2 3 4]``
``print(jnp.zeros(2))``
``[0. 0.]``
``print(jnp.linspace(0, 2, 9))``
``[0.   0.25 0.5  0.75 1.   1.25 1.5  1.75 2.  ]``

# But JAX NumPy is not NumPy…

## Different types

``type(np.zeros((2, 3)))``
``numpy.ndarray``
``type(jnp.zeros((2, 3)))``
``jaxlib.xla_extension.ArrayImpl``

## Different default data types

``np.zeros((2, 3)).dtype``
``dtype('float64')``
``jnp.zeros((2, 3)).dtype``
``dtype('float32')``

Standard for DL and libraries built for accelerators
Float64 are very slow on GPUs and not supported on TPUs

## Immutable arrays

``````a = np.arange(5)
a[0] = 9
print(a)``````
``[9 1 2 3 4]``
``````a = jnp.arange(5)
a[0] = 9``````
``TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable.``
``````b = a.at[0].set(9)
print(b)``````
``[9 1 2 3 4]``

## Strict input control

NumPy is easy-going:

``np.sum([1.0, 2.0])  # argument is a list``
``3.0``
``np.sum((1.0, 2.0))  # argument is a tuple``
``3.0``

To avoid inefficiencies, JAX will only accept arrays:

``jnp.sum([1.0, 2.0])``
``TypeError: sum requires ndarray or scalar arguments, got <class 'list'>``
``jnp.sum((1.0, 2.0))``
``TypeError: sum requires ndarray or scalar arguments, got <class 'tuple'>``

## Out of bounds indexing

NumPy will error if you index out of bounds:

``print(np.arange(5)[10])``
``IndexError: index 10 is out of bounds for axis 0 with size 5``

JAX will silently return the closest boundary:

``print(jnp.arange(5)[10])``
``4``

## PRNG

Traditional pseudorandom number generators are based on nondeterministic state of OS

Slow and problematic for parallel executions

JAX relies on explicitly-set random state called a key:

``````from jax import random

initial_key = random.PRNGKey(18)
print(initial_key)``````
``[ 0 18]``

## PRNG

Each key can only be used for one random function, but it can be split into new keys:

``new_key1, new_key2 = random.split(initial_key)``

`initial_key` can’t be used anymore now

``print(new_key1)``
``[4197003906 1654466292]``
``print(new_key2)``
``[1685972163 1654824463]``

We need to keep one key to split whenever we need and we can use the other one

## PRNG

To make sure we don’t reuse a key by accident, it is best to overwrite the initial key with one of the new ones

Here are easier names:

``````key = random.PRNGKey(18)
key, subkey = random.split(key)``````

We can now use `subkey` to generate a random array:

``x = random.normal(subkey, (3, 2))``

## Benchmarking

JAX uses asynchronous dispatch

Instead of waiting for a computation to complete before control returns to Python, the computation is dispatched to an accelerator and a future is created

To get proper timings, we need to make sure the future is resolved by using the `block_until_ready()` method

# JIT compilation

## JIT syntax

``````from jax import jit

key = random.PRNGKey(8)
key, subkey1, subkey2 = random.split(key, 3)

a = random.normal(subkey1, (500, 500))
b = random.normal(subkey2, (500, 500))

def sum_squared_error(a, b):
return jnp.sum((a-b)**2)``````

Our function could simply be used as:

``sse = sum_squared_error(a, b)``

## JIT syntax

Our code will run faster if we create a JIT compiled version and use that instead:

``````sum_squared_error_jit = jit(sum_squared_error)

sse = sum_squared_error_jit(a, b)``````

Alternatively, this can be written as:

``sse = jit(sum_squared_error)(a, b)``

Or with the `@jit` decorator:

``````@jit
def sum_squared_error(a, b):
return jnp.sum((a - b) ** 2)

sse = sum_squared_error(a, b)``````

## Static vs traced variables

``````@jit
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0

print(cond_func(1.0))``````
``jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]``

JIT compilation uses tracing of the code based on shape and dtype so that the same compiled code can be reused for new values with the same characteristics

Tracer objects are not real values but abstract representation that are more general

Here, an abstract general value does not work as it wouldn’t know which branch to take

## Static vs traced variables

One solution is to tell `jit()` to exclude the problematic arguments from tracing

with arguments positions:

``````def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0

cond_func_jit = jit(cond_func, static_argnums=(0,))

print(cond_func_jit(2.0))
print(cond_func_jit(-2.0))``````
``````8.0
4.0``````

## Static vs traced variables

One solution is to tell `jit()` to exclude the problematic arguments from tracing

with arguments names:

``````def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0

cond_func_jit_alt = jit(cond_func, static_argnames="x")

print(cond_func_jit_alt(2.0))
print(cond_func_jit_alt(-2.0))``````
``````8.0
4.0``````

## Control flow primitives

Another solution, is to use one of the structured control flow primitives:

``````from jax import lax

lax.cond(False, lambda x: x ** 2.0, lambda x: x ** 3.0, jnp.array([2.]))``````
``Array([8.], dtype=float32)``
``lax.cond(True, lambda x: x ** 2.0, lambda x: x ** 3.0, jnp.array([-2.]))``
``Array([4.], dtype=float32)``

## Control flow primitives

Other control flow primitives:

• `lax.while_loop`
• `lax.fori_loop`
• `lax.scan`

Other pseudo dynamic control flow functions:

• `lax.select` (NumPy API `jnp.where` and `jnp.select`)
• `lax.switch` (NumPy API `jnp.piecewise`)

## Static vs traced operations

Similarly, you can mark problematic operations as static so that they don’t get traced during JIT compilation:

``````@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
print(f(x))``````
``TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>]``

## Static vs traced operations

The problem here is that the shape of the argument to `prod()` depends on the value of `x` which is unknown at compilation time

One solution is to use the NumPy version of `prod()`:

``````import numpy as np

@jit
def f(x):
return x.reshape((np.prod(x.shape)))

print(f(x))``````
``[1. 1. 1. 1. 1. 1.]``

# Functionally pure functions

## Jaxprs

``````import jax

x = jnp.array([1., 4., 3.])
y = jnp.array([8., 1., 2.])

def f(x, y):
return 2 * x**2 + y

jax.make_jaxpr(f)(x, y) ``````
``````{ lambda ; a:f32[3] b:f32[3]. let
c:f32[3] = integer_pow[y=2] a
d:f32[3] = mul 2.0 c
in (e,) }``````

## Outputs only based on inputs

``````def f(x):
return a + x``````

`f` uses the variable `a` from the global environment

The output does not solely depend on the inputs: not a pure function

## Outputs only based on inputs

``````a = jnp.ones(3)
print(a)``````
``[1. 1. 1.]``
``````def f(x):
return a + x

print(jit(f)(jnp.ones(3)))``````
``[2. 2. 2.]``

Things seem ok here because this is the first run (tracing)

## Outputs only based on inputs

Now, let’s change the value of `a` to an array of zeros:

``````a = jnp.zeros(3)
print(a)``````
``[0. 0. 0.]``

And rerun the same code:

``print(jit(f)(jnp.ones(3)))``
``[2. 2. 2.]``

Our cached compiled program is run and we get a wrong result

## Outputs only based on inputs

The new value for `a` will only take effect if we re-trigger tracing by changing the shape and/or dtype of `x`:

``````a = jnp.zeros(4)
print(a)``````
``[0. 0. 0. 0.]``
``print(jit(f)(jnp.ones(4)))``
``[1. 1. 1. 1.]``

Passing to `f()` an argument of a different shape forced retracing

## No side effects

Side effects: anything beside returned output

Examples:

• Printing to standard output
• Reading from file/writing to file
• Modifying a global variable

## No side effects

The side effects will happen during tracing, but not on subsequent runs. You cannot rely on side effects in your code

``````def f(a, b):
print("Calculating sum")
return a + b

print(jit(f)(jnp.arange(3), jnp.arange(3)))``````
``````Calculating sum
[0 2 4]``````

Printing happened here because this is the first run

## No side effects

Let’s rerun the function:

``print(jit(f)(jnp.arange(3), jnp.arange(3)))``
``[0 2 4]``

This time, no printing

# Other transformations

## Automatic differentiation

Considering the function `f`:

``f = lambda x: x**3 + 2*x**2 - 3*x + 8``

We can create a new function `dfdx` that computes the gradient of `f` w.r.t. `x`:

``````from jax import grad

`dfdx` returns the derivatives

``print(dfdx(1.))``
``4.0``

## Composing transformations

Transformations can be composed:

``print(jit(grad(f))(1.))``
``4.0``
``print(grad(jit(f))(1.))``
``4.0``

## Forward and reverse modes

Other autodiff methods:

• Reverse-mode vector-Jacobian products: `jax.vjp`
• Forward-mode Jacobian-vector products: `jax.jvp`

## Higher-order differentiation

With a single variable, the `grad` function calls can be nested:

``````d2fdx = grad(dfdx)   # function to compute 2nd order derivatives
d3fdx = grad(d2fdx)  # function to compute 3rd order derivatives
...``````

With several variables:

• `jax.jacfwd` for forward-mode
• `jax.jacrev` for reverse-mode

## Pytrees

JAX has a nested container structure: pytree extremely useful for DNN

## Vectorization and parallelization

Other transformations for parallel run of computations across batches of arrays:

• Automatic vectorization with `jax.vmap`
• Parallelization across devices with `jax.pmap`

# Pushing optimizations further

## Lax API

`jax.numpy` is a high-level NumPy-like API wrapped around `jax.lax`

`jax.lax` is a more efficient lower-level API itself wrapped around XLA