Stateless models in Flax

Author

Marie-Hélène Burle

Deep learning models can be split into two big categories depending on the deep learning frameworks used to train them: stateful and stateless models.

What does this mean and where does Flax stand?

Stateful vs stateless models

Stateful models

In frameworks such as PyTorch or the Julia package Flux, model parameters and optimizer state are stored within the model instance. Instantiating a PyTorch model allocates memory for the model parameters. The model can then be described as stateful.

During the forward pass, only the inputs are passed through the model. The outputs depend on the inputs and on the state of the model.

During training, you will see code such as:

loss.backward()

to calculate the gradients or:

optimizer.step()

to optimize the parameters.

For more information, you can have a look at our PyTorch course.

Stateless models

JAX JIT compilation requires that functions be without side effects since side effects are only executed once, during tracing.

Updating model parameters and optimizer state thus cannot be done as a side-effect. The state cannot be part of the model instance—it needs to be explicit, that is, separated from the model. During instantiation, no memory is allocated for the parameters. During the forward pass, the state will pass, along with the inputs, through the model. The model is thus stateless and the constrains of pure functional programming are met (inputs lead to outputs without external influence or side effects).

Here is a very basic example from the JAX documentation.

  • Stateful approach (and why it doesn’t work with JAX)

We will define a new class called Counter with three functions to initialize, use, and reset a counter:

import jax
import jax.numpy as jnp

class Counter:
    """A simple counter."""
    
    def __init__(self):
        self.n = 0
      
    def count(self) -> int:
        """Increments the counter and returns the new value."""
        self.n += 1
        return self.n
  
    def reset(self):
        """Resets the counter to zero."""
        self.n = 0

Now we can create an instance called counter of the Counter class.

counter = Counter()

We can use the counter:

for _ in range(3):
    print(counter.count())
1
2
3

We can use it further:

for _ in range(3):
    print(counter.count())
4
5
6

Usually, within a neural network, you want to reset parameters at each run, so we can reset our counter:

counter.reset()

for _ in range(3):
    print(counter.count())
1
2
3

Now, what happens when we apply a transformation on the count() function? Let’s reset our counter and create a JIT compiled version of count():

counter.reset()
fast_count = jax.jit(counter.count)

And now let’s use the JIT compiled version:

for _ in range(3):
    print(fast_count())
1
1
1

For this to work, we need to initialize an explicit state and pass it as an argument to our count() function:

CounterState = int

class CounterV2:
    
    def count(self, n: CounterState) -> tuple[int, CounterState]:
        # You could just return n+1, but here we separate its role as 
        # the output and as the counter state for didactic purposes.
        return n+1, n+1
    
    def reset(self) -> CounterState:
        return 0

counter = CounterV2()
state = counter.reset()

for _ in range(3):
    value, state = counter.count(state)
    print(value)
1
2
3

Frameworks based on JAX such as Flax as well as the Julia package Lux (a modern rewrite of Flux with explicit model parameters and a philosophy similar to JAX’s) follow this functional programming approach.

The Linen API

In Flax, the base class for neural networks is the flax.linen.Module. Linen is a new API replacing the initial flax.nn API and taking full advantage of JAX transformations while automating the initialization and handling of the parameters.

Linen is imported this way:

from flax import linen as nn

To define a model, you create a subclass. The syntax closely resembles that of PyTorch torch.nn:

class Net(nn.Module):
    ...

But unlike in PyTorch, the parameters are passed through the model in the form of Pytrees (nested containers such as dictionaries, lists, and tuples).