JIT compilation
JIT compilation is a key component to JAX efficiency. For the most part, it is very easy to use, but there are subtleties to be aware of.
JIT
JAX functions are already compiled and optimized, but user functions can also be optimized for the XLA by JIT compilation which will combine computations.
Remember the map of JAX functioning:
This is done by the jax.jit()
function or the equivalent decorator @jit
.
Let’s consider this code:
import jax.numpy as jnp
from jax import jit
from jax import random
= random.PRNGKey(8)
key = random.split(key, 3)
key, subkey1, subkey2
= random.normal(subkey1, (500, 500))
a = random.normal(subkey2, (500, 500))
b
def sum_squared_error(a, b):
return jnp.sum((a-b)**2)
Our function can simply be used as:
print(sum_squared_error(a, b))
Our code will run faster if we create a JIT compiled version and use that instead (we will see how to benchmark JAX code later in the course. There are some subtleties for that too, so for now, just believe that it is faster. You will be able to test it later):
= jit(sum_squared_error)
sum_squared_error_jit print(sum_squared_error_jit(a, b))
502084.75
Alternatively, this can be written as:
print(jit(sum_squared_error)(a, b))
502084.75
Or as:
@jit
def sum_squared_error(a, b):
return jnp.sum((a - b) ** 2)
print(sum_squared_error(a, b))
502084.75
Understanding jaxprs
Let’s have a look at the jaxpr of a jit-compiled function.
This is what the jaxpr of the non-jit-compiled function looks like:
import jax
def sum_squared_error(a, b):
return jnp.sum((a - b) ** 2)
jax.make_jaxpr(sum_squared_error)(x, y)
{ lambda ; a:f32[3] b:f32[3]. let
c:f32[3] = sub a b
d:f32[3] = integer_pow[y=2] c
e:f32[] = reduce_sum[axes=(0,)] d
in (e,) }
The jaxpr of the jit-compiled function looks like this:
@jit
def sum_squared_error(a, b):
return jnp.sum((a - b) ** 2)
jax.make_jaxpr(sum_squared_error)(x, y)
{ lambda ; a:f32[3] b:f32[3]. let
c:f32[] = pjit[
name=sum_squared_error
jaxpr={ lambda ; d:f32[3] e:f32[3]. let
f:f32[3] = sub d e
g:f32[3] = integer_pow[y=2] f
h:f32[] = reduce_sum[axes=(0,)] g
in (h,) }
] a b
in (c,) }
JIT constraints
Using jit
in the example above was very easy. There are situations however in which tracing will fail.
Control flow
One example can arise with control flow:
@jit
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0
print(cond_func(1.0))
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function cond_func at jx_jit.qmd:85 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
JIT compilation uses tracing of the code based on shape and dtype so that the same compiled code can be reused for new values with the same characteristics. The tracer objects are not real values but abstract representation that are more general. In control flow situations such as the one we have here, an abstract general value does not work as it wouldn’t know which branch to take.
Static variables
One solution is to tell jit()
to exclude the problematic arguments (in our example the argument: x
) from tracing (i.e. to consider them as static). Of course, those elements will not be optimized, but the rest of the code will, so it is a lot better than not JIT compiling the function at all.
You can either use the static_argnums
parameter which takes an integer or a collection of integers to specify the position of the arguments to treat as static:
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0
= jit(cond_func, static_argnums=(0,))
cond_func_jit
print(cond_func_jit(2.0))
print(cond_func_jit(-2.0))
8.0
4.0
Or you can use static_argnames
which accepts argument names:
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0
= jit(cond_func, static_argnames='x')
cond_func_jit_alt
print(cond_func_jit_alt(2.0))
print(cond_func_jit_alt(-2.0))
8.0
4.0
You cannot use the @jit
decorator when you need to pass arguments to the jit
function, but you can still use a decorator:
from functools import partial
@partial(jit, static_argnums=(0,))
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0
or:
@partial(jit, static_argnames=['x'])
def cond_func(x):
if x < 0.0:
return x ** 2.0
else:
return x ** 3.0
Control flow primitives
If you don’t want the code to recompile for each new value, another solution, is to use one of the structured control flow primitives:
from jax import lax
False, lambda x: x ** 2.0, lambda x: x ** 3.0, jnp.array([2.])) lax.cond(
Array([8.], dtype=float32)
True, lambda x: x ** 2.0, lambda x: x ** 3.0, jnp.array([-2.])) lax.cond(
Array([4.], dtype=float32)
There are other control flow primitives:
lax.while_loop
lax.fori_loop
lax.scan
and other pseudo dynamic control flow functions:
lax.select
(NumPy APIjnp.where
andjnp.select
)lax.switch
(NumPy APIjnp.piecewise
)
Static operations
Similarly, you will need to mark problematic operations as static so that they don’t get traced during JIT compilation:
@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())
= jnp.ones((2, 3))
x print(f(x))
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The problem here is that the shape of the argument to jnp.reshape
is traced while it needs to be static.
One solution is to use the NumPy version of prod
which will not create a traced result:
import numpy as np
@jit
def f(x):
return x.reshape(np.prod(x.shape))
print(f(x))
[1. 1. 1. 1. 1. 1.]