Loading pre-trained weights

Authors

Marie-Hélène Burle

Code adapted from JAX’s Implement ViT from scratch

In this section, we transfer weights from a pre-trained model into our ViT model.

Context

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn pretr Pre-trained model opt Hyperparameters nn->opt pretr->nn train Train opt->train cp Checkpoint train->cp pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc tr transformers tr->pretr fl1 flax fl1->nn fl2 flax fl2->train oa optax oa->opt jx JAX jx->fl2 ob orbax ob->cp

import jax
import jax.numpy as jnp
from flax import nnx

class VisionTransformer(nnx.Module):
    def __init__(
        self,
        num_classes: int = 1000,
        in_channels: int = 3,
        img_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_dim: int = 3072,
        hidden_size: int = 768,
        dropout_rate: float = 0.1,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        # Patch and position embedding
        n_patches = (img_size // patch_size) ** 2
        self.patch_embeddings = nnx.Conv(
            in_channels,
            hidden_size,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding="VALID",
            use_bias=True,
            rngs=rngs,
        )

        initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
        self.position_embeddings = nnx.Param(
            initializer(
                rngs.params(),
                (1, n_patches + 1, hidden_size),
                jnp.float32
            )
        )
        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)

        self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))

        # Transformer Encoder blocks
        self.encoder = nnx.Sequential(*[
            TransformerEncoder(
                hidden_size,
                mlp_dim,
                num_heads,
                dropout_rate,
                rngs=rngs
            )
            for i in range(num_layers)
        ])
        self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)

        # Classification head
        self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        # Patch and position embedding
        patches = self.patch_embeddings(x)
        batch_size = patches.shape[0]
        patches = patches.reshape(batch_size, -1, patches.shape[-1])

        cls_token = jnp.tile(self.cls_token, [batch_size, 1, 1])
        x = jnp.concat([cls_token, patches], axis=1)
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)

        # Encoder blocks
        x = self.encoder(embeddings)
        x = self.final_norm(x)

        # fetch the first token
        x = x[:, 0]

        # Classification
        return self.classifier(x)

class TransformerEncoder(nnx.Module):
    def __init__(
        self,
        hidden_size: int,
        mlp_dim: int,
        num_heads: int,
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:

        self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            rngs=rngs,
        )
        self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)

        self.mlp = nnx.Sequential(
            nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
            nnx.gelu,
            nnx.Dropout(dropout_rate, rngs=rngs),
            nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
            nnx.Dropout(dropout_rate, rngs=rngs),
        )

    def __call__(self, x: jax.Array) -> jax.Array:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

model = VisionTransformer(num_classes=1000)

Load packages

Packages and modules necessary for this section:

# Hugging Face ViT Model transformer with image classification head
from transformers import FlaxViTForImageClassification

# Packages to test our model after weight transfer
import matplotlib.pyplot as plt
from transformers import ViTImageProcessor
from PIL import Image
import requests

FlaxViTForImageClassification instantiates a pretrained Flax model with an image classification head from a pre-trained ViT model configuration.

Load pre-trained weights

We want to load the weights from Google’s ViT model pre-trained on ImageNet-21k at resolution 224x224 and fine-tuned on ImageNet 2012 at resolution 224x224 introduced by Dosovitskiy et al. [] in our model.

For this, we use the from_pretrained method of FlaxViTForImageClassification and get the weights from Google’s model stored as google/vit-base-patch16-224 on the Hugging Face model Hub.

tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

Copy weights to our model

tf_model is a transformer ViT model with the pre-trained weights. We want to copy those weights to our ViT Flax model called model:

def vit_inplace_copy_weights(*, src_model, dst_model):
    assert isinstance(src_model, FlaxViTForImageClassification)
    assert isinstance(dst_model, VisionTransformer)

    tf_model_params = src_model.params
    tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)

    flax_model_params = nnx.state(dst_model, nnx.Param)
    flax_model_params_fstate = flax_model_params.flat_state()

    params_name_mapping = {
        ("cls_token",): ("vit", "embeddings", "cls_token"),
        ("position_embeddings",): (
            "vit",
            "embeddings",
            "position_embeddings"
        ),
        **{
            ("patch_embeddings", x): (
                "vit",
                "embeddings",
                "patch_embeddings",
                "projection",
                x
            )
            for x in ["kernel", "bias"]
        },
        **{
            ("encoder", "layers", i, "attn", y, x): (
                "vit",
                "encoder",
                "layer",
                str(i),
                "attention",
                "attention",
                y,
                x
            )
            for x in ["kernel", "bias"]
            for y in ["key", "value", "query"]
            for i in range(12)
        },
        **{
            ("encoder", "layers", i, "attn", "out", x): (
                "vit",
                "encoder",
                "layer",
                str(i),
                "attention",
                "output",
                "dense",
                x
            )
            for x in ["kernel", "bias"]
            for i in range(12)
        },
        **{
            ("encoder", "layers", i, "mlp", "layers", y1, x): (
                "vit",
                "encoder",
                "layer",
                str(i),
                y2,
                "dense",
                x
            )
            for x in ["kernel", "bias"]
            for y1, y2 in [(0, "intermediate"), (3, "output")]
            for i in range(12)
        },
        **{
            ("encoder", "layers", i, y1, x): (
                "vit", "encoder", "layer", str(i), y2, x
            )
            for x in ["scale", "bias"]
            for y1, y2 in [
                    ("norm1", "layernorm_before"),
                    ("norm2", "layernorm_after")
            ]
            for i in range(12)
        },
        **{
            ("final_norm", x): ("vit", "layernorm", x)
            for x in ["scale", "bias"]
        },
        **{
            ("classifier", x): ("classifier", x)
            for x in ["kernel", "bias"]
        }
    }

    nonvisited = set(flax_model_params_fstate.keys())

    for key1, key2 in params_name_mapping.items():
        assert key1 in flax_model_params_fstate, key1
        assert key2 in tf_model_params_fstate, (key1, key2)

        nonvisited.remove(key1)

        src_value = tf_model_params_fstate[key2]
        if key2[-1] == "kernel" and key2[-2] in ("key", "value", "query"):
            shape = src_value.shape
            src_value = src_value.reshape((shape[0], 12, 64))

        if key2[-1] == "bias" and key2[-2] in ("key", "value", "query"):
            src_value = src_value.reshape((12, 64))

        if key2[-4:] == ("attention", "output", "dense", "kernel"):
            shape = src_value.shape
            src_value = src_value.reshape((12, 64, shape[-1]))

        dst_value = flax_model_params_fstate[key1]
        assert src_value.shape == dst_value.value.shape, (
            key2, src_value.shape, key1, dst_value.value.shape
        )
        dst_value.value = src_value.copy()
        assert dst_value.value.mean() == src_value.mean(), (
            dst_value.value, src_value.mean()
        )

    assert len(nonvisited) == 0, nonvisited

    nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))

vit_inplace_copy_weights(src_model=tf_model, dst_model=model)

Test our model

Our model should now be able to classify objects if they belong to the 1000 classes of ImageNet-1K.

Let’s test it by passing the URL of the image of a Song Sparrow (Melospiza melodia):

url = "https://www.allaboutbirds.org/guide/assets/photo/308771371-480px.jpg"
image = Image.open(requests.get(url, stream=True).raw)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

inputs = processor(images=image, return_tensors="np")
outputs = tf_model(**inputs)
logits = outputs.logits

model.eval()
x = jnp.transpose(inputs["pixel_values"], axes=(0, 2, 3, 1))
output = model(x)

# Model predicts one of the 1000 ImageNet classes.
ref_class_idx = logits.argmax(-1).item()
pred_class_idx = output.argmax(-1).item()
assert jnp.abs(logits[0, :] - output[0, :]).max() < 0.1

fig, axs = plt.subplots(1, 2, figsize=(12, 8))
axs[0].set_title(
    f"Reference model:\n{tf_model.config.id2label[ref_class_idx]}\nP={nnx.softmax(logits, axis=-1)[0, ref_class_idx]:.3f}"
)
axs[0].imshow(image)
axs[1].set_title(
    f"Our model:\n{tf_model.config.id2label[pred_class_idx]}\nP={nnx.softmax(output, axis=-1)[0, pred_class_idx]:.3f}"
)
axs[1].imshow(image)

The Song Sparrow is apparently not in the 1000 classes. But the good news is that our model with the transferred weights gave exactly the same result as the google/vit-base-patch16-224 model and with the same probability. Brambling—another songbird—is probably the class the closest to a Song Sparrow. So all looks good.

Reduce number of classes

Our model now returns 1000 categories, but we want to fine-tune it on the Food-101 dataset [] that we have reduced to only 5 classes. So we need to replace the model classifier with one returning 5 classes:

model.classifier = nnx.Linear(model.classifier.in_features, 5, rngs=nnx.Rngs(0))

x = jnp.ones((4, 224, 224, 3))
y = model(x)
print("Predictions shape: ", y.shape)
Predictions shape:  (4, 5)

References

1.
Dosovitskiy A, Beyer L, Kolesnikov A, et al (2021) An image is worth 16x16 words: Transformers for image recognition at scale
2.
Bossard L, Guillaumin M, Van Gool L (2014) Food-101 – mining discriminative components with random forests. In: European conference on computer vision