Preprocessing data

Authors

Marie-Hélène Burle

Code adapted from JAX’s Implement ViT from scratch

This section covers an example of the second step of a classic workflow: preprocessing the data.

Context

There are many tools and options. In this example, we use TorchVision to transform and augment images and Grain to create data loaders.

load Load data proc Process data load->proc tv torchvision nn Define architecture proc->nn pretr Pre-trained model opt Hyperparameters nn->opt pretr->nn train Train opt->train cp Checkpoint train->cp pt torchdata pt->load tfds tfds tfds->load dt datasets dt->load gr grain gr->proc tv->proc tr transformers tr->pretr fl1 flax fl1->nn fl2 flax fl2->train oa optax oa->opt jx jax jx->fl2 ob orbax ob->cp

from datasets import load_dataset
import matplotlib.pyplot as plt

train_size = 5 * 750
val_size = 5 * 250

train_dataset = load_dataset("food101",
                             split=f"train[:{train_size}]")

val_dataset = load_dataset("food101",
                           split=f"validation[:{val_size}]")

labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
    label = val_dataset[i]["label"]
    if label not in labels_mapping:
        labels_mapping[label] = index
        index += 1

inv_labels_mapping = {v: k for k, v in labels_mapping.items()}

def display_datapoints(*datapoints, tag="", names_map=None):
    num_samples = len(datapoints)

    fig, axs = plt.subplots(1, num_samples, figsize=(20, 10))
    for i, datapoint in enumerate(datapoints):
        if isinstance(datapoint, dict):
            img, label = datapoint["image"], datapoint["label"]
        else:
            img, label = datapoint

        if hasattr(img, "dtype") and img.dtype in (np.float32, ):
            img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)

        label_str = f" ({names_map[label]})" if names_map is not None else ""
        axs[i].set_title(f"{tag} Label: {label}{label_str}")
        axs[i].imshow(img)

Load packages

Packages necessary for this section:

# general array manipulation
import numpy as np

# for image transformation and augmentation
from torchvision.transforms import v2 as T

# to create data loaders
import grain.python as grain

Data normalization and augmentation

Let’s preprocess our images to match the methods used in the vision transformer (ViT) introduced by Dosovitskiy A et al. [] and implemented in JAX. This will be useful when we fine tune this model with the Food dataset in another section.

The preprocessing involves normalization and random augmentation (to prevent overfitting) with TorchVision:

img_size = 224

def to_np_array(pil_image):
  return np.asarray(pil_image.convert("RGB"))

def normalize(image):
    mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    image = image.astype(np.float32) / 255.0
    return (image - mean) / std

tv_train_transforms = T.Compose([
    T.RandomResizedCrop((img_size, img_size), scale=(0.7, 1.0)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2, 0.2, 0.2),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

tv_test_transforms = T.Compose([
    T.Resize((img_size, img_size)),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

def get_transform(fn):
    def wrapper(batch):
        batch["image"] = [
            fn(pil_image) for pil_image in batch["image"]
        ]
        batch["label"] = [
            labels_mapping[label] for label in batch["label"]
        ]
        return batch
    return wrapper

train_transforms = get_transform(tv_train_transforms)
val_transforms = get_transform(tv_test_transforms)

train_dataset = train_dataset.with_transform(train_transforms)
val_dataset = val_dataset.with_transform(val_transforms)

Data loaders

We use Grain to create efficient data loaders:

seed = 12
train_batch_size = 32
val_batch_size = 2 * train_batch_size

train_sampler = grain.IndexSampler(
    len(train_dataset),
    shuffle=True,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

val_sampler = grain.IndexSampler(
    len(val_dataset),
    shuffle=False,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

train_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(train_batch_size, drop_remainder=True),
    ]
)

val_loader = grain.DataLoader(
    data_source=val_dataset,
    sampler=val_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(val_batch_size),
    ]
)

Inspect batches

train_batch = next(iter(train_loader))
val_batch = next(iter(val_loader))

print(
    "Training batch info:",
      train_batch["image"].shape,
      train_batch["image"].dtype,
      train_batch["label"].shape,
      train_batch["label"].dtype
)

print(
    "Validation batch info:",
      val_batch["image"].shape,
      val_batch["image"].dtype,
      val_batch["label"].shape,
      val_batch["label"].dtype
)
Training batch info: (32, 224, 224, 3) float32 (32,) int64
Validation batch info: (64, 224, 224, 3) float32 (64,) int64

Display the first three training and validation items:

display_datapoints(
    *[(train_batch["image"][i], train_batch["label"][i]) for i in range(3)],
    tag="(Training) ",
    names_map={
        k: train_dataset.features["label"].names[v]
               for k, v in inv_labels_mapping.items()
    }
)

display_datapoints(
    *[(val_batch["image"][i], val_batch["label"][i]) for i in range(3)],
    tag="(Validation) ",
    names_map={
        k: val_dataset.features["label"].names[v]
               for k, v in inv_labels_mapping.items()
    }
)
/tmp/ipykernel_24733/1098619187.py:34: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
  img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)

References

1.
Dosovitskiy A, Beyer L, Kolesnikov A, et al (2021) An image is worth 16x16 words: Transformers for image recognition at scale