from datasets import load_dataset_builder
= load_dataset_builder("mnist") ds_builder
Loading datasets
Flax does not implement methods to load datasets since Hugging Face, PyTorch, and TensorFlow already provide great APIs for this.
You only need one of the methods below to load datasets, so you should only install one of the packages datasets
, torchvision
, or tensorflow-datasets
. Trying to install them all will actually often result in conflicts between versions of some of their shared dependencies.
For this course, I installed tensorflow-datasets
in our training cluster.
Hugging Face datasets
The Datasets library from 🤗 is a lightweight, framework-agnostic, and easy to use API to download datasets from the Hugging Face Hub. It uses Apache Arrow’s efficient caching system, allowing large datasets to be used on machines with small memory (Lhoest et al. 2021).
Search dataset
Go to the Hugging Face Hub and search through thousands of open source datasets provided by the community.
Inspect dataset
You can get information on a dataset before downloading it.
Load the dataset builder for the dataset you are interested in:
Get a description of the dataset:
ds_builder.info.description
'The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000\nimages per class. There are 60,000 training images and 10,000 test images.\n'
Get information on the features:
ds_builder.info.features
{'image': Image(mode=None, decode=True, id=None),
'label': ClassLabel(names=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], id=None)}
Download dataset and load in session
from datasets import load_dataset
def get_dataset_hf():
= load_dataset("mnist")
mnist
= {}
ds
for split in ['train', 'test']:
= {
ds[split] 'image': np.array([np.array(im) for im in mnist[split]['image']]),
'label': np.array(mnist[split]['label'])
}
'image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
ds[split][
'image'] = jnp.expand_dims(ds[split]['image'], 3)
ds[split][
return ds['train'], ds['test']
PyTorch Torchvision datasets
Torchvision from PyTorch also provides an API to download and prepare many standard datasets as well as utilities to build your own.
from torchvision import datasets
def get_dataset_torch():
= {
mnist 'train': datasets.MNIST('./data', train=True, download=True),
'test': datasets.MNIST('./data', train=False, download=True)
}
= {}
ds
for split in ['train', 'test']:
= {
ds[split] 'image': mnist[split].data.numpy(),
'label': mnist[split].targets.numpy()
}
'image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
ds[split][
'image'] = jnp.expand_dims(ds[split]['image'], 3)
ds[split][
return ds['train'], ds['test']
TensorFlow datasets
TensorFlow also has a dataset API which can be installed as a standalone package.
import tensorflow_datasets as tfds
def get_dataset_tf(epochs, batch_size):
= tfds.builder('mnist')
mnist
mnist.download_and_prepare()
= {}
ds
for set in ['train', 'test']:
set] = tfds.as_numpy(mnist.as_dataset(split=set, batch_size=-1))
ds[
# cast to jnp and rescale pixel values
set]['image'] = jnp.float32(ds[set]['image']) / 255
ds[set]['label'] = jnp.int16(ds[set]['label'])
ds[
return ds['train'], ds['test']