Deep learning with JAX

Author

Marie-Hélène Burle

JAX is perfect for developing deep learning models:

  • it deals with multi-dimensional arrays,
  • it is extremely fast,
  • it is optimized for accelerators,
  • and it is capable of flexible automatic differentiation.

JAX is however not a DL library. While it is possible to create neural networks directly in JAX, it makes more sense to use libraries built on JAX that provide the toolkit necessary to build and train neural networks.

Deep learning workflow

Training a neural network from scratch requires a number of steps:

Load\ndataset Load dataset Define\narchitecture Define architecture Load\ndataset->Define\narchitecture Train Train Define\narchitecture->Train Test Test Train->Test Save\nmodel Save model Test->Save\nmodel


Pretrained models can also be used for feature extraction or transfer learning.

Deep learning ecosystem for JAX

Here is a classic ecosystem of libraries for deep learning with JAX:

  • Load datasets

    There are already good tools to load datasets (e.g. PyTorch, TensorFlow, Hugging Face), so JAX did not worry about creating its own implementation.

  • Define network architecture

    Neural networks can be build in JAX from scratch, but a number of packages built on JAX provide the necessary toolkit. Flax is the option recommended by the JAX developers and the one we will use in this course.

  • Train

    The package CLU (Common Loop Utils) is a set of helpers to write shorter training loops. Optax provides loss and optimization functions. Orbax brings checkpointing utilities.

  • Test

    Testing a model is easy to do directly in JAX.

  • Save model

    Flax provides methods to save a model.


To sum up, here is an ecosystem of libraries to use JAX for neural networks:

load Load data nn Define network PyTorch PyTorch train Train Optimize Checkpoint flax1 Flax test Test save Save model TensorFlow TensorFlow PyTorch->flax1 🤗 🤗 TensorFlow->flax1 🤗->flax1 jax1 JAX CLU Optax Orbax flax1->jax1 flax2 Flax jax2 JAX jax1->jax2 jax2->flax2


When working from pretrained models, Hugging Face also provides a great API to download from thousands of pretrained models.

How to get started?

A common approach is to start from one of the example projects and use it as a template.