Using the model for inference
Unless you were testing some new method or participating in a deep learning competition, the ultimate reason you trained a classification model is probably that you want to use it.
In this section, we cover how to save our model and how to use it for inference on any image.
How to save the code to script
Throughout this course, we developed the code bits by bits. It is a good practice to wrap it all up in a script. That will be useful for further training and if you move the script to a cluster to train on more hardware.
We have to make a few changes to our code while we create the script:
Strip the code of anything unnecessary that you might have used during prototyping.
It doesn’t make sense to use tqdm anymore, so remove the corresponding code.
We can’t display the graphs anymore, so we save them to files with
plt.savefig()(or remove allmatplotlibcode).When we aren’t using IPython (directly or via Jupyter), we don’t have access to the built-in magic commands such as
%%timeto time the execution of a cell. Instead, we use the following snippet:
start = time()
<Code to time>
end = time()
print(f"Training took {round((end - start) / 60, 1)} minutes")In this case, since it is the training that we want to time:
start = time()
for epoch in range(num_epochs):
train_one_epoch(epoch)
evaluate_model(epoch)
end = time()
print(f"Training took {round((end - start) / 60, 1)} minutes")- Wrap the part of the code related to training in a
mainfunction to prevent training from starting automatically when importing the model definition in other scripts.
Our script
Following the above steps, here is what we get for our script:
nabirds_train.py
import os
import polars as pl
import imageio.v3 as iio
import grain.python as grain
from jax import random
import dm_pix as pix
import numpy as np
import jax
import jax.numpy as jnp
from flax import nnx
from transformers import FlaxViTForImageClassification
import optax
import tqdm
import orbax.checkpoint as ocp
class NABirdsDataset:
"""NABirds dataset class."""
def __init__(self, metadata, data_dir):
self.metadata = metadata
self.data_dir = data_dir
def __len__(self):
return len(self.metadata)
def __getitem__(self, idx):
path = os.path.join(self.data_dir, self.metadata.get_column('path')[idx])
img = iio.imread(path)
species_name = self.metadata.get_column('species_name')[idx]
species_id = self.metadata.get_column('species_id')[idx]
photographer = self.metadata.get_column('photographer')[idx]
return {
'img': img,
'species_name': species_name,
'species_id': species_id,
'photographer': photographer,
}
class Normalize(grain.MapTransform):
def map(self, element):
img = element['img']
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
img = img.astype(np.float32) / 255.0
img_norm = (img - mean) / std
element['img'] = img_norm
return element
class ToFloat(grain.MapTransform):
def map(self, element):
element['img'] = element['img'].astype(np.float32) / 255.0
return element
class RandomCrop(grain.MapTransform):
def map(self, element):
element['img'] = pix.random_crop(
key=jax.random.key(0), # Note: Placeholder, replaced in main via closure/globals if needed or fixed
image=element['img'],
crop_sizes=(224, 224, 3)
)
return element
class RandomFlip(grain.MapTransform):
def map(self, element):
element['img'] = pix.random_flip_left_right(
key=jax.random.key(1),
image=element['img']
)
return element
class RandomContrast(grain.MapTransform):
def map(self, element):
element['img'] = pix.random_contrast(
key=jax.random.key(2),
image=element['img'],
lower=0.8,
upper=1.2
)
return element
class RandomGamma(grain.MapTransform):
def map(self, element):
element['img'] = pix.random_gamma(
key=jax.random.key(3),
image=element['img'],
min_gamma=0.6,
max_gamma=1.2
)
return element
class ZScore(grain.MapTransform):
def map(self, element):
img = element['img']
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
img = (img - mean) / std
element['img'] = img
return element
class TransformerEncoder(nnx.Module):
def __init__(
self,
hidden_size: int,
mlp_dim: int,
num_heads: int,
dropout_rate: float = 0.0,
*,
rngs: nnx.Rngs = nnx.Rngs(0),
) -> None:
self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
self.attn = nnx.MultiHeadAttention(
num_heads=num_heads,
in_features=hidden_size,
dropout_rate=dropout_rate,
broadcast_dropout=False,
decode=False,
deterministic=False,
rngs=rngs,
)
self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)
self.mlp = nnx.Sequential(
nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
nnx.gelu,
nnx.Dropout(dropout_rate, rngs=rngs),
nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
nnx.Dropout(dropout_rate, rngs=rngs),
)
def __call__(self, x: jax.Array) -> jax.Array:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nnx.Module):
def __init__(
self,
num_classes: int = 1000,
in_channels: int = 3,
img_size: int = 224,
patch_size: int = 16,
num_layers: int = 12,
num_heads: int = 12,
mlp_dim: int = 3072,
hidden_size: int = 768,
dropout_rate: float = 0.1,
*,
rngs: nnx.Rngs = nnx.Rngs(0),
):
n_patches = (img_size // patch_size) ** 2
self.patch_embeddings = nnx.Conv(
in_channels,
hidden_size,
kernel_size=(patch_size, patch_size),
strides=(patch_size, patch_size),
padding='VALID',
use_bias=True,
rngs=rngs,
)
initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
self.position_embeddings = nnx.Param(
initializer(rngs.params(), (1, n_patches + 1, hidden_size), jnp.float32)
)
self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))
self.encoder = nnx.Sequential(*[
TransformerEncoder(hidden_size, mlp_dim, num_heads, dropout_rate, rngs=rngs)
for i in range(num_layers)
])
self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)
self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
patches = self.patch_embeddings(x)
batch_size = patches.shape[0]
patches = patches.reshape(batch_size, -1, patches.shape[-1])
cls_token = jnp.tile(self.cls_token, [batch_size, 1, 1])
x = jnp.concat([cls_token, patches], axis=1)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
x = self.encoder(embeddings)
x = self.final_norm(x)
x = x[:, 0]
return self.classifier(x)
def compute_losses_and_logits(model: nnx.Module, imgs: jax.Array, species: jax.Array):
logits = model(imgs)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=species
).mean()
return loss, logits
def vit_inplace_copy_weights(*, src_model, dst_model):
assert isinstance(src_model, FlaxViTForImageClassification)
assert isinstance(dst_model, VisionTransformer)
tf_model_params = src_model.params
tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)
flax_model_params = nnx.state(dst_model, nnx.Param)
flax_model_params_fstate = dict(flax_model_params.flat_state())
params_name_mapping = {
('cls_token',): ('vit', 'embeddings', 'cls_token'),
('position_embeddings',): ('vit', 'embeddings', 'position_embeddings'),
**{
('patch_embeddings', x): ('vit', 'embeddings', 'patch_embeddings', 'projection', x)
for x in ['kernel', 'bias']
},
**{
('encoder', 'layers', i, 'attn', y, x): (
'vit', 'encoder', 'layer', str(i), 'attention', 'attention', y, x
)
for x in ['kernel', 'bias']
for y in ['key', 'value', 'query']
for i in range(12)
},
**{
('encoder', 'layers', i, 'attn', 'out', x): (
'vit', 'encoder', 'layer', str(i), 'attention', 'output', 'dense', x
)
for x in ['kernel', 'bias']
for i in range(12)
},
**{
('encoder', 'layers', i, 'mlp', 'layers', y1, x): (
'vit', 'encoder', 'layer', str(i), y2, 'dense', x
)
for x in ['kernel', 'bias']
for y1, y2 in [(0, 'intermediate'), (3, 'output')]
for i in range(12)
},
**{
('encoder', 'layers', i, y1, x): (
'vit', 'encoder', 'layer', str(i), y2, x
)
for x in ['scale', 'bias']
for y1, y2 in [('norm1', 'layernorm_before'), ('norm2', 'layernorm_after')]
for i in range(12)
},
**{
('final_norm', x): ('vit', 'layernorm', x)
for x in ['scale', 'bias']
},
**{
('classifier', x): ('classifier', x)
for x in ['kernel', 'bias']
}
}
nonvisited = set(flax_model_params_fstate.keys())
for key1, key2 in params_name_mapping.items():
assert key1 in flax_model_params_fstate, key1
assert key2 in tf_model_params_fstate, (key1, key2)
nonvisited.remove(key1)
src_value = tf_model_params_fstate[key2]
if key2[-1] == 'kernel' and key2[-2] in ('key', 'value', 'query'):
shape = src_value.shape
src_value = src_value.reshape((shape[0], 12, 64))
if key2[-1] == 'bias' and key2[-2] in ('key', 'value', 'query'):
src_value = src_value.reshape((12, 64))
if key2[-4:] == ('attention', 'output', 'dense', 'kernel'):
shape = src_value.shape
src_value = src_value.reshape((12, 64, shape[-1]))
dst_value = flax_model_params_fstate[key1]
assert src_value.shape == dst_value.value.shape, (key2, src_value.shape, key1, dst_value.value.shape)
dst_value.value = src_value.copy()
assert dst_value.value.mean() == src_value.mean(), (dst_value.value, src_value.mean())
assert len(nonvisited) == 0, nonvisited
# Notice the use of `flax.nnx.update` and `flax.nnx.State`.
nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))
@nnx.jit
def train_step(
model: nnx.Module, optimizer: nnx.Optimizer, imgs: np.ndarray, species_id: np.ndarray
):
# Convert np.ndarray to jax.Array on GPU
imgs = jnp.array(imgs)
species = jnp.array(species_id, dtype=jnp.int32)
grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)
(loss, logits), grads = grad_fn(model, imgs, species)
optimizer.update(grads) # In-place updates.
return loss
@nnx.jit
def eval_step(
model: nnx.Module, eval_metrics: nnx.MultiMetric, imgs: np.ndarray, species_id: np.ndarray
):
# Convert np.ndarray to jax.Array on GPU
imgs = jnp.array(imgs)
species = jnp.array(species_id, dtype=jnp.int32)
loss, logits = compute_losses_and_logits(model, imgs, species)
eval_metrics.update(
loss=loss,
logits=logits,
labels=species,
)
def main():
base_dir = 'nabirds'
cleaned_img_dir = os.path.join(base_dir, 'cleaned_images')
metadata = pl.read_parquet('metadata.parquet')
metadata_train = metadata.filter(pl.col('is_training_img') == 1)
metadata_val = metadata.filter(pl.col('is_training_img') == 0)
nabirds_train = NABirdsDataset(metadata_train, cleaned_img_dir)
nabirds_val = NABirdsDataset(metadata_val, cleaned_img_dir)
key = random.key(31)
seed = 123
train_batch_size = 8
val_batch_size = 2 * train_batch_size
train_sampler = grain.IndexSampler(
num_records=len(nabirds_train),
shuffle=True,
seed=seed,
shard_options=grain.NoSharding(),
num_epochs=1
)
train_loader = grain.DataLoader(
data_source=nabirds_train,
sampler=train_sampler,
operations=[
ToFloat(),
RandomCrop(),
RandomFlip(),
RandomContrast(),
RandomGamma(),
ZScore(),
grain.Batch(train_batch_size, drop_remainder=True)
]
)
val_sampler = grain.IndexSampler(
num_records=len(nabirds_val),
shuffle=False,
seed=seed,
shard_options=grain.NoSharding(),
num_epochs=1
)
val_loader = grain.DataLoader(
data_source=nabirds_val,
sampler=val_sampler,
operations=[
Normalize(),
grain.Batch(val_batch_size)
]
)
model = VisionTransformer(num_classes=1000)
tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
vit_inplace_copy_weights(src_model=tf_model, dst_model=model)
model.classifier = nnx.Linear(model.classifier.in_features, 405, rngs=nnx.Rngs(0))
num_epochs = 3
learning_rate = 0.001
momentum = 0.9
total_steps = len(nabirds_train) // train_batch_size
lr_schedule = optax.linear_schedule(learning_rate, 0.0, num_epochs * total_steps)
optimizer = nnx.ModelAndOptimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))
train_metrics_history = {
'train_loss': [],
}
eval_metrics_history = {
'val_loss': [],
'val_accuracy': [],
}
eval_metrics = nnx.MultiMetric(
loss=nnx.metrics.Average('loss'),
accuracy=nnx.metrics.Accuracy(),
)
bar_format = '{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]'
def train_one_epoch(epoch):
model.train() # Set model to the training mode: e.g. update batch statistics
with tqdm.tqdm(
desc=f"[train] epoch: {epoch + 1}/{num_epochs}, ",
total=total_steps,
bar_format=bar_format,
leave=True,
) as pbar:
for batch in train_loader:
loss = train_step(model, optimizer, batch['img'], batch['species_id'])
train_metrics_history['train_loss'].append(loss.item())
pbar.set_postfix({'loss': loss.item()})
pbar.update(1)
def evaluate_model(epoch):
# Computes the metrics on the training and test sets after each training epoch.
model.eval() # Sets model to evaluation model: e.g. use stored batch statistics.
eval_metrics.reset() # Reset the eval metrics
for val_batch in val_loader:
eval_step(model, eval_metrics, val_batch['img'], val_batch['species_id'])
for metric, value in eval_metrics.compute().items():
eval_metrics_history[f'val_{metric}'].append(value)
print(f"[val] epoch: {epoch + 1}/{num_epochs}")
print(f"- total loss: {eval_metrics_history['val_loss'][-1]:0.4f}")
print(f"- Accuracy: {eval_metrics_history['val_accuracy'][-1]:0.4f}")
path = ocp.test_utils.erase_and_create_empty('/project/def-sponsor00/nabirds/checkpoints/')
options = ocp.CheckpointManagerOptions(max_to_keep=3)
mngr = ocp.CheckpointManager(path, options=options)
def save_model(epoch):
# Get all params, statistics, RNGs, etc. from model:
state = nnx.state(model)
# Convert PRNG keys to the old format:
def get_key_data(x):
if isinstance(x, jax._src.prng.PRNGKeyArray):
if isinstance(x.dtype, jax._src.prng.KeyTy):
return jax.random.key_data(x)
return x
serializable_state = jax.tree.map(get_key_data, state)
mngr.save(epoch, args=ocp.args.StandardSave(serializable_state))
# Block the manager until all operations have finished running
# (only useful for asynchronous (distributed) training)
mngr.wait_until_finished()
start = time()
for epoch in range(num_epochs):
train_one_epoch(epoch)
evaluate_model(epoch)
save_model(epoch)
end = time()
print(f"Training took {round((end - start) / 60, 1)} minutes")
if __name__ == '__main__':
main()Inference script
Now we need an inference script that we can use with any bird image. It needs to process the image to make it consistent with our model:
nabirds_infer.py
import os
import sys
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
from flax import nnx
import imageio.v3 as iio
import polars as pl
metadata_path = 'metadata.parquet'
checkpoint_path = '/project/def-sponsor00/nabirds/checkpoints/'
def load_species_mapping(metadata_path=metadata_path):
"""Loads species ID to name mapping from metadata."""
if not os.path.exists(metadata_path):
print(f"Metadata file not found at {metadata_path}")
return {}
df = pl.read_parquet(metadata_path)
# Creates specific id -> name mapping
mapping = dict(df.select(['species_id', 'species_name']).unique().iter_rows())
return mapping
def preprocess_image(image_path):
"""Reads and preprocesses an image for the model."""
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found: {image_path}")
img = iio.imread(image_path)
# Ensure 3 channels (RGB)
if img.ndim == 2: # Grayscale
img = img[..., None]
img = np.repeat(img, 3, axis=-1)
elif img.shape[-1] == 4: # RGBA
img = img[..., :3]
# Convert to jax array and normalize to [0, 1]
img = jnp.array(img).astype(jnp.float32) / 255.0
# Resize to 224x224
img = jax.image.resize(img, (224, 224, 3), method='bilinear')
# Normalize with mean/std (matching training logic: ZScore transform)
mean = jnp.array([0.5, 0.5, 0.5])
std = jnp.array([0.5, 0.5, 0.5])
img = (img - mean) / std
# Add batch dimension
img = img[None, ...]
return img
def predict(image_path):
"""Restores model from checkpoint and runs prediction on a single image."""
# Restore model from checkpoint
options = ocp.CheckpointManagerOptions(max_to_keep=3)
mngr = ocp.CheckpointManager(checkpoint_path, options=options)
model = mngr.restore(mngr.latest_step())
model.eval()
mapping = load_species_mapping()
print(f"Processing image: {image_path}")
img = preprocess_image(image_path)
# Inference
logits = model(img)
probs = nnx.softmax(logits)
# Get top prediction
predicted_id = int(jnp.argmax(probs))
confidence = float(jnp.max(probs))
# def translator(df, species_id):
# species_name = df.unique(subset='species_id').filter(
# pl.col('species_id') == species_id
# ).select(pl.col('species_name')).item()
# return species_name
predicted_name = mapping.get(predicted_id, f"Unknown ID {predicted_id}")
print("-" * 30)
print(f"Prediction: {predicted_name}")
print(f"Species ID: {predicted_id}")
print(f"Confidence: {confidence:.2%}")
print("-" * 30)
# Top 5
top_k = 5
top_indices = jnp.argsort(probs, descending=True)[0, :top_k]
print(f"Top {top_k} predictions:")
for idx in top_indices:
idx = int(idx)
score = float(probs[0, idx])
name = mapping.get(idx, f"ID {idx}")
print(f" {name}: {score:.2%}")
if __name__ == "__main__":
if len(sys.argv) > 1:
image_path = sys.argv[1]
predict(image_path)
else:
print("Usage: uv run python nabirds_infer.py <path_to_bird_image>")Usage
From the command line and using uv, the script can be run with:
uv run nabirds_infer.py <path_to_bird_image>