Flax’s handling of model states
Deep learning models can be split into two categories depending on the framework used to train them: stateful and stateless models. Flax—being built on top of JAX—falls in the latter category.
In this section, we will see what all of this means and how Flax handles model states.
Dealing with state in JAX
JAX JIT compilation requires that functions be without side effects since side effects are only executed once, during tracing.
Updating model parameters and optimizer state thus cannot be done as a side-effect. The state cannot be part of the model instance—it needs to be explicit, that is, separated from the model. During instantiation, no memory is allocated for the parameters. During the forward pass, the parameters will be part of the inputs, along with the data. The model is thus stateless and the constrains of pure functional programming are met (inputs lead to outputs without external influence or side effects).
Let’s see why a stateful approach doesn’t work with JAX1: instead of defining a neural network class, we will define a very simple Counter
class, following the PyTorch approach, that just adds 1. This allows us to see right away what is going on.
import jax
import jax.numpy as jnp
class Counter:
def __init__(self):
self.n = 0
def count(self) -> int:
"""Adds one to the counter and returns the new value."""
self.n += 1
return self.n
def reset(self):
"""Resets the counter to zero."""
self.n = 0
Now we can create an instance called counter
of the Counter
class.
= Counter() counter
We can use the counter:
for _ in range(3):
print(counter.count())
1
2
3
Now, let’s try with a JIT compiled version of count()
:
= jax.jit(counter.count)
count_jit
counter.reset()
for _ in range(3):
print(count_jit())
1
1
1
This is because count
is not a functionally pure function. The tracing happens for the first run of the function (first iteration of the loop). Thereafter, the compiled version will rerun without taking into account the modifications of the attributes of counter
.
For this to work, we need to initialize an explicit state and pass it as an argument to the count
function:
= int
State
class Counter:
def count(self, n: State) -> tuple[int, State]:
return n+1, n+1
def reset(self) -> State:
return 0
= Counter()
counter = counter.reset()
state
for _ in range(3):
= counter.count(state)
value, state print(value)
1
2
3
= jax.jit(counter.count)
count_jit
= counter.reset()
state
for _ in range(3):
= count_jit(state)
value, state print(value)
1
2
3
As explained in JAX’s documentation, we turned a function of the type:
class StatefulClass
state: Statedef stateful_method(*args, **kwargs) -> Output:
Into:
class StatelessClass
def stateless_method(state: State, *args, **kwargs) -> (Output, State):
Stateful vs stateless models
Stateful models
In frameworks such as PyTorch or the Julia package Flux, model parameters and optimizer state are stored within the model instance. Instantiating a PyTorch model allocates memory for the model parameters. The model can then be described as stateful.
Stateless models
Frameworks based on JAX such as Flax but also the Julia package Lux (a modern rewrite of Flux with explicit model parameters and a philosophy similar to JAX’s) are stateless: they follow a functional programming approach in which the parameters are separate from the model and are passed as inputs to the forward pass along with the data.
Example: PyTorch vs Flax
Flax, being built on JAX, it requires functionally pure functions and thus stateless models.
Here is a comparison of the approach taken by PyTorch (stateful) vs Flax (stateless) to define and initialize a model (simplified model and workflow to show the principle):
This is how PyTorch works:
import torch
import torch.nn as nn
# we create a subclass of torch.nn.Module
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.dense1 = nn.Linear(4, 144)
self.dense2 = nn.Linear(144, 4)
def forward(self, x):
= self.dense1(x)
x = F.relu(x)
x = self.dense2(x)
x return x
# Create model instance
= Net()
model
# Random data and labels
= torch.empty((4, 12, 12, 1))
data = torch.randn((4, 12, 12, 1)) labels
During the forward pass, only the inputs are passed through the model, but of course the outputs depend on the inputs and on the state of the model.
Here is the Flax approach:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
import optax
Flax provides a setup syntax of model definition which will look more familiar to PyTorch users:
# Create a subclass of torch.nn.Module
class Net(nn.Module):
def setup(self):
self.dense1 = nn.Dense(12)
self.dense2 = nn.Dense(1)
def __call__(self, x):
= self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
x return x
Flax comes with a compact syntax of model definition which is equivalent to the setup syntax in all respect except style:
# Create a subclass of torch.nn.Module
class Net(nn.Module):
@nn.compact
def __call__(self, x):
= nn.Dense(12, name="dense1")(x)
x = nn.relu(x)
x = nn.Dense(12, name="dense2")(x)
x return x
The parameters are not part of the model. You initialize them afterwards and create a parameter object:
# Create model instance
= Net()
model
# Random data and labels
= random.split(random.key(13), 3)
key, subkey1, subkey2 = jnp.empty((4, 12, 12, 1))
data = random.normal(subkey1, (4, 12, 12, 1))
labels
# Initialize model parameters
= model.init(subkey2, data) params
Similarly, here are the stateful and stateless approaches to train the model:
# Forward pass
= model(data)
logits
= nn.CrossEntropyLoss(logits, labels)
loss
# Calculate gradients
loss.backward()
# Optimze parameters
optimizer.step()
# Forward pass
def loss_func(params, data):
= model.apply(params, data)
logits = optax.softmax_cross_entropy(logits, labels).mean()
loss return loss
# Calculate gradients
= jax.grad(loss_func)(params)
grads
# Optimze parameters
= state.apply_gradients(grads) params
The parameters are passed as inputs, along with the data, during the forward pass.
Flax training state
The demo above is stripped of any complexity to show the principle, but it is not realistic.
To handle every changing state during training (training step, state of the parameters, state of the optimizer), you can create a Flax training state.
Flax provides a dataclass that you can subclass to create a new training state class:
from flax.training import train_state
class TrainState(train_state.TrainState):
batch_stats: flax.core.FrozenDict
Then you can define the Flax training step with TrainState.create
:
= TrainState.create(
state = model.apply,
apply_fn = modulel.init(subkey2, data),
params = optax.sgd(0.01),
tx = params['batch_stats'],
batch_stats )
NNX
A new Flax API is under development and might replace Linen at some point.
It provides transformations similar to JAX’s but which work on non-pure functions. This would bring Flax much closer to PyTorch and turn it into a stateful NN library by re-adding the parameters inside the model.
Footnotes
Modified from JAX’s documentation.↩︎