Installing JAX
In this section, we will access a virtual training cluster through SSH and make JAX accessible.
We will also cover how to install JAX in the Alliance production clusters.
Unless you aren’t planning to use accelerators, JAX relies on GPUs/TPUs dependencies determined by your OS and hardware (e.g. CUDA and CUDNN). Making sure that the dependencies are installed, compatible, and working with JAX can be finicky, so it is a lot easier to install JAX from pip wheels.
On your computer
On your personal computer, use the wheel installation command from the official JAX site corresponding to your system.
On Windows, GPUs are only supported via Windows Subsystem for Linux 2.
Because JAX is designed for large array computations and machine learning, you will most likely want to use it on supercomputers. In this course, we will thus use a virtual Alliance cluster.
On an Alliance cluster
Logging in through SSH
Step 1: get the info
During the course, we will give you 3 pieces of information:
- a link to a list of usernames,
- the hostname for our temporary training cluster,
- the password to access that cluster.
Step 2: claim a username
Add your first name or a pseudo next to a free username on the list to claim it.
Your username is the name that was already on the list, NOT what you wrote next to it (which doesn’t matter at all and only serves at signalling that this username is now taken).
Your username will look like userxx—xx being 2 digits—with no space and no capital letter.
Step 3: run the ssh command
Linux users: open the terminal emulator of your choice.
macOS users: open “Terminal”.
Then type:
ssh userxx@hostnameand press Enter.
- Replace
userxxby your username (e.g.user09). - Replace
hostnameby the hostname we will give you the day of the workshop.
When asked:
Are you sure you want to continue connecting (yes/no/[fingerprint])?
Answer: “yes”.
We suggest using the free version of MobaXterm, a software that comes with a terminal emulator and a GUI interface for SSH sessions.
Here is how to install MobaXterm:
- download the “Installer edition” to your computer (green button to the right),
- unzip the file,
- double-click on the
.msifile to launch the installation.
Here is how to log in with MobaXterm:
- open MobaXterm,
- click on Session (top left corner),
- click on SSH (top left corner),
- fill in the Remote host * box with the cluster hostname we gave you,
- tick the box Specify username,
- fill in the box with the username you selected (e.g.
user09), - press OK,
- when asked
Are you sure you want to continue connecting (yes/no/[fingerprint])?, answer: “yes”.
Here is a live demo.
Step 4: enter the password
When prompted, enter the password we gave you.
You will not see anything happen as you type the password. This is normal and it is working, so keep on typing the password.
This is called blind typing and is a Linux safety feature. It can be unsettling at first not to get any feed-back while typing as it really looks like it is not working. Type slowly and make sure not to make typos.
Then press Enter.
Am I logged in?
To know whether or not you are logged in, look at your prompt: it should look like the following (with your actual username):
[userxx@login1 ~]$
Troubleshooting
Problems logging in are almost always due to typos. If you cannot log in, retry slowly, entering your password carefully.
How do I log out?
You can log out by pressing Ctl+d.
Install JAX
We already created a Python virtual environment and installed JAX to save time. The instructions for today thus differ from what you would normally do, but I include the normal instructions in a separate tab for your future reference.
I already created a virtual Python environment under /project and installed JAX in it to save time and space. All you have to do is activate it:
source /project/60055/env/bin/activateLook for available Python modules:
module spider pythonLoad the version of your choice:
module load python/3.11.5Create a Python virtual environment:
python -m venv ~/envActivate it:
source ~/env/bin/activateUpdate pip from wheel:
python -m pip install --upgrade pip --no-indexWhenever a Python wheel for a package is available on the Alliance clusters, you should use it instead of downloading the package from PyPI. To do this, simply add the --no-index flag to the install command.
You can see whether a wheel is available with avail_wheels <package> or look at the list of available wheels.
Advantages of wheels:
- compiled for the clusters hardware,
- ensures no missing or conflicting dependencies,
- much faster installation.
Install JAX from wheel:
python -m pip install jax --no-indexDon’t forget the --no-index flag here: the wheel will save you from having to deal with the CUDA and CUDNN dependencies, making your life a lot easier.
If you want to install a particular version of JAX, you first need to see what wheel is available:
avail_wheels "jax*"Then load the wheel of your choice:
python -m pip install jax==0.4.26 --no-index