Defining model architecture

Author

Marie-Hélène Burle

In this section, we define a model with Flax’s new API called NNX.

Context

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn pretr Pre-trained model opt Hyperparameters nn->opt pretr->nn cp Checkpoint opt->cp pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc tr transformers tr->pretr fl flax fl->nn oa optax oa->opt ob orbax ob->cp

from datasets import load_dataset
import numpy as np
from torchvision.transforms import v2 as T
import grain.python as grain

train_size = 5 * 750
val_size = 5 * 250

train_dataset = load_dataset("food101",
                             split=f"train[:{train_size}]")

val_dataset = load_dataset("food101",
                           split=f"validation[:{val_size}]")

labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
    label = val_dataset[i]["label"]
    if label not in labels_mapping:
        labels_mapping[label] = index
        index += 1

inv_labels_mapping = {v: k for k, v in labels_mapping.items()}

img_size = 224

def to_np_array(pil_image):
  return np.asarray(pil_image.convert("RGB"))

def normalize(image):
    # Image preprocessing matches the one of pretrained ViT
    mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    image = image.astype(np.float32) / 255.0
    return (image - mean) / std

tv_train_transforms = T.Compose([
    T.RandomResizedCrop((img_size, img_size), scale=(0.7, 1.0)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2, 0.2, 0.2),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

tv_test_transforms = T.Compose([
    T.Resize((img_size, img_size)),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

def get_transform(fn):
    def wrapper(batch):
        batch["image"] = [
            fn(pil_image) for pil_image in batch["image"]
        ]
        # map label index between 0 - 19
        batch["label"] = [
            labels_mapping[label] for label in batch["label"]
        ]
        return batch
    return wrapper

train_transforms = get_transform(tv_train_transforms)
val_transforms = get_transform(tv_test_transforms)

train_dataset = train_dataset.with_transform(train_transforms)
val_dataset = val_dataset.with_transform(val_transforms)

seed = 12
train_batch_size = 32
val_batch_size = 2 * train_batch_size

train_sampler = grain.IndexSampler(
    len(train_dataset),
    shuffle=True,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

val_sampler = grain.IndexSampler(
    len(val_dataset),
    shuffle=False,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

train_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(train_batch_size, drop_remainder=True),
    ]
)

val_loader = grain.DataLoader(
    data_source=val_dataset,
    sampler=val_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(val_batch_size),
    ]
)

Load packages

Package and module necessary for this section:

# to define the model architecture
from flax import nnx

# to get callables from functions with fewer arguments
from functools import partial

Flax API

Flax went through several APIs.

The initial nn API—now retired—got replaced in 2020 by the Linen API, still available with the Flax package. In 2024, they launched the NNX API.

Each iteration has moved further from JAX and closer to Python, with a syntax increasingly similar to PyTorch.

The old Linen API is a stateless model framework similar to the Julia package Lux.jl. It follows a strict functional programming approach in which the parameters are separate from the model and are passed as inputs to the forward pass along with the data. This is much closer to the JAX sublanguage, more optimized, but restrictive and unpopular in the deep learning community and among Python users.

By contrast, the new NNX API is a stateful model framework similar to PyTorch and the older Julia package Flux.jl: model parameters and optimizer state are stored within the model instance. Flax handles a lot of JAX’s constraints under the hood, making the code more familiar to Python/PyTorch users, simpler, and more forgiving.

While the Linen API still exists, new users are advised to learn the new NNX API.

Simple CNN

We will use LeNet-5 [1] model, initially used on the MNIST dataset by LeCun et al. [2]. We modify it to take three-channel images (RGB for colour images) instead of a single channel (black and white images as was the case in the MNIST) and have five categories as final output.

The architecture of this model is explained in details in this kaggle post.

class CNN(nnx.Module):
  """An adapted LeNet-5 model."""

  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(3, 6, kernel_size=(5, 5), rngs=rngs)
    self.max_pool = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
    self.conv2 = nnx.Conv(6, 16, kernel_size=(5, 5), rngs=rngs)
    self.linear1 = nnx.Linear(3136, 120, rngs=rngs)
    self.linear2 = nnx.Linear(120, 84, rngs=rngs)
    self.linear3 = nnx.Linear(84, 5, rngs=rngs)

  def __call__(self, x):
    x = self.max_pool(nnx.relu(self.conv1(x)))
    x = self.max_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = nnx.relu(self.linear2(x))
    x = self.linear3(x)
    return x

# Instantiate the model.
model = CNN(rngs=nnx.Rngs(0))

# Visualize it.
nnx.display(model)