Accelerators
One of the efficiencies of JAX is its use of accelerators. In this section, we can see how easily this is done.
Auto-detection
One of the convenience of the XLA used by JAX (and TensorFlow) is that the same code runs on any device without modification.
This is in contrast with PyTorch where tensors are created on the CPU by default and can be moved to the GPU using the .to
method.
Or tensors need to be created explicitly on a device (e.g. for GPU, x = torch.ones(2, 4, device='cuda')
).
Alternatively, the code can be made more robust with the creation of a device handle which will allow it to run without modification on CPU or GPU:
import torch
= (
device "cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
= torch.ones(2, 4, device=device) x
MPS = Apple Metal Performance Shaders (GPU on MacOS).
And there are methods to run PyTorch on TPU with the torch_xla
package, with tricks of scalability.
Interactive job with a GPU
The Alliance wiki documents how to use GPUs on Alliance clusters.
For now, let’s relinquish our current interactive job. It is important not to run nested jobs by running salloc
inside a running job.
Kill the current job by running (from Bash, not from ipython):
exit
Then, start an interactive job with a GPU:
salloc --time=1:0:0 --gpus-per-node=1 --mem=22000M
Reload the ipython module:
module load ipython-kernel/3.11
Re-activate the virtual python environment:
source /project/60055/env/bin/activate
Finally, relaunch IPython:
ipython
Effect on timing
Here is an example of the difference that a GPU makes compared to CPUs for a simple computation.
The following times are on my laptop which has one dedicated GPU.
First, let’s set things up:
import jax. numpy as jnp
from jax import random, device_put
import numpy as np
= 0
seed = random.PRNGKey(seed)
key = random.split(key)
key, subkey
= 3000 size
Now, let’s time a dot product of two arrays using NumPy (which only uses CPUs):
= np.random.normal(size=(size, size)).astype(np.float32)
x_np %timeit np.dot(x_np, x_np.T)
58.6 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
We can use the NumPy ndarrays in a JAX dot product function:
%timeit jnp.dot(x_np, x_np.T).block_until_ready()
31.1 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Remember that whenever you benchmark JAX computations, you need to use the block_until_ready
method, due to asynchronous dispatch, to ensure that you are timing the computation and not the creation of a future.
If you want to use NumPy ndarrays in JAX and you have accelerators available, a much better approach is to transfer them to the accelerators with the device_put
method:
= device_put(x_np)
x_to_gpu %timeit jnp.dot(x_to_gpu, x_to_gpu.T).block_until_ready()
13.2 ms ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
This is much faster and similar to the full JAX code would be:
= random.normal(key, (size, size), dtype=jnp.float32)
x_jx
%timeit jnp.dot(x_jx, x_jx.T).block_until_ready()
13.3 ms ± 33.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)