Training

Author

Marie-Hélène Burle

It is now time to train our model.

In this section, we will cover how to:

  • set the training hyperparameters,
  • define the training and evaluation steps,
  • run the actual training loop,
  • create checkpoints,
  • plots the training metrics, and finally
  • test the trained model on a few data samples.
base_dir = '<path-of-the-nabirds-dir>'

To be replaced by actual path: in our training cluster, the base_dir is at /project/def-sponsor00/nabirds:

base_dir = '/project/def-sponsor00/nabirds'
import os
import polars as pl
import imageio.v3 as iio
import grain.python as grain
from jax import random
import dm_pix as pix
import numpy as np
import jax
import jax.numpy as jnp
from flax import nnx
from transformers import FlaxViTForImageClassification


metadata = pl.read_parquet('metadata.parquet')
metadata_train = metadata.filter(pl.col('is_training_img') == 1)
metadata_val = metadata.filter(pl.col('is_training_img') == 0)
cleaned_img_dir = os.path.join(base_dir, 'cleaned_images')


class NABirdsDataset:
    """NABirds dataset class."""

    def __init__(self, metadata_file, data_dir):
        self.metadata_file = metadata_file
        self.data_dir = data_dir

    def __len__(self):
        return len(self.metadata_file)

    def __getitem__(self, idx):
        path = os.path.join(self.data_dir, self.metadata_file.get_column('path')[idx])
        img = iio.imread(path)
        species_name = self.metadata_file.get_column('species_name')[idx]
        species_id = self.metadata_file.get_column('species_id')[idx]
        photographer = self.metadata_file.get_column('photographer')[idx]

        return {
            'img': img,
            'species_name': species_name,
            'species_id': species_id,
            'photographer': photographer,
        }


nabirds_train = NABirdsDataset(metadata_train, cleaned_img_dir)
nabirds_val = NABirdsDataset(metadata_val, cleaned_img_dir)


class Normalize(grain.MapTransform):
    def map(self, element):
        img = element['img']
        mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        img = img.astype(np.float32) / 255.0
        img_norm = (img - mean) / std
        element['img'] = img_norm
        return element


class ToFloat(grain.MapTransform):
    def map(self, element):
        element['img'] = element['img'].astype(np.float32) / 255.0
        return element


key = random.key(31)
key, subkey1, subkey2, subkey3, subkey4 = random.split(key, num=5)


class RandomCrop(grain.MapTransform):
    def map(self, element):
        element['img'] = pix.random_crop(
            key=subkey1,
            image=element['img'],
            crop_sizes=(224, 224, 3)
        )
        return element


class RandomFlip(grain.MapTransform):
    def map(self, element):
        element['img'] = pix.random_flip_left_right(
            key=subkey2,
            image=element['img']
        )
        return element


class RandomContrast(grain.MapTransform):
    def map(self, element):
        element['img'] = pix.random_contrast(
            key=subkey3,
            image=element['img'],
            lower=0.8,
            upper=1.2
        )
        return element


class RandomGamma(grain.MapTransform):
    def map(self, element):
        element['img'] = pix.random_gamma(
            key=subkey4,
            image=element['img'],
            min_gamma=0.6,
            max_gamma=1.2
        )
        return element


class ZScore(grain.MapTransform):
    def map(self, element):
        img = element['img']
        mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        img = (img - mean) / std
        element['img'] = img
        return element


seed = 123
train_batch_size = 32
val_batch_size = 2 * train_batch_size

train_sampler = grain.IndexSampler(
    num_records=len(nabirds_train),
    shuffle=True,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1
)

train_loader = grain.DataLoader(
    data_source=nabirds_train,
    sampler=train_sampler,
    operations=[
        ToFloat(),
        RandomCrop(),
        RandomFlip(),
        RandomContrast(),
        RandomGamma(),
        ZScore(),
        grain.Batch(train_batch_size, drop_remainder=True)
    ]
)

val_sampler = grain.IndexSampler(
    num_records=len(nabirds_val),
    shuffle=False,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1
)

val_loader = grain.DataLoader(
    data_source=nabirds_val,
    sampler=val_sampler,
    operations=[
        Normalize(),
        grain.Batch(val_batch_size)
    ]
)

class VisionTransformer(nnx.Module):
    def __init__(
        self,
        num_classes: int = 1000,
        in_channels: int = 3,
        img_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_dim: int = 3072,
        hidden_size: int = 768,
        dropout_rate: float = 0.1,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        n_patches = (img_size // patch_size) ** 2
        self.patch_embeddings = nnx.Conv(
            in_channels,
            hidden_size,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding='VALID',
            use_bias=True,
            rngs=rngs,
        )

        initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
        self.position_embeddings = nnx.Param(
            initializer(rngs.params(), (1, n_patches + 1, hidden_size), jnp.float32)
        )
        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
        self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))
        self.encoder = nnx.Sequential(*[
            TransformerEncoder(hidden_size, mlp_dim, num_heads, dropout_rate, rngs=rngs)
            for i in range(num_layers)
        ])
        self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        patches = self.patch_embeddings(x)
        batch_size = patches.shape[0]
        patches = patches.reshape(batch_size, -1, patches.shape[-1])
        cls_token = jnp.tile(self.cls_token, [batch_size, 1, 1])
        x = jnp.concat([cls_token, patches], axis=1)
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        x = self.encoder(embeddings)
        x = self.final_norm(x)
        x = x[:, 0]

        return self.classifier(x)


class TransformerEncoder(nnx.Module):
    def __init__(
        self,
        hidden_size: int,
        mlp_dim: int,
        num_heads: int,
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:
        self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            rngs=rngs,
        )
        self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)

        self.mlp = nnx.Sequential(
            nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
            nnx.gelu,
            nnx.Dropout(dropout_rate, rngs=rngs),
            nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
            nnx.Dropout(dropout_rate, rngs=rngs),
        )

    def __call__(self, x: jax.Array) -> jax.Array:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


model = VisionTransformer(num_classes=1000)
tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')


def vit_inplace_copy_weights(*, src_model, dst_model):
    assert isinstance(src_model, FlaxViTForImageClassification)
    assert isinstance(dst_model, VisionTransformer)

    tf_model_params = src_model.params
    tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)

    flax_model_params = nnx.state(dst_model, nnx.Param)
    flax_model_params_fstate = dict(flax_model_params.flat_state())

    params_name_mapping = {
        ('cls_token',): ('vit', 'embeddings', 'cls_token'),
        ('position_embeddings',): ('vit', 'embeddings', 'position_embeddings'),
        **{
            ('patch_embeddings', x): ('vit', 'embeddings', 'patch_embeddings', 'projection', x)
            for x in ['kernel', 'bias']
        },
        **{
            ('encoder', 'layers', i, 'attn', y, x): (
                'vit', 'encoder', 'layer', str(i), 'attention', 'attention', y, x
            )
            for x in ['kernel', 'bias']
            for y in ['key', 'value', 'query']
            for i in range(12)
        },
        **{
            ('encoder', 'layers', i, 'attn', 'out', x): (
                'vit', 'encoder', 'layer', str(i), 'attention', 'output', 'dense', x
            )
            for x in ['kernel', 'bias']
            for i in range(12)
        },
        **{
            ('encoder', 'layers', i, 'mlp', 'layers', y1, x): (
                'vit', 'encoder', 'layer', str(i), y2, 'dense', x
            )
            for x in ['kernel', 'bias']
            for y1, y2 in [(0, 'intermediate'), (3, 'output')]
            for i in range(12)
        },
        **{
            ('encoder', 'layers', i, y1, x): (
                'vit', 'encoder', 'layer', str(i), y2, x
            )
            for x in ['scale', 'bias']
            for y1, y2 in [('norm1', 'layernorm_before'), ('norm2', 'layernorm_after')]
            for i in range(12)
        },
        **{
            ('final_norm', x): ('vit', 'layernorm', x)
            for x in ['scale', 'bias']
        },
        **{
            ('classifier', x): ('classifier', x)
            for x in ['kernel', 'bias']
        }
    }

    nonvisited = set(flax_model_params_fstate.keys())

    for key1, key2 in params_name_mapping.items():
        assert key1 in flax_model_params_fstate, key1
        assert key2 in tf_model_params_fstate, (key1, key2)

        nonvisited.remove(key1)

        src_value = tf_model_params_fstate[key2]
        if key2[-1] == 'kernel' and key2[-2] in ('key', 'value', 'query'):
            shape = src_value.shape
            src_value = src_value.reshape((shape[0], 12, 64))

        if key2[-1] == 'bias' and key2[-2] in ('key', 'value', 'query'):
            src_value = src_value.reshape((12, 64))

        if key2[-4:] == ('attention', 'output', 'dense', 'kernel'):
            shape = src_value.shape
            src_value = src_value.reshape((12, 64, shape[-1]))

        dst_value = flax_model_params_fstate[key1]
        assert src_value.shape == dst_value.value.shape, (key2, src_value.shape, key1, dst_value.value.shape)
        dst_value.value = src_value.copy()
        assert dst_value.value.mean() == src_value.mean(), (dst_value.value, src_value.mean())

    assert len(nonvisited) == 0, nonvisited
    # Notice the use of `flax.nnx.update` and `flax.nnx.State`.
    nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))

vit_inplace_copy_weights(src_model=tf_model, dst_model=model)

model.classifier = nnx.Linear(model.classifier.in_features, 405, rngs=nnx.Rngs(0))
TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.

Hyperparameters

We need to define the hyperparameters that will control the training process.

Epochs

First, we need to set the number of epochs. While fine-tuning a ViT model for classification, we can expect the following:

  • Epochs 1-5: Massive improvements. The model adapts from “identifying objects” to “identifying birds.”
  • Epochs 5-15: Small, incremental gains.
  • Epochs 15+: Often starts overfitting (unless you have a massive dataset or very strong augmentation).

Let’s do 3 epochs:

num_epochs = 3

Learning rate

The learning rate controls the size of the steps the optimizer takes in the direction of the gradient. You don’t want to overshoot the minimum, so it is a good idea to decrease the learning rate during training. We are doing this with an Optax scheduler, reducing it linearly from 0.001 to 0.

As we saw previously, if you reduce (or increase) the batch size dramatically, you should also reduce (or increase) the learning rate. Since we went from a batch size of 32 to 8, we shouldn’t start with a learning rate that is too high.

import optax

learning_rate = 0.001
total_steps = len(nabirds_train) // train_batch_size
lr_schedule = optax.linear_schedule(learning_rate, 0.0, num_epochs * total_steps)
iterate_subsample = np.linspace(0, num_epochs * total_steps, 100)

We can plot the learning rate schedule:

import matplotlib.pyplot as plt

plt.plot(
    np.linspace(0, num_epochs, len(iterate_subsample)),
    [lr_schedule(i) for i in iterate_subsample],
    lw=3,
)
plt.title('Learning rate')
plt.xlabel('Epochs')
plt.ylabel('Learning rate')
plt.grid()
plt.xlim((0, num_epochs))
plt.show()

Momentum

The momentum controls the inertia of the optimizer. It helps accelerate the optimizer in the right direction and dampens oscillations (instead of using only the current gradient to update weights, momentum adds a fraction of the previous update factor to the current one. So if the gradient keeps pointing in the same direction, the momentum increases the speed of updates. If the gradient bounces back and forth, the momentum decreases the speed of updates).

A momentum of 0.9 is pretty standard:

momentum = 0.9

Optimizer

Finally, we pass our learning rate schedule and momentum to a stochastic gradient descent function from Optax and create our optimizer:

optimizer = nnx.ModelAndOptimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))

Loss function

def compute_losses_and_logits(model: nnx.Module, imgs: jax.Array, species: jax.Array):
    logits = model(imgs)

    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=species
    ).mean()
    return loss, logits

Training and evaluation steps

@nnx.jit              # To JIT compile and automatically use GPU/TPU if available
def train_step(
    model: nnx.Module, optimizer: nnx.Optimizer, imgs: np.ndarray, species_id: np.ndarray
):
    # Convert np.ndarray to jax.Array on GPU
    imgs = jnp.array(imgs)
    species = jnp.array(species_id, dtype=jnp.int32)

    grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)
    (loss, logits), grads = grad_fn(model, imgs, species)

    optimizer.update(grads)  # In-place updates.

    return loss

@nnx.jit
def eval_step(
    model: nnx.Module, eval_metrics: nnx.MultiMetric, imgs: np.ndarray, species_id: np.ndarray
):
    # Convert np.ndarray to jax.Array on GPU
    imgs = jnp.array(imgs)
    species = jnp.array(species_id, dtype=jnp.int32)
    loss, logits = compute_losses_and_logits(model, imgs, species)

    eval_metrics.update(
        loss=loss,
        logits=logits,
        labels=species,
    )

Metrics

eval_metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average('loss'),
    accuracy=nnx.metrics.Accuracy(),
)

train_metrics_history = {
    'train_loss': [],
}

eval_metrics_history = {
    'val_loss': [],
    'val_accuracy': [],
}

Training and evaluation functions

import tqdm

bar_format = '{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]'

def train_one_epoch(epoch):
    model.train()  # Set model to the training mode: e.g. update batch statistics
    with tqdm.tqdm(
        desc=f"[train] epoch: {epoch + 1}/{num_epochs}, ",
        total=total_steps,
        bar_format=bar_format,
        leave=True,
    ) as pbar:
        for batch in train_loader:
            loss = train_step(model, optimizer, batch['img'], batch['species_id'])
            train_metrics_history['train_loss'].append(loss.item())
            pbar.set_postfix({'loss': loss.item()})
            pbar.update(1)

def evaluate_model(epoch):
    # Computes the metrics on the training and test sets after each training epoch.
    model.eval()  # Sets model to evaluation model: e.g. use stored batch statistics.

    eval_metrics.reset()  # Reset the eval metrics
    for val_batch in val_loader:
        eval_step(model, eval_metrics, val_batch['img'], val_batch['species_id'])

    for metric, value in eval_metrics.compute().items():
        eval_metrics_history[f'val_{metric}'].append(value)

    print(f"[val] epoch: {epoch + 1}/{num_epochs}")
    print(f"- total loss: {eval_metrics_history['val_loss'][-1]:0.4f}")
    print(f"- Accuracy: {eval_metrics_history['val_accuracy'][-1]:0.4f}")

Checkpointing

Checkpointing is essential if you train for a long time. This will save you from loosing hours, days, or weeks of training if something happens (cluster issue, power outage, computer failure, training interruption…)

Orbax provides a checkpointing management API for JAX:

import orbax.checkpoint as ocp

First of all, it is important to create a directory for the checkpoints and to ensure that this directory is empty before you start training (otherwise the checkpoints won’t be saved and you will only get an error about it after the training has happened—which means that you are basically loosing your entire training as the trained parameters aren’t saved!).

Let’s create a path for a directory called “checkpoints” in the current directory and ensure that it is empty:

path = ocp.test_utils.erase_and_create_empty('/project/def-sponsor00/nabirds/checkpoints/')

Be careful that the path provided needs to be an absolute path.

Then you set the options you want for the checkpoint manager such as: how often do you want to save checkpoints? how many checkpoints do you want to keep in total (as more checkpoints are saved, early ones can safely be deleted).

For instance, if we want to save a checkpoint every 2 steps and keep the last 3 checkpoints, we can set the options with:

options = ocp.CheckpointManagerOptions(save_interval_steps=2, max_to_keep=3)

Here, we will save a checkpoint at every step (the default) since are only running 3 epochs:

options = ocp.CheckpointManagerOptions(max_to_keep=3)

Now we can define our checkpoint manager:

mngr = ocp.CheckpointManager(path, options=options)

Past this point, if you follow the documentation, you will get the following warning when you try to restore your model from checkpoints:

WARNING:absl:Item "default" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.

This is because Orbax does not yet handle the new JAX PRNG key format.

A way around this is to create a function to save checkpoints that will convert PRNG keys from the new to the old format (function taken from this tutorial):

def save_model(epoch):
    # Get all params, statistics, RNGs, etc. from model:
    state = nnx.state(model)
    # Convert PRNG keys to the old format:
    def get_key_data(x):
        if isinstance(x, jax._src.prng.PRNGKeyArray):
            if isinstance(x.dtype, jax._src.prng.KeyTy):
                return jax.random.key_data(x)
        return x

    serializable_state = jax.tree.map(get_key_data, state)
    mngr.save(epoch, args=ocp.args.StandardSave(serializable_state))
    # Block the manager until all operations have finished running
    # (only useful for asynchronous (distributed) training)
    mngr.wait_until_finished()

Now, to create checkpoints following our management plan, we will add this function to the training loop.

Testing and debugging

It is a good idea to try your code on a very small problem to make sure it runs before launching your training loop: many scripts will only error upon finishing and if you start by running 25 epochs on your full dataset without prior testing, you may be waiting (and sitting on resources) for days before realizing that your code has issues.

An easy way to test your code on a small subset of your data is to play with the samplers of the data loaders: you only need to run the code on a handful of records to debug problems such as the checkpointer not working (and the trained model not being saved!).

So let’s replace num_records=len(nabirds_train) and num_records=len(nabirds_val) by num_records=10 for both the training and validation samplers. Then we can recreate the training and validation loaders:

train_sampler = grain.IndexSampler(
    num_records=10,
    shuffle=True,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1
)

train_loader = grain.DataLoader(
    data_source=nabirds_train,
    sampler=train_sampler,
    operations=[
        ToFloat(),
        RandomCrop(),
        RandomFlip(),
        RandomContrast(),
        RandomGamma(),
        ZScore(),
        grain.Batch(train_batch_size, drop_remainder=True)
    ]
)

val_sampler = grain.IndexSampler(
    num_records=10,
    shuffle=False,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1
)

val_loader = grain.DataLoader(
    data_source=nabirds_val,
    sampler=val_sampler,
    operations=[
        Normalize(),
        grain.Batch(val_batch_size)
    ]
)

Now we can run a test loop:

for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate_model(epoch)
    save_model(epoch)

… and debug as needed.

Once you are sure that the code runs, your checkpoints are being created as planned and all looks good, you can revert the values of the samplers and loaders:

train_sampler = grain.IndexSampler(
    num_records=len(nabirds_train),
    shuffle=True,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1
)

train_loader = grain.DataLoader(
    data_source=nabirds_train,
    sampler=train_sampler,
    operations=[
        ToFloat(),
        RandomCrop(),
        RandomFlip(),
        RandomContrast(),
        RandomGamma(),
        ZScore(),
        grain.Batch(train_batch_size, drop_remainder=True)
    ]
)

val_sampler = grain.IndexSampler(
    num_records=len(nabirds_val),
    shuffle=False,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1
)

val_loader = grain.DataLoader(
    data_source=nabirds_val,
    sampler=val_sampler,
    operations=[
        Normalize(),
        grain.Batch(val_batch_size)
    ]
)

Training loop

Time to run the training loop:

%%time

for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate_model(epoch)
    save_model(epoch)
[train] epoch: 1/3, [2991/2991], loss=0.392 [37:37<00:00]
[val] epoch: 1/3
- total loss: nan
- Accuracy: 0.8144
[train] epoch: 2/3, [2991/2991], loss=0.225 [39:31<00:00]
[val] epoch: 2/3
- total loss: nan
- Accuracy: 0.8545
[train] epoch: 3/3, [2991/2991], loss=0.224 [40:11<00:00]
[val] epoch: 3/3
- total loss: nan
- Accuracy: 0.8666

On my laptop with a dedicated GPU (Nvidia GeForce RTX 2060), each epoch takes about 40 min. On a desktop with a more powerful GPU, training takes about 10 min per epoch.

I am still investigating why my total loss is not a number.

Your turn:

Initially I ran into an out of memory (OOM) problem running this code on my machine:

ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 77463552 bytes.

What were my options at that point?

One option of course is to get more hardware to have access to more VRAM (the GPU memory): you can take the problem to the Alliance clusters or to a commercial cloud service. But training a classification model is something that you should be able to do on a laptop that has a dedicated GPU (which is what I have). How?

We mentioned the solution in an earlier section: reduce the batch size. Do so by dividing it by 2 until you stop running out of memory.

In my case, I tried to go from 32 to 16 and still ran out of memory, then I tried with a batch size of 8 and it worked.

So in the code above I replaced:

train_batch_size = 32

by:

train_batch_size = 8

and ran the rest as is.

Plot metrics

Let’s plot the training metrics.

Evolution of training loss:

plt.plot(train_metrics_history['train_loss'], label='Loss value during the training')
plt.legend()

The loss decreased in a very jagged fashion due to my small batch size. With more memory and a bigger batch size, the descent would have been smoother. To get a smoother descent with such small batches, I could have increased the momentum or calculated the metrics for several batches before passing them to the next step.

I also had some NaN (not a number) for the loss at the start of training. This didn’t matter because the training kept going and after a bit of training they became less frequent and eventually disappeared. To avoid this, I could have started with a lower learning rate.

Evolution of validation loss and accuracy:

fig, axs = plt.subplots(1, 2, figsize=(8, 8))
axs[0].set_title('Loss value on validation set')
axs[0].plot(eval_metrics_history['val_loss'])
axs[1].set_title('Accuracy on validation set')
axs[1].plot(eval_metrics_history['val_accuracy'])

You can also use TensorBoard for this or—even better—experiment tracking tools such as MLflow that will allow you to compare various training experiments.

Test a few samples

We can now run inference on a few test images.

Select a subset of test images and their labels:

test_indices = [250, 500, 750, 1000]
test_images = jnp.array([nabirds_val[i]['img'] for i in test_indices])
expected_labels = [nabirds_val[i]['species_name'] for i in test_indices]

Run the model to get the predictions for this subset:

model.eval()
preds = model(test_images)
probas = nnx.softmax(preds, axis=1)
pred_labels = probas.argmax(axis=1)

We need to define a translation function to get the species names from the species ids for the predictions:

def translator(df, species_id):
    species_name = df.unique(subset='species_id').filter(
        pl.col('species_id') == species_id
    ).select(pl.col('species_name')).item()

    return species_name

Let’s print the subset with their predicted vs expected labels:

num_samples = len(test_indices)

fig, axs = plt.subplots(1, num_samples, figsize=(7, 8))

for i in range(num_samples):
    img, expected_label = test_images[i], expected_labels[i]

    pred_label_id = pred_labels[i].item()
    pred_label_name = translator(metadata, pred_label_id)
    proba = probas[i, pred_label_id].item()
    if img.dtype in (np.float32, ):
        img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)

    plt.tight_layout()

    axs[i].set_title(
        f"""
        Expected: {expected_labels[i]}
        Predicted: {pred_label_name}
        p={proba:.2f}
        """,
        fontsize=6.5,
        linespacing=1.5
    )

    axs[i].axis('off')
    axs[i].imshow(img)

Restore from checkpoint

You can see how many checkpoints you have saved with the following:

# Steps of all checkpoints
mngr.all_steps()
[0, 1, 2]
# Step of the last checkpoint
mngr.latest_step()
2

To restore the last checkpoint:

model = mngr.restore(mngr.latest_step())