Defining a model architecture

Author

Marie-Hélène Burle

Defining a model architecture in Flax is currently done with the flax.linen API and is quite straightforward.

The Linen API

Linen is Flax’s current 1 NN API. Let’s load it, along with JAX and the JAX NumPy API:

import jax
import jax.numpy as jnp
from flax import linen as nn

To define a model, we create a subclass of the nn.Module which inherits all the characteristics of that module, saving us from defining all the behaviours a neural network should have (exactly as in PyTorch).

What we do need to define of course, is the architecture and the flow of data through it.

Linen contains all the classic elements to define a NN model:

Linen makes use of JAX’s shape inference so there is no need for “in” features (e.g. nn.Dense only requires one argument: the “out” feature; this is in contrast with PyTorch’s nn.Linear which requires both an “in” and an “out” features).

Example

As we already saw, Linen provides two syntaxes:

  • a longer syntax using setup: more redundant, more PyTorch-like, allows to define multiple methods,
  • a compact one with the @nn.compact decorator that can only use a single method and avoids duplication between setup and call.

Let’s use the latter here:

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
      x = nn.Conv(features=32, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = nn.Conv(features=64, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = x.reshape((x.shape[0], -1))
      x = nn.Dense(features=256)(x)
      x = nn.relu(x)
      x = nn.Dense(features=10)(x)
      x = nn.log_softmax(x)
      return x

Now we can create an instance of the model:

cnn = CNN()

Remember that in Linen the parameters are not part of the model: we need to initialize them as a method of the model by passing a random key and JAX array setting up their shape:

def get_initial_params(key):
    init_shape = jnp.ones((1, 28, 28, 1))
    initial_params = cnn.init(key, init_shape)
    return initial_params

key = jax.random.key(0)
key, model_key = jax.random.split(key)

params = get_initial_params(model_key)

Model inspection

Before using a model, it is a good idea to inspect it and make sure that everything is ok.

Inspect model layers

The tabulate method prints a summary table of each module (layer) in our model:

print(cnn.tabulate(
    jax.random.key(0), 
    jnp.ones((1, 28, 28, 1)),
    compute_flops=True, 
    compute_vjp_flops=True)
      )
                                                    CNN Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs              ┃ outputs             ┃ flops   ┃ vjp_flops ┃ params                     ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ CNN    │ float32[1,28,28,1]  │ float32[1,10]       │ 8708144 │ 26957634  │                            │
├─────────┼────────┼─────────────────────┼─────────────────────┼─────────┼───────────┼────────────────────────────┤
│ Conv_0  │ Conv   │ float32[1,28,28,1]  │ float32[1,28,28,32] │ 455424  │ 1341472   │ bias: float32[32]          │
│         │        │                     │                     │         │           │ kernel: float32[3,3,1,32]  │
│         │        │                     │                     │         │           │                            │
│         │        │                     │                     │         │           │ 320 (1.3 KB)               │
├─────────┼────────┼─────────────────────┼─────────────────────┼─────────┼───────────┼────────────────────────────┤
│ Conv_1  │ Conv   │ float32[1,14,14,32] │ float32[1,14,14,64] │ 6566144 │ 19704320  │ bias: float32[64]          │
│         │        │                     │                     │         │           │ kernel: float32[3,3,32,64] │
│         │        │                     │                     │         │           │                            │
│         │        │                     │                     │         │           │ 18,496 (74.0 KB)           │
├─────────┼────────┼─────────────────────┼─────────────────────┼─────────┼───────────┼────────────────────────────┤
│ Dense_0 │ Dense  │ float32[1,3136]     │ float32[1,256]      │ 1605888 │ 5620224   │ bias: float32[256]         │
│         │        │                     │                     │         │           │ kernel: float32[3136,256]  │
│         │        │                     │                     │         │           │                            │
│         │        │                     │                     │         │           │ 803,072 (3.2 MB)           │
├─────────┼────────┼─────────────────────┼─────────────────────┼─────────┼───────────┼────────────────────────────┤
│ Dense_1 │ Dense  │ float32[1,256]      │ float32[1,10]       │ 5130    │ 17940     │ bias: float32[10]          │
│         │        │                     │                     │         │           │ kernel: float32[256,10]    │
│         │        │                     │                     │         │           │                            │
│         │        │                     │                     │         │           │ 2,570 (10.3 KB)            │
├─────────┼────────┼─────────────────────┼─────────────────────┼─────────┼───────────┼────────────────────────────┤
│         │        │                     │                     │         │     Total │ 824,458 (3.3 MB)           │
└─────────┴────────┴─────────────────────┴─────────────────────┴─────────┴───────────┴────────────────────────────┘

                                        Total Parameters: 824,458 (3.3 MB)

Inspect initial parameters

Your turn:

The summary table includes the shape of the parameters.
Based on what we already learned, what is another way to get this information?

Printing the pytree of the initial parameters we created would make it fairly hard to get a sense of their structure, but we know that JAX can operate on pytrees thanks to the tree_util module, so let’s make use of this:

jax.tree.map(jnp.shape, params)
{'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)},
 'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)},
 'Dense_0': {'bias': (256,), 'kernel': (3136, 256)},
 'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}

Footnotes

  1. Flax had an initial flax.nn API, now retired. Linen made things easier and more PyTorch-like.
    A new API called NNX is being developed and may replace Linen in the future. It makes Flax even more similar to PyTorch by breaking with JAX’s functionally pure functions requirements and bringing the parameters back into the model.↩︎