and how to use it for deep learning
Library for Python developed by Google
Key data structure: Array
Composition, transformation, and differentiation of numerical programs
Compilation for CPUs, GPUs, and TPUs
NumPy-like and lower-level APIs
Requires strict functional programming
Default data type suited for deep learning
Like PyTorch, uses float32 as default. This level of precision is suitable for deep learning and increases efficiency (by contrast, NumPy defaults to float64)
The same code can run on CPUs or on accelerators (GPUs and TPUs)
XLA (Accelerated Linear Algebra) optimization
Asynchronous dispatch
Vectorization, data parallelism, and sharding
All levels of shared and distributed memory parallelism are supported
Summarized from a blog post by Chris Rackauckas
Considering the function f:
We can create a new function dfdx that computes the gradient of f w.r.t. x:
dfdx returns the derivatives:
4.0
jax.vjpjax.jvpWith a single variable, the grad function calls can be nested:
With several variables, you have to use the functions:
jax.jacfwd for forward-mode,jax.jacrev for reverse-mode.| Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
|---|---|---|---|---|---|---|
| CPU | yes | yes | yes | yes | yes | yes |
| NVIDIA GPU | yes | yes | no | n/a | no | experimental |
| Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
| AMD GPU | yes | no | experimental | n/a | no | no |
| Apple GPU | n/a | no | n/a | experimental | n/a | n/a |
| Intel GPU | experimental | n/a | n/a | n/a | no | no |
From JAX documentation
If you install packages which depend on JAX (e.g. Flax), they will by default install the CPU version of JAX. If you want to run JAX on GPUs, make sure to first install jax[cuda12]
You can install the CPU version on your machine to prototype and use a GPU version on the clusters (we have wheels)
The modular approach has the downside that several libraries are required and conflicts between dependencies can be a problem
The meta-library jax-ai-stack makes this easier to manage (install jax[cuda12] first for GPU)
Note that for now TensorFlow and packages which depend on it (e.g. TFDS, grain) are still stuck at Python 3.12, so you can’t use a newer Python version if you want to use some of them
On your machine (and your machine only), a great tool to manage Python versions and packages is uv (see our webinar). On the clusters, you have to use module to load the Python version you want and pip to install packages