Using the model for inference

Author

Marie-Hélène Burle

Unless you were testing some new method or participating in a deep learning competition, the ultimate reason you trained a classification model is probably that you want to use it.

In this section, we cover how to save our model and how to use it for inference on any image.

How to save the code to script

Throughout this course, we developed the code bits by bits. It is a good practice to wrap it all up in a script. That will be useful for further training and if you move the script to a cluster to train on more hardware.

We have to make a few changes to our code while we create the script:

  • Strip the code of anything unnecessary that you might have used during prototyping.

  • It doesn’t make sense to use tqdm anymore, so remove the corresponding code.

  • We can’t display the graphs anymore, so we save them to files with plt.savefig() (or remove all matplotlib code).

  • When we aren’t using IPython (directly or via Jupyter), we don’t have access to the built-in magic commands such as %%time to time the execution of a cell. Instead, we use the following snippet:

start = time()

<Code to time>

end = time()

print(f"Training took {round((end - start) / 60, 1)} minutes")

In this case, since it is the training that we want to time:

start = time()

for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate_model(epoch)

end = time()

print(f"Training took {round((end - start) / 60, 1)} minutes")
  • Wrap the part of the code related to training in a main function to prevent training from starting automatically when importing the model definition in other scripts.

Our script

Following the above steps, here is what we get for our script:

nabirds_train.py
import os
import polars as pl
import imageio.v3 as iio
import grain.python as grain
from jax import random
import dm_pix as pix
import numpy as np
import jax
import jax.numpy as jnp
from flax import nnx
from transformers import FlaxViTForImageClassification
import optax
import tqdm
import orbax.checkpoint as ocp


class NABirdsDataset:
    """NABirds dataset class."""
    def __init__(self, metadata, data_dir):
        self.metadata = metadata
        self.data_dir = data_dir
    def __len__(self):
        return len(self.metadata)
    def __getitem__(self, idx):
        path = os.path.join(self.data_dir, self.metadata.get_column('path')[idx])
        img = iio.imread(path)
        species_name = self.metadata.get_column('species_name')[idx]
        species_id = self.metadata.get_column('species_id')[idx]
        photographer = self.metadata.get_column('photographer')[idx]
        return {
            'img': img,
            'species_name': species_name,
            'species_id': species_id,
            'photographer': photographer,
        }


class Normalize(grain.MapTransform):
    def map(self, element):
        img = element['img']
        mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        img = img.astype(np.float32) / 255.0
        img_norm = (img - mean) / std
        element['img'] = img_norm
        return element


class ToFloat(grain.MapTransform):
    def map(self, element):
        element['img'] = element['img'].astype(np.float32) / 255.0
        return element


class RandomCrop(grain.MapTransform):
    def map(self, element):
        element['img'] = pix.random_crop(
            key=jax.random.key(0), # Note: Placeholder, replaced in main via closure/globals if needed or fixed
            image=element['img'],
            crop_sizes=(224, 224, 3)
        )
        return element


class RandomFlip(grain.MapTransform):
    def map(self, element):
        element['img'] = pix.random_flip_left_right(
            key=jax.random.key(1),
            image=element['img']
        )
        return element


class RandomContrast(grain.MapTransform):
    def map(self, element):
        element['img'] = pix.random_contrast(
            key=jax.random.key(2),
            image=element['img'],
            lower=0.8,
            upper=1.2
        )
        return element


class RandomGamma(grain.MapTransform):
    def map(self, element):
        element['img'] = pix.random_gamma(
            key=jax.random.key(3),
            image=element['img'],
            min_gamma=0.6,
            max_gamma=1.2
        )
        return element


class ZScore(grain.MapTransform):
    def map(self, element):
        img = element['img']
        mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        img = (img - mean) / std
        element['img'] = img
        return element


class TransformerEncoder(nnx.Module):
    def __init__(
        self,
        hidden_size: int,
        mlp_dim: int,
        num_heads: int,
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:
        self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            rngs=rngs,
        )
        self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.mlp = nnx.Sequential(
            nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
            nnx.gelu,
            nnx.Dropout(dropout_rate, rngs=rngs),
            nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
            nnx.Dropout(dropout_rate, rngs=rngs),
        )
    def __call__(self, x: jax.Array) -> jax.Array:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class VisionTransformer(nnx.Module):
    def __init__(
        self,
        num_classes: int = 1000,
        in_channels: int = 3,
        img_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_dim: int = 3072,
        hidden_size: int = 768,
        dropout_rate: float = 0.1,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        n_patches = (img_size // patch_size) ** 2
        self.patch_embeddings = nnx.Conv(
            in_channels,
            hidden_size,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding='VALID',
            use_bias=True,
            rngs=rngs,
        )
        initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
        self.position_embeddings = nnx.Param(
            initializer(rngs.params(), (1, n_patches + 1, hidden_size), jnp.float32)
        )
        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
        self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))
        self.encoder = nnx.Sequential(*[
            TransformerEncoder(hidden_size, mlp_dim, num_heads, dropout_rate, rngs=rngs)
            for i in range(num_layers)
        ])
        self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)
    def __call__(self, x: jax.Array) -> jax.Array:
        patches = self.patch_embeddings(x)
        batch_size = patches.shape[0]
        patches = patches.reshape(batch_size, -1, patches.shape[-1])
        cls_token = jnp.tile(self.cls_token, [batch_size, 1, 1])
        x = jnp.concat([cls_token, patches], axis=1)
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        x = self.encoder(embeddings)
        x = self.final_norm(x)
        x = x[:, 0]
        return self.classifier(x)


def compute_losses_and_logits(model: nnx.Module, imgs: jax.Array, species: jax.Array):
    logits = model(imgs)
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=species
    ).mean()
    return loss, logits


def vit_inplace_copy_weights(*, src_model, dst_model):
    assert isinstance(src_model, FlaxViTForImageClassification)
    assert isinstance(dst_model, VisionTransformer)
    tf_model_params = src_model.params
    tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)
    flax_model_params = nnx.state(dst_model, nnx.Param)
    flax_model_params_fstate = dict(flax_model_params.flat_state())
    params_name_mapping = {
        ('cls_token',): ('vit', 'embeddings', 'cls_token'),
        ('position_embeddings',): ('vit', 'embeddings', 'position_embeddings'),
        **{
            ('patch_embeddings', x): ('vit', 'embeddings', 'patch_embeddings', 'projection', x)
            for x in ['kernel', 'bias']
        },
        **{
            ('encoder', 'layers', i, 'attn', y, x): (
                'vit', 'encoder', 'layer', str(i), 'attention', 'attention', y, x
            )
            for x in ['kernel', 'bias']
            for y in ['key', 'value', 'query']
            for i in range(12)
        },
        **{
            ('encoder', 'layers', i, 'attn', 'out', x): (
                'vit', 'encoder', 'layer', str(i), 'attention', 'output', 'dense', x
            )
            for x in ['kernel', 'bias']
            for i in range(12)
        },
        **{
            ('encoder', 'layers', i, 'mlp', 'layers', y1, x): (
                'vit', 'encoder', 'layer', str(i), y2, 'dense', x
            )
            for x in ['kernel', 'bias']
            for y1, y2 in [(0, 'intermediate'), (3, 'output')]
            for i in range(12)
        },
        **{
            ('encoder', 'layers', i, y1, x): (
                'vit', 'encoder', 'layer', str(i), y2, x
            )
            for x in ['scale', 'bias']
            for y1, y2 in [('norm1', 'layernorm_before'), ('norm2', 'layernorm_after')]
            for i in range(12)
        },
        **{
            ('final_norm', x): ('vit', 'layernorm', x)
            for x in ['scale', 'bias']
        },
        **{
            ('classifier', x): ('classifier', x)
            for x in ['kernel', 'bias']
        }
    }
    nonvisited = set(flax_model_params_fstate.keys())
    for key1, key2 in params_name_mapping.items():
        assert key1 in flax_model_params_fstate, key1
        assert key2 in tf_model_params_fstate, (key1, key2)
        nonvisited.remove(key1)
        src_value = tf_model_params_fstate[key2]
        if key2[-1] == 'kernel' and key2[-2] in ('key', 'value', 'query'):
            shape = src_value.shape
            src_value = src_value.reshape((shape[0], 12, 64))
        if key2[-1] == 'bias' and key2[-2] in ('key', 'value', 'query'):
            src_value = src_value.reshape((12, 64))
        if key2[-4:] == ('attention', 'output', 'dense', 'kernel'):
            shape = src_value.shape
            src_value = src_value.reshape((12, 64, shape[-1]))
        dst_value = flax_model_params_fstate[key1]
        assert src_value.shape == dst_value.value.shape, (key2, src_value.shape, key1, dst_value.value.shape)
        dst_value.value = src_value.copy()
        assert dst_value.value.mean() == src_value.mean(), (dst_value.value, src_value.mean())
    assert len(nonvisited) == 0, nonvisited
    # Notice the use of `flax.nnx.update` and `flax.nnx.State`.
    nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))


@nnx.jit
def train_step(
    model: nnx.Module, optimizer: nnx.Optimizer, imgs: np.ndarray, species_id: np.ndarray
):
    # Convert np.ndarray to jax.Array on GPU
    imgs = jnp.array(imgs)
    species = jnp.array(species_id, dtype=jnp.int32)
    grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)
    (loss, logits), grads = grad_fn(model, imgs, species)
    optimizer.update(grads)  # In-place updates.
    return loss


@nnx.jit
def eval_step(
    model: nnx.Module, eval_metrics: nnx.MultiMetric, imgs: np.ndarray, species_id: np.ndarray
):
    # Convert np.ndarray to jax.Array on GPU
    imgs = jnp.array(imgs)
    species = jnp.array(species_id, dtype=jnp.int32)
    loss, logits = compute_losses_and_logits(model, imgs, species)
    eval_metrics.update(
        loss=loss,
        logits=logits,
        labels=species,
    )


def main():
    base_dir = 'nabirds'
    cleaned_img_dir = os.path.join(base_dir, 'cleaned_images')

    metadata = pl.read_parquet('metadata.parquet')
    metadata_train = metadata.filter(pl.col('is_training_img') == 1)
    metadata_val = metadata.filter(pl.col('is_training_img') == 0)

    nabirds_train = NABirdsDataset(metadata_train, cleaned_img_dir)
    nabirds_val = NABirdsDataset(metadata_val, cleaned_img_dir)

    key = random.key(31)

    seed = 123
    train_batch_size = 8
    val_batch_size = 2 * train_batch_size

    train_sampler = grain.IndexSampler(
        num_records=len(nabirds_train),
        shuffle=True,
        seed=seed,
        shard_options=grain.NoSharding(),
        num_epochs=1
    )

    train_loader = grain.DataLoader(
        data_source=nabirds_train,
        sampler=train_sampler,
        operations=[
            ToFloat(),
            RandomCrop(),
            RandomFlip(),
            RandomContrast(),
            RandomGamma(),
            ZScore(),
            grain.Batch(train_batch_size, drop_remainder=True)
        ]
    )

    val_sampler = grain.IndexSampler(
        num_records=len(nabirds_val),
        shuffle=False,
        seed=seed,
        shard_options=grain.NoSharding(),
        num_epochs=1
    )

    val_loader = grain.DataLoader(
        data_source=nabirds_val,
        sampler=val_sampler,
        operations=[
            Normalize(),
            grain.Batch(val_batch_size)
        ]
    )

    model = VisionTransformer(num_classes=1000)
    tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

    vit_inplace_copy_weights(src_model=tf_model, dst_model=model)

    model.classifier = nnx.Linear(model.classifier.in_features, 405, rngs=nnx.Rngs(0))

    num_epochs = 3
    learning_rate = 0.001
    momentum = 0.9
    total_steps = len(nabirds_train) // train_batch_size

    lr_schedule = optax.linear_schedule(learning_rate, 0.0, num_epochs * total_steps)

    optimizer = nnx.ModelAndOptimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))

    train_metrics_history = {
        'train_loss': [],
    }

    eval_metrics_history = {
        'val_loss': [],
        'val_accuracy': [],
    }

    eval_metrics = nnx.MultiMetric(
        loss=nnx.metrics.Average('loss'),
        accuracy=nnx.metrics.Accuracy(),
    )

    bar_format = '{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]'

    def train_one_epoch(epoch):
        model.train()  # Set model to the training mode: e.g. update batch statistics
        with tqdm.tqdm(
            desc=f"[train] epoch: {epoch + 1}/{num_epochs}, ",
            total=total_steps,
            bar_format=bar_format,
            leave=True,
        ) as pbar:
            for batch in train_loader:
                loss = train_step(model, optimizer, batch['img'], batch['species_id'])
                train_metrics_history['train_loss'].append(loss.item())
                pbar.set_postfix({'loss': loss.item()})
                pbar.update(1)

    def evaluate_model(epoch):
        # Computes the metrics on the training and test sets after each training epoch.
        model.eval()  # Sets model to evaluation model: e.g. use stored batch statistics.
        eval_metrics.reset()  # Reset the eval metrics
        for val_batch in val_loader:
            eval_step(model, eval_metrics, val_batch['img'], val_batch['species_id'])
        for metric, value in eval_metrics.compute().items():
            eval_metrics_history[f'val_{metric}'].append(value)
        print(f"[val] epoch: {epoch + 1}/{num_epochs}")
        print(f"- total loss: {eval_metrics_history['val_loss'][-1]:0.4f}")
        print(f"- Accuracy: {eval_metrics_history['val_accuracy'][-1]:0.4f}")

    path = ocp.test_utils.erase_and_create_empty('/project/def-sponsor00/nabirds/checkpoints/')
    options = ocp.CheckpointManagerOptions(max_to_keep=3)
    mngr = ocp.CheckpointManager(path, options=options)

    def save_model(epoch):
        # Get all params, statistics, RNGs, etc. from model:
        state = nnx.state(model)
        # Convert PRNG keys to the old format:
        def get_key_data(x):
            if isinstance(x, jax._src.prng.PRNGKeyArray):
                if isinstance(x.dtype, jax._src.prng.KeyTy):
                    return jax.random.key_data(x)
            return x

        serializable_state = jax.tree.map(get_key_data, state)
        mngr.save(epoch, args=ocp.args.StandardSave(serializable_state))
        # Block the manager until all operations have finished running
        # (only useful for asynchronous (distributed) training)
        mngr.wait_until_finished()

    start = time()

    for epoch in range(num_epochs):
        train_one_epoch(epoch)
        evaluate_model(epoch)
        save_model(epoch)

    end = time()

    print(f"Training took {round((end - start) / 60, 1)} minutes")

if __name__ == '__main__':
    main()

Inference script

Now we need an inference script that we can use with any bird image. It needs to process the image to make it consistent with our model:

nabirds_infer.py
import os
import sys
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
from flax import nnx
import imageio.v3 as iio
import polars as pl


metadata_path = 'metadata.parquet'
checkpoint_path = '/project/def-sponsor00/nabirds/checkpoints/'


def load_species_mapping(metadata_path=metadata_path):
    """Loads species ID to name mapping from metadata."""
    if not os.path.exists(metadata_path):
        print(f"Metadata file not found at {metadata_path}")
        return {}

    df = pl.read_parquet(metadata_path)
    # Creates specific id -> name mapping
    mapping = dict(df.select(['species_id', 'species_name']).unique().iter_rows())
    return mapping


def preprocess_image(image_path):
    """Reads and preprocesses an image for the model."""
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")

    img = iio.imread(image_path)

    # Ensure 3 channels (RGB)
    if img.ndim == 2: # Grayscale
        img = img[..., None]
        img = np.repeat(img, 3, axis=-1)
    elif img.shape[-1] == 4: # RGBA
        img = img[..., :3]

    # Convert to jax array and normalize to [0, 1]
    img = jnp.array(img).astype(jnp.float32) / 255.0

    # Resize to 224x224
    img = jax.image.resize(img, (224, 224, 3), method='bilinear')

    # Normalize with mean/std (matching training logic: ZScore transform)
    mean = jnp.array([0.5, 0.5, 0.5])
    std = jnp.array([0.5, 0.5, 0.5])
    img = (img - mean) / std

    # Add batch dimension
    img = img[None, ...]

    return img


def predict(image_path):
    """Restores model from checkpoint and runs prediction on a single image."""
    # Restore model from checkpoint
    options = ocp.CheckpointManagerOptions(max_to_keep=3)
    mngr = ocp.CheckpointManager(checkpoint_path, options=options)
    model = mngr.restore(mngr.latest_step())
    model.eval()

    mapping = load_species_mapping()

    print(f"Processing image: {image_path}")
    img = preprocess_image(image_path)

    # Inference
    logits = model(img)
    probs = nnx.softmax(logits)

    # Get top prediction
    predicted_id = int(jnp.argmax(probs))
    confidence = float(jnp.max(probs))

    # def translator(df, species_id):
    #     species_name = df.unique(subset='species_id').filter(
    #         pl.col('species_id') == species_id
    #     ).select(pl.col('species_name')).item()
    #     return species_name

    predicted_name = mapping.get(predicted_id, f"Unknown ID {predicted_id}")

    print("-" * 30)
    print(f"Prediction: {predicted_name}")
    print(f"Species ID: {predicted_id}")
    print(f"Confidence: {confidence:.2%}")
    print("-" * 30)

    # Top 5
    top_k = 5
    top_indices = jnp.argsort(probs, descending=True)[0, :top_k]
    print(f"Top {top_k} predictions:")
    for idx in top_indices:
        idx = int(idx)
        score = float(probs[0, idx])
        name = mapping.get(idx, f"ID {idx}")
        print(f"  {name}: {score:.2%}")


if __name__ == "__main__":
    if len(sys.argv) > 1:
        image_path = sys.argv[1]
        predict(image_path)
    else:
        print("Usage: uv run python nabirds_infer.py <path_to_bird_image>")

Usage

From the command line and using uv, the script can be run with:

uv run nabirds_infer.py <path_to_bird_image>