JAX neural networks with Flax

A deep learning course with noshadow

JAX is a very fast open source Python library for function transformations (including differentiation) and array computations on accelerators (GPUs/TPUs). It provides a structural framework on which domain specific libraries can build.

In the field of deep learning, the most popular of these libraries is Flax: Flax makes full use of JAX’s power and adds the tools to build and train neural networks.

This introduction to Flax covers the construction of networks, the handling of model states, and training optimization techniques. It assumes basic understanding of the functioning of JAX.


Start course ➤