Training models
We talked about how Flax handles state, about loading data, and about model architecture. It is now time to talk about training.
Training models is the crux of deep learning. This is the part that requires a lot of time and resources (and money if you use commercial cloud services). This is also where issues with convergence, underfitting or overfitting, and vanishing or exploding gradients can come in.
Consequently, this is where optimizations and JAX’s performance tricks (e.g. JIT compilation) matter the most. This is also where understanding of deep learning theory is important.
In this section, we will point to strategies and resources to navigate training. We will also see how to use the Alliance clusters to train your models.
Fundamental functioning
Calculate predictions
We can create some random inputs:
= jax.random.split(key)
key, x_key
= jax.random.normal(x_key, (1, 28, 28, 1)) x
The predictions of our model based on these inputs are obtained by:
= cnn.apply(params, x)
y print(y)
Update parameters
Optax—another library built on JAX—is a full toolkit for gradient processing and optimization. It contains all the classic optimizers and loss functions and makes it easy to create your own optimizers and optimizer schedulers. Flax initially used its own optimizers but has now fully adopted use of Optax.
Here is the most basic case:
import optax
= 1e-1
learning_rate = optax.sgd(learning_rate)
optimiser print(optimiser)
The optimizer is a gradient transformation. It is a tuple of an init
and an update
methods. Those are pure functions following the model of JAX and Flax. This means that they are stateless and that a state needs to be initialized and passed as input, exactly as we saw for Flax models.
Let’s initialize the optimizer state:
= optimiser.init(params) optimiser_state
The update method returns a gradient transformation (that we can later apply to the gradients) and an updated optimizer state.
The gradients are calculated by passing a loss function to jax.grad
and passing the parameters, the inputs, and the predictions to this transformed function (the derivative):
= jax.grad(<some-loss-function>)(params, x, y) grads
The loss function can be built from a large array of Optax loss methods.
Here is how to use optimizer.update
:
= optimiser.update(grads, optimiser_state, params) updates, optimiser_state
Key regularizations
Flax makes it easy to apply classic regularizations and optimization techniques.
Batch normalization improves convergence speed and has been a classic regularization technique since the publication of Sergey Ioffe and Christian Szegedy’s key paper in 2015. You can use it by adding a flax.linen.BatchNorm
layer to your model.
Similarly, dropout techniques are implemented with a flax.linen.Dropout
layer.
Getting started
The best way to get started building your own model is to go over the examples provided as template by Flax. They all follow the same format, making it easy to clone and modify them. You can even modify them directly in Google Colab for some of them, making experimentation easy without having to install anything.
Note however that things are not as simple as the documentation makes it to appear and some of the examples will not run for various reasons (dependency problem, error in code, etc.)
Let’s check this structure and look at a few models.
Then let’s run the ogbg-molpcba example together in Google Colab to have access to a free GPU.
Running Flax examples in the Alliance clusters
Instead of running these examples in Google Colab, you might want to run them on the Alliance clusters, particularly as you start developing your own model (rather than just run examples to learn techniques).
First, you need to get the model you are interested in to the cluster.
There are many ways you could go about it, but one option is to download the directory of that particular model to your machine as a zip file using one of several sites making this easy.
For the ogbg-molpcba example, you paste the link “https://github.com/google/flax/tree/main/examples/ogbg_molpcba” in the site.
You can then copy it to the cluster with:
scp <path-to-zip-file-on-your-machine> <user-name>@<hostname>:
It will look something like this (make sure to rename the zip file to remove the spaces or to quote the path):
scp examples-ogbg_molpcba.zip userxx@xxx.c3.ca:
Then you could run it using JupyterLab, but a more efficient method is to use sbatch
.
Create a script:
<your_job>.sh
#!/bin/bash
#SBATCH --account=def-<your_account>
#SBATCH --time=xxx
#SBATCH --mem-per-cpu=xxx
#SBATCH --cpus-per-task=xxx
#SBATCH --job-name="<your_job>"
# Setup
module load python/3.11.5
source ~/env/bin/activate
python -m pip install --upgrade pip --no-index
python -m pip install -r requirements.txt --no-index
# Run example
python main.py --workdir=./ogbg_molpcba --config=configs/default.py
And run the script:
sbatch <your_job>.sh