# Parallel computing

With performance in mind, JAX is built for parallel computing at all levels. This section gives an overview of the various parallel implementations.

## Asynchronous dispatch

One of the efficiencies of JAX is its use of asynchronous execution.

### Advantage

Let’s consider the code:

```
import jax.numpy as jnp
from jax import random
= random.normal(random.PRNGKey(0), (1000, 1000))
x = random.normal(random.PRNGKey(0), (1000, 1000))
y = jnp.dot(x, y) z
```

Instead of having to wait for the computation to complete before control returns to Python, this computation is dispatched to an accelerator and a future is created. This future is a jax.Array and can be passed to further computations immediately.

Of course, if you print the result or convert it to a NumPy ndarray, then JAX forces Python to wait for the result of the computation.

### Consequence on benchmarking

Timing `jnp.dot(x, y)`

would not give us the time it takes for the computation to take place, but the time it takes to dispatch the computation.

On my laptop, running the computation on one GPU, I get:

```
import timeit
"jnp.dot(x, y)",
timeit.timeit(=1000, globals=globals())/1000 number
```

`0.0005148850770000308`

To get a proper timing, we need to make sure that the future is resolved using the `block_until_ready()`

method: `jnp.dot(x, y).block_until_ready()`

.

On the same machine:

```
"jnp.dot(x, y).block_until_ready()",
timeit.timeit(=1000, globals=globals())/1000 number
```

`0.0005967016279998916`

The difference here is not huge because the GPU executes the matrix multiplication rapidly. Nevertheless, this is the true timing. If you benchmark your JAX code, make sure to do it this way.

If you are running small computations such as this one without accelerator, the dispatch will be on the same thread as the overhead of the asynchronous execution is larger than the speedup. Nevertheless, because it is difficult to predict when the dispatch will be asynchronous, you should always use `block_until_ready()`

in your benchmarks.

## Vectorization

Remember how a number of transformations are applied to Jaxprs. We already saw one of JAX’s main transformations: JIT compilation with `jax.jit()`

. Vectorization with `jax.vmap()`

is another one.

It automates the vectorization of complex functions (operations on arrays are naturally executed in a vectorized fashion—as is the case in R, in NumPy, etc.—but more complex functions are not).

Here is an example from JAX 101 commonly encountered in deep learning:

```
import jax
import jax.numpy as jnp
= jnp.arange(5)
x = jnp.array([2., 3., 4.])
w
def convolve(x, w):
= []
output for i in range(1, len(x)-1):
-1:i+2], w))
output.append(jnp.dot(x[ireturn jnp.array(output)
convolve(x, w)
```

`Array([11., 20., 29.], dtype=float32)`

How can you apply the function `convolve()`

on a batch of weights `w`

and vectors `x`

.

```
= jnp.stack([x, x])
xs = jnp.stack([w, w]) ws
```

We apply the `jax.vmap()`

transformation to the `convolve()`

function to create a new function. A tracing process is involved here too. We can then pass the batches to the new function:

```
= jax.vmap(convolve)
auto_batch_convolve auto_batch_convolve(xs, ws)
```

```
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
```

Transformations can be composed:

```
= jax.jit(auto_batch_convolve)
jitted_batch_convolve jitted_batch_convolve(xs, ws)
```

```
Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
```

## Data parallelism

The `jax.pmap()`

transformation does the same thing but each computation runs on a different device, allowing to scale things up dramatically.

## Multi-host communication

JAX does not have the ability to scale things up to the level of multi-node clusters, but the mpi4jax extension provides it.