from datasets import load_dataset
import numpy as np
from torchvision.transforms import v2 as T
import grain.python as grain
import jax
import jax.numpy as jnp
from flax import nnx
from transformers import FlaxViTForImageClassification
train_size = 5 * 750
val_size = 5 * 250
train_dataset = load_dataset("food101",
split=f"train[:{train_size}]")
val_dataset = load_dataset("food101",
split=f"validation[:{val_size}]")
labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
label = val_dataset[i]["label"]
if label not in labels_mapping:
labels_mapping[label] = index
index += 1
inv_labels_mapping = {v: k for k, v in labels_mapping.items()}
img_size = 224
def to_np_array(pil_image):
return np.asarray(pil_image.convert("RGB"))
def normalize(image):
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
image = image.astype(np.float32) / 255.0
return (image - mean) / std
tv_train_transforms = T.Compose([
T.RandomResizedCrop((img_size, img_size), scale=(0.7, 1.0)),
T.RandomHorizontalFlip(),
T.ColorJitter(0.2, 0.2, 0.2),
T.Lambda(to_np_array),
T.Lambda(normalize),
])
tv_test_transforms = T.Compose([
T.Resize((img_size, img_size)),
T.Lambda(to_np_array),
T.Lambda(normalize),
])
def get_transform(fn):
def wrapper(batch):
batch["image"] = [
fn(pil_image) for pil_image in batch["image"]
]
batch["label"] = [
labels_mapping[label] for label in batch["label"]
]
return batch
return wrapper
train_transforms = get_transform(tv_train_transforms)
val_transforms = get_transform(tv_test_transforms)
train_dataset = train_dataset.with_transform(train_transforms)
val_dataset = val_dataset.with_transform(val_transforms)
seed = 12
train_batch_size = 32
val_batch_size = 2 * train_batch_size
train_sampler = grain.IndexSampler(
len(train_dataset),
shuffle=True,
seed=seed,
shard_options=grain.NoSharding(),
num_epochs=1,
)
val_sampler = grain.IndexSampler(
len(val_dataset),
shuffle=False,
seed=seed,
shard_options=grain.NoSharding(),
num_epochs=1,
)
train_loader = grain.DataLoader(
data_source=train_dataset,
sampler=train_sampler,
worker_count=4,
worker_buffer_size=2,
operations=[
grain.Batch(train_batch_size, drop_remainder=True),
]
)
val_loader = grain.DataLoader(
data_source=val_dataset,
sampler=val_sampler,
worker_count=4,
worker_buffer_size=2,
operations=[
grain.Batch(val_batch_size),
]
)
class VisionTransformer(nnx.Module):
def __init__(
self,
num_classes: int = 1000,
in_channels: int = 3,
img_size: int = 224,
patch_size: int = 16,
num_layers: int = 12,
num_heads: int = 12,
mlp_dim: int = 3072,
hidden_size: int = 768,
dropout_rate: float = 0.1,
*,
rngs: nnx.Rngs = nnx.Rngs(0),
):
# Patch and position embedding
n_patches = (img_size // patch_size) ** 2
self.patch_embeddings = nnx.Conv(
in_channels,
hidden_size,
kernel_size=(patch_size, patch_size),
strides=(patch_size, patch_size),
padding="VALID",
use_bias=True,
rngs=rngs,
)
initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
self.position_embeddings = nnx.Param(
initializer(
rngs.params(),
(1, n_patches + 1, hidden_size),
jnp.float32
)
)
self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))
# Transformer Encoder blocks
self.encoder = nnx.Sequential(*[
TransformerEncoder(
hidden_size,
mlp_dim,
num_heads,
dropout_rate,
rngs=rngs
)
for i in range(num_layers)
])
self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)
# Classification head
self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
# Patch and position embedding
patches = self.patch_embeddings(x)
batch_size = patches.shape[0]
patches = patches.reshape(batch_size, -1, patches.shape[-1])
cls_token = jnp.tile(self.cls_token, [batch_size, 1, 1])
x = jnp.concat([cls_token, patches], axis=1)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
# Encoder blocks
x = self.encoder(embeddings)
x = self.final_norm(x)
# fetch the first token
x = x[:, 0]
# Classification
return self.classifier(x)
class TransformerEncoder(nnx.Module):
def __init__(
self,
hidden_size: int,
mlp_dim: int,
num_heads: int,
dropout_rate: float = 0.0,
*,
rngs: nnx.Rngs = nnx.Rngs(0),
) -> None:
self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
self.attn = nnx.MultiHeadAttention(
num_heads=num_heads,
in_features=hidden_size,
dropout_rate=dropout_rate,
broadcast_dropout=False,
decode=False,
deterministic=False,
rngs=rngs,
)
self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)
self.mlp = nnx.Sequential(
nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
nnx.gelu,
nnx.Dropout(dropout_rate, rngs=rngs),
nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
nnx.Dropout(dropout_rate, rngs=rngs),
)
def __call__(self, x: jax.Array) -> jax.Array:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
model = VisionTransformer(num_classes=1000)
tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
def vit_inplace_copy_weights(*, src_model, dst_model):
assert isinstance(src_model, FlaxViTForImageClassification)
assert isinstance(dst_model, VisionTransformer)
tf_model_params = src_model.params
tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)
flax_model_params = nnx.state(dst_model, nnx.Param)
flax_model_params_fstate = flax_model_params.flat_state()
params_name_mapping = {
("cls_token",): ("vit", "embeddings", "cls_token"),
("position_embeddings",): (
"vit",
"embeddings",
"position_embeddings"
),
**{
("patch_embeddings", x): (
"vit",
"embeddings",
"patch_embeddings",
"projection",
x
)
for x in ["kernel", "bias"]
},
**{
("encoder", "layers", i, "attn", y, x): (
"vit",
"encoder",
"layer",
str(i),
"attention",
"attention",
y,
x
)
for x in ["kernel", "bias"]
for y in ["key", "value", "query"]
for i in range(12)
},
**{
("encoder", "layers", i, "attn", "out", x): (
"vit",
"encoder",
"layer",
str(i),
"attention",
"output",
"dense",
x
)
for x in ["kernel", "bias"]
for i in range(12)
},
**{
("encoder", "layers", i, "mlp", "layers", y1, x): (
"vit",
"encoder",
"layer",
str(i),
y2,
"dense",
x
)
for x in ["kernel", "bias"]
for y1, y2 in [(0, "intermediate"), (3, "output")]
for i in range(12)
},
**{
("encoder", "layers", i, y1, x): (
"vit", "encoder", "layer", str(i), y2, x
)
for x in ["scale", "bias"]
for y1, y2 in [
("norm1", "layernorm_before"),
("norm2", "layernorm_after")
]
for i in range(12)
},
**{
("final_norm", x): ("vit", "layernorm", x)
for x in ["scale", "bias"]
},
**{
("classifier", x): ("classifier", x)
for x in ["kernel", "bias"]
}
}
nonvisited = set(flax_model_params_fstate.keys())
for key1, key2 in params_name_mapping.items():
assert key1 in flax_model_params_fstate, key1
assert key2 in tf_model_params_fstate, (key1, key2)
nonvisited.remove(key1)
src_value = tf_model_params_fstate[key2]
if key2[-1] == "kernel" and key2[-2] in ("key", "value", "query"):
shape = src_value.shape
src_value = src_value.reshape((shape[0], 12, 64))
if key2[-1] == "bias" and key2[-2] in ("key", "value", "query"):
src_value = src_value.reshape((12, 64))
if key2[-4:] == ("attention", "output", "dense", "kernel"):
shape = src_value.shape
src_value = src_value.reshape((12, 64, shape[-1]))
dst_value = flax_model_params_fstate[key1]
assert src_value.shape == dst_value.value.shape, (
key2, src_value.shape, key1, dst_value.value.shape
)
dst_value.value = src_value.copy()
assert dst_value.value.mean() == src_value.mean(), (
dst_value.value, src_value.mean()
)
assert len(nonvisited) == 0, nonvisited
nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))
vit_inplace_copy_weights(src_model=tf_model, dst_model=model)
model.classifier = nnx.Linear(model.classifier.in_features, 5, rngs=nnx.Rngs(0))