Loading data
In this section, we will download the Food-101 (Bossard, Guillaumin, and Van Gool 2014) dataset that we will later use to train and fine-tune models.
Choosing a library
Data can be downloaded and processed manually, but many datasets are available via Hugging Face datasets, torchvision, and TensorFlow datasets. Remember that JAX does not implement domain-specific utilities and is not a deep learning library. Flax is a deep learning library, but, because there are already so many good options to load and process data, they did not implement a method of their own.
Choose the library you are the most familiar with, or the one for which you found code somewhere, or the one that seems the easiest to you, or provides the exact functionality that you want for your project.
The Food-101 dataset for instance can be accessed with torchvision.datasets.Food101
since it is one of TorchVision datasets or with tfds.image_classification.Food101
since it is also one of TFDS datasets.
It is also in the Hugging Face Hub and that’s the method that we will use here.
Hugging Face datasets
The Datasets library from Hugging Face is a lightweight, framework-agnostic, and easy to use API to download datasets from the Hugging Face Hub. It uses Apache Arrow’s efficient caching system, allowing large datasets to be used on machines with small memory (Lhoest et al. 2021).
Search dataset
Go to the Hugging Face Hub and search through thousands of open source datasets provided by the community.
Inspect dataset
You can get information on a dataset before downloading it.
Load the dataset builder for the dataset you are interested in:
from datasets import load_dataset_builder
= load_dataset_builder("food101") ds_builder
Get a description of the dataset:
ds_builder.info.description
Get information on the features:
ds_builder.info.features
Download dataset
We will only use the first 3 classes of food (instead of 101) to test our code. To prevent us from all downloading the data (by default in ~/.cache/huggingface
), we will use a joint cache directory at /project/60055/data
.
from datasets import load_dataset
= 3 * 750
train_size = 3 * 250
val_size
= load_dataset("food101",
train_dataset =f"train[:{train_size}]",
split="/project/60055/data")
cache_dir
= load_dataset("food101",
val_dataset =f"validation[:{val_size}]",
split="/project/60055/data") cache_dir