Deep learning with JAX
JAX is perfect for developing deep learning models:
- it deals with multi-dimensional arrays,
- it is extremely fast,
- it is optimized for accelerators,
- and it is capable of flexible automatic differentiation.
JAX is however not a DL library. While it is possible to create neural networks directly in JAX, it makes more sense to use libraries built on JAX that provide the toolkit necessary to build and train neural networks.
Deep learning workflow
Training a neural network from scratch requires a number of steps:
Pretrained models can also be used for feature extraction or transfer learning.
Deep learning ecosystem for JAX
Here is a classic ecosystem of libraries for deep learning with JAX:
Load datasets
There are already good tools to load datasets (e.g. PyTorch, TensorFlow, Hugging Face), so JAX did not worry about creating its own implementation.
Define network architecture
Neural networks can be build in JAX from scratch, but a number of packages built on JAX provide the necessary toolkit. Flax is the option recommended by the JAX developers and the one we will use in this course.
Train
The package CLU (Common Loop Utils) is a set of helpers to write shorter training loops. Optax provides loss and optimization functions. Orbax brings checkpointing utilities.
Test
Testing a model is easy to do directly in JAX.
Save model
Flax provides methods to save a model.
To sum up, here is an ecosystem of libraries to use JAX for neural networks:
When working from pretrained models, Hugging Face also provides a great API to download from thousands of pretrained models.
How to get started?
A common approach is to start from one of the example projects and use it as a template.