AI
R
Python
Julia
Git
Bash
Emacs
Tools
Pallas
Extension to write GPU and TPU kernels
Author
Marie-Hélène Burle
tracer
Tracing
jaxpr
Jaxprs
(JAX expressions)
intermediate
representation
(IR)
tracer->jaxpr
jit
Just-in-time
(JIT)
compilation
hlo
High-level
optimized (HLO)
program
jit->hlo
triton
Triton
GPU
GPU
triton->GPU
mosaic
Mosaic
TPU
TPU
mosaic->TPU
transform
Vectorization
Parallelization
Differentiation
py
Pure Python
functions
py->tracer
jaxpr->jit
jaxpr->transform
hlo->triton
hlo->mosaic