Data augmentation

Author

Marie-Hélène Burle

base_dir = "<path-of-the-nabirds-dir>"

To be replaced by proper path.

import os
import polars as pl
import imageio.v3 as iio
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.transform import resize
import numpy as np
import grain.python as grain
import jax.numpy as jnp

img_dir = os.path.join(base_dir, "images")

bb_file = os.path.join(base_dir, "bounding_boxes.txt")
classes_translation_file = os.path.join(base_dir, "classes_fixed.txt")
class_labels_file = os.path.join(base_dir, "image_class_labels.txt")
img_file = os.path.join(base_dir, "images.txt")
photographers_file = os.path.join(base_dir, "photographers_fixed.txt")
sizes_file = os.path.join(base_dir, "sizes.txt")
train_test_split_file = os.path.join(base_dir, "train_test_split.txt")

bb = pl.read_csv(
    bb_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "bb_x", "bb_y", "bb_width", "bb_height"]
)

classes = pl.read_csv(
    class_labels_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "class"]
)

classes_translation = pl.read_csv(
    classes_translation_file,
    separator=" ",
    has_header=False,
    new_columns=["class", "id"]
)

img_paths = pl.read_csv(
    img_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "path"]
)

photographers = pl.read_csv(
    photographers_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "photographer"]
)

sizes = pl.read_csv(
    sizes_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "img_width", "img_height"]
)

train_test_split = pl.read_csv(
    train_test_split_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "is_training_img"]
)

classes_metadata = (
    classes.join(classes_translation, on="class")
)

metadata = (
    bb.join(classes_metadata, on="UUID")
    .join(img_paths, on="UUID")
    .join(photographers, on="UUID")
    .join(sizes, on="UUID")
    .join(train_test_split, on="UUID")
)

metadata_train = metadata.filter(pl.col("is_training_img") == 1)

class NABirdsDataset:
    """NABirds dataset class."""
    def __init__(self, metadata_file, data_dir):
        self.metadata = metadata_file
        self.data_dir = data_dir
    def __len__(self):
        return len(self.metadata)
    def __getitem__(self, idx):
        img_path = os.path.join(
            self.data_dir,
            self.metadata.get_column('path')[idx]
        )
        img = iio.imread(img_path)
        img_id = self.metadata.get_column('id')[idx].replace('_', ' ')
        img_photographer = self.metadata.get_column('photographer')[idx].replace('_', ' ')
        img_bb_x = self.metadata.get_column('bb_x')[idx]
        img_bb_y = self.metadata.get_column('bb_y')[idx]
        img_bb_width = self.metadata.get_column('bb_width')[idx]
        img_bb_height = self.metadata.get_column('bb_height')[idx]
        sample = {
            'image': img,
            'id': img_id,
            'photographer': img_photographer,
            'bbx' : img_bb_x,
            'bby' : img_bb_y,
            'bbwidth' : img_bb_width,
            'bbheight' : img_bb_height
        }
        return sample

nabirds_train = NABirdsDataset(
    metadata_train,
    img_dir
)

class NormAndCast(grain.MapTransform):
    """Transform class to normalize and cast images to float32."""
    def map(self, element):
        element['image'] = jnp.array(element['image'], dtype=jnp.float32) / 255.0
        return element

class BbCrop(grain.MapTransform):
    """Transform class to crop images to their bounding boxes."""
    def map(self, element):
        img = element['image']
        bbx = element['bbx']
        bby = element['bby']
        bbwidth = element['bbwidth']
        bbheight = element['bbheight']
        img_cropped = img[bby:bby+bbheight, bbx:bbx+bbwidth]
        element['image'] = img_cropped
        return element

target = (224, 224)

class PaddingResize(grain.MapTransform):
    """Transform class to resize images to a given size with padding to avoid distortion."""
    def map(self, element):
        img = element['image']
        h, w, _ = img.shape
        target_h, target_w = target

        # Calculate the scaling factor to fit the image inside the box
        scale = min(target_h / h, target_w / w)

        # Calculate the new dimensions of the image
        new_h = int(h * scale)
        new_w = int(w * scale)

        # Resize the image to these new dimensions
        img_resized = resize(img, (new_h, new_w), anti_aliasing=True)

        # Create a black canvas (zeros) of the target size
        out_img = np.zeros((target_h, target_w, img.shape[2]), dtype=img_resized.dtype)

        # Place the resized image in the center of the canvas
        y_offset = (target_h - new_h) // 2
        x_offset = (target_w - new_w) // 2
        out_img[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = img_resized

        element['image'] = out_img
        return element

transformations = [NormAndCast(), BbCrop(), PaddingResize()]

What is data augmentation?

The AlbumentationsX site has a good explanation of the concept of data augmentation.

cite (from bib tex file): - https://arxiv.org/abs/2205.01491

Tools

PIX

cite (from bib tex file): - paper on libraries

TorchVision transforms

skimage.transform from scikit-image (that we used previously to create a Transform that resizes our images with padding).

Augmentation techniques

PIX list of augmentations

import dm_pix as pix
from jax import random

# class Augment(grain.MapTransform):
#     """Transform class to normalize and cast images to float32."""
#     def map(self, element):
#         img = element['image']
#         key = random.PRNGKey(0)
#         delta = 0.7
#         img_brightness = pix.random_brightness(
#             key=key,
#             image=img,
#             max_delta=delta
#         )
#         key = random.PRNGKey(1)
#         img_flip = pix.random_flip_left_right(
#             key=key,
#             image=img_brightness
#         )
#         element['image'] = img_flip
#         return element

class RandomFlip(grain.MapTransform):
    """Transform class to normalize and cast images to float32."""
    def map(self, element):
        img = element['image']
        key = random.PRNGKey(0)
        img_flip = pix.random_flip_left_right(
            key=key,
            image=img
        )
        element['image'] = img_flip
        return element
transformations = [NormAndCast(), BbCrop(), PaddingResize(), RandomFlip()]
nabirds_train_seqsampler = grain.SequentialSampler(
    num_records=4,
    shard_options=grain.NoSharding()
)
nabirds_train_dl = grain.DataLoader(
    data_source=nabirds_train,
    operations=transformations,
    sampler=nabirds_train_seqsampler,
    worker_count=0
)
fig = plt.figure(figsize=(8, 8))

for i, element in enumerate(nabirds_train_dl):
    ax = plt.subplot(2, 2, i + 1)
    plt.tight_layout()
    ax.set_title(
        f'Element {i}\nIdentification: {element['id']}\nPicture by {element['photographer']}',
        fontsize=9
    )
    ax.axis('off')
    plt.imshow(element['image'])

plt.show()

nabirds_train_isampler = grain.IndexSampler(
  num_records=200,
  num_epochs=1,
  # shard_options=grain.sharding.ShardOptions(shard_index=0, shard_count=1, drop_remainder=True),
  shard_options=grain.ShardOptions(shard_index=0, shard_count=1, drop_remainder=True),
  shuffle=True,
  seed=0)

nabirds_train_dl = grain.DataLoader(
    data_source=nabirds_train,
    operations=transformations,
    sampler=nabirds_train_isampler,
    worker_count=0
)

fig = plt.figure(figsize=(8, 8))

for i, element in enumerate(nabirds_train_dl):
    ax = plt.subplot(2, 2, i + 1)
    plt.tight_layout()
    ax.set_title(
        f'Element {i}\nIdentification: {element['id']}\nPicture by {element['photographer']}',
        fontsize=9
    )
    ax.axis('off')
    plt.imshow(element['image'])
    if i == 3:
        plt.show()
        break