27


Supervised fine-tuning

We completed our base train, our mid train, and now it’s time for the third training phase: supervised fine-tuning or SFT. Base and mid train used identical mechanics with different data. Now we’re going to see new mechanics.

In base and mid-training, we tokenized our text and pushed the tokens into batches. As soon as we finished filling one sequence in a batch, we started filling the next sequence. Think of this as filling every seat on a train car and immediately starting to fill the next train car regardless if it causes traveling parties to get split up. This process resulted in sequences that ended like this:

Table 27.1. The last handful of tokens of three sequences in a training batch.

And others that started like this:

Table 27.2. The first handful of tokens of three sequences in a training batch.

Does this bother you? It bothered me when I first learned how tokens get formed into batches. We don’t want the model to learn that it’s okay to start text with a closing parenthesis or end a sentence with “system’s” leaving the writer’s point unresolved. We also don’t want it to learn that <|assistant_start|> can appear without a message from a user, or that it’s okay to not bring the assistant message to completion and mark it as done with <|assistant_end|>. The whole key to the transformer architecture is paying attention to information in prior tokens, not to mention the rotary embedding mechanism (see chapter 15) which lets the model exploit relative position. How can the model learn if we start sequences at random points in documents and sentences?

Perhaps because I’m used to writing traditional software, the sloppy alignment of text with batches looked like a bug to me. It’s not. It’s more like a correct tradeoff. Yes, the model will learn some suboptimal stuff, but that’s true anyway because we’re training on text scraped from the internet so it will see spelling that is wrong, grammar that is wrong, information that is wrong, and screwy, half-formed comments on reddit posts. The model will work in the end because the average of the messiness tends to be correct. Words may be misspelled in a bunch of ways but the correct spelling will appear far more frequently than the misspellings. This same idea applies to higher-level concepts. So from this perspective it’s no big deal to have sequences that get cut off or start in odd places. Yes, sometimes the model will oddly see an opening parenthesis at the start of a sequence, but this will be overwhelmed by the number of times where it correctly sees balanced opening and closing parenthesis. (We humans are also good at learning from incomplete, imperfect information. And just like models, if we see too much misinformation, we’re susceptible to believing it.)

Remember also that we trained on sequences of 2048 tokens which equates to around 10,000 words. For mid-training, our average conversation (from the first <|user_start|> to the final <|assistant_end|>) was around 500 tokens. With about four documents per sequence, we can be sure that the model will see more complete conversations than partial conversations. For base and mid-training, in the tradeoff between is it better to align text optimally or better to pack every batch to train as efficiently and on as many tokens as possible, the latter wins.

In this SFT (supervised fine-tuning) training phase, the tradeoff changes. The model has a strong foundation and now we want to fine-tune its parameters to sand away some of the rough edges, making it a more and more capable chat assistant. So in this phase, every sequence will be a complete conversation. For example, here are four conversations we’ll be training on:

Table 27.3. Four of the conversations we’ll be training on during SFT training.

I’ll plot each one so you can appreciate the vast differences in their lengths:

Figure 27.1. The user and assistant tokens in each of the four example conversations.

How are we going to get these into a batch? In base or mid-training, we would pack them into sequences of 2048 tokens as follows:

Figure 27.2. How we would have packed these example conversations into a batch during mid-training. The plot shows the first two of eight sequences in the batch.

In SFT we do not want to pack the conversations as shown in figure 27.2 because each sequence needs to be a complete conversation. We also can’t simply treat figure 27.1 as a batch—a batch can’t have one sequence of length 1038 tokens and another of length 330 tokens. Our entire architecture is built on doing massively parallel calculations via tensors and every row in a tensor must have the same number of columns. The easiest solution that fits with everything else we’ve been doing all along is to make the batch as long as the longest sequence. We’ll then need to pad the other shorter conversations with some token:

Figure 27.3. How we form a batch with the four example conversations in SFT training.

Now we have a tensor of size 4×1277 filled with tokens. We can stick it into our model and out will come a tensor of size 4×1277×65,536 representing, as usual, a prediction for the next token for each input token. (See chapter 8 for a reminder of why each prediction has 65,536 numbers.) What token should we use as the yellow padding token? We do have to use one of our 65,536 tokens and we also know the model will make useless predictions from each. It turns out it doesn’t matter what token we use. What matters is that we don’t let these predictions influence our learning. We need to ignore them in our loss calculation.

Look back at chapter 8 to see how we calculate loss. Since we take the mean of the loss of each individual token prediction, ignoring certain tokens is straightforward: leave them out of the mean. For example, of the 4 × 1277 = 5108 tokens in the batch above, we can include the losses for the 2693 tokens that are actually part of conversations and leave out the losses for the other 2415 tokens.

Is it crazy to do so much calculating and then throw so much of it away? If someone gave you twenty math problems to work out on a piece of paper and said they only needed the first ten, would you first do all twenty and then hand back the first ten? Of course not. With the example batch above we’ll probably do on the order of 10 teraflops of wasted calculations. That’s 10,000,000,000,000 floating point operations, over 200 million years of me calculating on paper.

But in the logic of GPUs with their incredible ability to perform calculations in parallel, throwing away 10 teraflops of calculations isn’t ideal but also isn’t crazy. Don’t think of it like a human or even a traditional computer program doing calculations. Instead think of this. There are 13,000 people at Beijing South station waiting to travel to Shanghai. A high speed train with 1300 seats pulls into the station every minute. Each train will get to Shanghai in the same amount of time regardless of how many passengers board. It’s more efficient and all the passengers will get to Shanghai sooner if we pack every car of every train, but it’s not horrible if half the seats are off limits. We’ll need twice as many trains, we wasted electricity and money, and it takes longer for all the passengers to get to Shanghai, but they’ll all get there within around six hours. If they walked it would take weeks.

Also, as you’ll see below, the total number of tokens we’ll train on this phase is tiny compared to the earlier phases, and so in the grand scheme of things it doesn’t matter that we’re less efficient. One of the golden rules of software engineering is if you’re going to optimize something, start with whatever takes up the most time. If a process takes 100 minutes to run, and you find a suboptimal subprocess that takes one of those 100 minutes and speed it up by an impressive ten times, you’ve barely achieved anything. On the other hand, if you double the speed of a different subprocess that takes up 80 of the minutes, you’ve reduced your total time from 100 minutes to 60 minutes. Of all training, our SFT phase is well under 1% of the total time. (The carpenter measures twice before cutting. The software engineer is supposed to measure at least once before optimizing!)

So now we’ve moved up this more refined (but less efficient) approach of doing one conversation per sequence. Does anything else bother you about how we did things in mid-training? Take a look at the 48-token conversation in our example batch:

<|bos|><|user_start|>Multiple Choice question: Which animal eats only plants?

- Cat=A

- Dog=B

- Lion=C

- Rabbit=D

Respond only with the letter of the correct answer.<|user_end|>

<|assistant_start|>D<|assistant_end|>

Our goal is for the model to give useful responses to the user messages. In this case this amounts to correctly answering a multiple choice question. Now think about the loss calculation just for this part of the batch. There are 48 tokens, so 1/48th of the loss will be based on how well “<|user_start|>” gets predicted by “<|bos|>,” 1/48th will be based on how well “Multiple” gets predicted by “<|bos|><|user_start|>,” and so on all the way up to 1/48th based on how well “D” gets predicted by the whole thing. When backprop does its magic, the influence of how well the model predicts “Multiple” will be similar to how well the model predicts “Dog” will be similar to how well the model predicts “D.”

Doesn’t this seem crazy? One of those predictions is much more important than the others. In fact, “D” is the only prediction we care about and the only one that will count towards a similar question being judged right or wrong in the ChatCORE evaluation. Do we want the model learning to mimic a typical user question or do we want the model learning to respond to the user? In fact, why should the model be penalized at all for not predicting some perhaps poorly written user question? Is this just a big bug in our loss calculation?

I guess it’s a matter of perspective. During base and mid-training it wasn’t wrong because we were building a foundation of language and concepts. Now that we’re fine-tuning, we have a more intentional goal and we want the gradient signal coming from backprop to be aligned with the goal. So in SFT we do in fact make a shift. We’ll only count the tokens we expect the assistant to generate towards the loss and ignore the user parts. This more or less equates to ignoring the loss on the blue parts in figure 27.3.

I say more or less because this same idea also applies to tool use. We want the model to learn how to correctly use a tool but we don’t want to waste our learning on the model figuring out how to simulate that tool. For example, if the model is solving a word problem, we want it to get good at thinking—yes, I should use a calculator to multiply say 12.3 and 4.56 and here’s the expression to hand to the calculator, but we don’t especially want it to get good at doing the calculating itself.

Take a look back at chapter 26 for a reminder of how the python and output special tokens are used. What tokens should we ignore when calculating loss?

Figure 27.4. We ignore calculator output tokens in the loss calculation.

By ignoring the output tokens, if the model gives very low probability to token “56” (maybe it thinks “48” is higher probability), we won’t waste precious resources nudging parameters trying to get it to correct itself.

You can think of what we’re doing here as like the difference between reading and writing. We’re tuning the model to be good at reading the user message and reading and writing the assistant message. The same is true for tool use. We don’t want the model to learn how to write “<|output_start|>56.088<|output_end|>.” But it does need to be capable of reading it because it will either need it for another step of solving the word problem or so it can communicate the output to the user.

This also hints at an important difference between ignoring loss predictions for padding tokens and ignoring loss predictions for user and output tokens. For the padding tokens shown in figure 27.3, we’re talking about teraflops of calculations that we throw away. For the user and output tokens, all the calculations are needed (except for some in the final transformer block layer and the final linear projection to the vocabulary). Take a look at how information from different positions is mixed forward in the casual self-attention block (chapter 14) to see why.

One more thing to point out is that multiple choice questions are an extreme example in that only a single token counts towards loss. (To be precise, it’s two tokens because we also care about the prediction of <|assistant_end|>.) You can see that there are many more assistant tokens in the other three conversations in the example batch plotted in figure 27.3.

The tokens that count towards loss are called supervised tokens. In base and mid-training, we’re learning from large quantities of unlabeled data. Here we’re supervising the training according to desired responses.


We’re just about ready to train. As with mid-training, the overwhelming majority of our training conversations will come from the SmolTalk dataset. The overall scale is much smaller. In mid-training we trained on around 850,000 conversations with a total length of around 425 million tokens. Here we’ll be training on around 22,000 conversations with a total length of ten million tokens. We’ll pause to measure validation loss from time to time using SmolTalk conversations excluded from the training data.

Since we’ll be fine-tuning with these very concentrated signals from the supervised tokens, we want to monitor things closely in case we start fitting to something that doesn’t actually improve accuracy. Every 200 steps we’ll compute an abbreviated version of our ChatCORE metric using approximately 1000 conversations from ARC Easy and 1000 from MMLU.

How long will this all take? Mid-training took 22 minutes. For SFT we’ll be training on around 2% of the tokens but our batches will be much smaller (because 2048 tokens would be an incredibly long conversation) and our batches won’t be packed with tokens. Let’s say we’re five times less efficient. It should still only be a matter of minutes.

Let’s start.


Well, that felt fast. It took under five minutes for the training and another ten minutes to run the ChatCORE evaluation. Unfortunately the picture from the run is far murkier than mid or base training. But let’s start with what worked as expected. This chart shows the number of supervised tokens per step:

Figure 27.5. Number of supervised tokens for each step across the full SFT training run.

In base and mid-training, we processed and calculated loss on a consistent 524,288 tokens per step. In SFT, for the reasons explained above, the number of supervised tokens is much smaller and varies per step based on the size of the training conversations and number of assistant tokens. It looks like we ended up with steps ranging from 4,000 supervised tokens to 18,000 supervised tokens. We intentionally go through our training conversations in random order and the graph confirms this worked as expected.

Here is GPU memory allocated:

Figure 27.6. GPU memory allocated over the full training run.

Our high watermark for memory was around 80% vs 100% in base and mid-training. This is also reasonable because with our variable batch sizes it’s harder to use all of the memory without the risk of running out of memory which would cause the entire training process to fail.

Now for the concerning part. Here’s validation loss:

Figure 27.7. Validation loss.

We measure validation loss using SmolTalk conversations. Just before training started we measured the validation loss as 0.8333. At step 100 it jumped to 0.8390 and stayed around there until we measured 0.8385 at the end. Now is an increase of 0.006 really a jump or is the better way to think about this that there is essentially no movement in validation loss? Let’s see if training loss tells us anything:

Figure 27.8. Training loss.

Training loss is always much noisier because it’s the loss at every single step which varies based on how well the model does with the particular random training conversations in that batch. Over the whole course of training, though, it should decline. That is certainly what happened in base and mid-training even though I didn’t show those graphs above. Here there doesn’t appear to be a dramatic increase or decrease. To better see, I’ll smooth out the graph using an exponential moving average. (See table 20.6 for a reminder of how EMA works).

Figure 27.9. Smoothed training loss.

The training loss is essentially flat. The picture remains murky and concerning.

Let’s look at the abbreviated ChatCORE evaluations we conducted during training using the ARC Easy and MMLU datasets.

Figure 27.10. Abbreviated ARC Easy evaluations.
Figure 27.11. Abbreviated MMLU evaluations.

This is the first evidence that we’re achieving something. ARC Easy goes from 63% accurate to 67% accurate and improves consistently over the course of training, or at least from the first time we conduct the evaluation at step 200. There are 1,024 ARC Easy conversations in our abbreviated evaluation which means we go from getting 645 right to 686 right which is meaningful. Our MMLU subset, also of 1,024 conversations, goes from 363 correct to 416 correct.

I’d like to repeat the training and measure validation loss more often so I can get a clearer picture of what’s happening between step 0 and step 100. I also want to do the abbreviated ChatCORE evaluation at step 0 in case we’re getting fooled by these two nice charts. Then, based on what I learn, I’d like to run a series of experiments: What happens if we train for longer with more data? What happens if we adjust the learning rate multiplier scheduler?

This gets at the concept of checkpoints which I didn’t mention earlier. Where does the actual model live when it’s not in the memory of the GPU? Now that I did a round of this SFT training, can I get back to the version that came out of mid-training so I can conduct these experiments, or do I need to start all over from the beginning of base training? What if I’m four days into conducting a long training and something goes wrong, can I get back to a version from before things went wrong?

At the start of chapter 24 we calculated that our 32-layer model has around 1.9 billion parameters. Of those parameters, the roughly 134 million in the embedding module take up two bytes each and all the rest require four bytes. (I explain why some parameters use two bytes and others four bytes in chapter 30.)

Table 27.4. Size of a checkpoint for our 32-layer model.

A checkpoint contains the model parameters saved as a file at some point in the training process. For our model, this file will be 6.75 gigabytes. So far, following Karpathy, I’ve been saving a checkpoint at the end of each training phase. When I ran the CORE evaluation, the ChatCORE evaluations, and when I tried a few of my own prompts with the model, in each case I first loaded the model from the appropriate checkpoint. This also means I could run the experiments I mentioned above starting with the mid-training checkpoint. I wouldn’t need to repeat base or mid-training. Phew.

I could also save checkpoints in the midst of training. For example, I could add code to save a checkpoint every 2000 steps during base training. You can be sure that a lot of checkpoints are being saved along the way when millions or hundreds of millions of dollars are being spent on a training run. But you also don’t want to save too often because it takes time to write all of the parameters to disk and you need space for all the files.

(If I really do want to save checkpoints in the middle of base training I’ll also need to include the optimizer parameters in the checkpoint, otherwise it will not be possible to resume training from that particular point. This will more than double the space required because I’ll need to store the moving averages as shown in chapters 20 and 21.)

You might wonder why the saved model is called a checkpoint and not a saved model. The idea is that you’re saving along the way so you can get back to a particular state. The concept is common in databases and games. It originated in the 1960s as companies like IBM figured out how to make robust computer systems that could recover by getting back to a good state after failure.

Speaking of loading checkpoints and conducting evaluations, let’s look at the ChatCORE results. I’ll show per-task accuracy and the overall score for both our mid-training checkpoint and SFT training checkpoint.

Table 27.5. ChatCORE measured after SFT training.

ChatCORE went up by a tiny amount. MMLU, ARC Easy, and GSMK8K all went up slightly. HumanEval and ARC Challenge went down slightly. I’m glad to see that the MMLU accuracy on the full 14,042 conversation of 40% is very close to what we saw with the abbreviated 1,024 conversation evaluation we measured during training. Same with ARC Easy. This means going forward we can trust these smaller samples as a guide and move faster with experiments.

If I were training a model on my own, I would now compare these results with external benchmarks for similar-sized models using the same tasks, consider what evaluations are most important for my intended use, and conduct the experiments I listed above. However, since I’m not in fact trying to cover new ground but am instead following the exact path that Karpathy laid out, I’m going to continue to the next and last phase of training. Yes, it’s much less nerve-racking to follow in the footsteps of others. Is this where we separate the scientists from the engineers?