Parallel computing

Author

Marie-Hélène Burle

With performance in mind, JAX is built for parallel computing at all levels. This section gives an overview of the various parallel implementations.

Asynchronous dispatch

One of the efficiencies of JAX is its use of asynchronous execution.

Advantage

Let’s consider the code:

import jax.numpy as jnp
from jax import random

x = random.normal(random.PRNGKey(0), (1000, 1000))
y = random.normal(random.PRNGKey(0), (1000, 1000))
z = jnp.dot(x, y)

Instead of having to wait for the computation to complete before control returns to Python, this computation is dispatched to an accelerator and a future is created. This future is a jax.Array and can be passed to further computations immediately.

Of course, if you print the result or convert it to a NumPy ndarray, then JAX forces Python to wait for the result of the computation.

Consequence on benchmarking

Timing jnp.dot(x, y) would not give us the time it takes for the computation to take place, but the time it takes to dispatch the computation.

On my laptop, running the computation on one GPU, I get:

import timeit

timeit.timeit("jnp.dot(x, y)",
              number=1000, globals=globals())/1000
0.0005148850770000308

To get a proper timing, we need to make sure that the future is resolved using the block_until_ready() method: jnp.dot(x, y).block_until_ready().

On the same machine:

timeit.timeit("jnp.dot(x, y).block_until_ready()",
              number=1000, globals=globals())/1000
0.0005967016279998916

The difference here is not huge because the GPU executes the matrix multiplication rapidly. Nevertheless, this is the true timing. If you benchmark your JAX code, make sure to do it this way.

If you are running small computations such as this one without accelerator, the dispatch will be on the same thread as the overhead of the asynchronous execution is larger than the speedup. Nevertheless, because it is difficult to predict when the dispatch will be asynchronous, you should always use block_until_ready() in your benchmarks.

Vectorization

Remember how a number of transformations are applied to Jaxprs. We already saw one of JAX’s main transformations: JIT compilation with jax.jit(). Vectorization with jax.vmap() is another one.

It automates the vectorization of complex functions (operations on arrays are naturally executed in a vectorized fashion—as is the case in R, in NumPy, etc.—but more complex functions are not).

Here is an example from JAX 101 commonly encountered in deep learning:

import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
    output = []
    for i in range(1, len(x)-1):
        output.append(jnp.dot(x[i-1:i+2], w))
    return jnp.array(output)

convolve(x, w)
Array([11., 20., 29.], dtype=float32)

How can you apply the function convolve() on a batch of weights w and vectors x.

xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

We apply the jax.vmap() transformation to the convolve() function to create a new function. A tracing process is involved here too. We can then pass the batches to the new function:

auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

Transformations can be composed:

jitted_batch_convolve = jax.jit(auto_batch_convolve)
jitted_batch_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

Data parallelism

The jax.pmap() transformation does the same thing but each computation runs on a different device, allowing to scale things up dramatically.

Sharding

JAX is also capable of running arrays across multiple devices.

Multi-host communication

JAX does not have the ability to scale things up to the level of multi-node clusters, but the mpi4jax extension provides it.