Loading data
Transform the data
We use PyTorch v2:
import numpy as np
from torchvision.transforms import v2 as T
= 224
img_size
def to_np_array(pil_image):
return np.asarray(pil_image.convert("RGB"))
def normalize(image):
# Image preprocessing matches the one of pretrained ViT
= np.array([0.5, 0.5, 0.5], dtype=np.float32)
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
std = image.astype(np.float32) / 255.0
image return (image - mean) / std
= T.Compose([
tv_train_transforms =(0.7, 1.0)),
T.RandomResizedCrop((img_size, img_size), scale
T.RandomHorizontalFlip(),0.2, 0.2, 0.2),
T.ColorJitter(
T.Lambda(to_np_array),
T.Lambda(normalize),
])
= T.Compose([
tv_test_transforms
T.Resize((img_size, img_size)),
T.Lambda(to_np_array),
T.Lambda(normalize),
])
def get_transform(fn):
def wrapper(batch):
"image"] = [
batch[for pil_image in batch["image"]
fn(pil_image)
]# map label index between 0 - 19
"label"] = [
batch[for label in batch["label"]
labels_mapping[label]
]return batch
return wrapper
= get_transform(tv_train_transforms)
train_transforms = get_transform(tv_test_transforms)
val_transforms
= train_dataset.with_transform(train_transforms)
train_dataset = val_dataset.with_transform(val_transforms) val_dataset
Visualize a few samples
import matplotlib.pyplot as plt
def display_datapoints(*datapoints, tag="", names_map=None):
= len(datapoints)
num_samples
= plt.subplots(1, num_samples, figsize=(20, 10))
fig, axs for i, datapoint in enumerate(datapoints):
if isinstance(datapoint, dict):
= datapoint["image"], datapoint["label"]
img, label else:
= datapoint
img, label
if hasattr(img, "dtype") and img.dtype in (np.float32, ):
= ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)
img
= f" ({names_map[label]})" if names_map is not None else ""
label_str f"{tag}Label: {label}{label_str}")
axs[i].set_title(
axs[i].imshow(img)
display_datapoints(0],
train_dataset[1000],
train_dataset[2000],
train_dataset[3000],
train_dataset[="(Training) ",
tag=train_dataset.features["label"].names
names_map
)
display_datapoints(0],
val_dataset[1000],
val_dataset[2000],
val_dataset[-1],
val_dataset[="(Validation) ",
tag=val_dataset.features["label"].names
names_map )