JIT compilation

Author

Marie-Hélène Burle

JIT compilation is a key component to JAX efficiency. For the most part, it is very easy to use, but there are subtleties to be aware of.

JIT

JAX functions are already compiled and optimized, but user functions can also be optimized for the XLA by JIT compilation which will combine computations.

Remember the map of JAX functioning:

tracer Tracing jaxpr Jaxpr (JAX expression) intermediate representation (IR) tracer->jaxpr jit JIT compilation hlo High-level optimized (HLO) program jit->hlo xla Accelerated Linear Algebra (XLA) CPU CPU xla->CPU GPU GPU xla->GPU TPU TPU xla->TPU transform Transformations py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->xla

This is done by the jax.jit() function or the equivalent decorator @jit.

Let’s consider this code:

import jax.numpy as jnp
from jax import jit
from jax import random

key = random.PRNGKey(8)
key, subkey1, subkey2 = random.split(key, 3)

a = random.normal(subkey1, (500, 500))
b = random.normal(subkey2, (500, 500))

def sum_squared_error(a, b):
    return jnp.sum((a-b)**2)

Our function can simply be used as:

print(sum_squared_error(a, b))

Our code will run faster if we create a JIT compiled version and use that instead (we will see how to benchmark JAX code later in the course. There are some subtleties for that too, so for now, just believe that it is faster. You will be able to test it later):

sum_squared_error_jit = jit(sum_squared_error)
print(sum_squared_error_jit(a, b))
502084.75

Alternatively, this can be written as:

print(jit(sum_squared_error)(a, b))
502084.75

Or as:

@jit
def sum_squared_error(a, b):
    return jnp.sum((a - b) ** 2)

print(sum_squared_error(a, b))
502084.75

Understanding jaxprs

Let’s have a look at the jaxpr of a jit-compiled function.

This is what the jaxpr of the non-jit-compiled function looks like:

import jax

def sum_squared_error(a, b):
    return jnp.sum((a - b) ** 2)

jax.make_jaxpr(sum_squared_error)(x, y)
{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = sub a b
    d:f32[3] = integer_pow[y=2] c
    e:f32[] = reduce_sum[axes=(0,)] d
  in (e,) }

The jaxpr of the jit-compiled function looks like this:

@jit
def sum_squared_error(a, b):
    return jnp.sum((a - b) ** 2)

jax.make_jaxpr(sum_squared_error)(x, y)
{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[] = pjit[
      name=sum_squared_error
      jaxpr={ lambda ; d:f32[3] e:f32[3]. let
          f:f32[3] = sub d e
          g:f32[3] = integer_pow[y=2] f
          h:f32[] = reduce_sum[axes=(0,)] g
        in (h,) }
    ] a b
  in (c,) }

JIT constraints

Using jit in the example above was very easy. There are situations however in which tracing will fail.

Control flow

One example can arise with control flow:

@jit
def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

print(cond_func(1.0))
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function cond_func at jx_jit.qmd:85 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

JIT compilation uses tracing of the code based on shape and dtype so that the same compiled code can be reused for new values with the same characteristics. The tracer objects are not real values but abstract representation that are more general. In control flow situations such as the one we have here, an abstract general value does not work as it wouldn’t know which branch to take.

Static variables

One solution is to tell jit() to exclude the problematic arguments (in our example the argument: x) from tracing (i.e. to consider them as static). Of course, those elements will not be optimized, but the rest of the code will, so it is a lot better than not JIT compiling the function at all.

You can either use the static_argnums parameter which takes an integer or a collection of integers to specify the position of the arguments to treat as static:

def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

cond_func_jit = jit(cond_func, static_argnums=(0,))

print(cond_func_jit(2.0))
print(cond_func_jit(-2.0))
8.0
4.0

Or you can use static_argnames which accepts argument names:

def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

cond_func_jit_alt = jit(cond_func, static_argnames='x')

print(cond_func_jit_alt(2.0))
print(cond_func_jit_alt(-2.0))
8.0
4.0

You cannot use the @jit decorator when you need to pass arguments to the jit function, but you can still use a decorator:

from functools import partial

@partial(jit, static_argnums=(0,))
def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

or:

@partial(jit, static_argnames=['x'])
def cond_func(x):
    if x < 0.0:
        return x ** 2.0
    else:
        return x ** 3.0

Control flow primitives

If you don’t want the code to recompile for each new value, another solution, is to use one of the structured control flow primitives:

from jax import lax

lax.cond(False, lambda x: x ** 2.0, lambda x: x ** 3.0, jnp.array([2.]))
Array([8.], dtype=float32)
lax.cond(True, lambda x: x ** 2.0, lambda x: x ** 3.0, jnp.array([-2.]))
Array([4.], dtype=float32)

There are other control flow primitives:

  • lax.while_loop
  • lax.fori_loop
  • lax.scan

and other pseudo dynamic control flow functions:

  • lax.select (NumPy API jnp.where and jnp.select)
  • lax.switch (NumPy API jnp.piecewise)

Static operations

Similarly, you will need to mark problematic operations as static so that they don’t get traced during JIT compilation:

@jit
def f(x):
    return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
print(f(x))
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

The problem here is that the shape of the argument to jnp.reshape is traced while it needs to be static.

One solution is to use the NumPy version of prod which will not create a traced result:

import numpy as np

@jit
def f(x):
    return x.reshape(np.prod(x.shape))

print(f(x))
[1. 1. 1. 1. 1. 1.]