Automatic differentiation
One of the transformations that can be applied to array computations is the calculation of gradients which is crucial to the backpropagation through deep neural networks.
Considering the function f
:
= lambda x: x**3 + 2*x**2 - 3*x + 8 f
We can create a new function dfdx
that computes the gradient of f
w.r.t. x
:
from jax import grad
= grad(f) dfdx
dfdx
returns the derivatives:
print(dfdx(1.))
4.0
Composing transformations
Transformations can be composed:
print(jit(grad(f))(1.))
4.0
print(grad(jit(f))(1.))
4.0
Forward and reverse modes
JAX offers other autodiff methods:
- reverse-mode vector-Jacobian products:
jax.vjp
, - forward-mode Jacobian-vector products:
jax.jvp
.
Higher-order differentiation
With a single variable, the grad
function calls can be nested:
= grad(dfdx) # function to compute 2nd order derivatives
d2fdx = grad(d2fdx) # function to compute 3rd order derivatives
d3fdx ...
With several variables, you have to use the functions:
jax.jacfwd
for forward-mode,jax.jacrev
for reverse-mode.