Why JAX?

Author

Marie-Hélène Burle

There are many excellent and popular deep learning frameworks already (e.g. PyTorch). So why did Google—already behind the successful TensorFlow project—start developing JAX?

In this section, we will look at the advantages brought by JAX—namely speed and flexible automatic differentiation.

What is JAX?

JAX is a library for Python developed by Google. Its key data structure is the array. It can perform composition, transformation, and differentiation of numerical programs as well as compilation for CPUs, GPUs, and TPUs.

It comes with a NumPy-like API as well as a lower-level API called lax. While the NumPy-like API looks familiar to NumPy users, JAX requires strict functional programming (i.e. functions should only depend on their inputs and should only return outputs).

A relatively new project

It is clear that JAX is not a widely adopted project yet.

JAX is fast

JAX was built with performance in mind. Its speed relies on design decisions at all levels.

  • Default data type

    Like PyTorch—a popular deep learning library—JAX uses float32 as its default data type. This level of precision is perfectly suitable for deep learning and increases efficiency (by contrast, NumPy defaults to float64).

  • JIT compilation

    JIT compilation combines computations, avoids the allocation of memory to temporary objects, and more generally optimizes code for the XLA.

  • Accelerators

    The same code can run on CPUs or on accelerators (GPUs and TPUs).

  • XLA optimization

    XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that takes JIT-compiled JAX programs and optimizes them for the available hardware (CPUs, GPUs, or TPUs).

  • Asynchronous dispatch

    Computations are executed on the accelerators asynchronously.

  • Vectorization, data parallelism, and sharding

    All levels of shared and distributed memory parallelism are supported in JAX.

Flexible differentiation

Automatic differentiation (autodiff or AD) is the evaluation by computer programs of the partial derivatives of functions. It is a key part of deep learning since training a model mostly consists of updating its weights and biases to decrease some loss function and this is done thanks to various gradient-based optimizations.

Several implementations have been developed by different teams over time. This post by Chris Rackauckas summarizes the trade-offs of the various strategies.

Removing Julia (which by the way has a lot to offer in the field of AD) and PyTorch’s stale attempt at JIT compilation, Chris Rackauckas’ post can be summarized this way:

01 Autodiff method 1 Static graph and XLA 02 Framework 2 Dynamic graph 1->2 a TensorFlow 4 Dynamic graph and XLA 2->4 b PyTorch 5 Pseudo-dynamic and XLA 4->5 d TensorFlow2 e JAX 03 Advantage 7 Mostly optimized AD 8 Convenient 9 Convenient 10 Convenient and mostly optimized AD 04 Disadvantage A Manual writing of IR B Limited AD optimization D Disappointing speed E Pure functions


TensorFlow’s initial approach with static computational graphs in a domain-specific language—while efficient thanks to the intermediate representation (IR) and XLA—was inconvenient, limited, and hard to debug. Mostly, users had to write the IR themselves.

PyTorch came with dynamic graphs—an approach so much more convenient that it marked the beginning of the decline of TensorFlow. The operations are stored during the forward pass which allows for easy automatic differentiation. However this “per value” AD does not allow for a lot of optimizations.

TensorFlow2 tried to bring dynamic graphs, but it was a poor match for the XLA.

This leaves room for new strategies. Julia offers several promising approaches, but implementations are not straightforward and projects are not always mature. It is an exciting avenue for developers, not necessarily an easy one for end users.

JAX is another attempt at bringing both optimization and flexibility to autodiff. With Google behind it, it is a new but fast growing project.