# Parallel computing

JAX is designed for DNN and linear algebra at scale. Processing vast amounts of data in parallel is crucial to its goal. Two of JAX’s transformations allow to turn linear code into parallel code very easily.

## Vectorization

Remember how a number of transformations are applied to jaxprs. We already saw two of JAX’s main transformations: JIT compilation with `jax.jit`

and automatic differentiation with `jax.grad`

. Vectorization with `jax.vmap`

is another one.

It automates the vectorization of complex functions (operations on arrays are naturally executed in a vectorized fashion—as is the case in R, in NumPy, etc.—but more complex functions are not).

Here is an example from JAX 101 commonly encountered in deep learning:

```
import jax
import jax.numpy as jnp
= jnp.arange(5)
x = jnp.array([2., 3., 4.])
w
def convolve(x, w):
= []
output for i in range(1, len(x)-1):
-1:i+2], w))
output.append(jnp.dot(x[ireturn jnp.array(output)
convolve(x, w)
```

`Array([11., 20., 29.], dtype=float32)`

See this great post for explanations of convolutions.

You will probably want to apply the function `convolve()`

to a batch of weights `w`

and vectors `x`

.

```
= jnp.stack([x, x, x])
xs = jnp.stack([w, w, w]) ws
```

We apply the `jax.vmap()`

transformation to the `convolve()`

function and pass the batches to it:

```
= jax.vmap(convolve)
vconvolve vconvolve(xs, ws)
```

```
Array([[11., 20., 29.],
[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
```

As we already saw, transformations can be composed:

```
= jax.jit(vconvolve)
vconvolve_jit vconvolve_jit(xs, ws)
```

```
Array([[11., 20., 29.],
[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
```

## Parallel runs across devices

The `jax.pmap`

transformation does the same thing but each computation runs on a different device (e.g. a different GPU) on the same node, allowing to scale things up further:

` jax.pmap(convolve)(xs, ws)`

`jax.pmap`

automatically JIT compiles the code, so it is unnecessary to pass this to `jax.jit`

.