import os
import polars as pl
import imageio.v3 as iio
import grain.python as grain
metadata = pl.read_parquet('metadata.parquet')
metadata_train = metadata.filter(pl.col('is_training_img') == 1)
metadata_val = metadata.filter(pl.col('is_training_img') == 0)
cleaned_img_dir = os.path.join(base_dir, 'cleaned_images')
class NABirdsDataset:
"""NABirds dataset class."""
def __init__(self, metadata_file, data_dir):
self.metadata_file = metadata_file
self.data_dir = data_dir
def __len__(self):
return len(self.metadata_file)
def __getitem__(self, idx):
path = os.path.join(self.data_dir, self.metadata_file.get_column('path')[idx])
img = iio.imread(path)
species_name = self.metadata_file.get_column('species_name')[idx]
species_id = self.metadata_file.get_column('species_id')[idx]
photographer = self.metadata_file.get_column('photographer')[idx]
return {
'img': img,
'species_name': species_name,
'species_id': species_id,
'photographer': photographer,
}
nabirds_train = NABirdsDataset(metadata_train, cleaned_img_dir)
nabirds_val = NABirdsDataset(metadata_val, cleaned_img_dir)DataLoaders
A critical part of deep learning is the loading of data to the model during the training loops.
DataLoaders handle the choice of which sample to load and in what order; they optimize the process in parallel by managing workers; they set several hyperparameters such as batch size and number of epochs.
In this section we explore DataLoaders with the Grain library [1].
Goal of DataLoaders
We can access elements of our Dataset class (as we did in the previous section) with:
for i, element in enumerate(nabirds_train):
print(element['img'].shape)
if i == 3:
break(269, 269, 3)
(269, 269, 3)
(269, 269, 3)
(269, 269, 3)
And we can return the next object by creating an iterator from of iterable dataset:
next(iter(nabirds_train)){'img': array([[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
...,
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]], shape=(269, 269, 3), dtype=uint8),
'species_name': 'Eared Grebe',
'species_id': 145,
'photographer': 'Laura Erickson'}
We are getting all these 0 because we padded the images with black.
But all this is extremely limited.
DataLoaders feed data to the model during training. They handle batching, shuffling, sharding across machines, the number of epochs, etc. All things that would be very inconvenient to implement from scratch.
Grain DataLoaders
The JAX AI stack includes the Grain library [1] to create DataLoaders, but it can also be done using PyTorch, TensorFlow Datasets, Hugging Face Datasets, or any method you are used to. That’s the advantage of the modular philosophy that the stack relies on. Grain is extremely efficient and does not rely on a huge set of dependencies as PyTorch and TensorFlow do.
In Grain, a DataLoader requires 3 components:
- A data source
- Transforms
- A sampler
Data source
We already have that: our data sources are our instances of Dataset class nabirds_train and nabirds_val for training and validation respectively.
Transformations
We need to split the data into batches. Batches can be defined with the grain.Batch method as a DataLoader transformation.
Let’s use batches of 32 with grain.Batch(batch_size=32).
Samplers
There are 2 types of Grain samplers: sequential and index samplers.
Sequential sampler
Grain comes with a basic sequential sampler.
import grain.python as grain
nabirds_train_seqsampler = grain.SequentialSampler(
num_records=4
)
for record_metadata in nabirds_train_seqsampler:
print(record_metadata)RecordMetadata(index=0, record_key=0, rng=None)
RecordMetadata(index=1, record_key=1, rng=None)
RecordMetadata(index=2, record_key=2, rng=None)
RecordMetadata(index=3, record_key=3, rng=None)
Index sampler
Grain index sampler is the one you should use as it allows for global shuffling of the dataset, setting the number of epochs, etc.
nabirds_train_isampler = grain.IndexSampler(
num_records=200,
shuffle=True,
seed=0
)
for i, record_metadata in enumerate(nabirds_train_isampler):
print(record_metadata)
if i == 3:
breakRecordMetadata(index=0, record_key=134, rng=Generator(Philox))
RecordMetadata(index=1, record_key=133, rng=Generator(Philox))
RecordMetadata(index=2, record_key=10, rng=Generator(Philox))
RecordMetadata(index=3, record_key=136, rng=Generator(Philox))
Let’s experiment with DataLoaders
With sequential sampler
nabirds_train_dl = grain.DataLoader(
data_source=nabirds_train,
sampler=nabirds_train_seqsampler,
worker_count=0
)We can plot images in this sequential DataLoader:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(8, 9))
for i, element in enumerate(nabirds_train_dl):
ax = plt.subplot(2, 2, i + 1)
plt.tight_layout()
ax.set_title(
f"""
{element['species_name']}
Picture by {element['photographer']}
""",
fontsize=9,
linespacing=1.5
)
ax.axis('off')
plt.imshow(element['img'])
plt.show()
Notice that, unlike last time we displayed some images, we aren’t looping through our Dataset (nabirds_train) anymore, but through our DataLoader (nabirds_train_dl).
Because we set the number of records to 4 in the sampler, we don’t have to break the loop.
With index sampler
nabirds_train_dl = grain.DataLoader(
data_source=nabirds_train,
sampler=nabirds_train_isampler,
worker_count=0
)Let’s plot these as well:
fig = plt.figure(figsize=(8, 9))
for i, element in enumerate(nabirds_train_dl):
ax = plt.subplot(2, 2, i + 1)
plt.tight_layout()
ax.set_title(
f"""
{element['species_name']}
Picture by {element['photographer']}
""",
fontsize=9,
linespacing=1.5
)
ax.axis('off')
plt.imshow(element['img'])
if i == 3:
plt.show()
break
Adding batch sizes
You can add the batch size in the operations argument of grain.DataLoader:
nabirds_train_dl = grain.DataLoader(
data_source=nabirds_train,
sampler=nabirds_train_isampler,
worker_count=0,
operations=[
grain.Batch(batch_size=32)
]
)Now we can access the next batch with:
next(iter(nabirds_train_dl)){'img': array([[[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
...,
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]]],
[[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
...,
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]]],
[[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
...,
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]]],
...,
[[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
...,
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]]],
[[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
...,
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
...,
[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]]],
[[[ 4, 0, 0],
[ 8, 2, 4],
[11, 7, 8],
...,
[ 0, 0, 7],
[ 1, 0, 8],
[ 2, 1, 9]],
[[ 4, 0, 0],
[ 6, 0, 0],
[ 5, 1, 0],
...,
[ 0, 0, 2],
[ 0, 0, 2],
[ 0, 0, 2]],
[[ 7, 2, 0],
[ 4, 0, 0],
[ 3, 0, 0],
...,
[ 1, 4, 0],
[ 1, 4, 0],
[ 1, 5, 0]],
...,
[[ 0, 2, 0],
[ 0, 2, 0],
[ 0, 2, 0],
...,
[ 0, 5, 0],
[ 0, 6, 0],
[ 0, 6, 0]],
[[ 0, 4, 3],
[ 0, 5, 1],
[ 0, 4, 3],
...,
[ 2, 2, 4],
[ 1, 3, 2],
[ 2, 2, 0]],
[[ 0, 1, 2],
[ 0, 1, 0],
[ 0, 0, 2],
...,
[ 1, 0, 9],
[ 0, 0, 7],
[ 1, 0, 7]]]], shape=(32, 269, 269, 3), dtype=uint8),
'photographer': array(['Jerry Ting', 'Jerry Ting', 'Robert Steffens', 'Dave Sanders',
'Jon Isacoff', 'Bob Gunderson', 'Joe Turner', 'Lois Manowitz',
'Ken Schneider', 'Phil Jeffrey', 'Ken Schneider', 'Laura Erickson',
'Tripp Davenport', 'Ned Harris',
'Stephen Ramirez www.birdsiview.org', 'Terry Gray', 'Jason Daly',
'Laura Erickson', 'Christopher Ciccone', 'Ned Harris',
'Ruth Sullivan', 'Allan Claybon', 'Ken Schneider', 'Nancy Landry',
'Muriel Neddermeyer', 'Conrad Tan', 'Tripp Davenport',
'Bill Schmoker', 'Chris Cochems', 'Terry Gray', 'Davor Desancic',
'Terry Gray'], dtype='<U34'),
'species_id': array([367, 22, 202, 105, 207, 134, 65, 359, 156, 112, 236, 269, 24,
384, 347, 276, 142, 145, 339, 231, 28, 228, 22, 17, 316, 350,
87, 370, 267, 252, 384, 295]),
'species_name': array(['Western Kingbird', 'Bald Eagle', 'House Finch', 'Carolina Wren',
'Indigo Bunting', 'Common Yellowthroat', 'Blue-winged Teal ',
'Verdin', 'Field Sparrow', 'Cedar Waxwing', 'Mottled Duck',
'Palm Warbler', 'Band-tailed Pigeon', 'White-winged Dove',
'Tennessee Warbler', 'Pine Grosbeak', 'Double-crested Cormorant',
'Eared Grebe', "Steller's Jay", 'Merlin', 'Barred Owl', 'Mallard',
'Bald Eagle', 'American Wigeon', 'Ruffed Grouse', 'Tree Swallow',
'Bufflehead', 'Western Screech-Owl', 'Pacific-slope Flycatcher',
'Northern Pygmy-Owl', 'White-winged Dove', 'Red-naped Sapsucker'],
dtype='<U24')}