Preprocessing data
This section covers an example of the second step of a classic workflow: preprocessing the data.
Context
There are many tools and options. In this example, we use TorchVision to transform and augment images and Grain to create data loaders.
Load packages
Packages necessary for this section:
# general array manipulation
import numpy as np
# for image transformation and augmentation
from torchvision.transforms import v2 as T
# to create data loaders
import grain.python as grain
Data normalization and augmentation
Let’s preprocess our images to match the methods used in the vision transformer (ViT) introduced by Dosovitskiy A et al. [1] and implemented in JAX. This will be useful when we fine tune this model with the Food dataset in another section.
The preprocessing involves normalization and random augmentation (to prevent overfitting) with TorchVision:
= 224
img_size
def to_np_array(pil_image):
return np.asarray(pil_image.convert("RGB"))
def normalize(image):
= np.array([0.5, 0.5, 0.5], dtype=np.float32)
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
std = image.astype(np.float32) / 255.0
image return (image - mean) / std
= T.Compose([
tv_train_transforms =(0.7, 1.0)),
T.RandomResizedCrop((img_size, img_size), scale
T.RandomHorizontalFlip(),0.2, 0.2, 0.2),
T.ColorJitter(
T.Lambda(to_np_array),
T.Lambda(normalize),
])
= T.Compose([
tv_test_transforms
T.Resize((img_size, img_size)),
T.Lambda(to_np_array),
T.Lambda(normalize),
])
def get_transform(fn):
def wrapper(batch):
"image"] = [
batch[for pil_image in batch["image"]
fn(pil_image)
]"label"] = [
batch[for label in batch["label"]
labels_mapping[label]
]return batch
return wrapper
= get_transform(tv_train_transforms)
train_transforms = get_transform(tv_test_transforms)
val_transforms
= train_dataset.with_transform(train_transforms)
train_dataset = val_dataset.with_transform(val_transforms) val_dataset
Data loaders
We use Grain to create efficient data loaders:
= 12
seed = 32
train_batch_size = 2 * train_batch_size
val_batch_size
= grain.IndexSampler(
train_sampler len(train_dataset),
=True,
shuffle=seed,
seed=grain.NoSharding(),
shard_options=1,
num_epochs
)
= grain.IndexSampler(
val_sampler len(val_dataset),
=False,
shuffle=seed,
seed=grain.NoSharding(),
shard_options=1,
num_epochs
)
= grain.DataLoader(
train_loader =train_dataset,
data_source=train_sampler,
sampler=4,
worker_count=2,
worker_buffer_size=[
operations=True),
grain.Batch(train_batch_size, drop_remainder
]
)
= grain.DataLoader(
val_loader =val_dataset,
data_source=val_sampler,
sampler=4,
worker_count=2,
worker_buffer_size=[
operations
grain.Batch(val_batch_size),
] )
Inspect batches
= next(iter(train_loader))
train_batch = next(iter(val_loader))
val_batch
print(
"Training batch info:",
"image"].shape,
train_batch["image"].dtype,
train_batch["label"].shape,
train_batch["label"].dtype
train_batch[
)
print(
"Validation batch info:",
"image"].shape,
val_batch["image"].dtype,
val_batch["label"].shape,
val_batch["label"].dtype
val_batch[ )
Training batch info: (32, 224, 224, 3) float32 (32,) int64
Validation batch info: (64, 224, 224, 3) float32 (64,) int64
Display the first three training and validation items:
display_datapoints(*[(train_batch["image"][i], train_batch["label"][i]) for i in range(3)],
="(Training) ",
tag={
names_map"label"].names[v]
k: train_dataset.features[for k, v in inv_labels_mapping.items()
}
)
display_datapoints(*[(val_batch["image"][i], val_batch["label"][i]) for i in range(3)],
="(Validation) ",
tag={
names_map"label"].names[v]
k: val_dataset.features[for k, v in inv_labels_mapping.items()
} )
/tmp/ipykernel_24733/1098619187.py:34: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)
References
1.
Dosovitskiy A, Beyer L, Kolesnikov A, et al (2021) An image is worth 16x16 words: Transformers for image recognition at scale