Defining model architecture

Author

Marie-Hélène Burle

In this section, we define a model with Flax’s new API called NNX.

Context

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn pretr Pre-trained model opt Optimize nn->opt pretr->nn cp Checkpoint opt->cp pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc tr transformers tr->pretr fl flax fl->nn oa optax oa->opt ob orbax ob->cp

from datasets import load_dataset
import numpy as np
from torchvision.transforms import v2 as T
import grain.python as grain

train_size = 5 * 750
val_size = 5 * 250

train_dataset = load_dataset("food101",
                             split=f"train[:{train_size}]")

val_dataset = load_dataset("food101",
                           split=f"validation[:{val_size}]")

labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
    label = val_dataset[i]["label"]
    if label not in labels_mapping:
        labels_mapping[label] = index
        index += 1

inv_labels_mapping = {v: k for k, v in labels_mapping.items()}

img_size = 224

def to_np_array(pil_image):
  return np.asarray(pil_image.convert("RGB"))

def normalize(image):
    # Image preprocessing matches the one of pretrained ViT
    mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    image = image.astype(np.float32) / 255.0
    return (image - mean) / std

tv_train_transforms = T.Compose([
    T.RandomResizedCrop((img_size, img_size), scale=(0.7, 1.0)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2, 0.2, 0.2),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

tv_test_transforms = T.Compose([
    T.Resize((img_size, img_size)),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

def get_transform(fn):
    def wrapper(batch):
        batch["image"] = [
            fn(pil_image) for pil_image in batch["image"]
        ]
        # map label index between 0 - 19
        batch["label"] = [
            labels_mapping[label] for label in batch["label"]
        ]
        return batch
    return wrapper

train_transforms = get_transform(tv_train_transforms)
val_transforms = get_transform(tv_test_transforms)

train_dataset = train_dataset.with_transform(train_transforms)
val_dataset = val_dataset.with_transform(val_transforms)

seed = 12
train_batch_size = 32
val_batch_size = 2 * train_batch_size

train_sampler = grain.IndexSampler(
    len(train_dataset),
    shuffle=True,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

val_sampler = grain.IndexSampler(
    len(val_dataset),
    shuffle=False,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

train_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(train_batch_size, drop_remainder=True),
    ]
)

val_loader = grain.DataLoader(
    data_source=val_dataset,
    sampler=val_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(val_batch_size),
    ]
)

Load packages

Package and module necessary for this section:

# to define the jax.Array type
import jax

# general JAX array manipulations
import jax.numpy as jnp

# to define the model architecture
from flax import nnx

# to get callables from functions with fewer arguments
from functools import partial

The Flax NNX API

Flax went through several APIs.

The initial nn API—now retired—got replaced in 2020 by the Linen API, still available with the Flax package. In 2024, they launched the NNX API.

Each iteration has moved further from JAX and closer to Python, with a syntax increasingly similar to PyTorch.

While the Linen API still exists, new users are advised to learn the new NNX API.

Stateful models

The old Linen API is a stateless model framework similar to the Julia package Lux.jl. It follows a strict functional programming approach in which the parameters are separate from the model and are passed as inputs to the forward pass along with the data. This is much closer to the JAX sublanguage, more optimized, but restrictive and unpopular in the deep learning community and among Python users.

By contrast, the new NNX API is a stateful model framework similar to PyTorch and the older Julia package Flux.jl: model parameters and optimizer state are stored within the model instance. Flax handles a lot of JAX’s constraints under the hood, making the code more familiar to Python/PyTorch users, simpler, and more forgiving.

The dynamic state handled by NNX is stored in nnx.Params and the static state (all types not handled by NNX) are stored directly as Python object attributes. This follows the classic Python object-oriented paradigm.

No shape inference

All model dimensions need to be explicitly stated.

Handling of PRNG

We saw that JAX has a complex way to handle pseudo-random number generation. While the Linen API required PRNG to be done explicitly in JAX by the user, the new NNX API defines the random state as an object state stored in a variable and carried by the model.

What this looks like

Define the model architecture:

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

Instantiate the model:

model = Linear(2, 5, rngs=nnx.Rngs(params=0))

Display the model structure:

nnx.display(model)

If you have the penzai package installed, you will see an interactive display of the model.

y = model(x=jnp.ones((1, 2)))
print("Predictions shape: ", y.shape)
Predictions shape:  (1, 5)

Example MLP with Flax NNX

Multilayer perceptrons (MLPs) are fully-connected feed-forward neural networks.

Here is an example of MLP with a single hidden layer for the MNIST dataset by LeCun et al. []:

And here is the implementation in Flax NNX:

class MLP(nnx.Module):

  def __init__(
          self,
          # 28x28 pixel images with 1 channel
          n_features: int = 784,
          n_hidden: int = 300,
          # 10 digits
          n_targets: int = 10,
          *,
          rngs: nnx.Rngs
  ):
    self.n_features = n_features
    self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)
    self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)
    self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)

  def __call__(self, x):
    x = x.reshape(x.shape[0], self.n_features) # flatten
    x = nnx.selu(self.layer1(x))
    x = nnx.selu(self.layer2(x))
    x = self.layer3(x)
    return x

# instantiate the model
model = MLP(rngs=nnx.Rngs(0))

# visualize it
nnx.display(model)

NNX API references:

Example CNN with Flax NNX

Convolutional neural networks (CNNs) take advantage of the spacial correlations that exist in images and allow to greatly reduce the number of neurons in vision networks.

LeNet-5 [] model, initially used on the MNIST dataset by LeCun et al. [], is an early and simple CNN. The architecture of this model is explained in details in this kaggle post and here is a schematic:

Image source: LeCun et al. []

You can find the keras code here and the PyTorch code (slightly modified) here for comparison.

class LeNet(nnx.Module):
  """An adapted LeNet-5 model."""

  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 6, kernel_size=(5, 5), rngs=rngs)
    self.max_pool = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
    self.conv2 = nnx.Conv(6, 16, kernel_size=(5, 5), rngs=rngs)
    self.linear1 = nnx.Linear(16 * 4 * 4, 120, rngs=rngs)
    self.linear2 = nnx.Linear(120, 84, rngs=rngs)
    self.linear3 = nnx.Linear(84, 10, rngs=rngs)

  def __call__(self, x):
    x = self.max_pool(nnx.relu(self.conv1(x)))
    x = self.max_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = nnx.relu(self.linear2(x))
    x = self.linear3(x)
    return x

# instantiate the model
model = LeNet(rngs=nnx.Rngs(0))

# visualize it
nnx.display(model)

NNX API references:

ViT with Flax NNX

LeNet (various iterations until 1998) was followed by AlexNet in 2011 and many increasingly complex CNNs, until multi-head attention and transformers changed everything.

Transformers are a complex neural network architecture developed by Google in 2017, after the seminal paper “Attention Is All You Need” []—cited 175,083 times as of April 2025 (!!!)—came out. They were initially only used in natural language processing (NLP), but have since been applied to vision.

To classify our food dataset, we will use the vision transformer (ViT) introduced by Dosovitskiy et al. [] (that we will fine-tune in a later section).

Here is a schematic of the model:

Image source: Dosovitskiy et al. []

And here is the JAX implementation by Google Research:

class VisionTransformer(nnx.Module):
    def __init__(
        self,
        num_classes: int = 1000,
        in_channels: int = 3,
        img_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_dim: int = 3072,
        hidden_size: int = 768,
        dropout_rate: float = 0.1,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        # Patch and position embedding
        n_patches = (img_size // patch_size) ** 2
        self.patch_embeddings = nnx.Conv(
            in_channels,
            hidden_size,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding="VALID",
            use_bias=True,
            rngs=rngs,
        )

        initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
        self.position_embeddings = nnx.Param(
            initializer(
                rngs.params(),
                (1, n_patches + 1, hidden_size),
                jnp.float32
            )
        )
        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)

        self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))

        # Transformer Encoder blocks
        self.encoder = nnx.Sequential(*[
            TransformerEncoder(
                hidden_size,
                mlp_dim,
                num_heads,
                dropout_rate,
                rngs=rngs
            )
            for i in range(num_layers)
        ])
        self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)

        # Classification head
        self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        # Patch and position embedding
        patches = self.patch_embeddings(x)
        batch_size = patches.shape[0]
        patches = patches.reshape(batch_size, -1, patches.shape[-1])

        cls_token = jnp.tile(self.cls_token, [batch_size, 1, 1])
        x = jnp.concat([cls_token, patches], axis=1)
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)

        # Encoder blocks
        x = self.encoder(embeddings)
        x = self.final_norm(x)

        # fetch the first token
        x = x[:, 0]

        # Classification
        return self.classifier(x)

class TransformerEncoder(nnx.Module):
    def __init__(
        self,
        hidden_size: int,
        mlp_dim: int,
        num_heads: int,
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:

        self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            rngs=rngs,
        )
        self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)

        self.mlp = nnx.Sequential(
            nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
            nnx.gelu,
            nnx.Dropout(dropout_rate, rngs=rngs),
            nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
            nnx.Dropout(dropout_rate, rngs=rngs),
        )

    def __call__(self, x: jax.Array) -> jax.Array:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

x = jnp.ones((4, 224, 224, 3))
model = VisionTransformer(num_classes=1000)
y = model(x)
print("Predictions shape: ", y.shape)
Predictions shape:  (4, 1000)

NNX API references:

References

1.
LeCun Y, Cortes C, Burges C (2010) MNIST handwritten digit database. ATT Labs [Online] Available: http://yannlecuncom/exdb/mnist 2
2.
LeCun Y, Bottou L, Bengio Y, Haffner P (1998) Gradient-based learning applied to document recognition. Proceedings of the IEEE 86(11):2278–2324
3.
Vaswani A, Shazeer N, Parmar N, et al (2017) Attention is all you need. Advances in neural information processing systems 30
4.
Dosovitskiy A, Beyer L, Kolesnikov A, et al (2021) An image is worth 16x16 words: Transformers for image recognition at scale