from datasets import load_dataset
import numpy as np
from torchvision.transforms import v2 as T
import grain.python as grain
import jax
import jax.numpy as jnp
from flax import nnx
from transformers import FlaxViTForImageClassification
import optax
import matplotlib.pyplot as plt
from time import time
train_size = 5 * 750
val_size = 5 * 250
train_dataset = load_dataset("food101",
split=f"train[:{train_size}]")
val_dataset = load_dataset("food101",
split=f"validation[:{val_size}]")
labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
label = val_dataset[i]["label"]
if label not in labels_mapping:
labels_mapping[label] = index
index += 1
inv_labels_mapping = {v: k for k, v in labels_mapping.items()}
img_size = 224
def to_np_array(pil_image):
return np.asarray(pil_image.convert("RGB"))
def normalize(image):
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
image = image.astype(np.float32) / 255.0
return (image - mean) / std
tv_train_transforms = T.Compose([
T.RandomResizedCrop((img_size, img_size), scale=(0.7, 1.0)),
T.RandomHorizontalFlip(),
T.ColorJitter(0.2, 0.2, 0.2),
T.Lambda(to_np_array),
T.Lambda(normalize),
])
tv_test_transforms = T.Compose([
T.Resize((img_size, img_size)),
T.Lambda(to_np_array),
T.Lambda(normalize),
])
def get_transform(fn):
def wrapper(batch):
batch["image"] = [
fn(pil_image) for pil_image in batch["image"]
]
batch["label"] = [
labels_mapping[label] for label in batch["label"]
]
return batch
return wrapper
train_transforms = get_transform(tv_train_transforms)
val_transforms = get_transform(tv_test_transforms)
train_dataset = train_dataset.with_transform(train_transforms)
val_dataset = val_dataset.with_transform(val_transforms)
seed = 12
train_batch_size = 32
val_batch_size = 2 * train_batch_size
train_sampler = grain.IndexSampler(
len(train_dataset),
shuffle=True,
seed=seed,
shard_options=grain.NoSharding(),
num_epochs=1,
)
val_sampler = grain.IndexSampler(
len(val_dataset),
shuffle=False,
seed=seed,
shard_options=grain.NoSharding(),
num_epochs=1,
)
train_loader = grain.DataLoader(
data_source=train_dataset,
sampler=train_sampler,
worker_count=4,
worker_buffer_size=2,
operations=[
grain.Batch(train_batch_size, drop_remainder=True),
]
)
val_loader = grain.DataLoader(
data_source=val_dataset,
sampler=val_sampler,
worker_count=4,
worker_buffer_size=2,
operations=[
grain.Batch(val_batch_size),
]
)
class VisionTransformer(nnx.Module):
def __init__(
self,
num_classes: int = 1000,
in_channels: int = 3,
img_size: int = 224,
patch_size: int = 16,
num_layers: int = 12,
num_heads: int = 12,
mlp_dim: int = 3072,
hidden_size: int = 768,
dropout_rate: float = 0.1,
*,
rngs: nnx.Rngs = nnx.Rngs(0),
):
# Patch and position embedding
n_patches = (img_size // patch_size) ** 2
self.patch_embeddings = nnx.Conv(
in_channels,
hidden_size,
kernel_size=(patch_size, patch_size),
strides=(patch_size, patch_size),
padding="VALID",
use_bias=True,
rngs=rngs,
)
initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
self.position_embeddings = nnx.Param(
initializer(
rngs.params(),
(1, n_patches + 1, hidden_size),
jnp.float32
)
)
self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))
# Transformer Encoder blocks
self.encoder = nnx.Sequential(*[
TransformerEncoder(
hidden_size,
mlp_dim,
num_heads,
dropout_rate,
rngs=rngs
)
for i in range(num_layers)
])
self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)
# Classification head
self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
# Patch and position embedding
patches = self.patch_embeddings(x)
batch_size = patches.shape[0]
patches = patches.reshape(batch_size, -1, patches.shape[-1])
cls_token = jnp.tile(self.cls_token, [batch_size, 1, 1])
x = jnp.concat([cls_token, patches], axis=1)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
# Encoder blocks
x = self.encoder(embeddings)
x = self.final_norm(x)
# fetch the first token
x = x[:, 0]
# Classification
return self.classifier(x)
class TransformerEncoder(nnx.Module):
def __init__(
self,
hidden_size: int,
mlp_dim: int,
num_heads: int,
dropout_rate: float = 0.0,
*,
rngs: nnx.Rngs = nnx.Rngs(0),
) -> None:
self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
self.attn = nnx.MultiHeadAttention(
num_heads=num_heads,
in_features=hidden_size,
dropout_rate=dropout_rate,
broadcast_dropout=False,
decode=False,
deterministic=False,
rngs=rngs,
)
self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)
self.mlp = nnx.Sequential(
nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
nnx.gelu,
nnx.Dropout(dropout_rate, rngs=rngs),
nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
nnx.Dropout(dropout_rate, rngs=rngs),
)
def __call__(self, x: jax.Array) -> jax.Array:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
model = VisionTransformer(num_classes=1000)
tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
def vit_inplace_copy_weights(*, src_model, dst_model):
assert isinstance(src_model, FlaxViTForImageClassification)
assert isinstance(dst_model, VisionTransformer)
tf_model_params = src_model.params
tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)
flax_model_params = nnx.state(dst_model, nnx.Param)
flax_model_params_fstate = flax_model_params.flat_state()
params_name_mapping = {
("cls_token",): ("vit", "embeddings", "cls_token"),
("position_embeddings",): (
"vit",
"embeddings",
"position_embeddings"
),
**{
("patch_embeddings", x): (
"vit",
"embeddings",
"patch_embeddings",
"projection",
x
)
for x in ["kernel", "bias"]
},
**{
("encoder", "layers", i, "attn", y, x): (
"vit",
"encoder",
"layer",
str(i),
"attention",
"attention",
y,
x
)
for x in ["kernel", "bias"]
for y in ["key", "value", "query"]
for i in range(12)
},
**{
("encoder", "layers", i, "attn", "out", x): (
"vit",
"encoder",
"layer",
str(i),
"attention",
"output",
"dense",
x
)
for x in ["kernel", "bias"]
for i in range(12)
},
**{
("encoder", "layers", i, "mlp", "layers", y1, x): (
"vit",
"encoder",
"layer",
str(i),
y2,
"dense",
x
)
for x in ["kernel", "bias"]
for y1, y2 in [(0, "intermediate"), (3, "output")]
for i in range(12)
},
**{
("encoder", "layers", i, y1, x): (
"vit", "encoder", "layer", str(i), y2, x
)
for x in ["scale", "bias"]
for y1, y2 in [
("norm1", "layernorm_before"),
("norm2", "layernorm_after")
]
for i in range(12)
},
**{
("final_norm", x): ("vit", "layernorm", x)
for x in ["scale", "bias"]
},
**{
("classifier", x): ("classifier", x)
for x in ["kernel", "bias"]
}
}
nonvisited = set(flax_model_params_fstate.keys())
for key1, key2 in params_name_mapping.items():
assert key1 in flax_model_params_fstate, key1
assert key2 in tf_model_params_fstate, (key1, key2)
nonvisited.remove(key1)
src_value = tf_model_params_fstate[key2]
if key2[-1] == "kernel" and key2[-2] in ("key", "value", "query"):
shape = src_value.shape
src_value = src_value.reshape((shape[0], 12, 64))
if key2[-1] == "bias" and key2[-2] in ("key", "value", "query"):
src_value = src_value.reshape((12, 64))
if key2[-4:] == ("attention", "output", "dense", "kernel"):
shape = src_value.shape
src_value = src_value.reshape((12, 64, shape[-1]))
dst_value = flax_model_params_fstate[key1]
assert src_value.shape == dst_value.value.shape, (
key2, src_value.shape, key1, dst_value.value.shape
)
dst_value.value = src_value.copy()
assert dst_value.value.mean() == src_value.mean(), (
dst_value.value, src_value.mean()
)
assert len(nonvisited) == 0, nonvisited
nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))
vit_inplace_copy_weights(src_model=tf_model, dst_model=model)
model.classifier = nnx.Linear(model.classifier.in_features, 5, rngs=nnx.Rngs(0))
num_epochs = 3
learning_rate = 0.001
momentum = 0.8
total_steps = len(train_dataset) // train_batch_size
lr_schedule = optax.linear_schedule(learning_rate, 0.0, num_epochs * total_steps)
optimizer = nnx.Optimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))
def compute_losses_and_logits(model: nnx.Module, images: jax.Array, labels: jax.Array):
logits = model(images)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels
).mean()
return loss, logits
@nnx.jit
def train_step(
model: nnx.Module, optimizer: nnx.Optimizer, batch: dict[str, np.ndarray]
):
# Convert np.ndarray to jax.Array on GPU
images = jnp.array(batch["image"])
labels = jnp.array(batch["label"], dtype=jnp.int32)
grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)
(loss, logits), grads = grad_fn(model, images, labels)
optimizer.update(grads) # In-place updates.
return loss
@nnx.jit
def eval_step(
model: nnx.Module, batch: dict[str, np.ndarray], eval_metrics: nnx.MultiMetric
):
# Convert np.ndarray to jax.Array on GPU
images = jnp.array(batch["image"])
labels = jnp.array(batch["label"], dtype=jnp.int32)
loss, logits = compute_losses_and_logits(model, images, labels)
eval_metrics.update(
loss=loss,
logits=logits,
labels=labels,
)
eval_metrics = nnx.MultiMetric(
loss=nnx.metrics.Average('loss'),
accuracy=nnx.metrics.Accuracy(),
)
train_metrics_history = {
"train_loss": [],
}
eval_metrics_history = {
"val_loss": [],
"val_accuracy": [],
}
def train_one_epoch(epoch):
model.train()
def evaluate_model(epoch):
model.eval()
eval_metrics.reset()
for val_batch in val_loader:
eval_step(model, val_batch, eval_metrics)
for metric, value in eval_metrics.compute().items():
eval_metrics_history[f'val_{metric}'].append(value)
print(f"[val] epoch: {epoch + 1}/{num_epochs}")
print(f"- total loss: {eval_metrics_history['val_loss'][-1]:0.4f}")
print(f"- Accuracy: {eval_metrics_history['val_accuracy'][-1]:0.4f}")
start = time()
for epoch in range(num_epochs):
train_one_epoch(epoch)
evaluate_model(epoch)
end = time()
print(f"Training took {round((end - start) / 60, 1)} minutes")
plt.plot(train_metrics_history["train_loss"], label="Loss value during the training")
plt.legend()
plt.savefig('loss.png')
fig, axs = plt.subplots(1, 2, figsize=(10, 10))
axs[0].set_title("Loss value on validation set")
axs[0].plot(eval_metrics_history["val_loss"])
axs[1].set_title("Accuracy on validation set")
axs[1].plot(eval_metrics_history["val_accuracy"])
plt.savefig('validation.png')