Defining model architecture

Author

Marie-Hélène Burle

In this section, we define a model with Flax’s new API called NNX.

Context

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn pretr Pre-trained model opt Optimize nn->opt pretr->nn cp Checkpoint opt->cp pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc tr transformers tr->pretr fl flax fl->nn oa optax oa->opt ob orbax ob->cp

from datasets import load_dataset
import numpy as np
from torchvision.transforms import v2 as T
import grain.python as grain

train_size = 5 * 750
val_size = 5 * 250

train_dataset = load_dataset("food101",
                             split=f"train[:{train_size}]")

val_dataset = load_dataset("food101",
                           split=f"validation[:{val_size}]")

labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
    label = val_dataset[i]["label"]
    if label not in labels_mapping:
        labels_mapping[label] = index
        index += 1

inv_labels_mapping = {v: k for k, v in labels_mapping.items()}

img_size = 224

def to_np_array(pil_image):
  return np.asarray(pil_image.convert("RGB"))

def normalize(image):
    # Image preprocessing matches the one of pretrained ViT
    mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    image = image.astype(np.float32) / 255.0
    return (image - mean) / std

tv_train_transforms = T.Compose([
    T.RandomResizedCrop((img_size, img_size), scale=(0.7, 1.0)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2, 0.2, 0.2),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

tv_test_transforms = T.Compose([
    T.Resize((img_size, img_size)),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

def get_transform(fn):
    def wrapper(batch):
        batch["image"] = [
            fn(pil_image) for pil_image in batch["image"]
        ]
        # map label index between 0 - 19
        batch["label"] = [
            labels_mapping[label] for label in batch["label"]
        ]
        return batch
    return wrapper

train_transforms = get_transform(tv_train_transforms)
val_transforms = get_transform(tv_test_transforms)

train_dataset = train_dataset.with_transform(train_transforms)
val_dataset = val_dataset.with_transform(val_transforms)

seed = 12
train_batch_size = 32
val_batch_size = 2 * train_batch_size

train_sampler = grain.IndexSampler(
    len(train_dataset),
    shuffle=True,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

val_sampler = grain.IndexSampler(
    len(val_dataset),
    shuffle=False,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

train_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(train_batch_size, drop_remainder=True),
    ]
)

val_loader = grain.DataLoader(
    data_source=val_dataset,
    sampler=val_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(val_batch_size),
    ]
)

Load packages

Package and module necessary for this section:

# to define the model architecture
from flax import nnx

# to get callables from functions with fewer arguments
from functools import partial

Flax API

Flax went through several APIs.

The initial nn API—now retired—got replaced in 2020 by the Linen API, still available with the Flax package. In 2024, they launched the NNX API.

Each iteration has moved further from JAX and closer to Python, with a syntax increasingly similar to PyTorch.

The old Linen API is a stateless model framework similar to the Julia package Lux.jl. It follows a strict functional programming approach in which the parameters are separate from the model and are passed as inputs to the forward pass along with the data. This is much closer to the JAX sublanguage, more optimized, but restrictive and unpopular in the deep learning community and among Python users.

By contrast, the new NNX API is a stateful model framework similar to PyTorch and the older Julia package Flux.jl: model parameters and optimizer state are stored within the model instance. Flax handles a lot of JAX’s constraints under the hood, making the code more familiar to Python/PyTorch users, simpler, and more forgiving.

While the Linen API still exists, new users are advised to learn the new NNX API.

Simple CNN

We will use LeNet-5 [1] model, initially used on the MNIST dataset by LeCun et al. [2]. We modify it to take three-channel images (RGB for colour images) instead of a single channel (black and white images as was the case in the MNIST) and have five categories as final output.

The architecture of this model is explained in details in this kaggle post.

class CNN(nnx.Module):
  """An adapted LeNet-5 model."""

  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(3, 6, kernel_size=(5, 5), rngs=rngs)
    self.max_pool = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
    self.conv2 = nnx.Conv(6, 16, kernel_size=(5, 5), rngs=rngs)
    self.linear1 = nnx.Linear(3136, 120, rngs=rngs)
    self.linear2 = nnx.Linear(120, 84, rngs=rngs)
    self.linear3 = nnx.Linear(84, 5, rngs=rngs)

  def __call__(self, x):
    x = self.max_pool(nnx.relu(self.conv1(x)))
    x = self.max_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = nnx.relu(self.linear2(x))
    x = self.linear3(x)
    return x

# Instantiate the model.
model = CNN(rngs=nnx.Rngs(0))

# Visualize it.
nnx.display(model)
2025-04-10 02:12:21.237918: W external/xla/xla/service/platform_util.cc:205] unable to create StreamExecutor for CUDA:0: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
    [... skipping hidden 1 frame]

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py:1021, in _init_backend(platform)
   1020 logger.debug("Initializing backend '%s'", platform)
-> 1021 backend = registration.factory()
   1022 # TODO(skye): consider raising more descriptive errors directly from backend
   1023 # factories instead of returning None.

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py:713, in register_plugin.<locals>.factory()
    712 if distributed.global_state.client is None:
--> 713   return xla_client.make_c_api_client(plugin_name, updated_options, None)
    715 distribute_options = {
    716     'node_id': distributed.global_state.process_id,
    717     'num_nodes': distributed.global_state.num_processes,
    718 }

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jaxlib/xla_client.py:207, in make_c_api_client(plugin_name, options, distributed_client)
    206   options = {}
--> 207 return _xla.get_c_api_client(plugin_name, options, distributed_client)

XlaRuntimeError: INTERNAL: no supported devices found for platform CUDA

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[3], line 22
     19     return x
     21 # Instantiate the model.
---> 22 model = CNN(rngs=nnx.Rngs(0))
     24 # Visualize it.
     25 nnx.display(model)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/flax/nnx/object.py:79, in ObjectMeta.__call__(cls, *args, **kwargs)
     78 def __call__(cls, *args: Any, **kwargs: Any) -> Any:
---> 79   return _graph_node_meta_call(cls, *args, **kwargs)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/flax/nnx/object.py:88, in _graph_node_meta_call(cls, *args, **kwargs)
     86 node = cls.__new__(cls, *args, **kwargs)
     87 vars(node)['_object__state'] = ObjectState()
---> 88 cls._object_meta_construct(node, *args, **kwargs)
     90 return node

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/flax/nnx/object.py:82, in ObjectMeta._object_meta_construct(cls, self, *args, **kwargs)
     81 def _object_meta_construct(cls, self, *args, **kwargs):
---> 82   self.__init__(*args, **kwargs)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/flax/nnx/rnglib.py:186, in Rngs.__init__(self, default, **rngs)
    184 for name, value in rngs.items():
    185   if isinstance(value, int):
--> 186     key = jax.random.key(value)
    187   elif isinstance(value, jax.Array):
    188     if value.dtype == jnp.uint32:

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/random.py:218, in key(seed, impl)
    200 def key(seed: int | ArrayLike, *,
    201         impl: PRNGSpecDesc | None = None) -> Array:
    202   """Create a pseudo-random number generator (PRNG) key given an integer seed.
    203 
    204   The result is a scalar array containing a key, whose dtype indicates
   (...)    216     and ``fold_in``.
    217   """
--> 218   return _key('key', seed, impl)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/random.py:198, in _key(ctor_name, seed, impl_spec)
    194 if np.ndim(seed):
    195   raise TypeError(
    196       f"{ctor_name} accepts a scalar seed, but was given an array of "
    197       f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 198 return prng.random_seed(seed, impl=impl)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/prng.py:529, in random_seed(seeds, impl)
    524 def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray:
    525   # Avoid overflow error in X32 mode by first converting ints to int64.
    526   # This breaks JIT invariance for large ints, but supports the common
    527   # use-case of instantiating with Python hashes in X32 mode.
    528   if isinstance(seeds, int):
--> 529     seeds_arr = jnp.asarray(np.int64(seeds))
    530   else:
    531     seeds_arr = jnp.asarray(seeds)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5820, in asarray(a, dtype, order, copy, device)
   5818 if dtype is not None:
   5819   dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 5820 return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5653, in array(object, dtype, copy, order, ndmin, device)
   5651 else:
   5652   raise TypeError(f"Unexpected input type for array: {type(object)}")
-> 5653 out_array: Array = lax_internal._convert_element_type(
   5654     out, dtype, weak_type=weak_type, sharding=sharding)
   5655 if ndmin > ndim(out_array):
   5656   out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py:612, in _convert_element_type(operand, new_dtype, weak_type, sharding, warn_on_complex_to_real_cast)
    610   return operand
    611 else:
--> 612   return convert_element_type_p.bind(
    613       operand, new_dtype=new_dtype, weak_type=bool(weak_type),
    614       sharding=sharding)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/core.py:463, in Primitive.bind(self, *args, **params)
    461 trace_ctx.set_trace(eval_trace)
    462 try:
--> 463   return self.bind_with_trace(prev_trace, args, params)
    464 finally:
    465   trace_ctx.set_trace(prev_trace)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py:3254, in _convert_element_type_bind_with_trace(trace, args, params)
   3252 def _convert_element_type_bind_with_trace(trace, args, params):
   3253   sharding = params['sharding']
-> 3254   operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
   3255   if sharding is not None and not config.sharding_in_types.value:
   3256     with core.set_current_trace(trace):

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/core.py:468, in Primitive.bind_with_trace(self, trace, args, params)
    467 def bind_with_trace(self, trace, args, params):
--> 468   return trace.process_primitive(self, args, params)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/core.py:954, in EvalTrace.process_primitive(self, primitive, args, params)
    952       return primitive.bind_with_trace(arg._trace, args, params)
    953 check_eval_args(args)
--> 954 return primitive.impl(*args, **params)

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/dispatch.py:89, in apply_primitive(prim, *args, **params)
     87 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     88 try:
---> 89   outs = fun(*args)
     90 finally:
     91   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 13 frame]

File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py:951, in backends()
    949       else:
    950         err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
--> 951       raise RuntimeError(err_msg)
    953 assert _default_backend is not None
    954 if not config.jax_platforms.value:

RuntimeError: Unable to initialize backend 'cuda': INTERNAL: no supported devices found for platform CUDA (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
import jax.numpy as jnp  # JAX NumPy

y = model(jnp.ones((4, 224, 224, 3)))
y
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[4], line 3
      1 import jax.numpy as jnp  # JAX NumPy
----> 3 y = model(jnp.ones((4, 224, 224, 3)))
      4 y

NameError: name 'model' is not defined
import optax

learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

nnx.display(optimizer)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[5], line 6
      3 learning_rate = 0.005
      4 momentum = 0.9
----> 6 optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
      7 metrics = nnx.MultiMetric(
      8   accuracy=nnx.metrics.Accuracy(),
      9   loss=nnx.metrics.Average('loss'),
     10 )
     12 nnx.display(optimizer)

NameError: name 'model' is not defined

References

1.
LeCun Y, Bottou L, Bengio Y, Haffner P (1998) Gradient-based learning applied to document recognition. Proceedings of the IEEE 86(11):2278–2324
2.
LeCun Y, Cortes C, Burges C (2010) MNIST handwritten digit database. ATT Labs [Online] Available: http://yannlecuncom/exdb/mnist 2