Deep learning with JAX


Marie-Hélène Burle

JAX is perfect for developing deep learning models:

  • it deals with multi-dimensional arrays,
  • it is extremely fast,
  • it is optimized for accelerators,
  • and it is capable of flexible automatic differentiation.

JAX is however not a DL library. While it is possible to create neural networks directly in JAX, it makes more sense to use libraries built on JAX that provide the toolkit necessary to build and train neural networks.

Deep learning workflow

Training a neural network from scratch requires a number of steps:

Load\ndataset Load dataset Define\narchitecture Define architecture Load\ndataset->Define\narchitecture Train Train Define\narchitecture->Train Test Test Train->Test Save\nmodel Save model Test->Save\nmodel

Pretrained models can also be used for feature extraction or transfer learning.

Deep learning ecosystem for JAX

Here is a classic ecosystem of libraries for deep learning with JAX:

  • Load datasets

    There are already good tools to load datasets (e.g. PyTorch, TensorFlow, Hugging Face), so JAX did not worry about creating its own implementation.

  • Define network architecture

    Neural networks can be build in JAX from scratch, but a number of packages built on JAX provide the necessary toolkit. Flax is a popular option and the one we will use in this course.

  • Train

    Training a model requires optimization functions. These can be implemented in JAX from scratch but the library Optax provides the core components.

  • Test

    Testing a model is easy to do directly in JAX.

  • Save model

    Flax provides methods to save a model.

To sum up, here is a common ecosystem of libraries to use JAX for neural networks:

Load\ndataset Load dataset Define\narchitecture Define architecture PyTorch PyTorch Train Train flax1 Flax Test Test Save\nmodel Save model TensorFlow TensorFlow PyTorch->flax1 Hugging Face Hugging Face TensorFlow->flax1 Hugging Face->flax1 jax1 JAX + Optax flax1->jax1 Flax Flax JAX JAX jax1->JAX JAX->Flax

When working from pretrained models, Hugging Face also provides a great API to download from thousands of pretrained models.