Flax’s handling of model states

Author

Marie-Hélène Burle

Deep learning models can be split into two categories depending on the framework used to train them: stateful and stateless models. Flax—being built on top of JAX—falls in the latter category.

In this section, we will see what all of this means and how Flax handles model states.

First, let’s start an interactive job:

salloc --time=2:0:0 --mem-per-cpu=3500M

Nowadays, IPython (Interactive Python) is known as the kernel used by Jupyter when running Python. Before the existence of Jupyter however, this kernel was created as a better command shell than the default Python shell. For interactive Python sessions in the command line, it is nicer and faster than plain Python with no downside. So we will use it for this course:

module load ipython-kernel/3.11

Now, let’s activate the virtual python environment:

source /project/60055/env/bin/activate

Finally, we can launch IPython:

ipython

Dealing with state in JAX

JAX JIT compilation requires that functions be without side effects since side effects are only executed once, during tracing.

Updating model parameters and optimizer state thus cannot be done as a side-effect. The state cannot be part of the model instance—it needs to be explicit, that is, separated from the model. During instantiation, no memory is allocated for the parameters. During the forward pass, the parameters will be part of the inputs, along with the data. The model is thus stateless and the constrains of pure functional programming are met (inputs lead to outputs without external influence or side effects).

Let’s see why a stateful approach doesn’t work with JAX1: instead of defining a neural network class, we will define a very simple Counter class, following the PyTorch approach, that just adds 1. This allows us to see right away what is going on.

1 Modified from JAX’s documentation.

import jax
import jax.numpy as jnp

class Counter:
    def __init__(self):
        self.n = 0
      
    def count(self) -> int:
        """Adds one to the counter and returns the new value."""
        self.n += 1
        return self.n
  
    def reset(self):
        """Resets the counter to zero."""
        self.n = 0

Now we can create an instance called counter of the Counter class.

counter = Counter()

We can use the counter:

for _ in range(3):
    print(counter.count())
1
2
3

Now, let’s try with a JIT compiled version of count():

count_jit = jax.jit(counter.count)

counter.reset()

for _ in range(3):
    print(count_jit())
1
1
1

This is because count is not a functionally pure function. The tracing happens for the first run of the function (first iteration of the loop). Thereafter, the compiled version will rerun without taking into account the modifications of the attributes of counter.

For this to work, we need to initialize an explicit state and pass it as an argument to the count function:

State = int

class Counter:
    def count(self, n: State) -> tuple[int, State]:
        return n+1, n+1
    
    def reset(self) -> State:
        return 0

counter = Counter()
state = counter.reset()

for _ in range(3):
    value, state = counter.count(state)
    print(value)
1
2
3
count_jit = jax.jit(counter.count)

state = counter.reset()

for _ in range(3):
    value, state = count_jit(state)
    print(value)
1
2
3

As explained in JAX’s documentation, we turned a function of the type:

class StatefulClass
  state: State
  def stateful_method(*args, **kwargs) -> Output:

Into:

class StatelessClass
  def stateless_method(state: State, *args, **kwargs) -> (Output, State):

Stateful vs stateless models

Stateful models

In frameworks such as PyTorch or the Julia package Flux, model parameters and optimizer state are stored within the model instance. Instantiating a PyTorch model allocates memory for the model parameters. The model can then be described as stateful.

Stateless models

Frameworks based on JAX such as Flax but also the Julia package Lux (a modern rewrite of Flux with explicit model parameters and a philosophy similar to JAX’s) are stateless: they follow a functional programming approach in which the parameters are separate from the model and are passed as inputs to the forward pass along with the data.

Example: PyTorch vs Flax

Flax, being built on JAX, it requires functionally pure functions and thus stateless models.

Here is a comparison of the approach taken by PyTorch (stateful) vs Flax (stateless) to define and initialize a model (simplified model and workflow to show the principle):

This is how PyTorch works:

import torch
import torch.nn as nn

# we create a subclass of torch.nn.Module
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.dense1 = nn.Linear(4, 144)
        self.dense2 = nn.Linear(144, 4)

    def forward(self, x):
        x = self.dense1(x)
        x = F.relu(x)
        x = self.dense2(x)
        return x

# Create model instance
model = Net()

# Random data and labels
data = torch.empty((4, 12, 12, 1))
labels = torch.randn((4, 12, 12, 1))

During the forward pass, only the inputs are passed through the model, but of course the outputs depend on the inputs and on the state of the model.

Here is the Flax approach:

import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
import optax

Flax provides a setup syntax of model definition which will look more familiar to PyTorch users:

# Create a subclass of torch.nn.Module
class Net(nn.Module):
  def setup(self):
    self.dense1 = nn.Dense(12)
    self.dense2 = nn.Dense(1)

  def __call__(self, x):
    x = self.dense1(x)
    x = nn.relu(x)
    x = self.dense2(x)
    return x

Flax comes with a compact syntax of model definition which is equivalent to the setup syntax in all respect except style:

# Create a subclass of torch.nn.Module
class Net(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(12, name="dense1")(x)
    x = nn.relu(x)
    x = nn.Dense(12, name="dense2")(x)
    return x

The parameters are not part of the model. You initialize them afterwards and create a parameter object:

# Create model instance
model = Net()

# Random data and labels
key, subkey1, subkey2 = random.split(random.key(13), 3)
data = jnp.empty((4, 12, 12, 1))
labels = random.normal(subkey1, (4, 12, 12, 1))

# Initialize model parameters
params = model.init(subkey2, data)

Similarly, here are the stateful and stateless approaches to train the model:

# Forward pass
logits = model(data)

loss = nn.CrossEntropyLoss(logits, labels)

# Calculate gradients
loss.backward()

# Optimze parameters
optimizer.step()
# Forward pass
def loss_func(params, data):
    logits = model.apply(params, data)
    loss = optax.softmax_cross_entropy(logits, labels).mean()
    return loss

# Calculate gradients
grads = jax.grad(loss_func)(params)

# Optimze parameters
params = state.apply_gradients(grads)

The parameters are passed as inputs, along with the data, during the forward pass.

Flax training state

The demo above is stripped of any complexity to show the principle, but it is not realistic.

To handle every changing state during training (training step, state of the parameters, state of the optimizer), you can create a Flax training state.

Flax provides a dataclass that you can subclass to create a new training state class:

from flax.training import train_state

class TrainState(train_state.TrainState):
    batch_stats: flax.core.FrozenDict

Then you can define the Flax training step with TrainState.create:

state = TrainState.create(
    apply_fn = model.apply,
    params = modulel.init(subkey2, data),
    tx = optax.sgd(0.01),
    batch_stats = params['batch_stats'],
)

NNX

A new Flax API is under development and might replace Linen at some point.

It provides transformations similar to JAX’s but which work on non-pure functions. This would bring Flax much closer to PyTorch and turn it into a stateful NN library by re-adding the parameters inside the model.