Loading pre-trained weights
In this section, we transfer weights from a pre-trained model into our ViT model.
Context
Load packages
Packages and modules necessary for this section:
# Hugging Face ViT Model transformer with image classification head
from transformers import FlaxViTForImageClassification
# Packages to test our model after weight transfer
import matplotlib.pyplot as plt
from transformers import ViTImageProcessor
from PIL import Image
import requests
FlaxViTForImageClassification instantiates a pretrained Flax model with an image classification head from a pre-trained ViT model configuration.
Load pre-trained weights
We want to load the weights from Google’s ViT model pre-trained on ImageNet-21k at resolution 224x224 and fine-tuned on ImageNet 2012 at resolution 224x224 introduced by Dosovitskiy et al. [1] in our model.
For this, we use the from_pretrained
method of FlaxViTForImageClassification
and get the weights from Google’s model stored as google/vit-base-patch16-224 on the Hugging Face model Hub.
= FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224') tf_model
Copy weights to our model
tf_model
is a transformer ViT model with the pre-trained weights. We want to copy those weights to our ViT Flax model called model
:
def vit_inplace_copy_weights(*, src_model, dst_model):
assert isinstance(src_model, FlaxViTForImageClassification)
assert isinstance(dst_model, VisionTransformer)
= src_model.params
tf_model_params = nnx.traversals.flatten_mapping(tf_model_params)
tf_model_params_fstate
= nnx.state(dst_model, nnx.Param)
flax_model_params = flax_model_params.flat_state()
flax_model_params_fstate
= {
params_name_mapping "cls_token",): ("vit", "embeddings", "cls_token"),
("position_embeddings",): (
("vit",
"embeddings",
"position_embeddings"
),**{
"patch_embeddings", x): (
("vit",
"embeddings",
"patch_embeddings",
"projection",
x
)for x in ["kernel", "bias"]
},**{
"encoder", "layers", i, "attn", y, x): (
("vit",
"encoder",
"layer",
str(i),
"attention",
"attention",
y,
x
)for x in ["kernel", "bias"]
for y in ["key", "value", "query"]
for i in range(12)
},**{
"encoder", "layers", i, "attn", "out", x): (
("vit",
"encoder",
"layer",
str(i),
"attention",
"output",
"dense",
x
)for x in ["kernel", "bias"]
for i in range(12)
},**{
"encoder", "layers", i, "mlp", "layers", y1, x): (
("vit",
"encoder",
"layer",
str(i),
y2,"dense",
x
)for x in ["kernel", "bias"]
for y1, y2 in [(0, "intermediate"), (3, "output")]
for i in range(12)
},**{
"encoder", "layers", i, y1, x): (
("vit", "encoder", "layer", str(i), y2, x
)for x in ["scale", "bias"]
for y1, y2 in [
"norm1", "layernorm_before"),
("norm2", "layernorm_after")
(
]for i in range(12)
},**{
"final_norm", x): ("vit", "layernorm", x)
(for x in ["scale", "bias"]
},**{
"classifier", x): ("classifier", x)
(for x in ["kernel", "bias"]
}
}
= set(flax_model_params_fstate.keys())
nonvisited
for key1, key2 in params_name_mapping.items():
assert key1 in flax_model_params_fstate, key1
assert key2 in tf_model_params_fstate, (key1, key2)
nonvisited.remove(key1)
= tf_model_params_fstate[key2]
src_value if key2[-1] == "kernel" and key2[-2] in ("key", "value", "query"):
= src_value.shape
shape = src_value.reshape((shape[0], 12, 64))
src_value
if key2[-1] == "bias" and key2[-2] in ("key", "value", "query"):
= src_value.reshape((12, 64))
src_value
if key2[-4:] == ("attention", "output", "dense", "kernel"):
= src_value.shape
shape = src_value.reshape((12, 64, shape[-1]))
src_value
= flax_model_params_fstate[key1]
dst_value assert src_value.shape == dst_value.value.shape, (
key2, src_value.shape, key1, dst_value.value.shape
)= src_value.copy()
dst_value.value assert dst_value.value.mean() == src_value.mean(), (
dst_value.value, src_value.mean()
)
assert len(nonvisited) == 0, nonvisited
nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))
=tf_model, dst_model=model) vit_inplace_copy_weights(src_model
Test our model
Our model should now be able to classify objects if they belong to the 1000 classes of ImageNet-1K.
Let’s test it by passing the URL of the image of a Song Sparrow (Melospiza melodia):
= "https://www.allaboutbirds.org/guide/assets/photo/308771371-480px.jpg"
url = Image.open(requests.get(url, stream=True).raw)
image
= ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
processor
= processor(images=image, return_tensors="np")
inputs = tf_model(**inputs)
outputs = outputs.logits
logits
eval()
model.= jnp.transpose(inputs["pixel_values"], axes=(0, 2, 3, 1))
x = model(x)
output
# Model predicts one of the 1000 ImageNet classes.
= logits.argmax(-1).item()
ref_class_idx = output.argmax(-1).item()
pred_class_idx assert jnp.abs(logits[0, :] - output[0, :]).max() < 0.1
= plt.subplots(1, 2, figsize=(12, 8))
fig, axs 0].set_title(
axs[f"Reference model:\n{tf_model.config.id2label[ref_class_idx]}\nP={nnx.softmax(logits, axis=-1)[0, ref_class_idx]:.3f}"
)0].imshow(image)
axs[1].set_title(
axs[f"Our model:\n{tf_model.config.id2label[pred_class_idx]}\nP={nnx.softmax(output, axis=-1)[0, pred_class_idx]:.3f}"
)1].imshow(image) axs[
The Song Sparrow is apparently not in the 1000 classes. But the good news is that our model with the transferred weights gave exactly the same result as the google/vit-base-patch16-224
model and with the same probability. Brambling—another songbird—is probably the class the closest to a Song Sparrow. So all looks good.
Reduce number of classes
Our model now returns 1000 categories, but we want to fine-tune it on the Food-101 dataset [2] that we have reduced to only 5 classes. So we need to replace the model classifier with one returning 5 classes:
= nnx.Linear(model.classifier.in_features, 5, rngs=nnx.Rngs(0))
model.classifier
= jnp.ones((4, 224, 224, 3))
x = model(x)
y print("Predictions shape: ", y.shape)
Predictions shape: (4, 5)