Pallas

Extension to write GPU and TPU kernels

Author

Marie-Hélène Burle

tracer Tracing jaxpr Jaxprs (JAX expressions) intermediate representation (IR) tracer->jaxpr jit Just-in-time (JIT) compilation hlo High-level optimized (HLO) program jit->hlo triton Triton GPU GPU triton->GPU mosaic Mosaic TPU TPU mosaic->TPU transform Vectorization Parallelization   Differentiation   py Pure Python functions py->tracer jaxpr->jit jaxpr->transform hlo->triton hlo->mosaic