Training models


Marie-Hélène Burle

We talked about how Flax handles state, about loading data, and about model architecture. It is now time to talk about training.

Training models is the crux of deep learning. This is the part that requires a lot of time and resources (and money if you use commercial cloud services). This is also where issues with convergence, underfitting or overfitting, and vanishing or exploding gradients can come in.

Consequently, this is where optimizations and JAX’s performance tricks (e.g. JIT compilation) matter the most. This is also where understanding of deep learning theory is important.

In this section, we will point to strategies and resources to navigate training. We will also see how to use the Alliance clusters to train your models.

First, let’s start an interactive job:

salloc --time=30 --mem-per-cpu=3500M

Nowadays, IPython (Interactive Python) is known as the kernel used by Jupyter when running Python. Before the existence of Jupyter however, this kernel was created as a better command shell than the default Python shell. For interactive Python sessions in the command line, it is nicer and faster than plain Python with no downside. So we will use it for this course:

module load ipython-kernel/3.11

Now, let’s activate the virtual python environment:

source /project/60055/env/bin/activate

Finally, we can launch IPython:


Then let’s rerun our model architecture and initialization of the pytree of model parameters from last course:

import jax
import jax.numpy as jnp
from flax import linen as nn

class CNN(nn.Module):
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

cnn = CNN()

def get_initial_params(key):
    init_shape = jnp.ones((1, 28, 28, 1))
    initial_params = cnn.init(key, init_shape)
    return initial_params

key = jax.random.key(0)
key, model_key = jax.random.split(key)

params = get_initial_params(model_key)

Fundamental functioning

Calculate predictions

We can create some random inputs:

key, x_key = jax.random.split(key)

x = jax.random.normal(x_key, (1, 28, 28, 1))

The predictions of our model based on these inputs are obtained by:

y = cnn.apply(params, x)

Update parameters

Optax—another library built on JAX—is a full toolkit for gradient processing and optimization. It contains all the classic optimizers and loss functions and makes it easy to create your own optimizers and optimizer schedulers. Flax initially used its own optimizers but has now fully adopted use of Optax.

Here is the most basic case:

import optax

learning_rate = 1e-1
optimiser = optax.sgd(learning_rate)

The optimizer is a gradient transformation. It is a tuple of an init and an update methods. Those are pure functions following the model of JAX and Flax. This means that they are stateless and that a state needs to be initialized and passed as input, exactly as we saw for Flax models.

Let’s initialize the optimizer state:

optimiser_state = optimiser.init(params)

The update method returns a gradient transformation (that we can later apply to the gradients) and an updated optimizer state.

The gradients are calculated by passing a loss function to jax.grad and passing the parameters, the inputs, and the predictions to this transformed function (the derivative):

grads = jax.grad(<some-loss-function>)(params, x, y)

The loss function can be built from a large array of Optax loss methods.

Here is how to use optimizer.update:

updates, optimiser_state = optimiser.update(grads, optimiser_state, params)

Key regularizations

Flax makes it easy to apply classic regularizations and optimization techniques.

Batch normalization improves convergence speed and has been a classic regularization technique since the publication of Sergey Ioffe and Christian Szegedy’s key paper in 2015. You can use it by adding a flax.linen.BatchNorm layer to your model.

Similarly, dropout techniques are implemented with a flax.linen.Dropout layer.

Getting started

The best way to get started building your own model is to go over the examples provided as template by Flax. They all follow the same format, making it easy to clone and modify them. You can even modify them directly in Google Colab for some of them, making experimentation easy without having to install anything.

Note however that things are not as simple as the documentation makes it to appear and some of the examples will not run for various reasons (dependency problem, error in code, etc.)

Let’s check this structure and look at a few models.

Then let’s run the ogbg-molpcba example together in Google Colab to have access to a free GPU.

Running Flax examples in the Alliance clusters

Instead of running these examples in Google Colab, you might want to run them on the Alliance clusters, particularly as you start developing your own model (rather than just run examples to learn techniques).

First, you need to get the model you are interested in to the cluster.

There are many ways you could go about it, but one option is to download the directory of that particular model to your machine as a zip file using one of several sites making this easy.

For the ogbg-molpcba example, you paste the link “” in the site.

You can then copy it to the cluster with:

scp <path-to-zip-file-on-your-machine> <user-name>@<hostname>:

It will look something like this (make sure to rename the zip file to remove the spaces or to quote the path):


Then you could run it using JupyterLab, but a more efficient method is to use sbatch.

Create a script:

#SBATCH --account=def-<your_account>
#SBATCH --time=xxx
#SBATCH --mem-per-cpu=xxx
#SBATCH --cpus-per-task=xxx
#SBATCH --job-name="<your_job>"

# Setup
module load python/3.11.5
source ~/env/bin/activate
python -m pip install --upgrade pip --no-index
python -m pip install -r requirements.txt --no-index

# Run example
python --workdir=./ogbg_molpcba --config=configs/

And run the script:

sbatch <your_job>.sh