Defining model architecture
In this section, we define a model with Flax’s new API called NNX.
Context
Load packages
Package and module necessary for this section:
# to define the model architecture
from flax import nnx
# to get callables from functions with fewer arguments
from functools import partial
Flax API
Flax went through several APIs.
The initial nn
API—now retired—got replaced in 2020 by the Linen API, still available with the Flax package. In 2024, they launched the NNX API.
Each iteration has moved further from JAX and closer to Python, with a syntax increasingly similar to PyTorch.
The old Linen API is a stateless model framework similar to the Julia package Lux.jl. It follows a strict functional programming approach in which the parameters are separate from the model and are passed as inputs to the forward pass along with the data. This is much closer to the JAX sublanguage, more optimized, but restrictive and unpopular in the deep learning community and among Python users.
By contrast, the new NNX API is a stateful model framework similar to PyTorch and the older Julia package Flux.jl: model parameters and optimizer state are stored within the model instance. Flax handles a lot of JAX’s constraints under the hood, making the code more familiar to Python/PyTorch users, simpler, and more forgiving.
While the Linen API still exists, new users are advised to learn the new NNX API.
Simple CNN
We will use LeNet-5 [1] model, initially used on the MNIST dataset by LeCun et al. [2]. We modify it to take three-channel images (RGB for colour images) instead of a single channel (black and white images as was the case in the MNIST) and have five categories as final output.
The architecture of this model is explained in details in this kaggle post.
class CNN(nnx.Module):
"""An adapted LeNet-5 model."""
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(3, 6, kernel_size=(5, 5), rngs=rngs)
self.max_pool = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
self.conv2 = nnx.Conv(6, 16, kernel_size=(5, 5), rngs=rngs)
self.linear1 = nnx.Linear(3136, 120, rngs=rngs)
self.linear2 = nnx.Linear(120, 84, rngs=rngs)
self.linear3 = nnx.Linear(84, 5, rngs=rngs)
def __call__(self, x):
= self.max_pool(nnx.relu(self.conv1(x)))
x = self.max_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = nnx.relu(self.linear2(x))
x = self.linear3(x)
x return x
# Instantiate the model.
= CNN(rngs=nnx.Rngs(0))
model
# Visualize it.
nnx.display(model)
2025-04-10 02:12:21.237918: W external/xla/xla/service/platform_util.cc:205] unable to create StreamExecutor for CUDA:0: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
--------------------------------------------------------------------------- XlaRuntimeError Traceback (most recent call last) [... skipping hidden 1 frame] File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py:1021, in _init_backend(platform) 1020 logger.debug("Initializing backend '%s'", platform) -> 1021 backend = registration.factory() 1022 # TODO(skye): consider raising more descriptive errors directly from backend 1023 # factories instead of returning None. File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py:713, in register_plugin.<locals>.factory() 712 if distributed.global_state.client is None: --> 713 return xla_client.make_c_api_client(plugin_name, updated_options, None) 715 distribute_options = { 716 'node_id': distributed.global_state.process_id, 717 'num_nodes': distributed.global_state.num_processes, 718 } File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jaxlib/xla_client.py:207, in make_c_api_client(plugin_name, options, distributed_client) 206 options = {} --> 207 return _xla.get_c_api_client(plugin_name, options, distributed_client) XlaRuntimeError: INTERNAL: no supported devices found for platform CUDA During handling of the above exception, another exception occurred: RuntimeError Traceback (most recent call last) Cell In[3], line 22 19 return x 21 # Instantiate the model. ---> 22 model = CNN(rngs=nnx.Rngs(0)) 24 # Visualize it. 25 nnx.display(model) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/flax/nnx/object.py:79, in ObjectMeta.__call__(cls, *args, **kwargs) 78 def __call__(cls, *args: Any, **kwargs: Any) -> Any: ---> 79 return _graph_node_meta_call(cls, *args, **kwargs) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/flax/nnx/object.py:88, in _graph_node_meta_call(cls, *args, **kwargs) 86 node = cls.__new__(cls, *args, **kwargs) 87 vars(node)['_object__state'] = ObjectState() ---> 88 cls._object_meta_construct(node, *args, **kwargs) 90 return node File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/flax/nnx/object.py:82, in ObjectMeta._object_meta_construct(cls, self, *args, **kwargs) 81 def _object_meta_construct(cls, self, *args, **kwargs): ---> 82 self.__init__(*args, **kwargs) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/flax/nnx/rnglib.py:186, in Rngs.__init__(self, default, **rngs) 184 for name, value in rngs.items(): 185 if isinstance(value, int): --> 186 key = jax.random.key(value) 187 elif isinstance(value, jax.Array): 188 if value.dtype == jnp.uint32: File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/random.py:218, in key(seed, impl) 200 def key(seed: int | ArrayLike, *, 201 impl: PRNGSpecDesc | None = None) -> Array: 202 """Create a pseudo-random number generator (PRNG) key given an integer seed. 203 204 The result is a scalar array containing a key, whose dtype indicates (...) 216 and ``fold_in``. 217 """ --> 218 return _key('key', seed, impl) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/random.py:198, in _key(ctor_name, seed, impl_spec) 194 if np.ndim(seed): 195 raise TypeError( 196 f"{ctor_name} accepts a scalar seed, but was given an array of " 197 f"shape {np.shape(seed)} != (). Use jax.vmap for batching") --> 198 return prng.random_seed(seed, impl=impl) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/prng.py:529, in random_seed(seeds, impl) 524 def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray: 525 # Avoid overflow error in X32 mode by first converting ints to int64. 526 # This breaks JIT invariance for large ints, but supports the common 527 # use-case of instantiating with Python hashes in X32 mode. 528 if isinstance(seeds, int): --> 529 seeds_arr = jnp.asarray(np.int64(seeds)) 530 else: 531 seeds_arr = jnp.asarray(seeds) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5820, in asarray(a, dtype, order, copy, device) 5818 if dtype is not None: 5819 dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] -> 5820 return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5653, in array(object, dtype, copy, order, ndmin, device) 5651 else: 5652 raise TypeError(f"Unexpected input type for array: {type(object)}") -> 5653 out_array: Array = lax_internal._convert_element_type( 5654 out, dtype, weak_type=weak_type, sharding=sharding) 5655 if ndmin > ndim(out_array): 5656 out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array))) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py:612, in _convert_element_type(operand, new_dtype, weak_type, sharding, warn_on_complex_to_real_cast) 610 return operand 611 else: --> 612 return convert_element_type_p.bind( 613 operand, new_dtype=new_dtype, weak_type=bool(weak_type), 614 sharding=sharding) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/core.py:463, in Primitive.bind(self, *args, **params) 461 trace_ctx.set_trace(eval_trace) 462 try: --> 463 return self.bind_with_trace(prev_trace, args, params) 464 finally: 465 trace_ctx.set_trace(prev_trace) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/lax/lax.py:3254, in _convert_element_type_bind_with_trace(trace, args, params) 3252 def _convert_element_type_bind_with_trace(trace, args, params): 3253 sharding = params['sharding'] -> 3254 operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params) 3255 if sharding is not None and not config.sharding_in_types.value: 3256 with core.set_current_trace(trace): File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/core.py:468, in Primitive.bind_with_trace(self, trace, args, params) 467 def bind_with_trace(self, trace, args, params): --> 468 return trace.process_primitive(self, args, params) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/core.py:954, in EvalTrace.process_primitive(self, primitive, args, params) 952 return primitive.bind_with_trace(arg._trace, args, params) 953 check_eval_args(args) --> 954 return primitive.impl(*args, **params) File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/dispatch.py:89, in apply_primitive(prim, *args, **params) 87 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) 88 try: ---> 89 outs = fun(*args) 90 finally: 91 lib.jax_jit.swap_thread_local_state_disable_jit(prev) [... skipping hidden 13 frame] File ~/parvus/prog/mint/ai/jxai/.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py:951, in backends() 949 else: 950 err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)" --> 951 raise RuntimeError(err_msg) 953 assert _default_backend is not None 954 if not config.jax_platforms.value: RuntimeError: Unable to initialize backend 'cuda': INTERNAL: no supported devices found for platform CUDA (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
import jax.numpy as jnp # JAX NumPy
= model(jnp.ones((4, 224, 224, 3)))
y y
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[4], line 3 1 import jax.numpy as jnp # JAX NumPy ----> 3 y = model(jnp.ones((4, 224, 224, 3))) 4 y NameError: name 'model' is not defined
import optax
= 0.005
learning_rate = 0.9
momentum
= nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
optimizer = nnx.MultiMetric(
metrics =nnx.metrics.Accuracy(),
accuracy=nnx.metrics.Average('loss'),
loss
)
nnx.display(optimizer)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[5], line 6 3 learning_rate = 0.005 4 momentum = 0.9 ----> 6 optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum)) 7 metrics = nnx.MultiMetric( 8 accuracy=nnx.metrics.Accuracy(), 9 loss=nnx.metrics.Average('loss'), 10 ) 12 nnx.display(optimizer) NameError: name 'model' is not defined