Flax’s handling of model states
Deep learning models can be split into two categories depending on the framework used to train them: stateful and stateless models. Flax—being built on top of JAX—falls in the latter category.
In this section, we will see what all of this means and how Flax handles model states.
Dealing with state in JAX
JAX JIT compilation requires that functions be without side effects since side effects are only executed once, during tracing.
Updating model parameters and optimizer state thus cannot be done as a side-effect. The state cannot be part of the model instance—it needs to be explicit, that is, separated from the model. During instantiation, no memory is allocated for the parameters. During the forward pass, the parameters will be part of the inputs, along with the data. The model is thus stateless and the constrains of pure functional programming are met (inputs lead to outputs without external influence or side effects).
Let’s see why a stateful approach doesn’t work with JAX1: instead of defining a neural network class, we will define a very simple Counter class, following the PyTorch approach, that just adds 1. This allows us to see right away what is going on.
import jax
import jax.numpy as jnp
class Counter:
def __init__(self):
self.n = 0
def count(self) -> int:
"""Adds one to the counter and returns the new value."""
self.n += 1
return self.n
def reset(self):
"""Resets the counter to zero."""
self.n = 0Now we can create an instance called counter of the Counter class.
counter = Counter()We can use the counter:
for _ in range(3):
print(counter.count())1
2
3
Now, let’s try with a JIT compiled version of count():
count_jit = jax.jit(counter.count)
counter.reset()
for _ in range(3):
print(count_jit())1
1
1
This is because count is not a functionally pure function. The tracing happens for the first run of the function (first iteration of the loop). Thereafter, the compiled version will rerun without taking into account the modifications of the attributes of counter.
For this to work, we need to initialize an explicit state and pass it as an argument to the count function:
State = int
class Counter:
def count(self, n: State) -> tuple[int, State]:
return n+1, n+1
def reset(self) -> State:
return 0
counter = Counter()
state = counter.reset()
for _ in range(3):
value, state = counter.count(state)
print(value)1
2
3
count_jit = jax.jit(counter.count)
state = counter.reset()
for _ in range(3):
value, state = count_jit(state)
print(value)1
2
3
As explained in JAX’s documentation, we turned a function of the type:
class StatefulClass
state: State
def stateful_method(*args, **kwargs) -> Output:Into:
class StatelessClass
def stateless_method(state: State, *args, **kwargs) -> (Output, State):Stateful vs stateless models
Stateful models
In frameworks such as PyTorch or the Julia package Flux, model parameters and optimizer state are stored within the model instance. Instantiating a PyTorch model allocates memory for the model parameters. The model can then be described as stateful.
Stateless models
Frameworks based on JAX such as Flax but also the Julia package Lux (a modern rewrite of Flux with explicit model parameters and a philosophy similar to JAX’s) are stateless: they follow a functional programming approach in which the parameters are separate from the model and are passed as inputs to the forward pass along with the data.
Example: PyTorch vs Flax
Flax, being built on JAX, it requires functionally pure functions and thus stateless models.
Here is a comparison of the approach taken by PyTorch (stateful) vs Flax (stateless) to define and initialize a model (simplified model and workflow to show the principle):
This is how PyTorch works:
import torch
import torch.nn as nn
# we create a subclass of torch.nn.Module
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.dense1 = nn.Linear(4, 144)
self.dense2 = nn.Linear(144, 4)
def forward(self, x):
x = self.dense1(x)
x = F.relu(x)
x = self.dense2(x)
return x
# Create model instance
model = Net()
# Random data and labels
data = torch.empty((4, 12, 12, 1))
labels = torch.randn((4, 12, 12, 1))During the forward pass, only the inputs are passed through the model, but of course the outputs depend on the inputs and on the state of the model.
Here is the Flax approach:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
import optaxFlax provides a setup syntax of model definition which will look more familiar to PyTorch users:
# Create a subclass of torch.nn.Module
class Net(nn.Module):
def setup(self):
self.dense1 = nn.Dense(12)
self.dense2 = nn.Dense(1)
def __call__(self, x):
x = self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
return xFlax comes with a compact syntax of model definition which is equivalent to the setup syntax in all respect except style:
# Create a subclass of torch.nn.Module
class Net(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(12, name="dense1")(x)
x = nn.relu(x)
x = nn.Dense(12, name="dense2")(x)
return xThe parameters are not part of the model. You initialize them afterwards and create a parameter object:
# Create model instance
model = Net()
# Random data and labels
key, subkey1, subkey2 = random.split(random.key(13), 3)
data = jnp.empty((4, 12, 12, 1))
labels = random.normal(subkey1, (4, 12, 12, 1))
# Initialize model parameters
params = model.init(subkey2, data)Similarly, here are the stateful and stateless approaches to train the model:
# Forward pass
logits = model(data)
loss = nn.CrossEntropyLoss(logits, labels)
# Calculate gradients
loss.backward()
# Optimze parameters
optimizer.step()# Forward pass
def loss_func(params, data):
logits = model.apply(params, data)
loss = optax.softmax_cross_entropy(logits, labels).mean()
return loss
# Calculate gradients
grads = jax.grad(loss_func)(params)
# Optimze parameters
params = state.apply_gradients(grads)The parameters are passed as inputs, along with the data, during the forward pass.
Flax training state
The demo above is stripped of any complexity to show the principle, but it is not realistic.
To handle every changing state during training (training step, state of the parameters, state of the optimizer), you can create a Flax training state.
Flax provides a dataclass that you can subclass to create a new training state class:
from flax.training import train_state
class TrainState(train_state.TrainState):
batch_stats: flax.core.FrozenDictThen you can define the Flax training step with TrainState.create:
state = TrainState.create(
apply_fn = model.apply,
params = modulel.init(subkey2, data),
tx = optax.sgd(0.01),
batch_stats = params['batch_stats'],
)NNX
A new Flax API is under development and might replace Linen at some point.
It provides transformations similar to JAX’s but which work on non-pure functions. This would bring Flax much closer to PyTorch and turn it into a stateful NN library by re-adding the parameters inside the model.
Footnotes
Modified from JAX’s documentation.↩︎