DataLoaders

Author

Marie-Hélène Burle

A critical part of deep learning is the loading of data to the model during the training loops.

DataLoaders handle the choice of which sample to load and in what order; they optimize the process in parallel by managing workers; they set several hyperparameters such as batch size and number of epochs.

In this section we explore DataLoaders with the Grain library [].

Training set

Let’s use our Dataset class (simplified slightly because we don’t need the bounding boxes anymore) and create one instance just with the training set:

base_dir = "<path-of-the-nabirds-dir>"

To be replaced by proper path.

import os
import polars as pl
import imageio.v3 as iio

metadata = pl.read_parquet("metadata.parquet")
metadata_train = metadata.filter(pl.col("is_training_img") == 1)
cleaned_img_dir = os.path.join(base_dir, "cleaned_images")

class NABirdsDataset:
    """NABirds dataset class."""
    def __init__(self, metadata_file, data_dir):
        self.metadata_file = metadata_file
        self.data_dir = data_dir

    def __len__(self):
        return len(self.metadata_file)

    def __getitem__(self, idx):
        path = os.path.join(self.data_dir, self.metadata_file.get_column('path')[idx])
        img = iio.imread(path)
        species = self.metadata_file.get_column('species')[idx].replace('_', ' ')
        subcategory = self.metadata_file.get_column('subcategory')[idx]
        if subcategory is not None:
            subcategory = subcategory.replace('_', ' ')
        photographer = self.metadata_file.get_column('photographer')[idx].replace('_', ' ')
        element = {
            'img': img,
            'species': species,
            'subcategory': subcategory,
            'photographer': photographer,
        }

        return element

nabirds_train = NABirdsDataset(metadata_train, cleaned_img_dir)

Goal of DataLoaders

We can access elements of our Dataset class (as we did in the previous section) with:

for i, element in enumerate(nabirds_train):
    print(element['img'].shape)
    if i == 3:
        break
(312, 688, 3)
(739, 1024, 3)
(722, 808, 3)
(753, 896, 3)

And we can return the next object by creating an iterator from of iterable dataset:

next(iter(nabirds_train))
{'img': array([[[123, 144, 171],
         [124, 145, 172],
         [125, 146, 173],
         ...,
         [130, 154, 180],
         [128, 152, 178],
         [127, 151, 177]],
 
        [[124, 145, 172],
         [125, 146, 173],
         [126, 147, 174],
         ...,
         [127, 151, 177],
         [128, 152, 178],
         [128, 152, 178]],
 
        [[126, 147, 174],
         [127, 148, 175],
         [127, 148, 175],
         ...,
         [124, 148, 174],
         [126, 150, 176],
         [127, 151, 177]],
 
        ...,
 
        [[100, 105, 124],
         [100, 106, 122],
         [100, 105, 124],
         ...,
         [114, 131, 159],
         [114, 131, 159],
         [114, 131, 159]],
 
        [[101, 106, 126],
         [101, 106, 125],
         [101, 106, 126],
         ...,
         [115, 132, 160],
         [115, 132, 160],
         [115, 132, 160]],
 
        [[102, 107, 127],
         [102, 107, 127],
         [102, 107, 127],
         ...,
         [116, 133, 161],
         [116, 133, 161],
         [116, 133, 161]]], shape=(312, 688, 3), dtype=uint8),
 'species': 'Eared Grebe',
 'subcategory': 'Nonbreeding/juvenile',
 'photographer': 'Laura Erickson'}

But all this is extremely limited.

DataLoaders feed data to the model during training. They handle batching, shuffling, sharding across machines, the number of epochs, etc.

Grain DataLoaders

The JAX AI stack includes the Grain library [] to create DataLoaders, but it can also be done using PyTorch, TensorFlow Datasets, Hugging Face Datasets, or any method you are used to. That’s the advantage of the modular philosophy that the stack relies on. Grain is extremely efficient and does not rely on a huge set of dependencies as PyTorch and TensorFlow do.

In Grain, a DataLoader requires 3 components:

  • A data source
  • Transforms
  • A sampler

Data source

We already have that: it is our instance of Dataset class that we called nabirds_train.

Transformations

We need to split the data into batches. Batches can be defined with the grain.Batch method as a DataLoader transformation.

Let’s use batches of 32 with grain.Batch(batch_size=32).

The batch size is a crucial hyperparameter: it impacts your training speed, model stability, and final accuracy.

Default strategy

If you are unsure where to start, use a batch size of 32.

32 is small enough to provide a regularizing effect (helping the model generalize) but large enough to benefit from parallel processing on GPUs.

Standard values

Always use powers of 2 (32, 64, 128, 256) because GPUs and CPUs are optimized for binary operations, and this aligns memory allocation efficiently.

Small batch size Large batch size
Training speed Slower: less efficient use of GPU Faster: maximizes GPU throughput
Generalization Better: the “noise” in the gradient helps the model escape sharp local minima Worse: can lead to overfitting
Convergence Noisy training curve: loss fluctuates Smoother training curve: stable descent
Memory usage Low: good for limited VRAM High: risk of OOM

Tuning the batch size

Ceiling

Your maximum batch size is dictated by your GPU memory (VRAM).

If you hit an out of memory (OOM) error, you need to back down to the the previous successful power of 2 (this is your hardware maximum).

Performance

Just because you can fit a batch size of 4096 doesn’t mean you should.

If training is stable but slow, double to 64, then double again to 128. You can increase the batch size to the hardware maximum to speed up epochs.

If the model overfits or diverges, try reducing the batch size. The “noisy” updates act like regularization, preventing the model from memorizing the data too perfectly.

Advanced techniques

  • Gradient accumulation:

If you need a batch size of 64 for stability but your GPU only fits 16, you can use gradient accumulation. You process 4 mini-batches of 16, accumulate the gradients, and update the weights once. This mathematically simulates a batch size of 64.

  • Dynamic batching:

Some advanced training regimes start with a small batch size to stabilize early training and increase it over time to speed up convergence (similar to learning rate decay).

Learning rate

If you change your batch size significantly, you should adjust your learning rate.

A rule that works well until you get to very large batch sizes is to double the learning rate when you double the batch size.

Samplers

Sequential sampler

Grain comes with a basic sequential sampler.

import grain.python as grain

nabirds_train_seqsampler = grain.SequentialSampler(
    num_records=4
)

for record_metadata in nabirds_train_seqsampler:
    print(record_metadata)
RecordMetadata(index=0, record_key=0, rng=None)
RecordMetadata(index=1, record_key=1, rng=None)
RecordMetadata(index=2, record_key=2, rng=None)
RecordMetadata(index=3, record_key=3, rng=None)

Index sampler

Grain index sampler is the one you should use as it allows for global shuffling of the dataset, setting the number of epochs, etc.

nabirds_train_isampler = grain.IndexSampler(
    num_records=200,
    shuffle=True,
    seed=0
)

for i, record_metadata in enumerate(nabirds_train_isampler):
  print(record_metadata)
  if i == 3:
      break
RecordMetadata(index=0, record_key=134, rng=Generator(Philox))
RecordMetadata(index=1, record_key=133, rng=Generator(Philox))
RecordMetadata(index=2, record_key=10, rng=Generator(Philox))
RecordMetadata(index=3, record_key=136, rng=Generator(Philox))

Building DataLoaders

We now have our source, samplers, and transformation (batching), so we can experiment with Grain DataLoaders.

With sequential sampler

nabirds_train_dl = grain.DataLoader(
    data_source=nabirds_train,
    sampler=nabirds_train_seqsampler,
    worker_count=0
)
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 9))

for i, element in enumerate(nabirds_train_dl):
    ax = plt.subplot(2, 2, i + 1)
    plt.tight_layout()
    ax.set_title(
        f"""
        Species: {element['species']}
        Additional information: {element['subcategory']}
        Picture by {element['photographer']}
        """,
        fontsize=9,
        linespacing=1.5
    )
    ax.axis('off')
    plt.imshow(element['img'])

plt.show()

Notice that, unlike last time we displayed some images, we aren’t looping through our Dataset (nabirds_train) anymore, but through our DataLoader (nabirds_train_dl).

Because we set the number of records to 4 in the sampler, we don’t have to break the loop.

With index sampler

nabirds_train_dl = grain.DataLoader(
    data_source=nabirds_train,
    sampler=nabirds_train_isampler,
    worker_count=0
)
fig = plt.figure(figsize=(8, 9))

for i, element in enumerate(nabirds_train_dl):
    ax = plt.subplot(2, 2, i + 1)
    plt.tight_layout()
    ax.set_title(
        f"""
        Species: {element['species']}
        Additional information: {element['subcategory']}
        Picture by {element['photographer']}
        """,
        fontsize=9,
        linespacing=1.5
    )
    ax.axis('off')
    plt.imshow(element['img'])
    if i == 3:
        plt.show()
        break

Adding batch sizes

nabirds_train_dl = grain.DataLoader(
    data_source=nabirds_train,
    sampler=nabirds_train_isampler,
    worker_count=0,
    operations=[
        grain.Batch(batch_size=32)
    ]
)

References

1.
Ritter M, Indyk I, Singh A, et al (2023) Grain - feeding JAX models