Model and training strategy

Author

Marie-Hélène Burle

base_dir = "<path-of-the-nabirds-dir>"

To be replaced by actual path: in our training cluster, the base_dir is at /project/def-sponsor00/nabirds:

base_dir = '/project/def-sponsor00/nabirds'
import os
import polars as pl
import imageio.v3 as iio
import grain.python as grain

metadata = pl.read_parquet("metadata.parquet")
metadata_train = metadata.filter(pl.col("is_training_img") == 1)
cleaned_img_dir = os.path.join(base_dir, "cleaned_images")

class NABirdsDataset:
    """NABirds dataset class."""

    def __init__(self, metadata_file, data_dir):
        self.metadata_file = metadata_file
        self.data_dir = data_dir

    def __len__(self):
        return len(self.metadata_file)

    def __getitem__(self, idx):
        path = os.path.join(self.data_dir, self.metadata_file.get_column("path")[idx])
        img = iio.imread(path)
        species = self.metadata_file.get_column("species")[idx].replace("_", " ")
        subcategory = self.metadata_file.get_column("subcategory")[idx]
        if subcategory is not None:
            subcategory = subcategory.replace("_", " ")
        photographer = self.metadata_file.get_column("photographer")[idx].replace("_", " ")
        element = {
            "img": img,
            "species": species,
            "subcategory": subcategory,
            "photographer": photographer,
        }

        return element

nabirds_train = NABirdsDataset(metadata_train, cleaned_img_dir)

Our strategy

Technique

Architecture

Options that would make sense for our example include ResNet and EfficientNet.

ResNet, or residual network, is a type of architecture in which the layers are reformulated as learning residual functions with reference to the layer inputs. This allows for deeper (and thus more performant) networks [@he2015deepresiduallearningimage].

from @he2015deepresiduallearningimage

ResNet-50 has become a classic CNN for image classification.

EfficientNet is a family of newer computer vision CNNs from Google that uses a compound coefficient to uniformly scale depth, width, and resolution of networks and achieves better accuracy with fewer parameters than other CNNs [@tan2020efficientnetrethinkingmodelscaling]. This makes them easier to train on fewer resources and can lead to better results. Tuning them is however harder than the more robust ResNet family.

from @tan2020efficientnetrethinkingmodelscaling

There are variations for different image sizes sizes. For instance:

Pre-trained weights

Choice of library

Strategy summary

Category Answer
Technique Transfer learning
Architecture EfficientNet-B2 (EfficientNet-B0 or ResNet-50 are other reasonable options)
Pre-trained weights ImageNet
Library

Implementation

import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx import bridge
from transformers import FlaxEfficientNetModel

class NABirdsEfficientNet(nnx.Module):
    def __init__(self, num_classes=555, *, rngs: nnx.Rngs):
        # 1. Load the pre-trained Linen model structure & weights from Hugging Face
        hf_model = FlaxEfficientNetModel.from_pretrained("google/efficientnet-b0")

        # 2. Extract the underlying Linen module and its variables
        linen_module = hf_model.module
        linen_variables = hf_model.params
        # HF stores batch_stats in 'batch_stats' if they exist, or inside params.
        # EfficientNet usually has 'batch_stats'. We merge them for the bridge.
        if hasattr(hf_model, 'batch_stats'):
            linen_variables = {**linen_variables, **hf_model.batch_stats}

        # 3. Create the NNX Bridge
        # We wrap the Linen module. ToNNX creates the structure.
        self.backbone = bridge.ToNNX(linen_module, rngs=rngs)

        # 4. WEIGHT SURGERY (The Critical Step)
        # We must initialize the bridge to create the NNX variable structure,
        # then replace those random variables with the pre-trained ones.
        dummy_input = jnp.ones((1, 224, 224, 3))
        self.backbone.lazy_init(dummy_input)

        # Transfer weights: Linen dict -> NNX State
        # The bridge maps Linen collections to NNX variable types automatically.
        # 'params' -> nnx.Param, 'batch_stats' -> nnx.BatchStat
        _, backbone_state = nnx.split(self.backbone)

        # This function recursively matches keys and updates the state
        def copy_weights(target_state, source_dict):
            for key, value in source_dict.items():
                if isinstance(value, dict) or hasattr(value, 'items'):
                    # Traverse deeper if it's a dict/FrozenDict
                    copy_weights(target_state[key], value)
                else:
                    # We found a leaf (array). Update the NNX Variable's value.
                    # Note: target_state[key] is a Variable (Param/BatchStat) wrapper
                    target_state[key].value = value

        copy_weights(backbone_state, linen_variables)
        nnx.update(self.backbone, backbone_state)

        # 5. Define your new Custom Head (Pure NNX)
        # EfficientNet-B0 output is 1280 dim
        self.head = nnx.Linear(1280, num_classes, rngs=rngs)

    def __call__(self, x):
        # Run backbone (bridge)
        # HF models output a generic object; we need 'last_hidden_state'
        # shape: [batch, 7, 7, 1280] for B0
        outputs = self.backbone(x)
        features = outputs.last_hidden_state

        # Global Average Pooling (standard for EfficientNet)
        features = jnp.mean(features, axis=(1, 2))

        # Classification
        return self.head(features)

# --- Usage Example ---

# 1. Initialize
rngs = nnx.Rngs(params=0, dropout=1)
model = NABirdsEfficientNet(num_classes=555, rngs=rngs)

# 2. Forward Pass
x = jax.random.normal(jax.random.key(0), (1, 224, 224, 3))
logits = model(x)

print(f"Logits shape: {logits.shape}") # (1, 555)
print("Model initialized and pre-trained weights loaded via NNX bridge.")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[3], line 5
      3 from flax import nnx
      4 from flax.nnx import bridge
----> 5 from transformers import FlaxEfficientNetModel
      7 class NABirdsEfficientNet(nnx.Module):
      8     def __init__(self, num_classes=555, *, rngs: nnx.Rngs):
      9         # 1. Load the pre-trained Linen model structure & weights from Hugging Face

ModuleNotFoundError: No module named 'transformers'