Defining a model architecture
Defining a model architecture in Flax is currently done with the flax.linen
API and is quite straightforward.
The Linen API
Linen is Flax’s current 1 NN API. Let’s load it, along with JAX and the JAX NumPy API:
import jax
import jax.numpy as jnp
from flax import linen as nn
To define a model, we create a subclass of the nn.Module
which inherits all the characteristics of that module, saving us from defining all the behaviours a neural network should have (exactly as in PyTorch).
What we do need to define of course, is the architecture and the flow of data through it.
Linen contains all the classic elements to define a NN model:
- modules to define standard layers (e.g. fully connected layer with
Dense
- convolution layer with
Conv
, pooling layers withmax_pool
oravg_pool
) - a set of activation functions (e.g.
relu
,softmax
,sigmoid
) - JAX transformations (e.g.
jit
,vmap
,jvp
,vjp
)
Linen makes use of JAX’s shape inference so there is no need for “in” features (e.g. nn.Dense
only requires one argument: the “out” feature; this is in contrast with PyTorch’s nn.Linear
which requires both an “in” and an “out” features).
Example
As we already saw, Linen provides two syntaxes:
- a longer syntax using
setup
: more redundant, more PyTorch-like, allows to define multiple methods, - a compact one with the
@nn.compact
decorator that can only use a single method and avoids duplication between setup and call.
Let’s use the latter here:
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
= nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
x return x
Now we can create an instance of the model:
= CNN() cnn
Remember that in Linen the parameters are not part of the model: we need to initialize them as a method of the model by passing a random key and JAX array setting up their shape:
def get_initial_params(key):
= jnp.ones((1, 28, 28, 1))
init_shape = cnn.init(key, init_shape)
initial_params return initial_params
= jax.random.key(0)
key = jax.random.split(key)
key, model_key
= get_initial_params(model_key) params
Model inspection
Before using a model, it is a good idea to inspect it and make sure that everything is ok.
Inspect model layers
The tabulate
method prints a summary table of each module (layer) in our model:
print(cnn.tabulate(
0),
jax.random.key(1, 28, 28, 1)),
jnp.ones((=True,
compute_flops=True)
compute_vjp_flops )
CNN Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ │ CNN │ float32[1,28,28,1] │ float32[1,10] │ 8708144 │ 26957634 │ │
├─────────┼────────┼─────────────────────┼─────────────────────┼─────────┼───────────┼────────────────────────────┤
│ Conv_0 │ Conv │ float32[1,28,28,1] │ float32[1,28,28,32] │ 455424 │ 1341472 │ bias: float32[32] │
│ │ │ │ │ │ │ kernel: float32[3,3,1,32] │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 320 (1.3 KB) │
├─────────┼────────┼─────────────────────┼─────────────────────┼─────────┼───────────┼────────────────────────────┤
│ Conv_1 │ Conv │ float32[1,14,14,32] │ float32[1,14,14,64] │ 6566144 │ 19704320 │ bias: float32[64] │
│ │ │ │ │ │ │ kernel: float32[3,3,32,64] │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 18,496 (74.0 KB) │
├─────────┼────────┼─────────────────────┼─────────────────────┼─────────┼───────────┼────────────────────────────┤
│ Dense_0 │ Dense │ float32[1,3136] │ float32[1,256] │ 1605888 │ 5620224 │ bias: float32[256] │
│ │ │ │ │ │ │ kernel: float32[3136,256] │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 803,072 (3.2 MB) │
├─────────┼────────┼─────────────────────┼─────────────────────┼─────────┼───────────┼────────────────────────────┤
│ Dense_1 │ Dense │ float32[1,256] │ float32[1,10] │ 5130 │ 17940 │ bias: float32[10] │
│ │ │ │ │ │ │ kernel: float32[256,10] │
│ │ │ │ │ │ │ │
│ │ │ │ │ │ │ 2,570 (10.3 KB) │
├─────────┼────────┼─────────────────────┼─────────────────────┼─────────┼───────────┼────────────────────────────┤
│ │ │ │ │ │ Total │ 824,458 (3.3 MB) │
└─────────┴────────┴─────────────────────┴─────────────────────┴─────────┴───────────┴────────────────────────────┘
Total Parameters: 824,458 (3.3 MB)
flops
: estimated FLOPs (floating point operations per second) cost of forward passvjp_flops
: estimated FLOPs cost of backward pass (vjp
stands for vector-Jacobian product)
Inspect initial parameters
Your turn:
The summary table includes the shape of the parameters.
Based on what we already learned, what is another way to get this information?
Footnotes
Flax had an initial
flax.nn
API, now retired. Linen made things easier and more PyTorch-like.
A new API called NNX is being developed and may replace Linen in the future. It makes Flax even more similar to PyTorch by breaking with JAX’s functionally pure functions requirements and bringing the parameters back into the model.↩︎