Defining model architecture
In this section, we define a model with Flax’s new API called NNX.
Context
Load packages
Package and module necessary for this section:
# to define the model architecture
from flax import nnx
# to get callables from functions with fewer arguments
from functools import partial
Flax API
Flax went through several APIs.
The initial nn
API—now retired—got replaced in 2020 by the Linen API, still available with the Flax package. In 2024, they launched the NNX API.
Each iteration has moved further from JAX and closer to Python, with a syntax increasingly similar to PyTorch.
The old Linen API is a stateless model framework similar to the Julia package Lux.jl. It follows a strict functional programming approach in which the parameters are separate from the model and are passed as inputs to the forward pass along with the data. This is much closer to the JAX sublanguage, more optimized, but restrictive and unpopular in the deep learning community and among Python users.
By contrast, the new NNX API is a stateful model framework similar to PyTorch and the older Julia package Flux.jl: model parameters and optimizer state are stored within the model instance. Flax handles a lot of JAX’s constraints under the hood, making the code more familiar to Python/PyTorch users, simpler, and more forgiving.
While the Linen API still exists, new users are advised to learn the new NNX API.
Simple CNN
We will use LeNet-5 [1] model, initially used on the MNIST dataset by LeCun et al. [2]. We modify it to take three-channel images (RGB for colour images) instead of a single channel (black and white images as was the case in the MNIST) and have five categories as final output.
The architecture of this model is explained in details in this kaggle post.
class CNN(nnx.Module):
"""An adapted LeNet-5 model."""
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(3, 6, kernel_size=(5, 5), rngs=rngs)
self.max_pool = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
self.conv2 = nnx.Conv(6, 16, kernel_size=(5, 5), rngs=rngs)
self.linear1 = nnx.Linear(3136, 120, rngs=rngs)
self.linear2 = nnx.Linear(120, 84, rngs=rngs)
self.linear3 = nnx.Linear(84, 5, rngs=rngs)
def __call__(self, x):
= self.max_pool(nnx.relu(self.conv1(x)))
x = self.max_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = nnx.relu(self.linear2(x))
x = self.linear3(x)
x return x
# Instantiate the model.
= CNN(rngs=nnx.Rngs(0))
model
# Visualize it.
nnx.display(model)