Our deep learning example
This course is structured around one concrete deep learning example which we introduce in this section.
The problem
We want to train a model to perform a fine-grained vision classification task in which the distinguishing features between classes are subtle.
The data
We will use the NABirds dataset from the Cornell Lab of Ornithology [1].
Our strategy
Before embarking on a deep learning project, it is crucial to think about the overall strategy to follow for optimum performance.
In particular, it is important to consider what should run on CPUs vs GPUs and what is part of the preliminary work (sometimes called “offline”) and what constitutes the actual training (“online”).
Here is our plan (click on the image to enlarge):
It is a hybrid approach, split in 2 phases:
Phase 1: preparation (CPU)
We will do the preliminary, deterministic cropping of images on CPUs because:
Doing it as a transformation during the training loop (so a Transform on our DataLoader) would make no sense—that work would be repeated at each epoch while we only need to do it once. We are much better off doing it once and writing the outputs to file.
Moreover, this would force the computer to load and decode large files to then throw away a large portions of them each time it sees these files (thousands of times during training). Saving to file at this step reduces the I/O load significantly and will make our training epochs run a lot faster.
Writing to file is actually a task done by CPUs, so sending the data to GPUs and back to save it to files is actually a very inefficient workflow.
For this first part, we don’t use JAX. Instead, we use classic NumPy arrays. Trying to use JAX, JIT-compilation, or GPUs for this initial step would actually be very inefficient.
Phase 2: training (GPU)
After that, we move to the GPUs for the training loop (which includes data augmentation transformations—these are random and happening with some probability at each epoch, so they must be part of the loop. We aren’t saving any of the transformed images to file at that point.)
For that part, we use JAX and we JIT compile. This is where we use the heavy lifting of compilation, JAX, and GPUs.
