JAX neural networks with Flax

A deep learning course with noshadow

Domain specific libraries have been developed on top of JAX to take advantage of its high-performance computations of multidimensional arrays and/or its autodiff abilities.

This course is an introduction to deep learning with JAX using the Flax neural networks library in combination with Optax. It assumes basic understanding of the functioning of JAX.

Start course ➤