Loading datasets

Author

Marie-Hélène Burle

Flax does not implement methods to load datasets since Hugging Face, PyTorch, and TensorFlow already provide great APIs for this.

You only need one of the methods below to load datasets, so you should only install one of the packages datasets, torchvision, or tensorflow-datasets. Trying to install them all will actually often result in conflicts between versions of some of their shared dependencies.

For this course, I installed tensorflow-datasets in our training cluster.

Hugging Face datasets

The Datasets library from 🤗 is a lightweight, framework-agnostic, and easy to use API to download datasets from the Hugging Face Hub. It uses Apache Arrow’s efficient caching system, allowing large datasets to be used on machines with small memory (Lhoest et al. 2021).

Search dataset

Go to the Hugging Face Hub and search through thousands of open source datasets provided by the community.

Inspect dataset

You can get information on a dataset before downloading it.

Load the dataset builder for the dataset you are interested in:

from datasets import load_dataset_builder
ds_builder = load_dataset_builder("mnist")

Get a description of the dataset:

ds_builder.info.description
'The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000\nimages per class. There are 60,000 training images and 10,000 test images.\n'

Get information on the features:

ds_builder.info.features
{'image': Image(mode=None, decode=True, id=None),
 'label': ClassLabel(names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], id=None)}

Download dataset and load in session

from datasets import load_dataset

def get_dataset_hf():
    mnist = load_dataset("mnist")

    ds = {}

    for split in ['train', 'test']:
        ds[split] = {
            'image': np.array([np.array(im) for im in mnist[split]['image']]),
            'label': np.array(mnist[split]['label'])
        }

        ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
        ds[split]['label'] = jnp.int16(ds[split]['label'])

        ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)

    return ds['train'], ds['test']

PyTorch Torchvision datasets

Torchvision from PyTorch also provides an API to download and prepare many standard datasets as well as utilities to build your own.

from torchvision import datasets

def get_dataset_torch():
    mnist = {
        'train': datasets.MNIST('./data', train=True, download=True),
        'test': datasets.MNIST('./data', train=False, download=True)
    }

    ds = {}

    for split in ['train', 'test']:
        ds[split] = {
            'image': mnist[split].data.numpy(),
            'label': mnist[split].targets.numpy()
        }

        ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
        ds[split]['label'] = jnp.int16(ds[split]['label'])

        ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)

    return ds['train'], ds['test']

TensorFlow datasets

TensorFlow also has a dataset API which can be installed as a standalone package.

import tensorflow_datasets as tfds

def get_dataset_tf(epochs, batch_size):
    mnist = tfds.builder('mnist')
    mnist.download_and_prepare()

    ds = {}

    for set in ['train', 'test']:
        ds[set] = tfds.as_numpy(mnist.as_dataset(split=set, batch_size=-1))

        # cast to jnp and rescale pixel values
        ds[set]['image'] = jnp.float32(ds[set]['image']) / 255
        ds[set]['label'] = jnp.int16(ds[set]['label'])

    return ds['train'], ds['test']

References

Lhoest, Quentin, Albert Villanova del Moral, Yacine Jernite, Abhishek Thakur, Patrick von Platen, Suraj Patil, Julien Chaumond, et al. 2021. “Datasets: A Community Library for Natural Language Processing.” In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, 175–84. Online; Punta Cana, Dominican Republic: Association for Computational Linguistics. https://aclanthology.org/2021.emnlp-demo.21.