19
Now is it time to train? First: optimizers
I’ve been yapping for so long you may have forgotten that we still haven’t trained our 32-layer model. The reason I introduced the CORE metric first is we’re going to want to use it during training. We’re about to embark on a long and expensive undertaking. For our model we’re talking a few days and around $1000 in GPU expense assuming Karpathy’s estimates are correct. For more serious models this could be weeks and tens of millions of dollars or much more. It would be irresponsible or at least anxiety-producing to train for so long with loss as our only way to judge if the training is working.
What we want to do is track a handful of indicators along the way. One is validation loss, of course, and another is our CORE metric. Every so often during training we’ll pause, measure CORE (on a subset of the 90,000+ items to save time), and keep track. If we’re 20% into training and the model isn’t getting smarter we know something is wrong. And oh, so many things can go wrong.
Which brings us to how we actually update the parameters during each training step. In chapter 3 I made this sound simple: calculate loss, do backprop, multiply gradient by learning rate, subtract the result from the parameters. This is the fundamental, beautiful idea of gradient descent.
Let me show an example using a model even simpler than our turkey feather and hedgehog quill predictors. Now we’re going to predict chicken feathers.
I have 1000 chickens. Here is a plot of length vs number of feathers for each chicken:
To keep the numbers simple, I subtracted the average of 55 cm from each length so that the lengths are centered at zero. I did the same with the number of feathers. For example, the chicken with the minimum length is around 52 cm and has around 8400 feathers. None of that actually matters for the point I’m going to make. Just keep it in mind if, like me, you find it helpful to think of data as mapping to real things. I’ll also tell you right now—it’s not a secret between us, only between us and the model—that I faked the number of feathers as 0.17 times the (centered) chicken length plus some noise.
For the turkey prediction we built a model with two inputs and two weights. For the alien hedgehog prediction we had one input and many weights, which we needed to match the odd relationship between length and number of quills. For chickens we’re going even simpler: a single input and a single weight.
Our goal during training is to get that weight to end up at 0.17. We’ll calculate loss the same way as we did in the hedgehog example: mean squared error. Look at chapter 10 if you want a reminder.
One nice thing about having a single weight in the model is we can easily come up with a range of weights and calculate the loss for each. For example, with our chickens, a weight of -2.0 will result in a loss of 4.8. A weight of -1.9 will result in a loss of 4.5. I’ll keep going up to a weight of 2.0 and plot:
You can see that the minimum loss will be when the weight is around 0.17, which is expected since that’s the number I used to fake the data.
During training, we start with our initial parameters (a single weight in this case), calculate the gradient, update the parameters, calculate the gradient again, update the weights again, etc. Let’s say we start with an initial weight of 1. The loss will calculate out to 0.7. That’s a point we could plot on the graph above. An advantage of keeping things to a single weight is we should be able to watch gradient descent. Our first point might be at (1, 0.7), and then our next point could be (0.8, 0.4), until we end up where we want with a weight of 0.17.
Let’s try. Let’s take a few steps with a learning rate of 1.0 starting with a weight of 1. You’ll have to trust that I calculated the loss and gradient correctly because I didn’t give you the chicken-by-chicken data shown in figure 19.2.
Recall from chapter 3 that the bigger the learning rate, the bigger step we take, because we update the weight by subtracting the learning rate times the gradient. With a learning rate of 1.0, we stepped from a weight of 1.0 all the way to a weight of -0.70. Let me plot the weights at these four steps:
This doesn’t look promising. Our weights jump from 1 to -0.7 to 1.1 to -0.83. It feels like we’re moving away from our goal. Maybe it will just take more steps to correct itself. After all, the gradient should always be telling us which way to move.
Now it’s clear that each step is taking us further away from our goal and we’re never going to converge on 0.17. Let me also plot the weight and gradient at each step, like what I showed in table 19.1 above but for all 20 steps:
At first the gradient is around 1.7, which makes sense since the slope at red point 1 is positive. Then we jump to a weight of -0.7 shown as red point 2. The gradient here is -1.8 which also makes sense because the slope at red point 2 is negative. Now we multiply our learning rate of 1 times -1.8 and subtract that from -0.7 bringing us to point 3 with a weight of 1.11 where the gradient will be an even greater positive number than at point 1. This is a vicious cycle that will not stop.
So the gradient is telling us the right direction to move in, and giving us a clue about how much to move, but we’re moving so far that we’re overshooting the weight that minimizes the loss. So let’s pick a different learning rate. Say 0.001 which is the default for the stochastic gradient descent function built into PyTorch. Let’s again do four steps.
The good news is we didn’t jump over the goal. The bad news is we barely moved at all. The points are so close that they look like a single point and the labels (1, 2, etc.) all fall on top of each other. Let’s do 100 steps.
We’re getting there, but very slowly. Let’s try a learning rate of 0.1 and do 30 steps.
This seems reasonable. It’s hard to see with all the dots converging near the minimum. We can plot the weights at each step and see that by around step 22 we’re close to a weight of 0.17.
Are we wasting steps? Should we use an even higher learning rate, something greater than 0.1 but less than the 1 we started with? Let’s try 0.5.
Look! We got there in around three steps. However, you might guess that there was luck involved, and that sharp turn in the right plot is a red flag. Still, let’s try an even bigger learning rate. Maybe we can get there in a single leap. Here’s a learning rate of 0.9:
Hmm. You can see it will converge, but even after 30 steps it’s not quite at 0.17. You can see what’s happening. We’re jumping over the minimum, but unlike when the learning rate was 1, it’s not fatal.
Let’s try a learning rate of 0.97 and do 100 steps.
This is really bad. If we rapidly explode out of control at least we’ll know. Here we might not even realize we’re in trouble.
I share all of this because if it’s this hard to get the simplest possible model with a single parameter to train, how are we possibly supposed to train a model with billions of parameters? One major advantage we have here is we can see what’s going on. This would also be true if there were two weights and we could visualize the loss function as a surface in three dimensions. We can’t visualize a billion dimensions and that only scratches the surface of our problems. What if we need different learning rates for different parts of the model? Why should we assume that gradients will have similar orders of magnitude everywhere? Or what if we want to start with one learning rate and decrease it later? How do we even know how to pick a starting learning rate? The default rate of 0.001 wasn’t good even for this simplest of all possible models.
Researchers began bumping up against these problems in a big way around 2010. Computers became powerful enough to train deeper models with more parameters, and there was evidence these models were useful and could beat simpler methods, but it also became clear that training was hard. Stochastic gradient descent wasn’t quite the miracle worker I made it out to be in chapter 3. A little bit of hunting for the right starting learning rate, a little bit of extra babysitting of the rate along the way, and a little bit of capturing parameters every so many steps and restarting from last known good point was fine up to a point, but couldn’t scale. So researchers began searching for ways to make training more robust.
Let’s give a name to the part of the training process where we use the gradient to adjust weights. PyTorch uses the term optimizer. I tell the optimizer—look, here are the model parameters, here are the corresponding gradients, now it’s your job to update the parameters. A basic optimizer could multiply gradient by learning rate and subtract from parameters. A more sophisticated optimizer could do something else.
It’s beyond me and beyond this book to explain the various types of optimizers, the research that led to them, and what they’re good for. I will though zoom into the two optimizers we’ll be using: Adam (or technically AdamW) and Muon.