Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Neural nets are a very powerful class of methods that have become popular in fields like computer vision and natural language processing, where coming up with good features can be challenging.

While there’s a rich mathematical foundation underlying neural nets, in this class we’ll focus on one of the big computational ideas that’s at the heart of most neural net implementations: backpropagation and automatic differentiation. While these ideas were initially conceived for neural nets, they’re now used in many other ways too: libraries like PyMC use automatic differentiation to do efficient Bayesian inference; and much more.

In general, automatic differentiation and backpropagation are useful for any problem where the solution involves computing gradients!

A feed-forward neural network

As we’ve already seen, linear regression is a simple but powerful model: to predict a value yy from a vector of features x=(x1,,xk)x = (x_1, \ldots, x_k), linear regression uses the following:

y=Wx+by = Wx + b

Here, WW is a vector of coefficients, sometimes also called weights, and bb is a scalar that we call the intercept or bias. As we saw in the previous section, linear models can fail when the relationship between xx and yy is nonlinear. We also saw that if we want to model complex, nonlinear interactions while still using linear models, we need to define more complex features.

Motivated by this, what if we tried using another layer of linear regression that could compute features for us? It might look something like this:

y=W2(W1x+b1features)+b2y = W_2(\overbrace{W_1 x + b_1}^\text{features}) + b_2

Here, W1W_1 is now an m×km \times k matrix of weights, and the result of the matrix-vector multiplication and addition W1x+b1W_1x + b_1 is an mm-dimensional vector of features. We then apply linear regression with those features, using the weights in the vector W2W_2 and intercept/bias in the scalar b2b_2, to obtain yy.

Unfortunately, this doesn’t work because it reduces to a single layer of linear regression. Applying a bit of algebra, we can simplify the above equation to y=(W2W1)x+(W2b1+b2)y = \big(W_2W_1\big)x + \big(W_2b_1 + b_2\big), which is just linear regression written in an unnecessarily complicated way.

In order to prevent the simplification to linear regression, we could apply a nonlinear function ff as part of computing the features:

y=W2f(W1x+b1)features+b2y = W_2 \overbrace{f(W_1 x + b_1)}^\text{features} + b_2

This is now the simplest possible neural network, which we call a feed-forward or fully connected network with one hidden layer (the so-called “hidden layer” is the result of the computation f(W1x+b1)f(W_1 x + b_1)).

The nonlinear function ff can be anything from a sigmoid or logistic function to the ReLU (restricted linear unit) function, f(z)=max(0,z)f(z) = \max(0, z).

In order to fit a linear regression model, we had to estimate good coefficients. In probabilistic terms, we did this using In order to fit a neural network, we have to estimate good weights W1,W2,W_1, W_2, \ldots and biases b1,b2,ldotsb_1, b_2, ldots.

To make our notation a little simpler, we’ll use θ\theta to denote all our parameters: θ=(W1,W2,b1,b2)\theta = (W_1, W_2, b_1, b_2). In order to find the best values of θ\theta, we’ll define a loss function (θ,y)\ell(\theta, y) and then use stochastic gradient descent to minimize it.

Loading...

Empirical risk minimization

We start by choosing a loss function. In general, the choice of loss function depends on the problem we’re solving, but two common choices are the squared error loss (also known as 2\ell_2 loss) and the binary cross-entropy loss (BCE). Let’s consider the 2\ell_2 loss:

(θ,y)=(yy^)2=(y[W2f(W1x+b1)+b2])2\begin{align*} \ell(\theta, y) &= (y - \hat{y})^2 \\ &= \left(y - \left[W_2 f(W_1 x + b_1) + b_2\right]\right)^2 \end{align*}

We’ll minimize the average loss:

R(θ)=1ni=1n(yi[W2f(W1xi+b1)+b2])2\begin{align*} R(\theta) &= \frac{1}{n} \sum_{i=1}^n \left(y_i - \left[W_2 f(W_1 x_i + b_1) + b_2\right]\right)^2 \end{align*}

Here, we’re averaging over the empirical distribution of the data in our training set, which makes this a frequentist risk. The process of minimizing this loss is often referred to as empirical risk minimization.

Review: Stochastic gradient descent

For more on stochastic gradient descent, you may find it helpful to review Chapter of the Data 100 textbook.

(Stochastic) gradient descent is a powerful tool that lets us find the minimum of any function, as long as we can compute its gradient. Recall that a gradient is a vector of partial derivatives with respect to each parameter. In the example above, our gradient would be

θ(θ,y)=(W1(θ,y)W2(θ,y)b1(θ,y)b2(θ,y))\nabla_\theta \ell (\theta, y) = \begin{pmatrix} \frac{\partial \ell}{\partial W_1}(\theta, y)\\ \frac{\partial \ell}{\partial W_2}(\theta, y)\\ \frac{\partial \ell}{\partial b_1}(\theta, y)\\ \frac{\partial \ell}{\partial b_2}(\theta, y) \end{pmatrix}

Gradient descent is an optimization procedure where we start with an initial estimate for our parameters, θ(0)\theta^{(0)}. We then repeatedly apply the following update to get θ(1),θ(2),\theta^{(1)}, \theta^{(2)}, \ldots:

θ(t+1)=θ(t)αθ(θ(t),y)\theta^{(t+1)} = \theta^{(t)} - \alpha \nabla_\theta \ell(\theta^{(t)}, y)

Here, α\alpha is a learning rate (typically a small positive number, also sometimes called a step size), and yy is the data we observed. In stochastic gradient descent, instead of computing the gradient using all of our data, we divide our data into batches, and at each iteration, compute the gradient on one batch in sequence.

This means that we must compute the gradient at every single iteration. So, anything we can do to compute gradients faster and more efficiently will make our entire optimization process faster and more efficient.

Loading...

Gradients and Backpropagation

Backpropagation is an algorithm for efficiently computing gradients by applying the chain rule in an order designed to avoid redundant computation. To see how it works, we’ll consider a very simple loss function of three variables. We’ll compute the gradient manually using the chain rule, and then we’ll see how backpropagation can do the same computation more efficiently.

Computing gradients with the chain rule

Consider a very simple loss function involving three variables, aa, bb, and cc:

L(a,b,c)=(a+3b)c2L(a, b, c) = (a + 3b)c^2

We can compute the partial derivatives with respect to aa, bb, and cc. To make it a little clearer when and where we’re using the chain rule, let q=a+3bq = a+3b and r=c2r = c^2, so that L=qrL = qr. The partial derivatives are:

La=Lqqa=c21Lb=Lqqb=c23Lc=Lrrc=(a+3b)2c\begin{align*} \frac{\partial L}{\partial a} &= \frac{\partial L}{\partial q}\cdot\frac{\partial q}{\partial a} = c^2 \cdot 1 \\ \frac{\partial L}{\partial b} &= \frac{\partial L}{\partial q}\cdot\frac{\partial q}{\partial b} = c^2 \cdot 3 \\ \frac{\partial L}{\partial c} &= \frac{\partial L}{\partial r}\cdot\frac{\partial r}{\partial c} = (a+3b) \cdot 2c \end{align*}

Even in this simple example, we can see that there was some redundant work involved: in doing this computation, we needed to compute Lq=c2\frac{\partial L}{\partial q} = c^2 twice. In a more complicated expression, especially one with many nested function calls, the redundant work would become much worse. Backpropagation gives us a way to compute these gradients more efficiently.

Loading...
Loading...

Backpropagation: an example

Instead of representing the computation as an algebraic expression, let’s express it as a computation graph. This is a visual representation of the mathematical expression.

Given specific numerical values for aa, bb, and cc, backpropagation is an efficient way to compute the loss and the gradient (i.e., all the partial derivatives), with no redundant computation.

We start by computing the loss itself. This involves just doing the computations specified by the graph, denoted by the blue numbers above the arrows:

Next, let’s notice that when we did the calculations in the previous section to find the gradient, most of our expressions started at the loss, then, using the chain rule, computed partial derivatives with respect to things like qq and rr. Let’s try to write these partial derivatives on the graph, and see if we can use them to keep working backwards.

  1. First, we’ll start with the easiest derivative, the derivative of the loss with respect to itself: LL\frac{\partial L}{\partial L}. This is just 1!

  2. Next, we’ll compute the derivative of the loss with respect to qq (top right branch of the graph). How did we get from qq to LL? We multiplied by 16 (that is, for these specific numbers, L=16qL = 16q). So, the partial derivative of LL with respect to qq is just 16.

  3. Continuing along the top part of the graph, now we can compute the derivative with respect to aa. How did we get from aa to qq? We added 9 (that is, for these specific numbers, q=a+9q = a + 9). So, the partial derivative of qq with respect to aa is just 1. But we’re trying to compute La\frac{\partial L}{\partial a}, not qa\frac{\partial q}{\partial a}. So, we’ll take advantage of the chain rule and multiply by the “derivative so far”: that’s just Lq\frac{\partial L}{\partial q} = 16. So, our answer is La=116=16\frac{\partial L}{\partial a} = 1 \cdot 16 = 16.

  4. Next, we’ll look at the bb branch of the graph. From similar reasoning to above, the derivative at the output of the “multiply by three” block is just 16. How do we use that to compute the derivative with respect to bb? To get from bb to that value, we multiplied by 3. So, the corresponding term in the chain rule is 3. We multiply that with what we have so far, 16, to get 48.

  5. Finally, let’s look at the cc branch at the bottom of the graph. We’ll start by computing the derivative with respect to rr. Similar to step 2 above, we multiplied rr by 11 to get LL, so that means that the derivative is 11.

  6. All we have left is to go through the “square” block. The derivative of its output with respect to its input is two times the input (in other words, rc=2c\frac{\partial r}{\partial c} = 2c). Since the input was 4, that means our new term is 8, and our overall derivative on this branch is 118=8811 \cdot 8 = 88.

Now we’re done! We’ve computed the derivatives, as shown in this completed graph with the backpropagation intermediate and final results in red below the arrows:

Loading...

Backpropagation

In general, all we need to successfully run backpropagation is the ability to differentiate every mathematical building block of our loss (don’t forget, the loss depends on the prediction). For every building block, we need to know how to compute the forward pass (the mathematical operation, like addition, multiplication, squaring, etc.) and the backward pass (multiplying by the derivative).

(Optional) Backpropagation in pytorch

Let’s see what this looks like in code using pytorch. We start by defining tensors for a, b, and c: tensors are the basic datatype of pytorch, much like arrays in numpy.

import torch

# Torch tensors are like numpy arrays
a = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)
c = torch.tensor(4., requires_grad=True)

We then define tensors for q and r. Note that each one contains both the value computed as well as the necessary operation to compute the gradient in the backward pass:

q = a + 3 * b
r = c ** 2
print(q, r)
tensor(11., grad_fn=<AddBackward0>) tensor(16., grad_fn=<PowBackward0>)

Finally, we define our loss:

L = q * r
L
tensor(176., grad_fn=<MulBackward0>)

Now that we’ve computed the loss, we can have PyTorch run backpropagation and compute all the derivatives with the .backward() method:

L.backward()

Let’s look at the gradient for each input:

print(a.grad, b.grad, c.grad)
tensor(16.) tensor(48.) tensor(88.)

We can see that the results match up precisely with what we computed above manually!

Loading...