Loading datasets

Author

Marie-Hélène Burle

Neither JAX nor Flax implement methods to load datasets since PyTorch, TensorFlow, and Hugging Face already provide great APIs for this.

Let’s use one of the most classic of all deep learning datasets—the MNIST (LeCun, Cortes, and Burges 2010)—to see how these APIs work.

Hugging Face Datasets

The Datasets library from 🤗 is a lightweight, framework-agnostic, and easy to use API to download datasets from the Hugging Face Hub. It uses Apache Arrow’s efficient caching system, allowing large datasets to be used on machines with small memory (Lhoest et al. 2021).

Search dataset

Go to the Hugging Face Hub and search through thousands of open source datasets provided by the community.

Inspect dataset

You can get information on a dataset before downloading it.

Load the dataset builder for the dataset you are interested in:

from datasets import load_dataset_builder
ds_builder = load_dataset_builder("mnist")

Get a description of the dataset:

ds_builder.info.description
'The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000\nimages per class. There are 60,000 training images and 10,000 test images.\n'

Get information on the features:

ds_builder.info.features
{'image': Image(decode=True, id=None),
 'label': ClassLabel(names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], id=None)}

Download dataset and load in session

from datasets import load_dataset

ds = load_dataset("mnist")
ds
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

You need to have the Pillow package installed for this to work since the MNIST is an image dataset.

Let’s explore our dataset dictionary:

len(ds)
2
ds.keys()
dict_keys(['train', 'test'])
ds['train']
Dataset({
    features: ['image', 'label'],
    num_rows: 60000
})
len(ds['train'])
60000
ds['train'][0]
{'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
 'label': 5}
len(ds['train'][0])
2
ds['train'][0].keys()
dict_keys(['image', 'label'])
ds['train'][0]['image']

ds['train'][0]['label']
5

Convert to JAX object

We need to convert our dataset to a JAX Array object:

dsj = ds.with_format("jax")
dsj

Printing dsj looks the same, but:

dsj['train'][0]
{'image': Array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
       ...,
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0]], dtype=uint8),
 'label': Array(5, dtype=int32)}
dsj['train'][0]['image']
Array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       ...,
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0]], dtype=uint8)
dsj['train'][0]['label']
Array(5, dtype=int32)
dsj['train']['label'][:10]
Array([5, 0, 4, 1, 9, 2, 1, 3, 1, 4], dtype=int32)

We can shuffle the data:

ds_shuffled = dsj.shuffle(seed=123)
ds_shuffled['train']['label'][:10]
Array([4, 4, 4, 1, 7, 8, 5, 2, 8, 3], dtype=int32)

For normalization and more complex operations, you will need the Hugging Face transformers package.

PyTorch

If you are familiar with PyTorch DataLoaders, this is an equally great option:

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST(
    '~/projects/def-sponsor00/data',
    train=True, download=True, transform=transform)

test_data = datasets.MNIST(
    '~/projects/def-sponsor00/data',
    train=False, transform=transform)

# create DataLoaders
train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=20, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=20, shuffle=False)

You can find more details in our PyTorch course.

TensorFlow Datasets

For those familiar with TensorFlow, here is an example from the Flax Quick start:

import tensorflow_datasets as tfds
import tensorflow as tf

def get_datasets(num_epochs, batch_size):
    """Load MNIST train and test datasets into memory."""

    train_ds = tfds.load('mnist', split='train')
    test_ds = tfds.load('mnist', split='test')
    
    # normalize train set
    train_ds = train_ds.map(
        lambda sample: {'image': tf.cast(sample['image'], tf.float32) / 255.0,
                        'label': sample['label']})
    # normalize test set
    test_ds = test_ds.map(
        lambda sample: {'image': tf.cast(sample['image'], tf.float32) / 255.0,
                        'label': sample['label']})
    
    # create shuffled dataset by allocating a buffer size of 1024
    # to randomly draw elements from
    train_ds = train_ds.repeat(num_epochs).shuffle(1024)

    # group into batches of batch_size and skip incomplete batch,
    # prefetch the next sample to improve latency
    train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)

    # create shuffled dataset by allocating a buffer size of 1024
    # to randomly draw elements from
    test_ds = test_ds.shuffle(1024)

    # group into batches of batch_size and skip incomplete batch,
    # prefetch the next sample to improve latency
    test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)

    return train_ds, test_ds

There is nothing wrong with using utilities from different libraries! Pick and choose the tools that serve your needs best.

References

LeCun, Yann, Corinna Cortes, and CJ Burges. 2010. “MNIST Handwritten Digit Database.” ATT Labs [Online]. Available: Http://Yann.lecun.com/Exdb/Mnist 2.
Lhoest, Quentin, Albert Villanova del Moral, Yacine Jernite, Abhishek Thakur, Patrick von Platen, Suraj Patil, Julien Chaumond, et al. 2021. “Datasets: A Community Library for Natural Language Processing.” In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, 175–84. Online; Punta Cana, Dominican Republic: Association for Computational Linguistics. https://aclanthology.org/2021.emnlp-demo.21.