Installing JAX

Author

Marie-Hélène Burle

In this section, we will access a virtual training cluster through SSH and make JAX accessible.

We will also cover how to install JAX in the Alliance production clusters.

Unless you aren’t planning to use accelerators, JAX relies on GPUs/TPUs dependencies determined by your OS and hardware (e.g. CUDA and CUDNN). Making sure that the dependencies are installed, compatible, and working with JAX can be finicky, so it is a lot easier to install JAX from pip wheels.

On your computer

On your personal computer, use the wheel installation command from the official JAX site corresponding to your system.

On Windows, GPUs are only supported via Windows Subsystem for Linux 2.

Because JAX is designed for large array computations and machine learning, you will most likely want to use it on supercomputers. In this course, we will thus use a virtual Alliance cluster.

On an Alliance cluster

Logging in through SSH

Open a terminal emulator

Windows users:  Install the free version of MobaXTerm and launch it.
MacOS users:   Launch Terminal.
Linux users:     Open the terminal emulator of your choice.

Access the cluster through secure shell

Windows users

Follow the first 18% of this demo.

For “Remote host”, use the hostname we gave you.
Select the box “Specify username” and provide your username.

Note that the password is entered through blind typing, meaning that you will not see anything happening as you type it. This is a Linux feature. While it is a little disturbing at first, do know that it is working. Make sure to type it slowly to avoid typos, then press the “enter” key on your keyboard.

MacOS and Linux users

In the terminal, run:

ssh <username>@<hostname>

Replace the username and hostname by their values.
For instance:

ssh user21@somecluster.c3.ca

You will be asked a question, answer “Yes”.

When prompted, type the password.

Note that the password is entered through blind typing, meaning that you will not see anything happening as you type it. This is a Linux feature. While it is a little disturbing at first, do know that it is working. Make sure to type it slowly to avoid typos, then press the “enter” key on your keyboard.

Troubleshooting

Problems logging in are almost always due to typos. If you cannot log in, retry slowly, entering your password carefully.

Install JAX

We already created a Python virtual environment and installed JAX to save time. The instructions for today thus differ from what you would normally do, but I include the normal instructions in a separate tab for your future reference.

I already created a virtual Python environment under /project and installed JAX in it to save time and space. All you have to do is activate it:

source /project/60055/env/bin/activate

Look for available Python modules:

module spider python

Load the version of your choice:

module load python/3.11.5

Create a Python virtual environment:

python -m venv ~/env

Activate it:

source ~/env/bin/activate

Update pip from wheel:

python -m pip install --upgrade pip --no-index

Whenever a Python wheel for a package is available on the Alliance clusters, you should use it instead of downloading the package from PyPI. To do this, simply add the --no-index flag to the install command.

You can see whether a wheel is available with avail_wheels <package> or look at the list of available wheels.

Advantages of wheels:

  • compiled for the clusters hardware,
  • ensures no missing or conflicting dependencies,
  • much faster installation.

Install JAX from wheel:

python -m pip install jax --no-index

Don’t forget the --no-index flag here: the wheel will save you from having to deal with the CUDA and CUDNN dependencies, making your life a lot easier.

If you want to install a particular version of JAX, you first need to see what wheel is available:

avail_wheels "jax*"

Then load the wheel of your choice:

python -m pip install jax==0.4.26 --no-index