Training#
Training is a familiar concept across machine learning methods. Our training aims to teach our model about the data we have shown so that it can predict outcomes that it has not seen yet. Here, we will look at training a simple XOR operation multilayer perceptron.
Pavlov’s Dog
We can use the famous experiment of Pavlov’s dog to explain the training process. Every time Pavlov rang a bell, he would give his dog a treat. Over time, the dog began to associate the bell with the treat, and therefore, he would salivate when he heard the bell. This training leads to changes in the processing in the cell body of the dog’s biological neurons and is directly analogous to the training we will perform on the artificial neural network.
Backpropagation#
A popular training algorithm for neural networks is known as backpropagation. Let’s look at using this to train the XOR operation above. Above, we have three layers, i.e., \(M=3\), where the output layer is the Mth-layer and \(y_{\textrm{pred},i}\) will be the final prediction. We define some loss function that we will use in the optimisation. Here we use the mean-squared error (MSE),
where \(N\) is the number of outputs (for the XOR, this is 1), and \(y_i\) is the true value (from our truth table).
We implement this with jax
for reasons that will become clear.
import jax.numpy as jnp
def mse(y, y_pred):
"""
Mean Squared Error
:param y: actual values
:param y_pred: predicted values
:return: mean squared error
"""
return jnp.mean(jnp.square(y - y_pred))
For an activation function, instead of using a Heaviside function, we will use the more advanced (and differentiable) logistic function.
def logistic(z):
"""
Compute the logistic function
:param z: input
:return: output of the logistic function
"""
return 1 / (1 + jnp.exp(-z))
The backpropagation process involves using the gradient descent algorithm we met previously and propagating this error through the layers. We can write the forward propagation as follows.
def forward_propagation(layer_one_weights, layer_one_biases,
layer_two_weights, layer_two_biases, input):
"""
The forward pass of the neural network
:param input: input data
:param layer_one_weights: weights of the first layer
:param layer_one_biases: biases of the first layer
:param layer_two_weights: weights of the second layer
:param layer_two_biases: biases of the second layer
:return: predicted values
"""
h = logistic(jnp.dot(input, layer_one_weights) + layer_one_biases)
y_pred = logistic(jnp.dot(h, layer_two_weights) + layer_two_biases)
return y_pred
We want to compute the gradient of the loss function from the predictions of the forward pass concerning each of the weights and biases.
Four objects must be optimised, so we include these as argnums
in the grad
call.
from jax import grad
def to_optimise(layer_one_weights, layer_one_biases, layer_two_weights, layer_two_biases, input, y):
"""
The function to be optimised.
:param same as forward_propagation:
:return: the mean squared error
"""
y_pred = forward_propagation(layer_one_weights, layer_one_biases,
layer_two_weights, layer_two_biases, input)
return mse(y, y_pred)
grad_fn = grad(to_optimise, argnums=(0, 1, 2, 3))
We can start with randomly initialised weights and biases.
import jax.random as random
key = random.PRNGKey(0)
key1, key2 = random.split(key)
W1 = random.normal(key1, (2, 2))
b1 = jnp.zeros((2,))
W2 = random.normal(key2, (2, 1))
b2 = jnp.zeros((1,))
From the truth table for the XOR operation, we can get the following true inputs and outputs.
x = jnp.array([[0, 0],
[0, 1],
[1, 0],
[1, 1]])
y = jnp.array([[0],
[1],
[1],
[0]])
Finally, we can run the gradient descent algorithm with a learning rate of 0.5 for 5000 epochs. In optimisation, we would call these epochs iterations.
learning_rate = 0.5
for epoch in range(5000):
grads = grad_fn(W1, b1, W2, b2, x, y)
W1 -= learning_rate * grads[0]
b1 -= learning_rate * grads[1]
W2 -= learning_rate * grads[2]
b2 -= learning_rate * grads[3]
After 5000 epochs, does our analysis work?
y_pred = forward_propagation(W1, b1, W2, b2, x)
print("Predictions:\n", y_pred)
Predictions:
[[0.04271777]
[0.95220965]
[0.9595658 ]
[0.03770101]]
If we round these predictions, we can see that the results match the truth table. This indicates that we have successfully trained the neural network. Let’s have a look at the weights and biases that were trained.
W1, b1, W2, b2
(Array([[ 4.9126444, 5.6503496],
[-5.1137004, -5.5375805]], dtype=float32),
Array([-2.6823142, 2.8359053], dtype=float32),
Array([[ 7.974698],
[-7.495365]], dtype=float32),
Array([3.4599495], dtype=float32))
It is important to highlight that this problem is not degenerate in that there is more than one solution for this multilayer perceptron that will solve the XOR operation.