4
Backpropagation
One nice thing about the technique we used to figure out w1 and w2 for our turkey feather model is that, if you squint, you begin to see how we could use the same approach with more complicated models. A crucial part of the technique was calculating the partial derivatives of loss with respect to w1 and w2. We did it by writing out the equations and somewhat manually working out the derivatives.
Will this scale? When we create a model with a billion weights, which we’ll get to, we know we’ll do calculations with those billion weights to make predictions and compute loss. But how long will it take to then calculate the partial derivatives of loss with respect to each of those billion weights?
I don’t think the answer is obvious, but here’s an intuition. To calculate the partial derivative of w1, we treated w2 as a constant and did a whole bunch of calculations. If you count, the whole bunch was about the same number of calculations as it took to compute the loss in the first place. To calculate the partial derivative of w2, we treated w1 as a constant and did another whole bunch of calculations. So with a billion weights, you get the sense that calculating the partial derivative of each one could require a billion times as many calculations as it took to make the predictions and compute loss. This could be infeasible and remember, we’re talking about calculations that need to be repeated for every single step (e.g. the rows in table 3.11) as we repeatedly adjust our weights.
Imagine I ask you to solve an intricate maze on a piece of paper. The maze has 100 entrances and a single finish, say a piece of cheese drawn in the middle. The final, optimal solution will be the same no matter where you start solving it from, but common sense will tell you to start from the center instead of arbitrarily picking one of the entrances.
There are many problems in computer science that work this way—approach them one way and the combinations get out of control, approach them backwards or by mixing forwards and backwards and they become tractable. As researchers in a variety of fields began to use computers to build and calculate models (e.g. economic models, early AI models) in the 1960s, they ran into the need to calculate many partial derivatives and recognized the efficiency of working backwards.
Amazingly, with backpropagation, the number of calculations to compute the partial derivatives of loss with respect to all of the weights will be about the same as to compute the loss itself. The key idea is that we can propagate the derivative of loss back through all of the original calculations rather than forward calculating the partial derivative of each weight. This will make more sense after I work through a few examples.
Forget about turkeys, feathers, and weights for a moment and go back to high school calculus. Let’s say you were asked to calculate the derivative of this function:
f(x) = (x2)3
You probably would have said: f(x) = x6 so the derivative of f(x) is 6x5. If x is 3, and we increase it a tiny bit, the output of the function will go up by 1458 times that tiny bit (talk about leverage!). As we did in chapter 3, let’s say our tiny bit is 0.01 and plug in real numbers:
f(3) = 729
f(3.01) should be around 729 + (1458)(.01) = 743.6
f(3.01) = (3.01)6 = 743.7
Great. Now let’s think of this problem in a different way that at first will seem more convoluted. If you do not simplify f(x) = (x2)3 to f(x) = x6, the way you would do the calculation in your head or on paper would be to first square the input and then take the result and cube it. This is also how a computer would calculate it.
We now have a machine with two steps. (I’m using “step” here to mean something different from the steps in the table 3.11.) Let’s plug in 3 as our input:
The number 3 goes into step A and 9 comes out, 9 goes into step B and 729 comes out. The output 729 is what we calculated a few paragraphs ago and you probably feel like I’m making a big deal out of nothing. Here’s where it gets interesting. Pretend A and B have dials off of which you can read numbers.
The dial on B is set to 9 automatically by B’s input. That’s just the way this machine works and we have no choice in the matter. But what if we could stick our finger on that red needle and increase it a tiny bit? By how much would the 729 increase? The derivative of x3 is 3x2. You can take my word if you don’t remember the rule from calculus. Evaluate the derivative at the input to box B (9) and we have 3(92) = 3(81) = 243. So if we turn B’s dial a tiny bit, the 729 will increase by 243 times that tiny bit.
Now let’s look at A and forget about B. The dial on A is set to 3. If we turn it up by a tiny amount, by how much will A’s output increase? The derivative of x2 is 2x. Plug in 3 and we get 2(3) = 6.
If we turn the dial up on A a tiny amount, A’s output, which is the input to B, goes up by 6 times the tiny amount. Each little increase in B’s dial causes B’s output to go up by 243 times as we calculated above. So overall, a tiny increase in A causes the overall output of 729 to increase by 6 × 243 = 1458 times the tiny amount. This matches our calculation at the start of the chapter.
I’m now going to organize the calculation in a diagram. It may look a little strange at first. See if you can puzzle out why I arranged it this way.
In the forward calculation, the diagram where the arrows go to the right, we start with an initial input, do whatever calculations we want to do, and get a final output. In the backward calculation, we start on the right and work our way back to the initial input so we can see how changing the initial input will change the final output. In the forward calculation, we compute whatever final output the machine is supposed to compute. In the backward calculation we’re computing something different: the derivative of that final output with respect to the input.
Let’s look at the box “B backward.” It takes its input and multiplies it by 243. I italicized “input” to emphasize that I’m talking about the input from the arrow to the right, not the input to box B in the forward calculation. Think of “B backward” as saying—some other box to the right of me already figured out by how many times a tiny increase going into it in the forward calculation will the final output increase; I don’t need to know the details; I just know that whatever that number is, if I multiply it by 243, that will say by how many times a tiny increase coming into me in the forward calculation will increase the overall output.
Let me now say the same thing a little more elegantly using the term derivative. Some other box to the right of me already knows the derivative with respect to itself of the final output; I don’t need to know the details; I just know that if I multiply that number by 243, that will be the derivative of the final output with respect to me.
So why does the diagram show a 1 coming in from the far right? Because there is no box to the right of B. B is the last step in the machine. An increase of a tiny amount in the output is an increase of 1 times that tiny amount in the output. I know that sounds silly, but it explains why our starting number for the backward calculation is always 1. Or think about it this way—imagine the final output is itself a box with a dial from which we read the output. As with other steps, we want to ask this question: If we force the dial to increase by a tiny amount, how much does the final output increase? The dial is the final output, so the increase will be 1 times the tiny amount.
Let me show the forward and backward calculations on a single diagram:
The exact details are less important than the idea that we can come up with model and loss calculations that involve a huge number of weights and all sorts of operations—multiplying, adding, raising to a power, taking a logarithm, replacing negative numbers with zero, clipping numbers, and there will still be an efficient (and organized!) way to find the partial derivatives with respect to the weights.
You might be objecting at this point. Why do the calculation backwards? Why not start with 1 on the left, multiply by six, multiply by 243, and we’ll still get 1458. Isn’t that less convoluted and doesn’t it involve the same number of calculations? Yes. But that’s only because in this example we’re calculating a single derivative. To give you this intuition, let’s say I asked you to multiply two sets of numbers:
a) 32 × 93 × 38
b) 14 × 93 × 38
Since we’re talking multiplication, order doesn’t matter. If you multiply from left to right, you’ll need to do two multiplications for each of a) and b) for a total of four multiplications. If you start on the right, you can reuse 38 × 93 from a) for b). This will be a total of three multiplications. When we’re calculating billions of partial derivatives of a single loss number we end up in a similar situation. The multiplications “closer to” the single loss number are repeated over and over for each partial derivative and we can avoid repeating them by working backward.
If some of this is sounding familiar, especially the part above figure 4.4 where I explained the intuition for why we multiply 243 by 6, here’s why. In high school you may have learned about the chain rule to compute the derivative for functions that are chained together. It was likely explained like this: If you have a function f(g(x)), then the derivative of f with respect to x is the derivative of f with respect to its input evaluated at g(x) times the derivative of g with respect to x. If f(x) is x3 and g(x) is x2 then this is exactly the example we’ve been working through. The derivative of f with respect to x according to the chain rule will be 3(g(x)2)(2x) = 3((x2)2)(2x) = (3)(x4)(2x) = 6x5. Plug in 3 for x and you get 1458, just like above.
Let’s do the forward and backward calculations for our turkey example in diagram form. Our goal will be to calculate loss in the forward calculation, and then calculate the partial derivatives of loss with respect to w1 and w2 in the backward calculation. There will be a lot of boxes and numbers. Don’t worry. Soon we’ll get to a cleaner way to organize the information and calculations.
For convenience, here’s the table with table with our turkey data (same as table 3.1):
First let’s draw all the boxes for the forward calculation.
There are a lot of boxes and arrows but mostly because we have to repeat everything for each of our three turkeys. Look at turkey #1. We take weight 1 (w1) and multiply it by the height of the turkey (1 meter), we take weight 2 (w2) and multiply it by the length of the turkey (1.5 meters), we add those two numbers together (that’s our prediction for number of feathers), we subtract 5000 (the actual number of feathers), that’s our error. Then we square the error, repeat for turkey #2 and turkey #3, and then add the three squared errors together. That’s our loss.
You might wonder why inside the boxes I refer to the input as “input” instead of the much shorter “x” as in figure 4.1. I could have used “x,” but “x” typically means data in AI (e.g. data about turkeys) and once you get used to that convention in future chapters, I don’t want you to come back to this diagram and wonder if I mean “x” the input to the box or “x” some data.
Here’s the forward calculation for w1=1000 and w2=3000, the same starting weights we used in chapter 3.
The loss, 1,312,500, is the same as we computed in chapter 3 (see table 3.3).
Now the fun and confusing part: the backward calculation. It took me a long time to work out the diagram. You may want to try on your own first. Even though I haven’t yet explained all of the rules of the road, trying will force you to think about what calculations must happen.
Let’s set up the calculation first, understand it, then we’ll plug 1 in on the right and see what we get. As in figure 4.4, I italicized “input” to remind you and me that the input into each box is from the right, as shown by the arrows. (This gets confusing because, as you know from figure 4.5, the backward calculation in boxes also depends on the input they got from the left during the forward calculation.)
As a reminder, our goal in this whole calculation is to figure out if we increase w1 by a tiny amount, by what multiple of that tiny amount will the loss change. And the same question for w2.
I want to explain how I came up with a few of the boxes. It might again be helpful to think about this yourself before reading my explanations.
- Box A - we take the input (from the right), multiply it by 1, and send it down the first arrow. We do the same with the second and third arrows. The reason this makes sense is that in the forward pass this box calculated “input 1 + input 2 + input 3.” The partial derivative of input 1 is 1 and same for inputs 2 and 3. Since each box multiplies by its derivative(s) to go backwards, we multiply input × 1 for each arrow. Or for the intuition, look at the top arrow bringing 250,000 into the box in the forward calculation shown in figure 4.7. If that 250,000 goes up by a small amount, the output of the box will go up by one times that small amount, and this would be true even if that 250,000 was a different number, because all the box does in the forward calculation is add its three inputs.
- Box B - during the forward pass this box did the calculation input2. The derivative of f(input) = input2 is 2 × input. The input during the forward pass was 500 so the value of the derivative is 2 × 500 = 1000. We multiply as we go backwards, so the backward calculation in the box is to multiply the input from the right by 1000.
- Box C - during the forward pass the box did the calculation input - 5000. The derivative of f(input) = input - 5000 is 1. The input to the box during the forward pass was 5500. However, just as with box A, this doesn’t matter. The derivative will be 1 regardless and so we multiply the input from the right by 1.
- Box D - during the forward pass the box did the calculation input × 1.5. The derivative of f(input) = (1.5)(input) is 1.5. The input to the box during the forward pass was 3000, but this is another instance where this doesn’t matter. The derivative is 1.5 so we multiply the input from the right by 1.5.
I didn’t explain what happens when multiple arrows come into a single box like w1. The rule is we add. To see the intuition, think about a small increase in w1. This will cause the loss to increase by a certain amount according to the first arrow coming into the box, by a certain amount according to the second arrow, and a certain amount according to the third. The total increase will therefore be the small increase in w1 times the sum of the amounts indicated by each arrow.
Now let’s plug 1 in on the far right and see what we get:
Look! The partial derivatives of loss with respect to w1 and w2 come out to 1875 and 3500, exactly as we calculated and showed in table 3.7 in chapter 3. Increase w1 a tiny bit, and the loss will increase by 1875 times that tiny bit. Increase w2 a tiny bit, and the loss will go up by 1500 times that tiny bit due to turkey #1, 2500 times that tiny bit due to turkey #2, and down by 500 times that tiny bit due to turkey #3, so in total 3500 times that tiny bit.
As a sort of sanity check, look at the number 1000 that I labeled F. You might be thinking that the number feels large. If I could force the (forward pass) output of box E to go up by a little bit, could that really cause an increase of 1000 times that little bit in the loss? Yes! Because in the forward pass, the output of E is 500 and, also in the forward pass, G takes that and squares it. The number 500 squared is 250,000. The number 501 squared is 251,001. The difference is around 1000. And since that’s added into the loss, it means the loss will go up by around 1000.
You’ve now seen a slightly more realistic example of backpropagation. We did a forward pass to calculate loss. And then in the backward pass we propagated the derivative (soon we’ll call this a gradient) back through the calculation to the weights.
You may also have heard of training a model. We’ve now seen the full idea of how a training loop works. Start with some weights. Do a forward calculation to compute loss. Do a backward calculation to get the partial derivatives of the loss for each weight. Update the weights a little. And repeat. It’s what we did with our two weights and three turkeys, and it’s also what’s used to train the models underlying ChatGPT with trillions of weights and tens of trillions of “turkeys.” Later I’ll get into what the “turkeys” are for these models.