Parallel computing


Marie-Hélène Burle

JAX is designed for DNN and linear algebra at scale. Processing vast amounts of data in parallel is crucial to its goal. Two of JAX’s transformations allow to turn linear code into parallel code very easily.


Remember how a number of transformations are applied to jaxprs. We already saw two of JAX’s main transformations: JIT compilation with jax.jit and automatic differentiation with jax.grad. Vectorization with jax.vmap is another one.

It automates the vectorization of complex functions (operations on arrays are naturally executed in a vectorized fashion—as is the case in R, in NumPy, etc.—but more complex functions are not).

Here is an example from JAX 101 commonly encountered in deep learning:

import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
    output = []
    for i in range(1, len(x)-1):
        output.append([i-1:i+2], w))
    return jnp.array(output)

convolve(x, w)
Array([11., 20., 29.], dtype=float32)

See this great post for explanations of convolutions.

You will probably want to apply the function convolve() to a batch of weights w and vectors x.

xs = jnp.stack([x, x, x])
ws = jnp.stack([w, w, w])

We apply the jax.vmap() transformation to the convolve() function and pass the batches to it:

vconvolve = jax.vmap(convolve)
vconvolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

As we already saw, transformations can be composed:

vconvolve_jit = jax.jit(vconvolve)
vconvolve_jit(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

Parallel runs across devices

The jax.pmap transformation does the same thing but each computation runs on a different device (e.g. a different GPU) on the same node, allowing to scale things up further:

jax.pmap(convolve)(xs, ws)

jax.pmap automatically JIT compiles the code, so it is unnecessary to pass this to jax.jit.

JAX is also capable of running distributed arrays across multiple devices through sharding.

JAX does not have the ability to scale things up to the level of multi-node clusters, but the mpi4jax extension provides multi-host communication for distributed parallelism.