Automatic differentiation

Author

Marie-Hélène Burle

One of the transformations that can be applied to array computations is the calculation of gradients which is crucial to the backpropagation through deep neural networks.

Considering the function f:

f = lambda x: x**3 + 2*x**2 - 3*x + 8

We can create a new function dfdx that computes the gradient of f w.r.t. x:

from jax import grad

dfdx = grad(f)

dfdx returns the derivatives:

print(dfdx(1.))
4.0

Composing transformations

Transformations can be composed:

print(jit(grad(f))(1.))
4.0
print(grad(jit(f))(1.))
4.0

Forward and reverse modes

JAX offers other autodiff methods:

  • reverse-mode vector-Jacobian products: jax.vjp,
  • forward-mode Jacobian-vector products: jax.jvp.

Higher-order differentiation

With a single variable, the grad function calls can be nested:

d2fdx = grad(dfdx)   # function to compute 2nd order derivatives
d3fdx = grad(d2fdx)  # function to compute 3rd order derivatives
...

With several variables, you have to use the functions:

  • jax.jacfwd for forward-mode,
  • jax.jacrev for reverse-mode.