17


KV cache

While input and output tensor dimensions are top of mind from chapter 16, let me ask you a question: How much effort will it take to produce a sentence-long answer to the prompt “Please tell me about Paris”? You may want to look back at chapter 9 to remember how text generation works.

Prompt: Please tell me about Paris.

Completion: Paris is the capital of France and one of the world’s most influential cities—historically, culturally, and intellectually.

My question “how much effort” is a little vague. The best way to think about it is how many low-level operations inside the computer will it take to generate the full sentence? More operations equates to more time, more electricity, and more cost.

Look back at figure 5.7 to get a feel for the operations involved in multiplying one matrix by another. One multiplication of floating point numbers (see chapter C028) could be 0.1234567 × 0.9876543. Do this on a piece of paper. How long does it take you?

I just tried. I’m not sure I’ve ever multiplied out so many digits before.

Figure 17.1. Me multiplying two numbers.

Here’s how long it took me:

Figure 17.2. How long it took me to multiply the two numbers.

Call it 12 minutes. Keep that in mind. That’s for one floating point operation or FLOP.

Let’s take the first matrix multiplication we’ll encounter, the Q linear transformation in the first transformer block for my 20-layer model. If our prompt “Please tell me about Paris” is five tokens, the matrix multiplication that actually does the linear transformation will look like this:

Figure 17.3. The first matrix multiplication when computing the next token for the prompt “Please tell me about Paris.”

A rough estimate for the total floating point operations to do a single matrix multiplication is rows of the first matrix × 2 × columns of the first matrix × columns of the second matrix. (As a reminder, the columns of the first matrix will always be equal to the rows of the second matrix.) So in this example that’s 5 × 1280 × 2 × 1280 = 16,384,000 FLOPs.

We start by feeding “Please tell me about Paris” into the model in order to predict the next token. In order to estimate the total FLOPs required for this calculation but without going totally crazy, let’s only count the matrix multiplications. For the sake of the point I’m making here we won’t worry about norms, softmax, addition, etc.

To keep organized I made a spreadsheet and matched colors with my diagrams.

Table 17.1. Estimated floating point operations for all of the matrix multiplications to compute the next token for “Please tell me about Paris” in my 20-layer model.

So that’s around five billion FLOPs. But that’s just for predicting our sixth token, say “Paris” in a completion that starts: “Paris is the capital of…” Now we need to feed in six tokens: “Please tell me about Paris. Paris” to get back our seventh token. Using my same spreadsheet, I estimated the number of flops to compute the seventh token at around 5.7 billion, so we’re already up to 11 billion to generate the two tokens. Let’s say the full completion is 20 tokens. This means we’ll keep repeating the process of predicting the next token and feeding the full sequence in as input until we input a sequence of 24 tokens to predict our 25th.

Table 17.2. Total FLOPs for the matrix multiplies to generate a 20-token completion to “Please tell me about Paris.”

That’s 277 billion floating point operations to tell me that “Paris is the capital of France and one of the world’s most influential cities—historically, culturally, and intellectually.” Not long ago that would have sounded insane. Of course, just because something sounds insane isn’t a reason to avoid it. Thinking of Paris, Marie and Pierre Curie started with a literal ton of pitchblende in 1896 and spent years making painstaking extractions to isolate and discover a speck of Radium. Closer to home, the history of computing is filled with solving problems in less efficient but more generalizable ways knowing that advances in hardware (e.g. Moore’s Law) will eventually make the inefficiency moot.

But to frame the insanity, if I were to calculate those 277 billion operations manually on paper at 12 minutes per FLOP, and pass the torch on before death, and so on, it would take over six million years to produce those twenty tokens. Humans have only known how to do multiplication for something like four thousand years.

How about an Apple IIe from 1983? This was the original hit product from Apple, one of the world’s first widely available personal computers. I dug mine up and amazingly it still runs. It supported floating point multiplication, though in software, meaning that the multiplication of two floating point numbers would be translated into lower level instructions for the chip to execute vs the chip having the intrinsic capability. (See a link to the floating point operations assembly code written by Apple co-founder Steve Wozniak himself in the further reading section.)

I wrote a loop in BASIC that repeated the same multiplication I did on paper 1000 times in a row. I timed it. It took 58 seconds. I also ran the same loop without the multiplication and that took 3 seconds. So figure 55 seconds to do 1000 operations. That’s 18 floating point operations per second. At that speed it would take around 500 years to generate the twenty token completion to “Please tell me about Paris.”

Figure 17.4. A BASIC loop that multiplies 0.1234567 and 0.9876543 1000 times on an Apple IIe.

I have a MacBook Air with an Apple M1 chip from 2020. The M1 chip, like most modern chips, has built-in support for floating point operations. I wrote a loop in the C programming language and tried my best to make sure the chip was doing nothing other than the floating point multiplication of two numbers 10 billion times. This took around five seconds, so that’s two billion operations per second. At that speed the full calculation of the twenty token completion would require between two and three minutes. (I share this only to illustrate how big a deal 277 billion FLOPs is or isn’t for a modern computer. When I actually use the model on my Mac to generate tokens, it’s faster because even a Mac from 2020 has a GPU and a way to use parallelism to efficiently multiply matrices. Read about MPS if interested.)

And for fun I also timed my white plastic MacBook from 2010 with an Intel Core 2 Duo chip, which also has dedicated support for floating point operations.

Table 17.3. Time to complete 277 billion floating point operations.

I share all of this to give you a feel for what it takes to perform on the order of 277 billion operations to generate around 20 words. It’s not an insurmountable number of operations by modern standards by any stretch of the imagination, but it’s not so trivial that we shouldn’t worry about it.

And something else that might be bothering you is—doesn’t it seem wasteful to put all those tokens in and predict a next token for each when all we care about is the prediction for the final token? It bothers me which is why I also mentioned it at the end of chapter 9. And yet, we do need information from all the tokens. Trace the operations in the scaled dot product attention module (figure 14.6) to remind yourself that we use queries, keys, and values to pull in information from earlier positions.

Let’s trace what happens when we input the six tokens to the model (the original prompt plus the first generated token). Every box in our diagram operates independently on each of the six tokens / positions until we get to the first matrix multiplication in scaled dot product attention. And after the second matrix multiplication, we go back to operating on each position independently.

We also know that we only allow information to get blended forward. The sixth position will end up with information from positions five, four, etc. but position five will never pull in information from position six. This holds true in every transformer block. So at the end of our 20 layers, the size 1280 vector for position five is identical to what it was in our original calculation before we added that sixth token to our input.

Let me repeat that because it’s so important and, at least for me, when I think about so many operations and so many numbers all over the place, it’s not obvious. We input a sequence of five tokens to generate a sixth. Then we input these original five tokens plus the new one. In this second time through the model, the size 1280 vectors corresponding to the fifth position coming out of every transformer block are identical to what they were for the fifth position the first time through the model.

This observation hints at a much more efficient way to calculate. We only need to feed in the last generated token as long as the keys and values computed previously, and the new ones for this new token, can be made available to each transformer block. Let me show this in a digram:

Figure 17.5. If we save the keys and values in each transformer block we don’t need to compute them again.

We save the old keys and values from before, and we generate the new ones by running the input through the K and V linear transformations. For queries the situation is even better because we only need to worry about the query for the new token. In other words, we need to know what this new position wants to pay attention to, but there’s no need to even worry about what the earlier positions paid attention to.

Figure 17.6. A single query for our new position can be scored against the keys for all positions to determine attention weights.

Here’s the full causal self attention and scaled dot product attention diagrams with the key/value cache. “Cached” means saved previously. “Append” means to add on to the end.

Figure 17.7. Causal self attention with a KV cache.

The cache is shown in orange. The asterisk indicates a dimension swap as explained under figure 16.5.

Coming into causal self attention we have just a single position corresponding to the new token we’ve added to the sequence. This is why the overall input has shape 1×1280. Just before scaled product attention we reach into the cache and get the keys and values we computed earlier and the shape becomes 6×10×128. We also add the new key and value to the cache so they will be there for next time.

Here’s what happens inside scaled dot product attention:

Figure 17.8. The q input is only the single new position. The k and v inputs are the same as usual.

Wow, look at all that saved work. Yes, we have to keep a lot of memory around (if we can), but it’s no more than we would need to do the calculations from scratch, and we now only calculate outputs that we didn’t calculate before.

If you want an intuition for why this works without thinking about matrix multiplication, consider the description of matching queries with keys, computing scores, and mixing values described in chapter 14. We want to match our new query against all keys (saved and the new one) to get a score for each that we can use to blend in all of the values (saved and the new one). Nothing about this requires the queries from prior positions.

Here’s my spreadsheet estimating floating point operations for all matrix multiplications for this 6th input token using the KV cache approach:

Table 17.4. Estimated floating point operations for all of the matrix multiplications to compute the seventh token.

And here’s the total for generating the entire completion:

Table 17.5. Total FLOPs for the matrix multiplies to generate a 20-token completion to “Please tell me about Paris” with the KV cache.

We drop from 277 billion to 23 billion. This is a whole order of magnitude fewer operations. A human could finish in only half a million years.

You’ll notice that the number of operations for each additional token is about the same. It is in fact increasing, but at a slow and constant rate. You can see why by looking at the table 17.4. The 6 becomes a 7 becomes an 8 and so on. My intuition for this is that as the sequence gets longer, each new position needs to be compared against more prior positions.

You might be tempted to think that this part of the calculation barely matters. For six tokens, it’s only around 30,000 FLOPs for the 10 heads. For all 20 transformer blocks that’s 600,000 FLOPs. That’s nothing compared to the nearly billion total FLOPs to compute the next token. Don’t get tricked by the fact that six or 24 is so much smaller than 1280. Don’t get fooled either by the fact that a generated sequence of a couple dozen tokens is child’s play compared to the size 2048 sequence I used during training. When you start chatting with ChatGPT, the entire back and forth goes into that sequence. As do any documents the model needs to read along the way, information it pulls in from prior chats, the system prompt with core instructions, instructions and output from tools, and hidden (from you) thinking. It adds up fast.

Let’s say the sequence grows to 20,000 tokens which is completely reasonable with a modern model. (Although not useful with our 20-layer example model, but that’s okay.) Here’s the updated spreadsheet:

Table 17.6. Estimated floating point operations to compute the next token starting with a 20,000-token sequence.

We’ve jumped from around one billion to three billion flops. Sequence length is in fact a huge factor in memory and number of calculations. This perhaps gives an appreciation for why people were shocked when Google came out with a model with a million token context window in early 2024.

Theory is great. Time to see the KV cache in action. I used my 20-layer model with the prompt we’ve been using this whole chapter (“Please tell me about Paris”) and forced it to generate 10,000 tokens. Without the KV cache this took 4 minutes and 29 seconds. With the cache it took 1 minute and 22 seconds.

Without the cache, the GPU was working at near 100% the whole time (doing a lot of useless work!). More and more memory was allocated and then freed. You can see that in this chart:

Figure 17.9. GPU utilization and memory use during generation of 10,000 tokens without a KV cache.

With the cache, the GPU did not need to allocate as much memory or work as hard, which would free it up to do other more useful work, say generating text for other users at the same time.

Figure 17.10. GPU utilization and memory use during generation of 10,000 tokens with a KV cache.

Enough inside baseball. Let’s assume we can get the model to generate text. How will we know if it’s any good?