Loading data

Authors

Marie-Hélène Burle

Part of JAX tutorial

In this section, we will download the Food-101 () dataset that we will later use to train and fine-tune models.

Choosing a library

Data can be downloaded and processed manually, but many datasets are available via Hugging Face datasets, torchvision, and TensorFlow datasets. Remember that JAX does not implement domain-specific utilities and is not a deep learning library. Flax is a deep learning library, but, because there are already so many good options to load and process data, they did not implement a method of their own.

Choose the library you are the most familiar with, or the one for which you found code somewhere, or the one that seems the easiest to you, or provides the exact functionality that you want for your project.

The Food-101 dataset for instance can be accessed with torchvision.datasets.Food101 since it is one of TorchVision datasets or with tfds.image_classification.Food101 since it is also one of TFDS datasets.

It is also in the Hugging Face Hub and that’s the method that we will use here.

Hugging Face datasets

The Datasets library from Hugging Face is a lightweight, framework-agnostic, and easy to use API to download datasets from the Hugging Face Hub. It uses Apache Arrow’s efficient caching system, allowing large datasets to be used on machines with small memory ().

Search dataset

Go to the Hugging Face Hub and search through thousands of open source datasets provided by the community.

Inspect dataset

You can get information on a dataset before downloading it.

Load the dataset builder for the dataset you are interested in:

from datasets import load_dataset_builder
ds_builder = load_dataset_builder("food101")

Get a description of the dataset:

ds_builder.info.description

Get information on the features:

ds_builder.info.features

Download dataset

We will only use the first 3 classes of food (instead of 101) to test our code. To prevent us from all downloading the data (by default in ~/.cache/huggingface), we will use a joint cache directory at /project/60055/data.

from datasets import load_dataset

train_size = 3 * 750
val_size = 3 * 250

train_dataset = load_dataset("food101",
                             split=f"train[:{train_size}]",
                             cache_dir="/project/60055/data")

val_dataset = load_dataset("food101",
                           split=f"validation[:{val_size}]",
                           cache_dir="/project/60055/data")

References

Bossard, Lukas, Matthieu Guillaumin, and Luc Van Gool. 2014. “Food-101 – Mining Discriminative Components with Random Forests.” In European Conference on Computer Vision.
Lhoest, Quentin, Albert Villanova del Moral, Yacine Jernite, Abhishek Thakur, Patrick von Platen, Suraj Patil, Julien Chaumond, et al. 2021. “Datasets: A Community Library for Natural Language Processing.” In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, 175–84. Online; Punta Cana, Dominican Republic: Association for Computational Linguistics. https://aclanthology.org/2021.emnlp-demo.21.