import os
import polars as pl
import imageio.v3 as iio
metadata = pl.read_parquet("metadata.parquet")
metadata_train = metadata.filter(pl.col("is_training_img") == 1)
cleaned_img_dir = os.path.join(base_dir, "cleaned_images")
class NABirdsDataset:
"""NABirds dataset class."""
def __init__(self, metadata_file, data_dir):
self.metadata_file = metadata_file
self.data_dir = data_dir
def __len__(self):
return len(self.metadata_file)
def __getitem__(self, idx):
path = os.path.join(self.data_dir, self.metadata_file.get_column('path')[idx])
img = iio.imread(path)
species = self.metadata_file.get_column('species')[idx].replace('_', ' ')
subcategory = self.metadata_file.get_column('subcategory')[idx]
if subcategory is not None:
subcategory = subcategory.replace('_', ' ')
photographer = self.metadata_file.get_column('photographer')[idx].replace('_', ' ')
element = {
'img': img,
'species': species,
'subcategory': subcategory,
'photographer': photographer,
}
return element
nabirds_train = NABirdsDataset(metadata_train, cleaned_img_dir)DataLoaders
A critical part of deep learning is the loading of data to the model during the training loops.
DataLoaders handle the choice of which sample to load and in what order; they optimize the process in parallel by managing workers; they set several hyperparameters such as batch size and number of epochs.
In this section we explore DataLoaders with the Grain library [1].
Training set
Let’s use our Dataset class (simplified slightly because we don’t need the bounding boxes anymore) and create one instance just with the training set:
base_dir = "<path-of-the-nabirds-dir>"To be replaced by proper path.
Goal of DataLoaders
We can access elements of our Dataset class (as we did in the previous section) with:
for i, element in enumerate(nabirds_train):
print(element['img'].shape)
if i == 3:
break(312, 688, 3)
(739, 1024, 3)
(722, 808, 3)
(753, 896, 3)
And we can return the next object by creating an iterator from of iterable dataset:
next(iter(nabirds_train)){'img': array([[[123, 144, 171],
[124, 145, 172],
[125, 146, 173],
...,
[130, 154, 180],
[128, 152, 178],
[127, 151, 177]],
[[124, 145, 172],
[125, 146, 173],
[126, 147, 174],
...,
[127, 151, 177],
[128, 152, 178],
[128, 152, 178]],
[[126, 147, 174],
[127, 148, 175],
[127, 148, 175],
...,
[124, 148, 174],
[126, 150, 176],
[127, 151, 177]],
...,
[[100, 105, 124],
[100, 106, 122],
[100, 105, 124],
...,
[114, 131, 159],
[114, 131, 159],
[114, 131, 159]],
[[101, 106, 126],
[101, 106, 125],
[101, 106, 126],
...,
[115, 132, 160],
[115, 132, 160],
[115, 132, 160]],
[[102, 107, 127],
[102, 107, 127],
[102, 107, 127],
...,
[116, 133, 161],
[116, 133, 161],
[116, 133, 161]]], shape=(312, 688, 3), dtype=uint8),
'species': 'Eared Grebe',
'subcategory': 'Nonbreeding/juvenile',
'photographer': 'Laura Erickson'}
But all this is extremely limited.
DataLoaders feed data to the model during training. They handle batching, shuffling, sharding across machines, the number of epochs, etc.
Grain DataLoaders
The JAX AI stack includes the Grain library [1] to create DataLoaders, but it can also be done using PyTorch, TensorFlow Datasets, Hugging Face Datasets, or any method you are used to. That’s the advantage of the modular philosophy that the stack relies on. Grain is extremely efficient and does not rely on a huge set of dependencies as PyTorch and TensorFlow do.
In Grain, a DataLoader requires 3 components:
- A data source
- Transforms
- A sampler
Data source
We already have that: it is our instance of Dataset class that we called nabirds_train.
Transformations
We need to split the data into batches. Batches can be defined with the grain.Batch method as a DataLoader transformation.
Let’s use batches of 32 with grain.Batch(batch_size=32).
Samplers
Sequential sampler
Grain comes with a basic sequential sampler.
import grain.python as grain
nabirds_train_seqsampler = grain.SequentialSampler(
num_records=4
)
for record_metadata in nabirds_train_seqsampler:
print(record_metadata)RecordMetadata(index=0, record_key=0, rng=None)
RecordMetadata(index=1, record_key=1, rng=None)
RecordMetadata(index=2, record_key=2, rng=None)
RecordMetadata(index=3, record_key=3, rng=None)
Index sampler
Grain index sampler is the one you should use as it allows for global shuffling of the dataset, setting the number of epochs, etc.
nabirds_train_isampler = grain.IndexSampler(
num_records=200,
shuffle=True,
seed=0
)
for i, record_metadata in enumerate(nabirds_train_isampler):
print(record_metadata)
if i == 3:
breakRecordMetadata(index=0, record_key=134, rng=Generator(Philox))
RecordMetadata(index=1, record_key=133, rng=Generator(Philox))
RecordMetadata(index=2, record_key=10, rng=Generator(Philox))
RecordMetadata(index=3, record_key=136, rng=Generator(Philox))
Building DataLoaders
We now have our source, samplers, and transformation (batching), so we can experiment with Grain DataLoaders.
With sequential sampler
nabirds_train_dl = grain.DataLoader(
data_source=nabirds_train,
sampler=nabirds_train_seqsampler,
worker_count=0
)import matplotlib.pyplot as plt
fig = plt.figure(figsize=(8, 9))
for i, element in enumerate(nabirds_train_dl):
ax = plt.subplot(2, 2, i + 1)
plt.tight_layout()
ax.set_title(
f"""
Species: {element['species']}
Additional information: {element['subcategory']}
Picture by {element['photographer']}
""",
fontsize=9,
linespacing=1.5
)
ax.axis('off')
plt.imshow(element['img'])
plt.show()
Notice that, unlike last time we displayed some images, we aren’t looping through our Dataset (nabirds_train) anymore, but through our DataLoader (nabirds_train_dl).
Because we set the number of records to 4 in the sampler, we don’t have to break the loop.
With index sampler
nabirds_train_dl = grain.DataLoader(
data_source=nabirds_train,
sampler=nabirds_train_isampler,
worker_count=0
)fig = plt.figure(figsize=(8, 9))
for i, element in enumerate(nabirds_train_dl):
ax = plt.subplot(2, 2, i + 1)
plt.tight_layout()
ax.set_title(
f"""
Species: {element['species']}
Additional information: {element['subcategory']}
Picture by {element['photographer']}
""",
fontsize=9,
linespacing=1.5
)
ax.axis('off')
plt.imshow(element['img'])
if i == 3:
plt.show()
break
Adding batch sizes
nabirds_train_dl = grain.DataLoader(
data_source=nabirds_train,
sampler=nabirds_train_isampler,
worker_count=0,
operations=[
grain.Batch(batch_size=32)
]
)