Installing JAX
In this section, we will access a virtual training cluster through SSH and make JAX accessible.
We will also cover how to install JAX in the Alliance production clusters.
Unless you aren’t planning to use accelerators, JAX relies on GPUs/TPUs dependencies determined by your OS and hardware (e.g. CUDA and CUDNN). Making sure that the dependencies are installed, compatible, and working with JAX can be finicky, so it is a lot easier to install JAX from pip wheels.
On your computer
On your personal computer, use the wheel installation command from the official JAX site corresponding to your system.
On Windows, GPUs are only supported via Windows Subsystem for Linux 2.
Because JAX is designed for large array computations and machine learning, you will most likely want to use it on supercomputers. In this course, we will thus use a virtual Alliance cluster.
On an Alliance cluster
Logging in through SSH
Open a terminal emulator
Windows users: Install the free version of MobaXTerm and launch it.
MacOS users: Launch Terminal.
Linux users: Open the terminal emulator of your choice.
Access the cluster through secure shell
Windows users
Follow the first 18% of this demo.
For “Remote host”, use the hostname we gave you.
Select the box “Specify username” and provide your username.
Note that the password is entered through blind typing, meaning that you will not see anything happening as you type it. This is a Linux feature. While it is a little disturbing at first, do know that it is working. Make sure to type it slowly to avoid typos, then press the “enter” key on your keyboard.
MacOS and Linux users
In the terminal, run:
ssh <username>@<hostname>
Replace the username and hostname by their values.
For instance:
ssh user21@somecluster.c3.ca
You will be asked a question, answer “Yes”.
When prompted, type the password.
Note that the password is entered through blind typing, meaning that you will not see anything happening as you type it. This is a Linux feature. While it is a little disturbing at first, do know that it is working. Make sure to type it slowly to avoid typos, then press the “enter” key on your keyboard.
Troubleshooting
Problems logging in are almost always due to typos. If you cannot log in, retry slowly, entering your password carefully.
Install JAX
We already created a Python virtual environment and installed JAX to save time. The instructions for today thus differ from what you would normally do, but I include the normal instructions in a separate tab for your future reference.
I already created a virtual Python environment under /project
and installed JAX in it to save time and space. All you have to do is activate it:
source /project/60055/env/bin/activate
Look for available Python modules:
module spider python
Load the version of your choice:
module load python/3.11.5
Create a Python virtual environment:
python -m venv ~/env
Activate it:
source ~/env/bin/activate
Update pip from wheel:
python -m pip install --upgrade pip --no-index
Whenever a Python wheel for a package is available on the Alliance clusters, you should use it instead of downloading the package from PyPI. To do this, simply add the --no-index
flag to the install command.
You can see whether a wheel is available with avail_wheels <package>
or look at the list of available wheels.
Advantages of wheels:
- compiled for the clusters hardware,
- ensures no missing or conflicting dependencies,
- much faster installation.
Install JAX from wheel:
python -m pip install jax --no-index
Don’t forget the --no-index
flag here: the wheel will save you from having to deal with the CUDA and CUDNN dependencies, making your life a lot easier.
If you want to install a particular version of JAX, you first need to see what wheel is available:
avail_wheels "jax*"
Then load the wheel of your choice:
python -m pip install jax==0.4.26 --no-index