Installing packages
Installing packages for deep learning can be an adventure, particularly since a deep learning workflow with JAX requires quite a few additional packages.
For this course, we have already installed the packages, but this section is important for you when you will want to install packages on your machine or on the Alliance clusters.
On your machine
On your machine (but not on the Alliance clusters), I recommend that you use uv to create a Python project with your chosen Python version and all the necessary packages. uv
installs packages much faster than pip
and it is able to resolve dependencies very well. It also manages Python versions.
If you want more information on this, I will give a webinar on uv
very soon.
Install Python 3.12 since Grain hasn’t been ported to Python 3.13 yet:
uv python install 3.12
Create a Python project (let’s call it jaxdl
) and cd
into it:
uv init --no-readme --no-description jaxdl
cd jaxdl
You should see a .python-version
file set at version 3.12. You can edit this file if you want to use another Python version and uv
will automatically install it.
Install the packages:
uv add jax[cuda12] jax-ai-stack[grain] datasets matplotlib penzai torchvision tqdm transformers
Quick explanation of packages we are installing:
- jax[cuda12] ➔ only if you want to run JAX on GPUs
- jax-ai-stack[grain] ➔ installs JAX for the CPU (if not already installed for the GPU),
Flax—the main NN library,
Optax—optimizers & loss functions,
Orbax—for checkpointing,
Grain—to build efficient dataloaders,
ml_dtypes—NumPy dtype extensions for deep learning
- datasets ➔ from Hugging Face—to load data
- matplotlib ➔ to visualise samples
- penzai ➔ to have interactive model display
- torchvision ➔ to augment the data to prevent over-fitting
- tqdm ➔ progress bar
- transformers ➔ from Hugging Face—to load pretrained weights
You will see that the dependencies have automatically populated a pyproject.toml
file and that a virtual environment called .venv
was created.
As long as you are within the project, you don’t need to activate that virtual environment. You can just launch Python (or IPython, ptpython, Jupyter…) and the packages will be available.
Alternatively, if you need to for advanced workflow involving other tools (e.g. Quarto), you can activate it as you would any other Python virtual environment:
source .venv/bin/activate
On an Alliance cluster
I already installed all the necessary packages in the training cluster to save time and space. The instructions for today thus differ from what you would normally do and production cluster instructions in the second tab are for your future reference only.
Look for available Python modules:
module spider python
Load the version of your choice:
TensorFlow and all packages depending from it (including TensorFlow Datasets and Grain) are still not (as of April 2025) ported to Python 3.13.
module load python/3.12.4
The Hugging Face Datasets package uses PyArrow for efficiency. In order to install it, we also need to load an Arrow module.
Let’s see what versions are available:
module spider arrow
Any version should be fine. Let’s load the latest (as of April 2025):
module load arrow/19.0.1
I created a virtual Python environment with all necessary packages under /project
. All you have to do today is activate it with:
source /project/60055/env/bin/activate
Look for available Python modules:
module spider python
Load the version of your choice:
TensorFlow and all packages depending from it (including TensorFlow Datasets and Grain) are still not (as of April 2025) ported to Python 3.13
module load python/3.12.4
The Hugging Face Datasets package uses PyArrow for efficiency. In order to install it, we also need to load an Arrow module.
Let’s see what versions are available:
module spider arrow
Any version should be fine. Let’s load the latest (as of April 2025):
module load arrow/19.0.1
Create a Python virtual environment:
python -m venv ~/env
Activate it:
source ~/env/bin/activate
Update pip from wheel:
python -m pip install --upgrade pip --no-index
Whenever a Python wheel for a package is available on the Alliance clusters, you should use it instead of downloading the package from PyPI. To do this, simply add the --no-index
flag to the install command.
You can see whether a wheel is available with avail_wheels <package>
or look at the list of available wheels.
Advantages of wheels:
- compiled for the clusters hardware,
- ensures no missing or conflicting dependencies,
- much faster installation.
Install libraries from wheel:
python -m pip install --no-index jax[cuda12] jax-ai-stack[grain] datasets matplotlib penzai torchvision tqdm transformers
Don’t forget --no-index
to install from wheels.
Quick explanation of packages we are installing:
- jax[cuda12] ➔ only if you want to run JAX on GPUs
- jax-ai-stack[grain] ➔ installs JAX for the CPU (if not already installed for the GPU),
Flax—the main NN library,
Optax—optimizers & loss functions,
Orbax—for checkpointing,
Grain—to build efficient dataloaders,
ml_dtypes—NumPy dtype extensions for deep learning
- datasets ➔ from Hugging Face—to load data
- matplotlib ➔ to visualise samples
- penzai ➔ to have interactive model display
- torchvision ➔ to augment the data to prevent over-fitting
- tqdm ➔ progress bar
- transformers ➔ from Hugging Face—to load pretrained weights