Marie-Hélène Burle

One of the efficiencies of JAX is its use of accelerators. In this section, we can see how easily this is done.


One of the convenience of the XLA used by JAX (and TensorFlow) is that the same code runs on any device without modification.

This is in contrast with PyTorch where tensors are created on the CPU by default and can be moved to the GPU using the .to method.

Or tensors need to be created explicitly on a device (e.g. for GPU, x = torch.ones(2, 4, device='cuda')).

Alternatively, the code can be made more robust with the creation of a device handle which will allow it to run without modification on CPU or GPU:

import torch

device = (
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"

x = torch.ones(2, 4, device=device)

MPS = Apple Metal Performance Shaders (GPU on MacOS).

And there are methods to run PyTorch on TPU with the torch_xla package, with tricks of scalability.

Interactive job with a GPU

The Alliance wiki documents how to use GPUs on Alliance clusters.

For now, let’s relinquish our current interactive job. It is important not to run nested jobs by running salloc inside a running job.

Kill the current job by running (from Bash, not from ipython):


Then, start an interactive job with a GPU:

salloc --time=1:0:0 --gpus-per-node=1 --mem=22000M

Reload the ipython module:

module load ipython-kernel/3.11

Re-activate the virtual python environment:

source /project/60055/env/bin/activate

Finally, relaunch IPython:


Effect on timing

Here is an example of the difference that a GPU makes compared to CPUs for a simple computation.

The following times are on my laptop which has one dedicated GPU.

First, let’s set things up:

import jax. numpy as jnp
from jax import random, device_put
import numpy as np

seed = 0
key = random.PRNGKey(seed)
key, subkey = random.split(key)

size = 3000

Now, let’s time a dot product of two arrays using NumPy (which only uses CPUs):

x_np = np.random.normal(size=(size, size)).astype(np.float32)
%timeit, x_np.T)
58.6 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

We can use the NumPy ndarrays in a JAX dot product function:

%timeit, x_np.T).block_until_ready()
31.1 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Remember that whenever you benchmark JAX computations, you need to use the block_until_ready method, due to asynchronous dispatch, to ensure that you are timing the computation and not the creation of a future.

If you want to use NumPy ndarrays in JAX and you have accelerators available, a much better approach is to transfer them to the accelerators with the device_put method:

x_to_gpu = device_put(x_np)
%timeit, x_to_gpu.T).block_until_ready()
13.2 ms ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

This is much faster and similar to the full JAX code would be:

x_jx = random.normal(key, (size, size), dtype=jnp.float32)

%timeit, x_jx.T).block_until_ready()
13.3 ms ± 33.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)