28


Reinforcement learning

Time for our final phase of training: reinforcement learning. Reinforcement learning is as different from supervised fine-tuning as supervised fine-tuning is from base training, or even more so. I think you’ll find it fascinating and intuitive.

Of all the different ways you’ve picked up skills in your life, what type of learning has been most effective? Think of skills broadly—anything from academic like math and writing, physical like shooting a basketball or swimming, or even social. For me and I think most people it’s learning by doing. You want some idea from watching someone else or reading a book, but then you try on your own, make mistakes, and learn from them. As long as you can gauge success—answer to math problem is correct, basketball goes into the hoop—and aren’t ridiculously far from the mark, you learn from your mistakes, you learn what works, and you improve.

Here’s a question from the GSM8K dataset. It’s the same one I showed above in chapter 26.

Frankie's parents let him have many pets. He has six more snakes than he has cats. He has one less parrot than cats. Six of his pets have four legs. He has 2 dogs. How many pets does he have in total?

If you were in third grade and trying to become good at this type of word problem, your best strategy would be to do a practice problem, be told if you were right, find your mistake if not, then do another practice problem, and another, until some high level idea of how to do word problems got locked into your brain.

For our model, we’re evaluating it on how many word problems it gets right, just as we would a human taking a test. But during the training we’ve done so far, we never give it a chance to practice and learn. We also give it almost no credit if it gets the right answer in a different way. The signal from backpropagation is based on how well the model predicts all of the tokens in the canonical answer. For example, for the word problem above, the model will try to learn to match this assistant message:

<|assistant_start|>He has 6 - 2 = <|python_start|>6-2<|python_end|>

<|output_start|>4<|output_end|>

4 cats.

He has 4 - 1 = <|python_start|>4-1<|python_end|>

<|output_start|>3<|output_end|>

3 parrots.

He has 4 + 6 = <|python_start|>4+6<|python_end|>

<|output_start|>10<|output_end|>

10 snakes.

He has a total of 2 + 4 + 3 + 10 = <|python_start|>2+4+3+10<|python_end|>

<|output_start|>19<|output_end|>

19 pets.

#### 19<|assistant_end|>

As of the end of SFT training, we’re getting 13% of these GSM8K problems right. Our goal now will be to see if we can bump this up by letting the model try these problems on its own during training and learn from when it gets them right.

During reinforcement learning, we’ll need to judge if the model has done something good or bad. This is called a reward. If a model is being trained to play a video game, for example, the reward could be coins earned in the game plus appropriate additional points for completing levels. For GSM8K, the only dataset we’ll be using for reinforcement learning with Nanochat, Karpathy kept it simple: a reward of 1 if the answer is correct and a reward of 0 if not.

When we trained in the earlier phases, we always started with a certain number of input tokens and a matching number of target tokens. We then compared input tokens to target tokens to compute loss, ignoring certain tokens for SFT training. Here’s a reminder:

Figure 28.1. In earlier training phases we pulled input and target tokens directly from our datasets.

For reinforcement learning, we will first generate these input and target tokens using the model, and then use the model a second time to calculate loss.

Figure 28.2. In reinforcement learning, we use the model to generate training data.

The engine takes an input prompt, feeds it to the model, takes the new predicted token and then feeds that back in over and over until a complete assistant response is generated or we hit a maximum number of tokens. The engine is also responsible for the KV Cache (see 17) and calling out to the python calculator as needed along the way as described in chapter 26. This is exactly what we’ve been doing all along during evaluations (and when I’ve been trying out the model), but it’s our first time using it during training.

Inside the blue engine box we operate the model in inference mode. We use it to predict (infer) tokens but none of the calculations will be tracked for backpropagation. You can think of this process as having two somewhat disconnected components. One generates data and the other uses the data to train.

You’ll notice that I show the usual “x” tokens going into the GPT model and target tokens going into the loss calculator. The “x” tokens are the full conversation generated by the engine. The target tokens, as in earlier training phases, are the same tokens except shifted over by one so we can calculate loss based on how well we predict the next token for each “x” token. We also use the same supervision logic as in SFT so that we consider only the tokens the assistant is responsible for in the loss.

If we stopped there, all that we would achieve is training our model on lower quality and likely often wrong conversations. This is where the rewards come into play. We need the loss calculation to take into account the reward associated with a conversation.

In figure 28.2 I show a single generated conversation. This won’t be all that helpful because it means the model only has one chance to generate a correct or incorrect solution to the word problem. What we want is to generate multiple different responses with a high enough temperature (see chapter 9) that each response will be different.

Let’s say the engine generates the following conversations for the “Frankie’s parents…” word problem prompt above. (Of course the responses are fairly long. The ellipses in the table below are covering up a lot of tokens.)

Table 28.1. The start and end of the assistant part of five example conversations generated by the model in response to a word problem.

How would you assign the rewards? The assistant response in the first conversation is wrong because the correct answer is 19. The second conversation is intended to be the one shown in full at the start of this chapter and is correct. In the the third conversation, the model gets so off track that it doesn’t even finish with “####” and a number. In the fourth response the model takes a different approach but gets the right answer. And the final response is wrong.

Table 28.2. Reward assigned to each generated conversation.

We need to somehow spread these rewards over the supervised tokens such that the final loss will be brought down by tokens that are presumably good (the ones in conversations 2 and 4) and brought up by the tokens that are presumably bad (the ones in conversation 1, 3, and 5). In other words, we assume that the model worked through the problem in a good way in the responses where it got to a correct answer and so we want to reinforce that way of thinking.

We calculate an advantage score for each conversation. The advantage tells us how much better or worse the rewards for each conversation are than the average for the conversations. We calculate it by subtracting the mean of the rewards from the rewards and this will cause all the advantages to sum to zero. Something you may find intuitive is that if all the conversations are correct, there’s nothing to learn. A little less intuitive is that if all of the conversations are wrong, there’s also nothing to learn.

Table 28.3. Advantage for each generated conversation.

We’re getting close. We need to now bring these advantages into the loss function. Let’s pretend that we ran each conversation through the model, compared the supervised tokens with the targets, and determined:

Table 28.4. Sum of per-supervised-token log probabilities for the generated conversations.

The sum of log probabilities means the logs of the per-token probabilities added together. This is the negative of negative log loss :), or think of it as the “natural log of this probability” row shown in table 8.7 before the multiplication by -1. Unlike with negative log loss, here a bigger number means the conversation is more likely according to our model. For example, the model, unfortunately, considers conversation #3 to be the most likely when we normalize for length.

Next we bring together the advantages and the sum of log probabilities. Reinforcement learning terminology is about maximizing rewards rather than minimizing loss. Everything is reversed. You learn a policy via gradient ascent to maximize rewards rather than learning a model via gradient descent to minimize loss. For example, if using reinforcement learning to learn how to play a video game, the policy will specify what keyboard / mouse / joystick actions to take, and the policy can be learned in an analogous way to how we’ve been learning our model parameters. I share all of this only so the term policy gradient objective will make some sense in table 28.5. As you see, we turn it right back into a loss to match up with the rest of our machinery which is focused on minimizing loss.

Table 28.5. Loss calculation.

To see how this works, let’s pretend that we successfully update parameters such that correct conversation #2 becomes more likely. This would show as an increase in the sum of log probabilities. Since this is desired, we would expect the loss to decrease.

Table 28.6. Loss calculation after conversation #2 becomes more likely.

And yes, the loss decreases.

I showed five conversations. When we train in a moment, in each step we’ll use 16 word problems from GSM8K, and for each we’ll generate 16 conversations. This is a total of 256 conversations per step (i.e. the loss calculation like that shown in table 28.5 will have 256 rows). We’ll spread the conversations across the GPUs, calculate loss as shown, and ask the optimizers to update the parameters once per step as usual. There are 7,473 word problems in the training portion of GSM8K so we’ll need 7,473 ÷ 16 = 467 training steps.

In all of our evaluations so far (CORE and ChatCORE) we only looked at the model’s highest probability predictions. For example, going back to the SQuAD example I showed in chapter 24, the model predicted “The Partridge Family” which was wrong, so we gave the model a zero on that question. We didn’t look to see if given a few chances at a higher temperature it could have generated “Duel” and then given it partial credit.

Our presumption in this training phase is that given a few chances, the model will sometimes generate the right answer and sometimes generate the wrong answer and we’re going to nudge the parameters toward making it more likely that its highest probability output is a correct answer. It therefore makes sense to track how the model does with multiple chances during the evaluation.

This multiple-attempt metric is called pass@k. For example, a pass@3 of 25% would mean that the model correctly answered 25% of all of the GSM8K word problems in an evaluation set given three attempts to get the right answer per problem. We’ll give the model eight chances per problem which will let us measure pass@1, pass@2, and so on up to pass@8. We’ll want pass@8 to be higher than pass@1 in the beginning or there will be nothing to learn, and we’ll hope that all of these metrics increase over the course of training.

We’ll pause and conduct the pass@k evaluation every 60 training steps. We’ll use 400 word problems from the same testing GSM8K dataset used in ChatCORE. The model has never seen these problems during earlier training phases and will not be trained on them during this phase.

When I trained my smaller 20-layer model I saw something weird—the pass@k metrics improved and then began to decline. I suspect Karpathy encountered this too because he configured the code to save a checkpoint every 60 training steps and we’ll do the same.

I’m not sure why the metrics decline. Perhaps, as with too much interbreeding, there is a collapse from the model fitting too much to its own generated tokens. Or maybe under the incredible pressure from the gradient the model discovers a way to cheat—after all it has seen these specific GSM8K problems before during mid and SFT training—and so instead of getting better at actually solving the problem, it instead gets good at pulling out a memorized answer. Whatever the situation, if the same thing happens when we train our 32-layer model in a moment, my plan is to pick the checkpoint from whenever the metrics are at their peak as our final model version.

One other metric we’ll track along the way is average reward. If you look back at table 28.2 you can see why this will be helpful. If the model is getting smarter at solving these word problems, we should expect more conversations in a given step to be right, so the average reward will go up. It’s also a type of sanity check because if the average reward is nearly always zero it means the model is never generating correct conversations and there’s no way to learn.

Okay, let’s do it. I’ll run the reinforcement learning and then do another full ChatCORE evaluation.


That wasn’t so fast. The training took an hour and forty-five minutes and the chat evaluation took another ten minutes. Total spend was just under $50. The underlying GSM8K training conversations contained about 1.3 million tokens. If we assume the generated conversations were roughly the same size, that’s around 16 × 1.3 = 21 million tokens. This is less than 5% of the tokens we trained on in mid-training which took under half an hour in total. I’m not surprised it was slower because we had to generate conversations and work in small, inefficient batches, but I am surprised at how much slower.

Anyway, let’s get to the exciting parts. Here’s pass@8:

Figure 28.3. Pass@8 for our abbreviated GSM8K evaluation.

Before the start of training our pass@8 was 0.275. This means with eight tries we were getting 110 out of the 400 evaluation problems right. That alone is interesting since if you look back at table 27.5, our performance with only one chance was 13%. Our model is not that large, and word problems are tricky. It’s good to know that even when we failed before, given a few chances the model could sometimes come up with a correct solution. This is nothing to laugh at. A million monkeys typing at a million words per minute for a million years wouldn’t have a shot at typing out a correct solution. (Although they would occasionally type the correct numerical answer.)

Even more exciting is that we moved the ball forward to around 39% correct. That’s a net additional 46 problems that the model learned how to solve. Here’s pass@1:

Figure 28.4. Pass@1 for our abbreviated GSM8K evaluation.

This metric had even more dramatic improvement. We almost tripled from 6.75% to 19.5% at step 360. If you’re wondering why the starting point of 6.75% is so much lower than the ChatCORE score of 13%, it’s not an apples to apples comparison. In ChatCORE we take the single most probable output. In pass@k we set a higher temperature and so it’s less likely that the first generation will be the best. The dramatic improvement in pass@1 should indicate that even slightly less likely generations are correctly solving the word problem. I’ll dig into this later.

I mentioned above why we keep an eye on average rewards per step. Here’s the plot:

Figure 28.5. Average rewards per step.

The single lowest average was 0.015625 at step 35. This corresponds to four correct conversations out of 256. I take it as a reasonable sign that we never had a full batch of 256 conversations with none correct. The single highest average was 0.38281 at step 348, corresponding to 98 correct conversations. Overall the average rewards trended up, another sign that the model was learning. Let me smooth the plot to be sure:

Figure 28.6. Smoothed average rewards per step.

Looks good.

I’m concerned about the drop in pass@1 and pass@8 after step 360. Although I’m not showing the charts, the same trend is present in pass@2 through pass@7. For this reason, I used the checkpoint from step 360 for the ChatCORE evaluation (table 28.8 below) and will also use it as the model we actually chat with in the next chapter.

As usual, I’m itching to look at specific examples. I’m curious to see, for a word problem that the model gets correct in some generations and wrong in others, how exactly is it getting it right or wrong. I wrote a script that goes through the first 200 GMS8K problems in the evaluation dataset and for each prints out the correct answer, the highest probability generation, and eight generations using a higher temperature. Here’s a word problem where the highest probability generation was correct:

Gunther, the gorilla, had 48 bananas hidden under a fern branch. When Gunther wasn't looking, Arnold, the chimpanzee, stole half of the bananas from the pile. The next day, Gunther added another 25 bananas to his pile, but later that evening, Arnold stole another 12 of the bananas. On the third day, Gunther added another 6 bananas to his pile and began counting bananas. How many bananas did Gunther find were in the pile?

Even reading the problem is a reminder how much less lazy these models are than we humans. I have zero desire to read it carefully enough to solve it, but the model will happily solve it nine times. Here’s its highest probability generation:

<|assistant_start|>Gunther had 48 bananas hidden under the fern branch.

When Gunther wasn't looking, Arnold stole half of the bananas, which is 48 / 2 = <|python_start|>48/2<|python_end|><|output_start|>24.0<|output_end|>24 bananas.

On the third day, Gunther added another 25 bananas to his pile, so he now has 24 + 25 = <|python_start|>24+25<|python_end|><|output_start|>49<|output_end|>49 bananas.

Arnold later stole another 12 bananas, so Gunther now has 49 - 12 = <|python_start|>49-12<|python_end|><|output_start|>37<|output_end|>37 bananas.

On the third day, Gunther added another 6 bananas, so he now has 37 + 6 = <|python_start|>37+6<|python_end|><|output_start|>43<|output_end|>43 bananas in his pile.

#### 43<assistant_end>

That first “On the third day” should say “The next day,” but all the numbers and steps are correct and it gets the correct answer of 43. It’s incredible how even our relatively small model has learned when and how to use a tool (the Python calculator).

Here’s the first of the higher temperature responses. It’s wrong, but of course we humans are also entitled to make mistakes that we keep in our mind. (For now. Mind reading is becoming real and in the future we may no longer be able to keep our private thoughts private.)

<|assistant_start|>Gunther started with 48 bananas.

Half of the bananas were stolen by Arnold on the first day, so Gunther started with 48 / 2 = <|python_start|>48/2<|python_end|><|output_start|>24.0<|output_end|>24 bananas that were stolen.

Gunther added 25 more bananas on the second day, so now had 24 + 25 = <|python_start|>24+25<|python_end|><|output_start|>49<|output_end|>49 bananas in the pile.

On the third day, Arnold stole 12 more bananas, so now Gunther had 49 + 12 = <|python_start|>49+12<|python_end|><|output_start|>61<|output_end|>61 bananas in the pile.

Gunther added the 6 bananas he found on the third day, so now he had 61 + 6 = <|python_start|>61+6<|python_end|><|output_start|>67<|output_end|>67 bananas in his pile.

#### 67<assistant_end>

The text of this answer is mostly correct, but the model makes a sloppy error when it writes “49 + 12” instead of “49 - 12.” I can only imagine that as researchers witnessed GPT models writing out steps and solving word problems for the first time it motivated them to keep pushing—could a bigger model solve harder problems? The answer is yes, but I think it’s one of those things that’s only obvious in retrospect.

Here’s a problem the model did not solve in any of the generations from my script. In other words, it counted as a fail in pass@8:

Scarlett found an aquarium for $10.00 at a yard sale. At the pet store, she bought 2 bags of rocks for $2.50 each and 3 pieces of coral at $2.00 apiece. She bought 20 fish at $0.50 each and she needed fish food that cost $2.00. How much did she spend?

At first glance it doesn’t look especially harder than problems the model solved. Looking at the nine generations, which I won’t paste here, it makes a different sloppy error each time, for example forgetting to add the $5 for the rocks even though it calculated the $5. I’m thinking there are too many calculations in this problem for our model not to lose track, but this is just speculation.

Finally, to end on something of a high note, here’s a problem that the model got right with the highest probability generation (although for slightly wrong reasons) and with seven of the eight higher temperature generations. I suppose you can judge how easy or hard a problem is by the percentage of multiple generations that are correct.

The price of a laptop is $1000. If you get a 20% discount, how much do you have to pay?

Here are all nine generations so you can see the different ways the model approached the problem. It used the calculator multiple times in every case but I removed those parts to make things less cluttered.

Table 28.7. Nine generations for a single problem.

Zooming out, I ran the full ChatCORE evaluation on the step 360 checkpoint. Here are the results:

Table 28.8. ChatCORE metric for the final version of our model.

Good news. We drove GSM8K accuracy from 13% to 19%. And overall, despite small declines in some of the other tasks, ChatCORE increased to 0.382. It’s time to chat!