Defining model architecture
In this section, we define a model with Flax’s new API called NNX.
Context
Load packages
Package and module necessary for this section:
# to define the jax.Array type
import jax
# general JAX array manipulations
import jax.numpy as jnp
# to define the model architecture
from flax import nnx
# to get callables from functions with fewer arguments
from functools import partial
The Flax NNX API
Flax went through several APIs.
The initial nn
API—now retired—got replaced in 2020 by the Linen API, still available with the Flax package. In 2024, they launched the NNX API.
Each iteration has moved further from JAX and closer to Python, with a syntax increasingly similar to PyTorch.
While the Linen API still exists, new users are advised to learn the new NNX API.
Stateful models
The old Linen API is a stateless model framework similar to the Julia package Lux.jl. It follows a strict functional programming approach in which the parameters are separate from the model and are passed as inputs to the forward pass along with the data. This is much closer to the JAX sublanguage, more optimized, but restrictive and unpopular in the deep learning community and among Python users.
By contrast, the new NNX API is a stateful model framework similar to PyTorch and the older Julia package Flux.jl: model parameters and optimizer state are stored within the model instance. Flax handles a lot of JAX’s constraints under the hood, making the code more familiar to Python/PyTorch users, simpler, and more forgiving.
The dynamic state handled by NNX is stored in nnx.Params
and the static state (all types not handled by NNX) are stored directly as Python object attributes. This follows the classic Python object-oriented paradigm.
No shape inference
All model dimensions need to be explicitly stated.
Handling of PRNG
We saw that JAX has a complex way to handle pseudo-random number generation. While the Linen API required PRNG to be done explicitly in JAX by the user, the new NNX API defines the random state as an object state stored in a variable and carried by the model.
What this looks like
Define the model architecture:
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
= rngs.params()
key self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w + self.b
Instantiate the model:
= Linear(2, 5, rngs=nnx.Rngs(params=0)) model
Display the model structure:
nnx.display(model)
If you have the penzai package installed, you will see an interactive display of the model.
= model(x=jnp.ones((1, 2)))
y print("Predictions shape: ", y.shape)
Predictions shape: (1, 5)
Example MLP with Flax NNX
Multilayer perceptrons (MLPs) are fully-connected feed-forward neural networks.
Here is an example of MLP with a single hidden layer for the MNIST dataset by LeCun et al. [1]:
And here is the implementation in Flax NNX:
class MLP(nnx.Module):
def __init__(
self,
# 28x28 pixel images with 1 channel
int = 784,
n_features: int = 300,
n_hidden: # 10 digits
int = 10,
n_targets: *,
rngs: nnx.Rngs
):self.n_features = n_features
self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)
self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)
self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)
def __call__(self, x):
= x.reshape(x.shape[0], self.n_features) # flatten
x = nnx.selu(self.layer1(x))
x = nnx.selu(self.layer2(x))
x = self.layer3(x)
x return x
# instantiate the model
= MLP(rngs=nnx.Rngs(0))
model
# visualize it
nnx.display(model)
NNX API references:
- flax.nnx.Linear layer class
- flax.nnx.selu SELU activation function
Example CNN with Flax NNX
Convolutional neural networks (CNNs) take advantage of the spacial correlations that exist in images and allow to greatly reduce the number of neurons in vision networks.
LeNet-5 [2] model, initially used on the MNIST dataset by LeCun et al. [1], is an early and simple CNN. The architecture of this model is explained in details in this kaggle post and here is a schematic:
You can find the keras code here and the PyTorch code (slightly modified) here for comparison.
class LeNet(nnx.Module):
"""An adapted LeNet-5 model."""
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 6, kernel_size=(5, 5), rngs=rngs)
self.max_pool = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
self.conv2 = nnx.Conv(6, 16, kernel_size=(5, 5), rngs=rngs)
self.linear1 = nnx.Linear(16 * 4 * 4, 120, rngs=rngs)
self.linear2 = nnx.Linear(120, 84, rngs=rngs)
self.linear3 = nnx.Linear(84, 10, rngs=rngs)
def __call__(self, x):
= self.max_pool(nnx.relu(self.conv1(x)))
x = self.max_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = nnx.relu(self.linear2(x))
x = self.linear3(x)
x return x
# instantiate the model
= LeNet(rngs=nnx.Rngs(0))
model
# visualize it
nnx.display(model)
NNX API references:
- flax.nnx.Conv convolution module
- flax.nnx.Linear layer class
flax.nnx.max_pool
is missing in the API documentation as of April 2025- flax.nnx.relu activation function
ViT with Flax NNX
LeNet (various iterations until 1998) was followed by AlexNet in 2011 and many increasingly complex CNNs, until multi-head attention and transformers changed everything.
Transformers are a complex neural network architecture developed by Google in 2017, after the seminal paper “Attention Is All You Need” [3]—cited 175,083 times as of April 2025 (!!!)—came out. They were initially only used in natural language processing (NLP), but have since been applied to vision.
To classify our food dataset, we will use the vision transformer (ViT) introduced by Dosovitskiy et al. [4] (that we will fine-tune in a later section).
Here is a schematic of the model:
And here is the JAX implementation by Google Research:
class VisionTransformer(nnx.Module):
def __init__(
self,
int = 1000,
num_classes: int = 3,
in_channels: int = 224,
img_size: int = 16,
patch_size: int = 12,
num_layers: int = 12,
num_heads: int = 3072,
mlp_dim: int = 768,
hidden_size: float = 0.1,
dropout_rate: *,
= nnx.Rngs(0),
rngs: nnx.Rngs
):# Patch and position embedding
= (img_size // patch_size) ** 2
n_patches self.patch_embeddings = nnx.Conv(
in_channels,
hidden_size,=(patch_size, patch_size),
kernel_size=(patch_size, patch_size),
strides="VALID",
padding=True,
use_bias=rngs,
rngs
)
= jax.nn.initializers.truncated_normal(stddev=0.02)
initializer self.position_embeddings = nnx.Param(
initializer(
rngs.params(),1, n_patches + 1, hidden_size),
(
jnp.float32
)
)self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))
# Transformer Encoder blocks
self.encoder = nnx.Sequential(*[
TransformerEncoder(
hidden_size,
mlp_dim,
num_heads,
dropout_rate,=rngs
rngs
)for i in range(num_layers)
])self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)
# Classification head
self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
# Patch and position embedding
= self.patch_embeddings(x)
patches = patches.shape[0]
batch_size = patches.reshape(batch_size, -1, patches.shape[-1])
patches
= jnp.tile(self.cls_token, [batch_size, 1, 1])
cls_token = jnp.concat([cls_token, patches], axis=1)
x = x + self.position_embeddings
embeddings = self.dropout(embeddings)
embeddings
# Encoder blocks
= self.encoder(embeddings)
x = self.final_norm(x)
x
# fetch the first token
= x[:, 0]
x
# Classification
return self.classifier(x)
class TransformerEncoder(nnx.Module):
def __init__(
self,
int,
hidden_size: int,
mlp_dim: int,
num_heads: float = 0.0,
dropout_rate: *,
= nnx.Rngs(0),
rngs: nnx.Rngs -> None:
)
self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
self.attn = nnx.MultiHeadAttention(
=num_heads,
num_heads=hidden_size,
in_features=dropout_rate,
dropout_rate=False,
broadcast_dropout=False,
decode=False,
deterministic=rngs,
rngs
)self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)
self.mlp = nnx.Sequential(
=rngs),
nnx.Linear(hidden_size, mlp_dim, rngs
nnx.gelu,=rngs),
nnx.Dropout(dropout_rate, rngs=rngs),
nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
nnx.Dropout(dropout_rate, rngs
)
def __call__(self, x: jax.Array) -> jax.Array:
= x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
x return x
= jnp.ones((4, 224, 224, 3))
x = VisionTransformer(num_classes=1000)
model = model(x)
y print("Predictions shape: ", y.shape)
Predictions shape: (4, 1000)
NNX API references:
- flax.nnx.Conv convolution module
- flax.nnx.Dropout dropout class
- flax.nnx.LayerNorm layer normalization class
- flax.nnx.Linear linear layer class
- flax.nnx.MultiHeadAttention multi-head attention class
- flax.nnx.Param parameter class
- flax.nnx.Sequential helper class
- flax.nnx.gelu GELU activation function