import os
import polars as pl
import imageio.v3 as iio
import grain.python as grain
metadata = pl.read_parquet("metadata.parquet")
metadata_train = metadata.filter(pl.col("is_training_img") == 1)
cleaned_img_dir = os.path.join(base_dir, "cleaned_images")
class NABirdsDataset:
"""NABirds dataset class."""
def __init__(self, metadata_file, data_dir):
self.metadata_file = metadata_file
self.data_dir = data_dir
def __len__(self):
return len(self.metadata_file)
def __getitem__(self, idx):
path = os.path.join(self.data_dir, self.metadata_file.get_column("path")[idx])
img = iio.imread(path)
species = self.metadata_file.get_column("species")[idx].replace("_", " ")
subcategory = self.metadata_file.get_column("subcategory")[idx]
if subcategory is not None:
subcategory = subcategory.replace("_", " ")
photographer = self.metadata_file.get_column("photographer")[idx].replace("_", " ")
element = {
"img": img,
"species": species,
"subcategory": subcategory,
"photographer": photographer,
}
return element
nabirds_train = NABirdsDataset(metadata_train, cleaned_img_dir)Model and training strategy
Our strategy
Technique
Architecture
Options that would make sense for our example include ResNet and EfficientNet.
ResNet, or residual network, is a type of architecture in which the layers are reformulated as learning residual functions with reference to the layer inputs. This allows for deeper (and thus more performant) networks [@he2015deepresiduallearningimage].

ResNet-50 has become a classic CNN for image classification.
EfficientNet is a family of newer computer vision CNNs from Google that uses a compound coefficient to uniformly scale depth, width, and resolution of networks and achieves better accuracy with fewer parameters than other CNNs [@tan2020efficientnetrethinkingmodelscaling]. This makes them easier to train on fewer resources and can lead to better results. Tuning them is however harder than the more robust ResNet family.

There are variations for different image sizes sizes. For instance:
- EfficientNet b0 for images of size 224x224
- EfficientNet b2 for images 260x260
- EfficientNet b3 for images 300x300
- EfficientNet b7 for images 600x600
Pre-trained weights
Choice of library
Strategy summary
| Category | Answer |
|---|---|
| Technique | Transfer learning |
| Architecture | EfficientNet-B2 (EfficientNet-B0 or ResNet-50 are other reasonable options) |
| Pre-trained weights | ImageNet |
| Library |
Implementation
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx import bridge
from transformers import FlaxEfficientNetModel
class NABirdsEfficientNet(nnx.Module):
def __init__(self, num_classes=555, *, rngs: nnx.Rngs):
# 1. Load the pre-trained Linen model structure & weights from Hugging Face
hf_model = FlaxEfficientNetModel.from_pretrained("google/efficientnet-b0")
# 2. Extract the underlying Linen module and its variables
linen_module = hf_model.module
linen_variables = hf_model.params
# HF stores batch_stats in 'batch_stats' if they exist, or inside params.
# EfficientNet usually has 'batch_stats'. We merge them for the bridge.
if hasattr(hf_model, 'batch_stats'):
linen_variables = {**linen_variables, **hf_model.batch_stats}
# 3. Create the NNX Bridge
# We wrap the Linen module. ToNNX creates the structure.
self.backbone = bridge.ToNNX(linen_module, rngs=rngs)
# 4. WEIGHT SURGERY (The Critical Step)
# We must initialize the bridge to create the NNX variable structure,
# then replace those random variables with the pre-trained ones.
dummy_input = jnp.ones((1, 224, 224, 3))
self.backbone.lazy_init(dummy_input)
# Transfer weights: Linen dict -> NNX State
# The bridge maps Linen collections to NNX variable types automatically.
# 'params' -> nnx.Param, 'batch_stats' -> nnx.BatchStat
_, backbone_state = nnx.split(self.backbone)
# This function recursively matches keys and updates the state
def copy_weights(target_state, source_dict):
for key, value in source_dict.items():
if isinstance(value, dict) or hasattr(value, 'items'):
# Traverse deeper if it's a dict/FrozenDict
copy_weights(target_state[key], value)
else:
# We found a leaf (array). Update the NNX Variable's value.
# Note: target_state[key] is a Variable (Param/BatchStat) wrapper
target_state[key].value = value
copy_weights(backbone_state, linen_variables)
nnx.update(self.backbone, backbone_state)
# 5. Define your new Custom Head (Pure NNX)
# EfficientNet-B0 output is 1280 dim
self.head = nnx.Linear(1280, num_classes, rngs=rngs)
def __call__(self, x):
# Run backbone (bridge)
# HF models output a generic object; we need 'last_hidden_state'
# shape: [batch, 7, 7, 1280] for B0
outputs = self.backbone(x)
features = outputs.last_hidden_state
# Global Average Pooling (standard for EfficientNet)
features = jnp.mean(features, axis=(1, 2))
# Classification
return self.head(features)
# --- Usage Example ---
# 1. Initialize
rngs = nnx.Rngs(params=0, dropout=1)
model = NABirdsEfficientNet(num_classes=555, rngs=rngs)
# 2. Forward Pass
x = jax.random.normal(jax.random.key(0), (1, 224, 224, 3))
logits = model(x)
print(f"Logits shape: {logits.shape}") # (1, 555)
print("Model initialized and pre-trained weights loaded via NNX bridge.")--------------------------------------------------------------------------- ModuleNotFoundError Traceback (most recent call last) Cell In[3], line 5 3 from flax import nnx 4 from flax.nnx import bridge ----> 5 from transformers import FlaxEfficientNetModel 7 class NABirdsEfficientNet(nnx.Module): 8 def __init__(self, num_classes=555, *, rngs: nnx.Rngs): 9 # 1. Load the pre-trained Linen model structure & weights from Hugging Face ModuleNotFoundError: No module named 'transformers'