JIT subtleties

Author

Marie-Hélène Burle

JIT compilation is a key component to what makes JAX efficient. For the most part, it is very easy to use, but there are subtleties to be aware of. We will explore those in this section.

JIT

Instead of executing computations one at a time, they can be combined and optimized by JIT compilation before being passed to the XLA.

This is done by the jax.jit() function or the equivalent decorator @jit.

Let’s consider this code:

import jax.numpy as jnp
from jax import jit
from jax import random

key = random.PRNGKey(8)
key, subkey1, subkey2 = random.split(key, 3)

a = random.normal(subkey1, (500, 500))
b = random.normal(subkey2, (500, 500))

def sum_squared_error(a, b):
    return jnp.sum((a-b)**2)

Our function can simply be used as:

print(sum_squared_error(a, b))

The code will run faster however if we create a JIT compiled version of the function and use that instead (we will see how to benchmark JAX code later in the course):

sum_squared_error_jit = jit(sum_squared_error)
print(sum_squared_error_jit(a, b))
502084.75

Alternatively, this can be written as:

print(jit(sum_squared_error)(a, b))
502084.75

Or as:

@jit
def sum_squared_error(a, b):
    return jnp.sum((a - b) ** 2)

print(sum_squared_error(a, b))
502084.75

This was very easy. There are situations however in which tracing will fail.

JIT constraints

Static vs traced variables

One example can arise with control flow.

@jit
def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

print(cond_func(1.0))
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function cond_func at jx_jit.qmd:85 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

What is going on?

JIT compilation uses tracing of the code based on shape and dtype so that the same compiled code can be reused for new values with the same characteristics. The tracer objects are not real values but abstract representation that are more general. In control flow situations such as the one we have here, an abstract general value does not work as it wouldn’t know which branch to take.

The solution is to tell jit() not to trace some arguments (to consider them as static) by using the static_argnums parameter. This parameter takes an integer or a collection of integers to specify the position of the arguments to treat as static.

Here, our function only accepts one argument, so we will use static_argnums=(0,)

def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

cond_func_jit = jit(cond_func, static_argnums=(0,))

print(cond_func_jit(2.0))
print(cond_func_jit(-2.0))
8.0
4.0

Alternatively, the arguments that we want to be treated as static can be specified by their names:

def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

cond_func_jit_alt = jit(cond_func, static_argnames="x")

print(cond_func_jit_alt(2.0))
print(cond_func_jit_alt(-2.0))
8.0
4.0

Static vs traced operations

Similarly, you will need to mark certain operations are static so that they don’t get traced during JIT compilation:

@jit
def f(x):
    return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
print(f(x))
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

The solution here is to use NumPy of JAX:

import numpy as np

@jit
def f(x):
    return x.reshape((np.prod(x.shape),))

print(f(x))
[1. 1. 1. 1. 1. 1.]