The JAX AI stack

Author

Marie-Hélène Burle

JAX is a powerful library for Python arrays that runs the same code on any hardware (CPU, GPU, TPU), but allows efficient JIT-compilation on accelerators. It also performs composable transformations such as vectorization and automatic differentiation.

While it can be used in other fields (Bayesian statistics/probabilistic programming, physics simulations, solvers…), it has all the characteristics required for deep learning and is heavily used for it in combination with AI-specific libraries built on top of it.

The JAX AI stack combines JAX with a selection of such libraries also developed by Google for easier installation and integration. It is based on a modular approach which means that any package can be replaced by whatever tool the user prefers (e.g. you can use a PyTorch DataLoader, TorchVision transformations, and train your model with JAX).

Slides (Click and wait: this reveal.js presentation may take a little time to load.)

Slides content for easier browsing.