Installing packages

Author

Marie-Hélène Burle

Installing packages for deep learning can be an adventure, particularly since a deep learning workflow with JAX requires quite a few additional packages.

For this course, we have already installed the packages, but this section is important for you when you will want to install packages on your machine or on the Alliance clusters.

On your machine

On your machine (but not on the Alliance clusters), I recommend that you use uv to create a Python project with your chosen Python version and all the necessary packages. uv installs packages much faster than pip and it is able to resolve dependencies very well. It also manages Python versions.

If you want more information on this, I will give a webinar on uv very soon.

Install Python 3.12 since Grain hasn’t been ported to Python 3.13 yet:

uv python install 3.12

Create a Python project (let’s call it jaxdl) and cd into it:

uv init --no-readme --no-description jaxdl
cd jaxdl

You should see a .python-version file set at version 3.12. You can edit this file if you want to use another Python version and uv will automatically install it.

Install the packages:

uv add jax[cuda12] jax-ai-stack[grain] datasets matplotlib penzai torchvision tqdm transformers

Quick explanation of packages we are installing:

- jax[cuda12]           ➔ only if you want to run JAX on GPUs
- jax-ai-stack[grain]   ➔ installs JAX for the CPU (if not already installed for the GPU),
                                   Flax—the main NN library,
                                   Optax—optimizers & loss functions,
                                   Orbax—for checkpointing,
                                   Grain—to build efficient dataloaders,
                                   ml_dtypes—NumPy dtype extensions for deep learning
- datasets              ➔ from Hugging Face—to load data
- matplotlib            ➔ to visualise samples
- penzai                ➔ to have interactive model display
- torchvision           ➔ to augment the data to prevent over-fitting
- tqdm                  ➔ progress bar
- transformers          ➔ from Hugging Face—to load pretrained weights

You will see that the dependencies have automatically populated a pyproject.toml file and that a virtual environment called .venv was created.

As long as you are within the project, you don’t need to activate that virtual environment. You can just launch Python (or IPython, ptpython, Jupyter…) and the packages will be available.

Alternatively, if you need to for advanced workflow involving other tools (e.g. Quarto), you can activate it as you would any other Python virtual environment:

source .venv/bin/activate

On an Alliance cluster

I already installed all the necessary packages in the training cluster to save time and space. The instructions for today thus differ from what you would normally do and production cluster instructions in the second tab are for your future reference only.

Look for available Python modules:

module spider python

Load the version of your choice:

TensorFlow and all packages depending from it (including TensorFlow Datasets and Grain) are still not (as of April 2025) ported to Python 3.13.

module load python/3.12.4

The Hugging Face Datasets package uses PyArrow for efficiency. In order to install it, we also need to load an Arrow module.

Let’s see what versions are available:

module spider arrow

Any version should be fine. Let’s load the latest (as of April 2025):

module load arrow/19.0.1

I created a virtual Python environment with all necessary packages under /project. All you have to do today is activate it with:

source /project/60055/env/bin/activate

Look for available Python modules:

module spider python

Load the version of your choice:

TensorFlow and all packages depending from it (including TensorFlow Datasets and Grain) are still not (as of April 2025) ported to Python 3.13

module load python/3.12.4

The Hugging Face Datasets package uses PyArrow for efficiency. In order to install it, we also need to load an Arrow module.

Let’s see what versions are available:

module spider arrow

Any version should be fine. Let’s load the latest (as of April 2025):

module load arrow/19.0.1

Create a Python virtual environment:

python -m venv ~/env

Activate it:

source ~/env/bin/activate

Update pip from wheel:

python -m pip install --upgrade pip --no-index

Whenever a Python wheel for a package is available on the Alliance clusters, you should use it instead of downloading the package from PyPI. To do this, simply add the --no-index flag to the install command.

You can see whether a wheel is available with avail_wheels <package> or look at the list of available wheels.

Advantages of wheels:

  • compiled for the clusters hardware,
  • ensures no missing or conflicting dependencies,
  • much faster installation.

Install libraries from wheel:

python -m pip install --no-index jax[cuda12] jax-ai-stack[grain] datasets matplotlib penzai torchvision tqdm transformers

Don’t forget --no-index to install from wheels.

Quick explanation of packages we are installing:

- jax[cuda12]           ➔ only if you want to run JAX on GPUs
- jax-ai-stack[grain]   ➔ installs JAX for the CPU (if not already installed for the GPU),
                                   Flax—the main NN library,
                                   Optax—optimizers & loss functions,
                                   Orbax—for checkpointing,
                                   Grain—to build efficient dataloaders,
                                   ml_dtypes—NumPy dtype extensions for deep learning
- datasets              ➔ from Hugging Face—to load data
- matplotlib            ➔ to visualise samples
- penzai                ➔ to have interactive model display
- torchvision           ➔ to augment the data to prevent over-fitting
- tqdm                  ➔ progress bar
- transformers          ➔ from Hugging Face—to load pretrained weights