Benchmarking JAX code

Author

Marie-Hélène Burle

You have to be careful when benchmarking JAX code to actually measure the computation time and not the dispatch time.

Asynchronous dispatch

One of the efficiencies of JAX is its use of asynchronous execution.

Let’s consider the code:

import jax.numpy as jnp
from jax import random

key = random.PRNGKey(11)
key, subkey1, subkey2 = random.split(key, 3)

x = random.normal(subkey1, (1000, 1000))
y = random.normal(subkey2, (1000, 1000))

z = jnp.dot(x, y)

Instead of having to wait for the computation to complete before control returns to Python, this computation is dispatched to an accelerator and a future is created. This future is a jax.Array and can be passed to further computations immediately.

Of course, if you print the result or convert it to a NumPy ndarray, JAX forces Python to wait for the result of the computation.

Consequence for benchmarking

Timing jnp.dot(x, y) would not give us the time it takes for the computation to take place, but the time it takes to dispatch the computation.

On my laptop which has one dedicated GPU I get:

%timeit jnp.dot(x, y)
496 µs ± 948 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%timeit is an IPython built-in magic command. In Python, you would have to use the timeit module.

To get a proper timing, we need to make sure that the future is resolved using the block_until_ready method:

On the same machine:

%timeit jnp.dot(x, y).block_until_ready()
598 µs ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

The difference here is not huge because the GPU executes the matrix multiplication rapidly. Nevertheless, this is the true timing. If you benchmark your JAX code, make sure to do it this way.

If you are running small computations such as this one without accelerator, the computation will be dispatched on the thread running the Python process because the overhead of the asynchronous execution would be larger than the speedup you would gain from it. This means that, if you are running the above code on CPUs, you should get the same time with and without block_until_ready.

Nevertheless, because it is difficult to predict when the dispatch will be asynchronous, you should always use block_until_ready in your benchmarks.