Parallel computing
JAX is designed for DNN and linear algebra at scale. Processing vast amounts of data in parallel is crucial to its goal. Two of JAX’s transformations allow to turn linear code into parallel code very easily.
Vectorization
Remember how a number of transformations are applied to jaxprs. We already saw two of JAX’s main transformations: JIT compilation with jax.jit
and automatic differentiation with jax.grad
. Vectorization with jax.vmap
is another one.
It automates the vectorization of complex functions (operations on arrays are naturally executed in a vectorized fashion—as is the case in R, in NumPy, etc.—but more complex functions are not).
Here is an example from JAX 101 commonly encountered in deep learning:
import jax
import jax.numpy as jnp
= jnp.arange(5)
x = jnp.array([2., 3., 4.])
w
def convolve(x, w):
= []
output for i in range(1, len(x)-1):
-1:i+2], w))
output.append(jnp.dot(x[ireturn jnp.array(output)
convolve(x, w)
Array([11., 20., 29.], dtype=float32)
See this great post for explanations of convolutions.
You will probably want to apply the function convolve()
to a batch of weights w
and vectors x
.
= jnp.stack([x, x, x])
xs = jnp.stack([w, w, w]) ws
We apply the jax.vmap()
transformation to the convolve()
function and pass the batches to it:
= jax.vmap(convolve)
vconvolve vconvolve(xs, ws)
Array([[11., 20., 29.],
[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
As we already saw, transformations can be composed:
= jax.jit(vconvolve)
vconvolve_jit vconvolve_jit(xs, ws)
Array([[11., 20., 29.],
[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
Parallel runs across devices
The jax.pmap
transformation does the same thing but each computation runs on a different device (e.g. a different GPU) on the same node, allowing to scale things up further:
jax.pmap(convolve)(xs, ws)
jax.pmap
automatically JIT compiles the code, so it is unnecessary to pass this to jax.jit
.