Deep learning with the JAX AI stack

Training a classification model on your own data with  noshadow

This example-based course takes you through the initial steps necessary to go from your own data to a trained model.

It uses a computer vision classification problem as the study case and a modern and efficiency-oriented stack of libraries as the tools, including:

  • Polars for faster DataFrames,
  • ImageIO for reading in images into ndarrays,
  • Grain or PyTorch for datasets classes and dataloaders,
  • PIX for data augmentation,
  • JAX for JIT compilation, accelerators use, automatic vectorization, and automatic differentiation,
  • Flax for neural networks building,
  • Orbax for checkpointing,
  • Optax for optimization.

The main objective is to help you get started with deep learning by explaining concepts such as dataset classes, dataloaders, data augmentation, training, etc. as we move along our study case.

While deep learning concepts are relatively simple, it is often a frustrating affair to get anything to work and to adapt documentation or online tutorials to your own data. By walking through this together, we will hopefully make this part less challenging when you try it on your own.

Along the way, I will introduce core JAX concepts and get you started with this powerful library.

This course doesn’t have prerequisites, although some knowledge of Python, NumPy, and a basic understanding of neural networks help.


Start course ➤