Autodifferentiation#
Autodifferentiation aims to efficiently and accurately calculate the derivative of a function. Automatic differentiation is a computational technique with extremely valuable applications in optimisation and machine learning. For example, in the context of neural networks, autodifferentiation enables backpropagation by computing gradients of loss functions (we will look at this in detail later).
Autodifferentiation is distinct from the symbolic differentiation, which we saw earlier with sympy
and numerical differentiation.
These both present difficulties in performing computationally, with the former requiring a mathematical expression to be converted to a computer program, which can be inefficient, and the latter suffering from round-off errors.
These problems become worse as higher-order derivatives are found.
The only aim of autodifferentiation, unlike symbolic differentiation, is to obtain a numerical value for the derivative of the function.
Further Reading
If you want to learn about the rather complex mathematics that enables autodifferentiation, this video is an interesting place to start.
Autodifferentiation benefits from the fact that it is exact to machine precision and highly efficient, though complex to implement. However, recently, a range of exciting Python packages have enabled its implementation. For example, JAX is a high-performance array computing library (a bit like a fancier NumPy) that can perform autodifferentiation of functions.
JAX for Autodifferentiation#
The JAX library is constructed in a similar fashion to NumPy. Therefore, we can construct the same example as shown previously.
import jax.numpy as jnp
from jax import grad
grad_square = grad(jnp.square)
grad_square(0.5)
Array(1., dtype=float32, weak_type=True)
We observed previously that the gradient of \(x^2\) would be \(2x\), and we can see that JAX can reproduce that result.
You will notice that the input to the grad
function is itself a function.
Therefore, to perform the gradient of a more complex function, you need to create a function that performs the whole process and pass this to grad
.
For example, to find the gradient of \(x^2 + \log(x)\) at \(x=0.1\),
def f(x):
"""
An example function for JAX.
:param x: input to the function.
:return: output of the function.
"""
return jnp.square(x) + jnp.log(x)
grad_complex = grad(f)
grad_complex(0.1)
Array(10.2, dtype=float32, weak_type=True)
It is also possible to calculate higher-order derivatives with JAX by nesting the returns of the grad
function.
grad_complex_two = grad(grad(f))
grad_complex_two(0.1)
Array(-97.99999, dtype=float32, weak_type=True)
It is essential to highlight here that as JAX is an array computation library, it is possible to use JAX arrays to perform the derivative at many values of \(x\) with a single call.
However, we must use a vectorising map, vmap
, which maps a function over the arguments.
from jax import vmap
import matplotlib.pyplot as plt
x = jnp.linspace(0, 1, 100)
y = jnp.square(x)
y_dash = vmap(grad(jnp.square))(x)
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].plot(x, y)
ax[0].set_xlabel('x')
ax[0].set_ylabel('y')
ax[1].plot(x, y_dash)
ax[1].set_xlabel('x')
ax[1].set_ylabel('dy/dx')
plt.tight_layout()
plt.show()
