The Linen API

Author

Marie-Hélène Burle

The Linen API

In Flax, the base class for neural networks is the flax.linen.Module. Linen is a new API replacing the initial flax.nn API and taking full advantage of JAX transformations while automating the initialization and handling of the parameters.

Linen is imported this way:

from flax import linen as nn

To define a model, you create a subclass. The syntax closely resembles that of PyTorch torch.nn:

class Net(nn.Module):
    ...

But unlike in PyTorch, the parameters are passed through the model in the form of Pytrees (nested containers such as dictionaries, lists, and tuples).