3
A basic model: predicting feathers
When I was two weeks into playing with Nanochat, which we’re going to get to, we were out for a family dinner—Otto in Brookline, Massachusetts—and I wanted to explain to my high school sophomore daughter the idea of learning weights. We went back and forth over an example where we could keep the numbers in our heads while pizza was in our hands. I’m going to recreate the example here with a little more fake data since I can now write things down on this piece of screen.
In chapter 2 we talked about models for predicting words. Now we’re going to build a model to predict the number of feathers on a turkey. Don’t ask. That’s just what came to mind when we were chatting. There are a lot of wild turkeys wandering around Newton, where we live, and this wasn’t all that long before Thanksgiving.
But more seriously—you’re going to encounter models for predicting turkey feathers, hedgehog quills, chicken feathers, and home prices in the coming chapters. You’ll have to trust me that I’m not leading you down needless tangents. I need to build up to the concepts underlying GPT models and we’ll both find it easier if I use specific, contrived examples instead of keeping everything abstract.
Think of this model as a machine. It takes two inputs: the height of the turkey and the length of the turkey. It gives one output: its prediction for the number of feathers.
Without thinking too hard, it seems like we could multiply the height by some number, multiply the length by some other number, add those together, and that will be our prediction. Let’s call these numbers weights, as in the weights of the model.
Prediction = height × weight_1 + length × weight_2
Going forward, to make it a little easier to read, let’s write weight_1 as w1 and weight_2 as w2
Let’s also say we have data on three turkeys meaning we know their height, length, and number of feathers:
If our model is any good, then:
- 1 × w1 + 1.5 × w2 should be close to 5000
- 0.75 × w1 + 1.25 × w2 should be close to 3500
- 1.25 × w1 + 1 × w2 should be close to 4500
People figured out the math for picking the best possible w1 and w2 hundreds of years ago. However, we’re not going to worry about that. Instead we’ll purposely go about it in a way that will seem overcomplicated and expensive to compute because later on this technique will prove magical.
So what I said to my daughter over pizza was: Forget about finding the ideal w1 and w2 for the moment, just pick some random numbers. Mia picked 1000 and 3000. Great, calculate the prediction for turkey number 1:
1 × 1000 + 1.5 × 3000 = 5500
Wow! Not too bad. Can we quantify how good or bad it is? Mia said to subtract the actual number of feathers, so 5500 - 5000 = 500. Let’s do that for all three turkeys.
Now we have a number that tells us how good or bad our prediction is for turkey #1, another number for turkey #2, and a third for turkey #3. I pointed out that if we pick new weights that improve the prediction for turkey #1 but make it much worse for turkey #2 that wouldn’t be so great. I asked how we could come up with a single number that said how good our prediction was across all three turkeys.
She said add the numbers together. If we nail all three predictions that number will be zero. But there’s a problem—if one prediction is way off in one direction, and another in the other, we’ll add a big negative and a big positive and get zero, but actually neither prediction is good. Take the absolute value she said. Yes! That way we treat -250 and 250 as the same thing, both 250 feathers away from the correct prediction.
At this point I stepped in and said, you know what, let’s square instead of taking the absolute value. Like absolute value it will take negative numbers and make them positive and it also has two other advantages which aren’t that important right now.
Now we’re talking about prediction minus actual for each turkey, the square of that, and the sum of those numbers. It’s going to get hard to keep having this discussion unless we give names to some of these numbers. I’ll do that in a table first and then explain.
The column labeled “prediction - actual” from table 3.2 is now called error. That makes sense because it is the error in our prediction. For turkey #1 the prediction is 500 feathers too high, for turkey #2 it’s 1000 feathers too high, and for #3 we’re 250 short.
Squared error is what the name says, the square of the error. For example, 5002 = 500 × 500 = 250,000.
Loss is the sum of the squared errors. Think of it as how much are we losing in our predictions. The less the better. A loss of zero would mean we nail every prediction. Look at the table and see why that’s true—the only way the loss could be zero is if each prediction perfectly matches the actual number of feathers.
This single loss number is a beautiful thing. Even though we’ve got these two weights we’re trying to figure out, and this set of data with three turkeys, each of which has a height and a length and a number of feathers, we’ve now boiled the whole thing down to a single loss number. If we tweak our weights and the loss goes down, that’s good. If it goes up we’re moving in the wrong direction.
Speaking of, I told Mia to change one of the weights so we could see what happens to the loss. She increased w1 from 1000 to 2000:
And now calculate our predictions, errors, and loss with these new weights:
The loss went from around one million to around six million, so for sure these new weights made our prediction machine worse.
We could guess again, but it sure would be nice to have a clue about which direction to move each weight in. If we increase w1 a little bit, will the loss go up (bad) or down (good)? What about w2? That type of question might sound familiar from high school calculus. You were shown an equation like this:
f(x) = x2
And asked a question like if x is 3 and it increases by a tiny amount, what happens to f(x)? You figured that out by taking the derivative of x2, which is 2x. Don’t worry if you don't remember the rule. The important thing is to have an intuition for what it means. The derivative of f(x) is 2x, so at x = 3 the derivative is 6. This means if x increases by a tiny amount, f(x) will increase by 6 times that amount. Let’s look at this with actual numbers:
f(3) = 32 = 9
Let’s make our tiny amount 0.01. If we increase 3 by this tiny amount, we expect f(x) to increase by 6 times 0.01, so 0.06, so f(3.01) should be around 9.06.
f(3.01) = (3.01)2 = 9.0601
To really jog your memory, in case this isn’t familiar, you can also think of the derivative of giving us the line tangent to the function at that point:
Okay, now forget about x and f(x) and let’s get back to figuring out how to tweak w1 and w2 to improve our turkey feather model. Since we want to take the derivative, let’s write out the loss in terms of w1 and w2 instead of plugging in the weights we guessed at above.
I left out the squared error to keep the table legible. Now let’s write out the whole equation.
loss = (turkey #1 error)2 + (turkey #2 error)2 + (turkey #3 error)2
loss = (5000 - ((1)(w1) + (1.5)(w2)))2 +
(3500 - ((0.75)(w1) + (1.25)(w2)))2 +
(4500 - ((1.25)(w1) + (1)(w2)))2
You can multiply that all out by hand and combine like terms, or use an online calculator. Either way you’ll get this:
loss = 3.125(w1)2 + 7.375(w1)(w2) - 26500(w1) + 4.8125(w2)2 - 32750(w2) + 57500000
Want some convincing we didn’t mess anything up? Plug in our original weights (1000 and 3000) and it should come out to 1,312,500. (I checked and it does.)
In our calculus example above, we calculated the derivative of f(x) with respect to x. That told us how a change in x affected f(x). Here we want the derivative of the loss, but with respect to what? w1? Or w2? We already know we care about both of them since we want to adjust both to get to the best model we can.
Here’s where we cross into territory that as far as I can remember we never learned in high school calculus, although when you get into it, you’ll see it’s less complicated than other high school calculus concepts.
We want to calculate the partial derivative of loss with respect to w1. (And we’ll also want to calculate the partial derivative of loss with respect to w2.) Think of this like if we hold w2 steady at say 3000, and we increase w1 a tiny bit, how much will the loss increase or decrease? We can calculate this just like a regular derivative if we treat w1 as our variable and w2 as a constant.
Partial derivative of loss with respect to w1 = 6.25(w1) + 7.375(w2) - 26500
Partial derivative of loss with respect to w2 = 7.375(w1) + 9.625(w2) - 32750
(You can either get these using an online calculator or by remembering a few rules from calculus.)
Evaluate the partial derivatives at w1=1000 and w2=3000:
Partial derivative of loss with respect to w1 = 6.25(1000) + 7.375(3000) - 26500 = 1875
Partial derivative of loss with respect to w2 = 7.375(1000) + 9.625(3000) - 32750 = 3500
We’re starting to deal with a lot of numbers again. Let’s stick them in a table so we don’t get confused.
If we increase w1 a little bit, loss will go up by around 1875 times that little bit. So we should decrease w1. If we increase w2 a little bit, loss will go up by 3500 times that little bit. So we should also decrease w2. Now we know in what direction to adjust w1 and w2, but not how much of an adjustment to make. If we adjust by a very small amount the loss will barely change. If we adjust by a huge amount we might overshoot the optimal value. We’ll get into this more later, but for now, assume we have something called a learning rate which we’ll set to 1% (0.01). The way we use the learning rate is by adjusting each weight by its partial derivative times the learning rate. This way we’ll take larger jumps when w1 and w2 are far away from optimal, and smaller jumps when we’re closer. Again, we’ll come back to this.
Does adjust mean add or subtract? If the derivative is positive we need to decrease the weight, so we should adjust by subtracting the partial derivative times the learning rate. Let’s do this:
Updated w1 = 1000 - (0.01)(1875) = 981
Updated w2 = 3000 - (0.01)(3500) = 2965
Calculate the loss using those weights and add that all to the table:
Hooray! Loss came down. Plug the new w1 and w2 into our partial derivative equations above and we can add the partial derivatives at these new weights to the table too:
The partial derivatives are both still positive, so we want to decrease both w1 and w2 again. I also added a step column. Let’s update the weights again exactly as we did above:
Our loss is under a million! Let’s do a thousand steps. I’ll only show some of them in the table.
We might have hoped for a loss of zero, but that’s probably not possible because we have three turkeys and two weights. 133,190 seems a lot better than 1,312,500. Let’s make predictions with w1=2311 and w2=1633:
We’re only off by 22 feathers on turkey #3. For turkey #1 we’re short 240 feathers, and for turkey #2, we’re over by 275 feathers. Much better than where we started.
You may be itching to know how close we got to the optimal answer. In the future we’ll work with models where that’s tricky or impossible to figure out. Here, though, it’s easy to compute the w1 and w2 that minimize loss. As I mentioned above, people figured out how to solve these problems a long time ago and without doing 1000 steps of computation. Set both partial derivatives above to 0 and you’ll have two equations and two variables. Solve for w1 and w2 and you’ll get (rounded to nearest whole number):
w1 = 2347
w2 = 1604
We got close. Let’s plug those weights in:
You also may be wondering if we could have gotten even closer to the optimal weights if we used a smaller learning rate and/or more steps. The answer is yes.