Author

Marie-Hélène Burle

One of the transformations that can be applied to array computations is the calculation of gradients which is crucial to the backpropagation through deep neural networks.

Considering the function `f`:

``f = lambda x: x**3 + 2*x**2 - 3*x + 8``

We can create a new function `dfdx` that computes the gradient of `f` w.r.t. `x`:

``````from jax import grad

`dfdx` returns the derivatives:

``print(dfdx(1.))``
``4.0``

## Composing transformations

Transformations can be composed:

``print(jit(grad(f))(1.))``
``4.0``
``print(grad(jit(f))(1.))``
``4.0``

## Forward and reverse modes

JAX offers other autodiff methods:

• reverse-mode vector-Jacobian products: `jax.vjp`,
• forward-mode Jacobian-vector products: `jax.jvp`.

## Higher-order differentiation

With a single variable, the `grad` function calls can be nested:

``````d2fdx = grad(dfdx)   # function to compute 2nd order derivatives
d3fdx = grad(d2fdx)  # function to compute 3rd order derivatives
...``````

With several variables, you have to use the functions:

• `jax.jacfwd` for forward-mode,
• `jax.jacrev` for reverse-mode.