The JAX AI stack
JAX is a powerful library for Python arrays that runs the same code on any hardware (CPU, GPU, TPU), but allows efficient JIT-compilation on accelerators. It also performs composable transformations such as vectorization and automatic differentiation.
While it can be used in other fields (Bayesian statistics/probabilistic programming, physics simulations, solvers…), it has all the characteristics required for deep learning and is heavily used for it in combination with AI-specific libraries built on top of it.
The JAX AI stack combines JAX with a selection of such libraries also developed by Google for easier installation and integration. It is based on a modular approach which means that any package can be replaced by whatever tool the user prefers (e.g. you can use a PyTorch DataLoader, TorchVision transformations, and train your model with JAX).
Slides (Click and wait: this reveal.js presentation may take a little time to load.)
Slides content for easier browsing.