Accelerated array computing and flexible differentiation with JAX
Content from the webinar slides for easier browsing.
Context
What is JAX?
- Library for Python developed by Google.
- Key data structure: Array.
- Composition, transformation, and differentiation of numerical programs.
- Compilation for CPUs, GPUs, and TPUs.
- NumPy-like and lower-level APIs.
- Requires strict functional programming.
Why JAX?
Summarized from a blog post by Chris Rackauckas.
Getting started
Installation
Install from pip wheels:
- Personal computer: use wheels installation commands from official site.
- Alliance clusters:
python -m pip install jax --no-index.
Windows: GPU support only via WSL.
The NumPy API
import numpy as np
print(np.array([(1, 2, 3), (4, 5, 6)]))[[1 2 3]
[4 5 6]]
print(np.arange(5))[0 1 2 3 4]
print(np.zeros(2))[0. 0.]
print(np.linspace(0, 2, 9))[0. 0.25 0.5 0.75 1. 1.25 1.5 1.75 2. ]
import jax.numpy as jnp
print(jnp.array([(1, 2, 3), (4, 5, 6)]))[[1 2 3]
[4 5 6]]
print(jnp.arange(5))[0 1 2 3 4]
print(jnp.zeros(2))[0. 0.]
print(jnp.linspace(0, 2, 9))[0. 0.25 0.5 0.75 1. 1.25 1.5 1.75 2. ]
But JAX NumPy is not NumPy…
Different types
type(np.zeros((2, 3)))numpy.ndarray
type(jnp.zeros((2, 3)))jaxlib.xla_extension.ArrayImpl
Different default data types
np.zeros((2, 3)).dtypedtype('float64')
jnp.zeros((2, 3)).dtypedtype('float32')
Standard for DL and libraries built for accelerators.
Float64 are very slow on GPUs and not supported on TPUs.
Immutable arrays
a = np.arange(5)
a[0] = 9
print(a)[9 1 2 3 4]
a = jnp.arange(5)
a[0] = 9TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable.
b = a.at[0].set(9)
print(b)[9 1 2 3 4]
Strict input control
NumPy is easy-going:
np.sum([1.0, 2.0]) # argument is a listnp.float64(3.0)
np.sum((1.0, 2.0)) # argument is a tuplenp.float64(3.0)
To avoid inefficiencies, JAX will only accept arrays:
jnp.sum([1.0, 2.0])TypeError: sum requires ndarray or scalar arguments, got <class 'list'>
jnp.sum((1.0, 2.0))TypeError: sum requires ndarray or scalar arguments, got <class 'tuple'>
Out of bounds indexing
NumPy will error if you index out of bounds:
print(np.arange(5)[10])--------------------------------------------------------------------------- IndexError Traceback (most recent call last) Cell In[10], line 1 ----> 1 print(np.arange(5)[10]) IndexError: index 10 is out of bounds for axis 0 with size 5
JAX will silently return the closest boundary:
print(jnp.arange(5)[10])4
PRNG
Traditional pseudorandom number generators are based on nondeterministic state of OS.
Slow and problematic for parallel executions.
JAX relies on explicitly-set random state called a key:
from jax import random
initial_key = random.PRNGKey(18)
print(initial_key)[ 0 18]
Each key can only be used for one random function, but it can be split into new keys:
new_key1, new_key2 = random.split(initial_key)initial_key can’t be used anymore now.
print(new_key1)[4197003906 1654466292]
print(new_key2)[1685972163 1654824463]
We need to keep one key to split whenever we need and we can use the other one.
To make sure we don’t reuse a key by accident, it is best to overwrite the initial key with one of the new ones.
Here are easier names:
key = random.PRNGKey(18)
key, subkey = random.split(key)We can now use subkey to generate a random array:
x = random.normal(subkey, (3, 2))Benchmarking
JAX uses asynchronous dispatch.
Instead of waiting for a computation to complete before control returns to Python, the computation is dispatched to an accelerator and a future is created.
To get proper timings, we need to make sure the future is resolved by using the block_until_ready() method.
JAX functioning
JIT compilation
JIT syntax
from jax import jit
key = random.PRNGKey(8)
key, subkey1, subkey2 = random.split(key, 3)
a = random.normal(subkey1, (500, 500))
b = random.normal(subkey2, (500, 500))
def sum_squared_error(a, b):
return jnp.sum((a-b)**2)Our function could simply be used as:
sse = sum_squared_error(a, b)Our code will run faster if we create a JIT compiled version and use that instead:
sum_squared_error_jit = jit(sum_squared_error)
sse = sum_squared_error_jit(a, b)Alternatively, this can be written as:
sse = jit(sum_squared_error)(a, b)Or with the @jit decorator:
@jit
def sum_squared_error(a, b):
return jnp.sum((a - b) ** 2)
sse = sum_squared_error(a, b)Static vs traced variables
@jit
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0
print(cond_func(1.0))jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]
JIT compilation uses tracing of the code based on shape and dtype so that the same compiled code can be reused for new values with the same characteristics.
Tracer objects are not real values but abstract representation that are more general.
Here, an abstract general value does not work as it wouldn’t know which branch to take.
One solution is to tell jit() to exclude the problematic arguments from tracing …
… with arguments positions:
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0
cond_func_jit = jit(cond_func, static_argnums=(0,))
print(cond_func_jit(2.0))
print(cond_func_jit(-2.0))8.0
4.0
… with arguments names:
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0
cond_func_jit_alt = jit(cond_func, static_argnames="x")
print(cond_func_jit_alt(2.0))
print(cond_func_jit_alt(-2.0))8.0
4.0
Control flow primitives
Another solution, is to use one of the structured control flow primitives:
from jax import lax
lax.cond(False, lambda x: x ** 2.0, lambda x: x ** 3.0, jnp.array([2.]))Array([8.], dtype=float32)
lax.cond(True, lambda x: x ** 2.0, lambda x: x ** 3.0, jnp.array([-2.]))Array([4.], dtype=float32)
Other control flow primitives:
lax.while_looplax.fori_looplax.scan
Other pseudo dynamic control flow functions:
lax.select(NumPy APIjnp.whereandjnp.select)lax.switch(NumPy APIjnp.piecewise)
Static vs traced operations
Similarly, you can mark problematic operations as static so that they don’t get traced during JIT compilation:
@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())
x = jnp.ones((2, 3))
print(f(x))TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>]
The problem here is that the shape of the argument to prod() depends on the value of x which is unknown at compilation time.
One solution is to use the NumPy version of prod():
import numpy as np
@jit
def f(x):
return x.reshape((np.prod(x.shape)))
print(f(x))[1. 1. 1. 1. 1. 1.]
Functionally pure functions
Jaxprs
import jax
x = jnp.array([1., 4., 3.])
y = jnp.array([8., 1., 2.])
def f(x, y):
return 2 * x**jax.make_jaxpr(f)(x, y) { lambda ; a:f32[3] b:f32[3]. let
c:f32[3] = integer_pow[y=2] a
d:f32[3] = mul 2.0 c
e:f32[3] = add d b
in (e,) }
Outputs only based on inputs
def f(x):
return a + xf uses the variable a from the global environment.
The output does not solely depend on the inputs: not a pure function.
a = jnp.ones(3)
print(a)[1. 1. 1.]
def f(x):
return print(jit(f)(jnp.ones(3)))[2. 2. 2.]
Things seem ok here because this is the first run (tracing).
Now, let’s change the value of a to an array of zeros:
a = jnp.zeros(3)
print(a)[0. 0. 0.]
And rerun the same code:
print(jit(f)(jnp.ones(3)))[2. 2. 2.]
Our cached compiled program is run and we get a wrong result.
The new value for a will only take effect if we re-trigger tracing by changing the shape and/or dtype of x:
a = jnp.zeros(4)
print(a)[0. 0. 0. 0.]
print(jit(f)(jnp.ones(4)))[1. 1. 1. 1.]
Passing to f() an argument of a different shape forced retracing.
No side effects
Side effects: anything beside returned output.
Examples:
- Printing to standard output
- Reading from file/writing to file
- Modifying a global variable
The side effects will happen during tracing, but not on subsequent runs. You cannot rely on side effects in your code.
def f(a, b):
print("Calculating sum")
return print(jit(f)(jnp.arange(3), jnp.arange(3)))Calculating sum
[0 2 4]
Printing happened here because this is the first run.
Let’s rerun the function:
print(jit(f)(jnp.arange(3), jnp.arange(3)))[0 2 4]
This time, no printing.
Other transformations
Automatic differentiation
Considering the function f:
f = lambda x: x**3 + 2*x**2 - 3*x + 8We can create a new function dfdx that computes the gradient of f w.r.t. x:
from jax import grad
dfdx = grad(f)dfdx returns the derivatives.
print(dfdx(1.))4.0
Composing transformations
Transformations can be composed:
print(jit(grad(f))(1.))4.0
print(grad(jit(f))(1.))4.0
Forward and reverse modes
Other autodiff methods:
- Reverse-mode vector-Jacobian products:
jax.vjp - Forward-mode Jacobian-vector products:
jax.jvp
Higher-order differentiation
With a single variable, the grad function calls can be nested:
d2fdx = grad(dfdx) # function to compute 2nd order derivatives
d3fdx = grad(d2fdx) # function to compute 3rd order derivatives
...With several variables:
jax.jacfwdfor forward-modejax.jacrevfor reverse-mode
Pytrees
JAX has a nested container structure: pytree extremely useful for DNN.
Vectorization and parallelization
Other transformations for parallel run of computations across batches of arrays:
- Automatic vectorization with
jax.vmap - Parallelization across devices with
jax.pmap
Pushing optimizations further
Lax API
jax.numpy is a high-level NumPy-like API wrapped around jax.lax.
jax.lax is a more efficient lower-level API itself wrapped around XLA.