Automatic differentiation
One of the transformations that can be applied to array computations is the calculation of gradients which is crucial to the backpropagation through deep neural networks.
Considering the function f:
f = lambda x: x**3 + 2*x**2 - 3*x + 8We can create a new function dfdx that computes the gradient of f w.r.t. x:
from jax import grad
dfdx = grad(f)dfdx returns the derivatives:
print(dfdx(1.))4.0
Composing transformations
Transformations can be composed:
print(jit(grad(f))(1.))4.0
print(grad(jit(f))(1.))4.0
Forward and reverse modes
JAX offers other autodiff methods:
- reverse-mode vector-Jacobian products:
jax.vjp, - forward-mode Jacobian-vector products:
jax.jvp.
Higher-order differentiation
With a single variable, the grad function calls can be nested:
d2fdx = grad(dfdx) # function to compute 2nd order derivatives
d3fdx = grad(d2fdx) # function to compute 3rd order derivatives
...With several variables, you have to use the functions:
jax.jacfwdfor forward-mode,jax.jacrevfor reverse-mode.