Fine-tuning the model

Authors

Marie-Hélène Burle

Code adapted from JAX’s Implement ViT from scratch

In this section, we fine-tune our model with our sample (5 classes) of the Food-101 dataset [].

Context

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn pretr Pre-trained model opt Hyperparameters nn->opt pretr->nn train Train opt->train cp Checkpoint train->cp pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc tr transformers tr->pretr fl1 flax fl1->nn fl2 flax fl2->train oa optax oa->opt jx JAX jx->fl2 ob orbax ob->cp

from datasets import load_dataset
import numpy as np
from torchvision.transforms import v2 as T
import grain.python as grain
import jax
import jax.numpy as jnp
from flax import nnx
from transformers import FlaxViTForImageClassification
import optax

train_size = 5 * 750
val_size = 5 * 250

train_dataset = load_dataset("food101",
                             split=f"train[:{train_size}]")

val_dataset = load_dataset("food101",
                           split=f"validation[:{val_size}]")

labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
    label = val_dataset[i]["label"]
    if label not in labels_mapping:
        labels_mapping[label] = index
        index += 1

inv_labels_mapping = {v: k for k, v in labels_mapping.items()}

img_size = 224

def to_np_array(pil_image):
  return np.asarray(pil_image.convert("RGB"))

def normalize(image):
    mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    image = image.astype(np.float32) / 255.0
    return (image - mean) / std

tv_train_transforms = T.Compose([
    T.RandomResizedCrop((img_size, img_size), scale=(0.7, 1.0)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2, 0.2, 0.2),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

tv_test_transforms = T.Compose([
    T.Resize((img_size, img_size)),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

def get_transform(fn):
    def wrapper(batch):
        batch["image"] = [
            fn(pil_image) for pil_image in batch["image"]
        ]
        batch["label"] = [
            labels_mapping[label] for label in batch["label"]
        ]
        return batch
    return wrapper

train_transforms = get_transform(tv_train_transforms)
val_transforms = get_transform(tv_test_transforms)

train_dataset = train_dataset.with_transform(train_transforms)
val_dataset = val_dataset.with_transform(val_transforms)

seed = 12
train_batch_size = 32
val_batch_size = 2 * train_batch_size

train_sampler = grain.IndexSampler(
    len(train_dataset),
    shuffle=True,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

val_sampler = grain.IndexSampler(
    len(val_dataset),
    shuffle=False,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

train_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(train_batch_size, drop_remainder=True),
    ]
)

val_loader = grain.DataLoader(
    data_source=val_dataset,
    sampler=val_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(val_batch_size),
    ]
)

class VisionTransformer(nnx.Module):
    def __init__(
        self,
        num_classes: int = 1000,
        in_channels: int = 3,
        img_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_dim: int = 3072,
        hidden_size: int = 768,
        dropout_rate: float = 0.1,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        # Patch and position embedding
        n_patches = (img_size // patch_size) ** 2
        self.patch_embeddings = nnx.Conv(
            in_channels,
            hidden_size,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding="VALID",
            use_bias=True,
            rngs=rngs,
        )

        initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
        self.position_embeddings = nnx.Param(
            initializer(
                rngs.params(),
                (1, n_patches + 1, hidden_size),
                jnp.float32
            )
        )
        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)

        self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))

        # Transformer Encoder blocks
        self.encoder = nnx.Sequential(*[
            TransformerEncoder(
                hidden_size,
                mlp_dim,
                num_heads,
                dropout_rate,
                rngs=rngs
            )
            for i in range(num_layers)
        ])
        self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)

        # Classification head
        self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        # Patch and position embedding
        patches = self.patch_embeddings(x)
        batch_size = patches.shape[0]
        patches = patches.reshape(batch_size, -1, patches.shape[-1])

        cls_token = jnp.tile(self.cls_token, [batch_size, 1, 1])
        x = jnp.concat([cls_token, patches], axis=1)
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)

        # Encoder blocks
        x = self.encoder(embeddings)
        x = self.final_norm(x)

        # fetch the first token
        x = x[:, 0]

        # Classification
        return self.classifier(x)

class TransformerEncoder(nnx.Module):
    def __init__(
        self,
        hidden_size: int,
        mlp_dim: int,
        num_heads: int,
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:

        self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            rngs=rngs,
        )
        self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)

        self.mlp = nnx.Sequential(
            nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
            nnx.gelu,
            nnx.Dropout(dropout_rate, rngs=rngs),
            nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
            nnx.Dropout(dropout_rate, rngs=rngs),
        )

    def __call__(self, x: jax.Array) -> jax.Array:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

model = VisionTransformer(num_classes=1000)

tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

def vit_inplace_copy_weights(*, src_model, dst_model):
    assert isinstance(src_model, FlaxViTForImageClassification)
    assert isinstance(dst_model, VisionTransformer)

    tf_model_params = src_model.params
    tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)

    flax_model_params = nnx.state(dst_model, nnx.Param)
    flax_model_params_fstate = flax_model_params.flat_state()

    params_name_mapping = {
        ("cls_token",): ("vit", "embeddings", "cls_token"),
        ("position_embeddings",): (
            "vit",
            "embeddings",
            "position_embeddings"
        ),
        **{
            ("patch_embeddings", x): (
                "vit",
                "embeddings",
                "patch_embeddings",
                "projection",
                x
            )
            for x in ["kernel", "bias"]
        },
        **{
            ("encoder", "layers", i, "attn", y, x): (
                "vit",
                "encoder",
                "layer",
                str(i),
                "attention",
                "attention",
                y,
                x
            )
            for x in ["kernel", "bias"]
            for y in ["key", "value", "query"]
            for i in range(12)
        },
        **{
            ("encoder", "layers", i, "attn", "out", x): (
                "vit",
                "encoder",
                "layer",
                str(i),
                "attention",
                "output",
                "dense",
                x
            )
            for x in ["kernel", "bias"]
            for i in range(12)
        },
        **{
            ("encoder", "layers", i, "mlp", "layers", y1, x): (
                "vit",
                "encoder",
                "layer",
                str(i),
                y2,
                "dense",
                x
            )
            for x in ["kernel", "bias"]
            for y1, y2 in [(0, "intermediate"), (3, "output")]
            for i in range(12)
        },
        **{
            ("encoder", "layers", i, y1, x): (
                "vit", "encoder", "layer", str(i), y2, x
            )
            for x in ["scale", "bias"]
            for y1, y2 in [
                    ("norm1", "layernorm_before"),
                    ("norm2", "layernorm_after")
            ]
            for i in range(12)
        },
        **{
            ("final_norm", x): ("vit", "layernorm", x)
            for x in ["scale", "bias"]
        },
        **{
            ("classifier", x): ("classifier", x)
            for x in ["kernel", "bias"]
        }
    }

    nonvisited = set(flax_model_params_fstate.keys())

    for key1, key2 in params_name_mapping.items():
        assert key1 in flax_model_params_fstate, key1
        assert key2 in tf_model_params_fstate, (key1, key2)

        nonvisited.remove(key1)

        src_value = tf_model_params_fstate[key2]
        if key2[-1] == "kernel" and key2[-2] in ("key", "value", "query"):
            shape = src_value.shape
            src_value = src_value.reshape((shape[0], 12, 64))

        if key2[-1] == "bias" and key2[-2] in ("key", "value", "query"):
            src_value = src_value.reshape((12, 64))

        if key2[-4:] == ("attention", "output", "dense", "kernel"):
            shape = src_value.shape
            src_value = src_value.reshape((12, 64, shape[-1]))

        dst_value = flax_model_params_fstate[key1]
        assert src_value.shape == dst_value.value.shape, (
            key2, src_value.shape, key1, dst_value.value.shape
        )
        dst_value.value = src_value.copy()
        assert dst_value.value.mean() == src_value.mean(), (
            dst_value.value, src_value.mean()
        )

    assert len(nonvisited) == 0, nonvisited

    nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))

vit_inplace_copy_weights(src_model=tf_model, dst_model=model)

model.classifier = nnx.Linear(model.classifier.in_features, 5, rngs=nnx.Rngs(0))

num_epochs = 3
learning_rate = 0.001
momentum = 0.8
total_steps = len(train_dataset) // train_batch_size

lr_schedule = optax.linear_schedule(learning_rate, 0.0, num_epochs * total_steps)

optimizer = nnx.Optimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))

def compute_losses_and_logits(model: nnx.Module, images: jax.Array, labels: jax.Array):
    logits = model(images)

    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=labels
    ).mean()
    return loss, logits

@nnx.jit
def train_step(
    model: nnx.Module, optimizer: nnx.Optimizer, batch: dict[str, np.ndarray]
):
    # Convert np.ndarray to jax.Array on GPU
    images = jnp.array(batch["image"])
    labels = jnp.array(batch["label"], dtype=jnp.int32)

    grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)

    optimizer.update(grads)  # In-place updates.

    return loss

@nnx.jit
def eval_step(
    model: nnx.Module, batch: dict[str, np.ndarray], eval_metrics: nnx.MultiMetric
):
    # Convert np.ndarray to jax.Array on GPU
    images = jnp.array(batch["image"])
    labels = jnp.array(batch["label"], dtype=jnp.int32)
    loss, logits = compute_losses_and_logits(model, images, labels)

    eval_metrics.update(
        loss=loss,
        logits=logits,
        labels=labels,
    )

eval_metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average('loss'),
    accuracy=nnx.metrics.Accuracy(),
)

train_metrics_history = {
    "train_loss": [],
}

eval_metrics_history = {
    "val_loss": [],
    "val_accuracy": [],
}
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 352
    348     assert len(nonvisited) == 0, nonvisited
    350     nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))
--> 352 vit_inplace_copy_weights(src_model=tf_model, dst_model=model)
    354 model.classifier = nnx.Linear(model.classifier.in_features, 5, rngs=nnx.Rngs(0))
    356 num_epochs = 3

Cell In[1], line 319, in vit_inplace_copy_weights(src_model, dst_model)
    236 flax_model_params_fstate = flax_model_params.flat_state()
    238 params_name_mapping = {
    239     ("cls_token",): ("vit", "embeddings", "cls_token"),
    240     ("position_embeddings",): (
   (...)    316     }
    317 }
--> 319 nonvisited = set(flax_model_params_fstate.keys())
    321 for key1, key2 in params_name_mapping.items():
    322     assert key1 in flax_model_params_fstate, key1

AttributeError: 'FlatState' object has no attribute 'keys'

Load packages

# to have a progress bar during training
import tqdm

# to visualize evolution of loss and sample data
import matplotlib.pyplot as plt

Training and evaluation functions

bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"

def train_one_epoch(epoch):
    model.train()
    with tqdm.tqdm(
        desc=f"[train] epoch: {epoch}/{num_epochs}, ",
        total=total_steps,
        bar_format=bar_format,
        leave=True,
    ) as pbar:
        for batch in train_loader:
            loss = train_step(model, optimizer, batch)
            train_metrics_history["train_loss"].append(loss.item())
            pbar.set_postfix({"loss": loss.item()})
            pbar.update(1)

def evaluate_model(epoch):
    model.eval()

    eval_metrics.reset()
    for val_batch in val_loader:
        eval_step(model, val_batch, eval_metrics)

    for metric, value in eval_metrics.compute().items():
        eval_metrics_history[f'val_{metric}'].append(value)

    print(f"[val] epoch: {epoch + 1}/{num_epochs}")
    print(f"- total loss: {eval_metrics_history['val_loss'][-1]:0.4f}")
    print(f"- Accuracy: {eval_metrics_history['val_accuracy'][-1]:0.4f}")

Train the model

%%time

for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate_model(epoch)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
File <timed exec>:1

NameError: name 'num_epochs' is not defined

OOM issues

As you can see, I ran out of memory when running this code on my machine.

Out of memory (OOM) problems are common when trying to train a model with JAX on GPUs. See for instance this question on Stack Overflow and this issue in the JAX repo.

According to the JAX documentation on GPU memory allocation, you can try the following:

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'

or, if you use IPython (or Jupyter which runs IPython), you can use the equivalent syntax using the IPython built-in magic command to set environment variables %env:

%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.5

None of these solutions worked for me neither on my machine nor on Cedar and I am starting to suspect that there is a problem with this particular version of jaxlib.

Without GPUs (so on our training cluster), training will be much longer, but you won’t run into this problem.

Metrics graphs

If we hadn’t run out of memory, we could graph our metrics.

Evolution of the loss during training:

plt.plot(train_metrics_history["train_loss"], label="Loss value during the training")
plt.legend()

Loss and accuracy on the validation set:

fig, axs = plt.subplots(1, 2, figsize=(10, 10))
axs[0].set_title("Loss value on validation set")
axs[0].plot(eval_metrics_history["val_loss"])
axs[1].set_title("Accuracy on validation set")
axs[1].plot(eval_metrics_history["val_accuracy"])

Check sample data

And we could look at the model predictions for 5 items:

test_indices = [1, 250, 500, 750, 1000]

test_images = jnp.array([val_dataset[i]["image"] for i in test_indices])
expected_labels = [val_dataset[i]["label"] for i in test_indices]

model.eval()
preds = model(test_images)
num_samples = len(test_indices)
names_map = train_dataset.features["label"].names

probas = nnx.softmax(preds, axis=1)
pred_labels = probas.argmax(axis=1)


fig, axs = plt.subplots(1, num_samples, figsize=(20, 10))
for i in range(num_samples):
    img, expected_label = test_images[i], expected_labels[i]

    pred_label = pred_labels[i].item()
    proba = probas[i, pred_label].item()
    if img.dtype in (np.float32, ):
        img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)

    expected_label_str = names_map[inv_labels_mapping[expected_label]]
    pred_label_str = names_map[inv_labels_mapping[pred_label]]
    axs[i].set_title(f"Expected: {expected_label_str} vs \nPredicted: {pred_label_str}, P={proba:.2f}")
    axs[i].imshow(img)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[6], line 18
     15     img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)
     17 expected_label_str = names_map[inv_labels_mapping[expected_label]]
---> 18 pred_label_str = names_map[inv_labels_mapping[pred_label]]
     19 axs[i].set_title(f"Expected: {expected_label_str} vs \nPredicted: {pred_label_str}, P={proba:.2f}")
     20 axs[i].imshow(img)

KeyError: 693

References

1.
Bossard L, Guillaumin M, Van Gool L (2014) Food-101 – mining discriminative components with random forests. In: European conference on computer vision