Marie-Hélène Burle
April 16, 2024
Library for Python developed by Google
Key data structure: Array
Composition, transformation, and differentiation of numerical programs
Compilation for CPUs, GPUs, and TPUs
NumPy-like and lower-level APIs
Requires strict functional programming
Summarized from a blog post by Chris Rackauckas
Install from pip wheels:
python -m pip install jax --no-index
Windows: GPU support only via WSL
NumPy is easy-going:
NumPy will error if you index out of bounds:
Traditional pseudorandom number generators are based on nondeterministic state of OS
Slow and problematic for parallel executions
JAX relies on explicitly-set random state called a key:
[ 0 18]
Each key can only be used for one random function, but it can be split into new keys:
initial_key
can’t be used anymore now
[4197003906 1654466292]
[1685972163 1654824463]
We need to keep one key to split whenever we need and we can use the other one
To make sure we don’t reuse a key by accident, it is best to overwrite the initial key with one of the new ones
Here are easier names:
We can now use subkey
to generate a random array:
JAX uses asynchronous dispatch
Instead of waiting for a computation to complete before control returns to Python, the computation is dispatched to an accelerator and a future is created
To get proper timings, we need to make sure the future is resolved by using the block_until_ready()
method
from jax import jit
key = random.PRNGKey(8)
key, subkey1, subkey2 = random.split(key, 3)
a = random.normal(subkey1, (500, 500))
b = random.normal(subkey2, (500, 500))
def sum_squared_error(a, b):
return jnp.sum((a-b)**2)
Our function could simply be used as:
Our code will run faster if we create a JIT compiled version and use that instead:
Alternatively, this can be written as:
Or with the @jit
decorator:
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]
JIT compilation uses tracing of the code based on shape and dtype so that the same compiled code can be reused for new values with the same characteristics
Tracer objects are not real values but abstract representation that are more general
Here, an abstract general value does not work as it wouldn’t know which branch to take
One solution is to tell jit()
to exclude the problematic arguments from tracing
with arguments positions:
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0
cond_func_jit = jit(cond_func, static_argnums=(0,))
print(cond_func_jit(2.0))
print(cond_func_jit(-2.0))
8.0
4.0
One solution is to tell jit()
to exclude the problematic arguments from tracing
with arguments names:
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0
cond_func_jit_alt = jit(cond_func, static_argnames="x")
print(cond_func_jit_alt(2.0))
print(cond_func_jit_alt(-2.0))
8.0
4.0
Another solution, is to use one of the structured control flow primitives:
Array([8.], dtype=float32)
Array([4.], dtype=float32)
Other control flow primitives:
lax.while_loop
lax.fori_loop
lax.scan
Other pseudo dynamic control flow functions:
lax.select
(NumPy API jnp.where
and jnp.select
)lax.switch
(NumPy API jnp.piecewise
)Similarly, you can mark problematic operations as static so that they don’t get traced during JIT compilation:
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>]
The problem here is that the shape of the argument to prod()
depends on the value of x
which is unknown at compilation time
One solution is to use the NumPy version of prod()
:
[1. 1. 1. 1. 1. 1.]
import jax
x = jnp.array([1., 4., 3.])
y = jnp.array([8., 1., 2.])
def f(x, y):
return 2 * x**2 + y
jax.make_jaxpr(f)(x, y)
{ lambda ; a:f32[3] b:f32[3]. let
c:f32[3] = integer_pow[y=2] a
d:f32[3] = mul 2.0 c
e:f32[3] = add d b
in (e,) }
f
uses the variable a
from the global environment
The output does not solely depend on the inputs: not a pure function
[1. 1. 1.]
[2. 2. 2.]
Things seem ok here because this is the first run (tracing)
Now, let’s change the value of a
to an array of zeros:
[0. 0. 0.]
And rerun the same code:
[2. 2. 2.]
Our cached compiled program is run and we get a wrong result
The new value for a
will only take effect if we re-trigger tracing by changing the shape and/or dtype of x
:
[0. 0. 0. 0.]
[1. 1. 1. 1.]
Passing to f()
an argument of a different shape forced retracing
Side effects: anything beside returned output
Examples:
The side effects will happen during tracing, but not on subsequent runs. You cannot rely on side effects in your code
Calculating sum
[0 2 4]
Printing happened here because this is the first run
Let’s rerun the function:
[0 2 4]
This time, no printing
Considering the function f
:
We can create a new function dfdx
that computes the gradient of f
w.r.t. x
:
dfdx
returns the derivatives
4.0
Transformations can be composed:
4.0
4.0
Other autodiff methods:
jax.vjp
jax.jvp
With a single variable, the grad
function calls can be nested:
d2fdx = grad(dfdx) # function to compute 2nd order derivatives
d3fdx = grad(d2fdx) # function to compute 3rd order derivatives
...
With several variables:
jax.jacfwd
for forward-modejax.jacrev
for reverse-modeJAX has a nested container structure: pytree extremely useful for DNN
Other transformations for parallel run of computations across batches of arrays:
jax.vmap
jax.pmap
jax.numpy
is a high-level NumPy-like API wrapped around jax.lax
jax.lax
is a more efficient lower-level API itself wrapped around XLA