The Linen API
The Linen API
In Flax, the base class for neural networks is the flax.linen.Module. Linen is a new API replacing the initial flax.nn
API and taking full advantage of JAX transformations while automating the initialization and handling of the parameters.
Linen is imported this way:
from flax import linen as nn
To define a model, you create a subclass. The syntax closely resembles that of PyTorch torch.nn
:
class Net(nn.Module):
...
But unlike in PyTorch, the parameters are passed through the model in the form of Pytrees (nested containers such as dictionaries, lists, and tuples).