Loading pre-trained weights
In this section, we transfer weights from a pre-trained model into our ViT model.
Context
Load packages
Packages and modules necessary for this section:
# Hugging Face ViT Model transformer with image classification head
from transformers import FlaxViTForImageClassification
# Packages to test our model after weight transfer
import matplotlib.pyplot as plt
from transformers import ViTImageProcessor
from PIL import Image
import requestsFlaxViTForImageClassification instantiates a pretrained Flax model with an image classification head from a pre-trained ViT model configuration.
Load pre-trained weights
We want to load the weights from Google’s ViT model pre-trained on ImageNet-21k at resolution 224x224 and fine-tuned on ImageNet 2012 at resolution 224x224 introduced by Dosovitskiy et al. [1] in our model.
For this, we use the from_pretrained method of FlaxViTForImageClassification and get the weights from Google’s model stored as google/vit-base-patch16-224 on the Hugging Face model Hub.
tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')Copy weights to our model
tf_model is a transformer ViT model with the pre-trained weights. We want to copy those weights to our ViT Flax model called model:
def vit_inplace_copy_weights(*, src_model, dst_model):
assert isinstance(src_model, FlaxViTForImageClassification)
assert isinstance(dst_model, VisionTransformer)
tf_model_params = src_model.params
tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)
flax_model_params = nnx.state(dst_model, nnx.Param)
flax_model_params_fstate = flax_model_params.flat_state()
params_name_mapping = {
("cls_token",): ("vit", "embeddings", "cls_token"),
("position_embeddings",): (
"vit",
"embeddings",
"position_embeddings"
),
**{
("patch_embeddings", x): (
"vit",
"embeddings",
"patch_embeddings",
"projection",
x
)
for x in ["kernel", "bias"]
},
**{
("encoder", "layers", i, "attn", y, x): (
"vit",
"encoder",
"layer",
str(i),
"attention",
"attention",
y,
x
)
for x in ["kernel", "bias"]
for y in ["key", "value", "query"]
for i in range(12)
},
**{
("encoder", "layers", i, "attn", "out", x): (
"vit",
"encoder",
"layer",
str(i),
"attention",
"output",
"dense",
x
)
for x in ["kernel", "bias"]
for i in range(12)
},
**{
("encoder", "layers", i, "mlp", "layers", y1, x): (
"vit",
"encoder",
"layer",
str(i),
y2,
"dense",
x
)
for x in ["kernel", "bias"]
for y1, y2 in [(0, "intermediate"), (3, "output")]
for i in range(12)
},
**{
("encoder", "layers", i, y1, x): (
"vit", "encoder", "layer", str(i), y2, x
)
for x in ["scale", "bias"]
for y1, y2 in [
("norm1", "layernorm_before"),
("norm2", "layernorm_after")
]
for i in range(12)
},
**{
("final_norm", x): ("vit", "layernorm", x)
for x in ["scale", "bias"]
},
**{
("classifier", x): ("classifier", x)
for x in ["kernel", "bias"]
}
}
nonvisited = set(flax_model_params_fstate.keys())
for key1, key2 in params_name_mapping.items():
assert key1 in flax_model_params_fstate, key1
assert key2 in tf_model_params_fstate, (key1, key2)
nonvisited.remove(key1)
src_value = tf_model_params_fstate[key2]
if key2[-1] == "kernel" and key2[-2] in ("key", "value", "query"):
shape = src_value.shape
src_value = src_value.reshape((shape[0], 12, 64))
if key2[-1] == "bias" and key2[-2] in ("key", "value", "query"):
src_value = src_value.reshape((12, 64))
if key2[-4:] == ("attention", "output", "dense", "kernel"):
shape = src_value.shape
src_value = src_value.reshape((12, 64, shape[-1]))
dst_value = flax_model_params_fstate[key1]
assert src_value.shape == dst_value.value.shape, (
key2, src_value.shape, key1, dst_value.value.shape
)
dst_value.value = src_value.copy()
assert dst_value.value.mean() == src_value.mean(), (
dst_value.value, src_value.mean()
)
assert len(nonvisited) == 0, nonvisited
nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))
vit_inplace_copy_weights(src_model=tf_model, dst_model=model)Test our model
Our model should now be able to classify objects if they belong to the 1000 classes of ImageNet-1K.
Let’s test it by passing the URL of the image of a Song Sparrow (Melospiza melodia):
url = "https://www.allaboutbirds.org/guide/assets/photo/308771371-480px.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
inputs = processor(images=image, return_tensors="np")
outputs = tf_model(**inputs)
logits = outputs.logits
model.eval()
x = jnp.transpose(inputs["pixel_values"], axes=(0, 2, 3, 1))
output = model(x)
# Model predicts one of the 1000 ImageNet classes.
ref_class_idx = logits.argmax(-1).item()
pred_class_idx = output.argmax(-1).item()
assert jnp.abs(logits[0, :] - output[0, :]).max() < 0.1
fig, axs = plt.subplots(1, 2, figsize=(12, 8))
axs[0].set_title(
f"Reference model:\n{tf_model.config.id2label[ref_class_idx]}\nP={nnx.softmax(logits, axis=-1)[0, ref_class_idx]:.3f}"
)
axs[0].imshow(image)
axs[1].set_title(
f"Our model:\n{tf_model.config.id2label[pred_class_idx]}\nP={nnx.softmax(output, axis=-1)[0, pred_class_idx]:.3f}"
)
axs[1].imshow(image)The Song Sparrow is apparently not in the 1000 classes. But the good news is that our model with the transferred weights gave exactly the same result as the google/vit-base-patch16-224 model and with the same probability. Brambling—another songbird—is probably the class the closest to a Song Sparrow. So all looks good.
Reduce number of classes
Our model now returns 1000 categories, but we want to fine-tune it on the Food-101 dataset [2] that we have reduced to only 5 classes. So we need to replace the model classifier with one returning 5 classes:
model.classifier = nnx.Linear(model.classifier.in_features, 5, rngs=nnx.Rngs(0))
x = jnp.ones((4, 224, 224, 3))
y = model(x)
print("Predictions shape: ", y.shape)