The Dataset class

Author

Marie-Hélène Burle

base_dir = "<path-of-the-nabirds-dir>"

To be replaced by actual path.

#| echo: false

base_dir = "nabirds"
import os
import polars as pl

img_dir = os.path.join(base_dir, "images")

bb_file = os.path.join(base_dir, "bounding_boxes.txt")
classes_translation_file = os.path.join(base_dir, "classes_fixed.txt")
class_labels_file = os.path.join(base_dir, "image_class_labels.txt")
img_file = os.path.join(base_dir, "images.txt")
photographers_file = os.path.join(base_dir, "photographers_fixed.txt")
sizes_file = os.path.join(base_dir, "sizes.txt")
train_test_split_file = os.path.join(base_dir, "train_test_split.txt")

bb = pl.read_csv(
    bb_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "bb_x", "bb_y", "bb_width", "bb_height"]
)

classes = pl.read_csv(
    class_labels_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "class"]
)

classes_translation = pl.read_csv(
    classes_translation_file,
    separator=" ",
    has_header=False,
    new_columns=["class", "id"]
)

img_paths = pl.read_csv(
    img_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "path"]
)

photographers = pl.read_csv(
    photographers_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "photographer"]
)

sizes = pl.read_csv(
    sizes_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "img_width", "img_height"]
)

train_test_split = pl.read_csv(
    train_test_split_file,
    separator=" ",
    has_header=False,
    new_columns=["UUID", "is_training_img"]
)

classes_metadata = (
    classes.join(classes_translation, on="class")
)

metadata = (
    bb.join(classes_metadata, on="UUID")
    .join(img_paths, on="UUID")
    .join(photographers, on="UUID")
    .join(sizes, on="UUID")
    .join(train_test_split, on="UUID")
)

metadata_train = metadata.filter(pl.col("is_training_img") == 1)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[1], line 4
      1 import os
      2 import polars as pl
----> 4 img_dir = os.path.join(base_dir, "images")
      6 bb_file = os.path.join(base_dir, "bounding_boxes.txt")
      7 classes_translation_file = os.path.join(base_dir, "classes_fixed.txt")

NameError: name 'base_dir' is not defined

Create class for our dataset

To read in the images, there are many options, including:

Here, we are using imageio.imread from imageio which is an excellent option because it automatically creates a NumPy ndarrays, choosing a dtype based on the image, and it is faster than other options (scikit-image actually use it now instead of their own implementation).

The same thing can be achieved using PyTorch, TensorFlow Datasets, Hugging Face Datasets, or any method you are used to. Here, I am showing how this would be done with PyTorch.

PyTorch provides torch.utils.data.Dataset, an abstract class representing a dataset. You need to write a subclass of torch.utils.data.Dataset (let’s call it NABirdsDataset) so that it inherits from torch.utils.data.Dataset, but with characteristics matching our own dataset.

Load the packages:

from torch.utils.data import Dataset, DataLoader
import imageio.v3 as iio
import matplotlib.pyplot as plt
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 1
----> 1 from torch.utils.data import Dataset, DataLoader
      2 import imageio.v3 as iio
      3 import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'torch'

A PyTorch custom Dataset class must implement three methods:

  • __init__: initializes a new instance (object) of the class,
  • __len__: returns the number of samples in the new dataset class, and
  • __getitem__: loads and returns a sample from the dataset at a given index idx:
class NABirdsDataset(Dataset):
    """NABirds dataset class."""
    def __init__(self, metadata_file, data_dir, transform=None):
        self.metadata = metadata_file
        self.data_dir = data_dir
        self.transform = transform
    def __len__(self):
        return len(self.metadata)
    def __getitem__(self, idx):
        img_path = os.path.join(
            self.data_dir,
            self.metadata.get_column("path")[idx]
        )
        img = iio.imread(img_path)
        img_id = self.metadata.get_column("id")[idx].replace("_", " ")
        img_photographer = self.metadata.get_column("photographer")[idx].replace("_", " ")
        img_bb_x = self.metadata.get_column("bb_x")[idx]
        img_bb_y = self.metadata.get_column("bb_y")[idx]
        img_bb_width = self.metadata.get_column("bb_width")[idx]
        img_bb_height = self.metadata.get_column("bb_height")[idx]
        sample = {
            "image": img,
            "id": img_id,
            "photographer": img_photographer,
            "bb" : (img_bb_x, img_bb_y, img_bb_width, img_bb_height)
        }
        if self.transform:
            sample = self.transform(sample)
        return sample
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[3], line 1
----> 1 class NABirdsDataset(Dataset):
      2     """NABirds dataset class."""
      3     def __init__(self, metadata_file, data_dir, transform=None):

NameError: name 'Dataset' is not defined

Instantiate our dataset class

nabirds_train = NABirdsDataset(
    metadata_train,
    os.path.join(base_dir, img_dir)
    )
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 nabirds_train = NABirdsDataset(
      2     metadata_train,
      3     os.path.join(base_dir, img_dir)
      4     )

NameError: name 'NABirdsDataset' is not defined

Display a data sample

Let’s display the first 4 images and their bounding boxes:

fig = plt.figure()

for i, sample in enumerate(nabirds_train):
    print(i, sample['image'].shape)
    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title(
        f"Sample {i}, identification: {sample['id']}, picture by {sample['photographer']}"
    )
    ax.axis('off')
    ax.imshow(sample['image'])
    rect = patches.Rectangle(
        (sample['bb'][0], sample['bb'][1]),
        sample['bb'][2],
        sample['bb'][3],
        linewidth=2,
        edgecolor='r',
        facecolor='none'
    )
    ax.add_patch(rect)
    if i == 3:
        plt.show()
        break
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 fig = plt.figure()
      3 for i, sample in enumerate(nabirds_train):
      4     print(i, sample['image'].shape)

NameError: name 'plt' is not defined

Notice how the images are all of different sizes. This is a problem. We are also not making use of the bounding boxes this dataset comes with, hence using parts of images we know do not contain any bird unnecessarily.

We will address these problems in the next section.