14
Attention
I’m going to read you a paragraph. I want you to think about the scene and understanding forming in your mind and how these change after each word.
“In…” That’s the first word. What is your mind thinking? Normally you already have context—is this in a novel, a friend talking with you, on a test, the start of a video? But here you have none other than it’s going to be something appropriate enough for this book. Okay, well, you’re going to read about something in something. Perhaps a place like “In France,” or a field like “In computer science,” or “In baseball,” or the expression “In fact, he did know the murderer,” or a year “In 2025,” or the start of a grand story or even the Bible, “In the beginning…” Anyway, your mind is primed for more and likely preparing for location, subject matter, or temporal context.
“In 1873…” Those are the first two words. Now you’re prepared for information or a story about something that happened a long time ago, but not in ancient or prehistoric times. You might already have 19th century images forming in your mind—maybe someone in a tophat on a street with horse-drawn carriages comes to mind like in a period drama. 1873 isn’t a year that has significance to me, so my mind isn’t jumping to anything in particular as it would with 1776.
“In 1873, Tesla…” In isolation, Tesla is going to instantly make you think cars. And maybe even here you think for the tiniest moment, but you correct because you know the car company wasn’t around in 1873. You may know that there was a famous engineer named Nikola Tesla and if so whatever you may know about him—inventor, electricity, motors—will be in your mind. Even if you didn't know about Nikola Tesla, the capital T will make you think it’s about a person, and you’ll in fact already have new information—ah, so that’s where the car company got its name!
“In 1873, Tesla returned to his birthtown, Smiljan.” Now you have a lot of information. Smiljan is a town. It’s where Tesla was born, almost certainly years before 1873 because it doesn’t say “His parents brought him back to his birthtown.” He obviously left Smiljan and you may be primed to learn why or to where. Also the way “Smiljan” is spelled and sounds is a clue to its location. As a mere mortal you may soon (like in seconds) forget the specific year 1873 and specific town, but you know this sentence contains that information. Okay, let’s keep going.
“In 1873, Tesla returned to his birthtown, Smiljan. Shortly after he arrived…” You know with high certainty that “he” refers to Tesla and you’re anticipating learning what happened next. You also have formed an estimate of “shortly—days or weeks or possibly months, but not years or centuries or millennia (this isn’t geology) and not nanoseconds (this isn’t particle physics).
“In 1873, Tesla returned to his birthtown, Smiljan. Shortly after he arrived, Tesla contracted…” In isolation, “contracted” is likely to relate to a legal agreement, “He contracted with a vendor…,” but in the scene forming in our mind we probably don’t even think of that. Instead we think that he got sick and expect the next word to be the disease or a modifier of the disease “a severe…” The disease is going to be something that was around in the 19th century, not Covid. Even if you’ve never heard of Tesla, your mind could be linking up a potential chain already: there was a person, getting sick somehow resulted in him doing something worthy of a car company naming themselves after him.
“In 1873, Tesla returned to his birthtown, Smiljan. Shortly after he arrived, Tesla contracted cholera; he was bedridden for nine months and was near death multiple times. Tesla's father, in a moment of despair, promised to send him to the best engineering school if he recovered from the illness (his father had originally wanted him to enter the priesthood). What was the year when Tesla went back to Smiljan?”
That’s the whole paragraph. It turns out to be a reading comprehension question that asks for the year, though it could also have asked for the name of the town, the disease, what Tesla was originally supposed to become, or for how long he was bedridden. To predict the next word (i.e. answer the question) you either remember that it’s 1873 or remember that this information is in the beginning of the paragraph and look back.
This paragraph comes from one of over ten thousand similar paragraphs in one of many datasets that we’ll be using later to evaluate if the model is becoming competent, as we’ll cover in chapter 18. For now I’m using it to motivate the concept of attention.
So how freaking amazing are our minds? With every word we refine our understanding. We see the big picture and the details. These are informed by what we know about the world as a whole, and we know how to go back and retrieve specific pieces of information when they seem worth paying attention to.
How is a computer, a deterministic machine, possibly going to do any of that? Well, one way it’s definitely not is through hand-coded rules. As we talked about for translation in chapter 6, hand-created rules are only useful in restricted domains like reading weather reports or mechanical descriptions of changes in stock price.
Instead we need a computer model that comes closer to our mental model. At a localized level it needs to know that “Tesla” is a person not a company and that “he” refers to Tesla. On that final question mark, the model needs to see the bigger picture that a question is being asked, that the question is asking for a year that something happened, and be able to know or find the answer.
As we get into how this works, I don’t want to pretend that this specific approach is obvious, inevitable, or optimal. I also don’t want to pretend that I have a perfect intuition for it. But real-world results speak for themselves. This approach has brought us machines that think and that can understand and generate text, images, video, and sound. That’s not to say that there won’t be better approaches in the future or that we won’t refine our understanding of why it works.
As we talked about for embeddings in chapter 11, we need to come up with an architecture that plausibly fits the problem, try training it, and see if it works. In the hedgehog example we saw how too few layers and parameters could not match the odd relationship between length and number of quills, but as soon as we gave the model enough room to learn, it figured it out.
Here, the one thing we know for sure is we must mix information among positions. If the size-1280 output vector of the final transformer block for the final token “?” (as in “Smiljan?”) is to be projected on the full vocabulary and have “1873” come out on top, this vector must contain information from positions earlier in the sequence. (I’m simplifying a little bit because we want the final prediction to be “ 18” and then we want “ 18” to predict “73.”)
You likely have intuition that not all information is equal, and different information is important at different positions. When we get to “Smiljan,” “birthdown” is relevant. When we get to the final question mark, the first sentence is more relevant than the second, and the year 1873 is the most relevant token (tokens) in the first sentence.
To get into the next level of detail we need to take a little detour into the terminology used for databases and information retrieval. Here’s an example of a database. (If you’re used to thinking of databases as having tables and columns, imagine this as a single table with two columns.)
A query of 1770 will match the row with key 1770 and return the value Beethoven. A query of 1810 will match two rows and return the values Schumann and Chopin.
Keys and queries don’t have to be strict matches. Another type of database, especially popular in recent years, is a vector database:
The key is a long vector of numbers and the value is a document. The idea is that the key represents the document in much the same way an embedding represents a token as described in chapter 11. Let’s say our query is “articles about composers” which we’ll also turn into a vector. Instead of the database doing an exact match on key, it will score the similarity between the query and each key, using, for example, the same cosine similarity calculation I used to compare embeddings in chapter 11 (see table 11.3). Example:
We could tell the database to return all matches above a certain score or to return the top few matches. In this case, if we asked for the top two matches, we would correctly retrieve the documents on Mozart and Bieber but not Washington.
If you’ve used databases before but not with vector similarity, notice that the technique is not equivalent to a substring search. “Composer” does not appear in either the Bieber document or the Washington document but the Bieber document is correctly scored higher.
Now with that crash course in database concepts behind us we can get back to the transformer. What if, while calculating our output for a specific position in the sequence, we could issue a query for the type of information that would be helpful from among all the earlier positions in the sequence? And what if each of those earlier positions could advertise a key with the type of information it has? Then we could go and somehow get the values of the positions whose keys match our query and merge that information into the current position. The result of this is what we want: the output for this position, that vector of size 1280 that we want to represent the meaning of the sequence up to this point in high-dimensional space, will have mixed in information from earlier positions.
Let me give an example using the Tesla reading comprehension question. We want that final question mark position to issue a query—hey, here’s the type of information I’m looking for. We want the “1873” position from earlier in the paragraph to advertise a key saying—hi all, here’s the type of information I have. And we want there to be a strong match between the query and key which will result in information from the “1873” position being brought into the “?” position.
This meet-in-the middle-approach of “I generate a query and you generate a key” is common and essential to making information retrieval possible. It predates computers. For example, the Dewey Decimal System was introduced to organize libraries starting in the late 1800s and was a way to match queries with books. Of course Melvil Dewey had to design his system from scratch and hand-categorize books, and the musicians at Pandora in 1999 had to manually categorize songs. But with our deep learning way of thinking, we simply define the architecture and wish the system into existence. In this case the architecture is one where positions emit queries, keys, and values so that appropriate information can be retrieved and incorporated.
So how are we possibly going to get the model calculations to do something like emitting queries, keys, and values, then matching up queries and keys, and then pulling in values? Suppose our sequence so far is “The capital of France is” and, to keep things a little easier to imagine, let’s say we’re in the first transformer block and our embeddings only have two dimensions (which would never actually work). So here’s the input:
For each position (row), we want to come up with a key advertising what it has to offer, a value, and a query saying what it’s interested in from earlier positions. It’s a little tricky that unlike a person querying for a book in the library, every single position both has information to offer other positions and is itself querying for information.
Let’s turn to our old friend the linear transformation. Our hope is that under the pressure of backpropagation we’ll have three linear transformations that learn to produce useful keys, values, and queries. We stick with the same dimensions (2 in this tiny example) for each of our keys, values, and queries, so let’s say our three transformations applied to the input in table 14.4 output this:
I made the query numbers for the first four positions gray to emphasize that right now, in this example, we’re only thinking about our final position, position 5.
I should mention that I normalized the queries and keys after the linear transform so that they all have the same length. You’ll see this step later when we get to figure 14.7 and I’ll explain the exact calculation in chapter 22. This is why the vectors in the upcoming plots all have the same length, and it’s why we’ll be able to compute the query / key matching score with simple pairwise multiplication and addition. These details don’t matter to build your intuition.
Getting back on track, we want to compare the query for that last position (-0.51, 1.32) to each key. This is similar to the vector database lookup in table 14.3. Since we’re working with two dimensions we can plot the vectors:
The black vectors are our keys. The red vector is the query for the last position. The query is close to “capital” indicating that the “is” position is very interested in the “capital” position and will want to pull information in from its value.
Since all the vectors have the same length, we can compute a score saying how similar the query is to each of the other vectors by doing pairwise multiplication and adding. If you play with this calculation you’ll see that the bigger the number, the smaller the angle between the query and key.
(If you try to repeat my calculation you’ll see that after summing the pairwise multiplications I then multiplied by a scale factor of around 0.7. I’ll explain this below.)
As you can see in the scores and confirm in the plot, the red query matches most closely with the key for position 2 “capital” followed by “of” followed by “France” followed by “is” itself. “The” is in last place. Unlike the database example where we were only interested in the top few results or the results above a certain threshold, here we’re going to use all of the values but mix them in according to the score. To do this we’ll want to resize our scores so they add up to 1 (or 100%). We’ll call this the attention weight because it’s how much attention we plan to give the values from each position.
Notice how “capital” had the highest score and now has the highest weight of 0.60. I did the conversion using the softmax function. The details are not important, but if you’re curious, take a look back at table 8.10 where I showed how to calculate it.
Now we’re on the home stretch. We want our output vector for this position 5 to consist of 60% of the value from position 2, 26% of the value from position 3, and so on. Multiply the weights by the values, add them together, and this mix of the values is more or less our output.
Coming out of this, position 5 contains the vector [-35.60, 60.21]. (Again, in reality the vector would have way more numbers.) This vector represents information about the entire sequence that is relevant at this position to predicting the next token: “Paris” we hope.
Yes, it’s complicated. It will take time to wrap your head around. But the good news for using it in a deep learning model is everything I showed can be done with a few matrix operations. In fact we’ll be able to calculate the outputs for all of the positions at the same time.
Now let’s look at it in diagram form. I’m going to work from the inside out. Assume we’re inside the blue causal self attention module shown in diagram 13.2, we’ve done the linear transformations and normalizations that I mentioned and will show on a diagram later, and we’re now ready with our query, key, and value tensors. The query, key, and value tensors, similar to what’s shown in table 14.5, are the inputs to scaled dot product attention:
The operations on matrices shown in this diagram will do everything I walked through above and more. As a reminder, I’m using “@” to indicate matrix multiplication.
Let’s drill into a few of the operations and then I’ll show the same query, key, and value I used above flowing through all of the operations. We’ll start with the mask. If you accidentally left out the mask the trained model would be useless. To explain, I want to zoom way out of the internals of the model and back to the big picture discussed in chapter 8.
During training, we feed in huge amounts of text and calculate loss based on how well the first position predicts token 3, the second position predicts token 11, and so on. The one thing we must absolutely, positively not do when the model is calculating the next token for a certain position is let it use information from tokens after that position.
For example, the next token prediction for position 3 can and should use information from positions 1, 2, and its own position. But it must not use information from positions 4 and 5. Otherwise all the model will learn is how to cheat by copying the next token.
The mask guarantees no cheating. It makes sure that for a given position, the attention weights will be zero for all greater positions. Table 14.8 showed the attention weights for position 5 meaning how much attention position 5 wants to give positions 1–5. Here’s now a table with the weights for all five positions before and after the mask:
Notice that in the masked version, for position 1, 100% of the weight falls on position 1. Think about that for a second. It’s the only option because position 1 is not allowed to see positions 2–5. For position 2, 84% of the weight falls on position 1, 16% on position 2, and none on 3–5. When the weights for position 2 get multiplied by the values for all positions, no information from position 3, 4, or 5 will get blended in.
Now backing up to the start of figure 14.2, the first step is to matrix multiply the query by the transpose of the keys. The transpose of a matrix is the same matrix with the rows and columns switched. For example:
Our goal is to compute a score between every query and every key, and as a reminder, we compute the score by doing element-wise multiplication and then adding. This is called the dot product. Since we have 5 queries and 5 keys in our example, we want to end up with 25 scores. If you work through it, you’ll see that matrix multiplying the query by the transpose of the keys results in a matrix where each element is the dot product of one row in the query and one row in the keys. However, there’s no need to work through it, just know you’ll end up with a matrix like the one shown as the output to the matrix multiply (@) in figure 14.6 below. That number 0.49 at the top left is the score between the first query and the first key.
One last detail is the box I labeled “scale.” This means divide the scores by the square root of the size of the vectors, 2 in this example but in reality it will be a much bigger number. We do this because with lots of numbers in our vectors, the scores (dot products) will get large and softmax doesn’t play well with large numbers. I’ll show this with an example:
The big scores are 100 times the non-big scores, and with these big numbers softmax assigns all of the weight to position 2. We do not want that! We’re trying to blend information from multiple positions. Look back at table 8.10 as a reminder for how the softmax calculation works. If you play with a few numbers you’ll see why it behaves poorly with big ones.
Now you know everything you need to know to follow every calculation in scaled dot product attention. Here’s the same example we worked through in tables above but now using matrices and calculating the output for all positions, not just the last position.
Of course this is a crazy diagram and the point is not to look at the individual numbers. Notice that the output for the last position (-35.60, 60.21) matches the calculation from table 14.8. Note also that the output for the first position (13.89, -43.42) matches the first position in the values since the first position isn’t allowed to mix in values from any other position.
All the dimensions work out. We start with identically sized query, key, and value matrices, all 5×2. Computing the scores / weights takes us to 5×5, and at the end, multiplying by the values takes us back to 5×2. It’s elegant!
I’ve mentioned a few times that in a real situation the number 2 will be much larger. It’s also interesting to think about 5, the number of tokens. If our sequence length during training is 2048, as I used when I trained the 20-layer model and showed in figure 8.14, we’ll end up with a weight matrix of size 2048×2048. You may have heard of context length. During generation, we may want a sequence of tokens, our context, to be much longer than even 2048 and so you can imagine how large the weight matrix becomes. Fortunately there are techniques to avoid computing the full weight matrix as you’ll learn in chapter 17.
It’s time to see where query, key, and value come from. I mentioned linear transformations and normalization above but I glossed over certain details and didn’t show a diagram. Well, here it is:
As we talk this through, let’s think of our input as being a tensor of size 100×1280 corresponding to a sequence of 100 positions each represented by a vector with 1280 dimensions. (For scaled dot product attention above it was helpful to think of each vector as having size two, but that’s too small to see things here.)
We take the input and do three different linear transformations. All three transformations preserve the dimensions (1280 → 1280 in our example). Each will learn its own weights during training in such a way that Q transforms its input into queries, K transforms its input into keys, and V transforms its input into values. So far this is what we covered above.
Here’s something you haven’t seen yet: split into heads. What it does is straightforward and doesn’t involve learning any parameters. Say we want to split into 10 heads:
All we’re doing is splitting up the vectors. The question is why. I’ll share a picture of why it feels intuitive to me that multiple heads could be helpful.
Let’s pretend each position in the sequence is represented by a color. Position 1 is blue, position 2 is green, and position 3 is red:
Now pretend I’m figuring out the output for position 3. We match queries and keys and the resulting attention weight tells us to pay 70% attention to position 1, 5% to position 2, and 25% to position 3. Let me mix the colors in that ratio:
However, if I break the vectors up into two, perhaps the queries and keys will tell me to mix the first half with weights of 75%, 5%, and 25% and the second half with weights of 5%, 75%, and 25%. My output for position 3 will then look like this:
The mixing that occurs within a position is already incredibly rich due to the linear transformations, especially the big ones in the MLP covered in chapter 13. However, mixing between positions seems a little too restrictive. Using multiple heads gives more room for the model to learn what to pay attention to.
The classic way to think about this is that one head might learn to pay attention to adjectives that modify the current position, and another to the nouns referenced by the pronoun in the current position. I think it’s the right concept but I’m not sure what actually happens is ever that obvious or interpretable.
If you’re following every detail you might wonder how the tensor dimensions change after we split into heads. Let me show that, and for completeness, I’ll include the batch dimension too.
For example, using our zero-based indexing notation mentioned above in chapter 8 and the example we’ve been following, output[0,1,2,:] would be a vector of length 128 corresponding to the third head in the second position of the first sequence in the batch. And if you’re wondering, yes, we need to pick a number of heads that divides evenly into D.
The next and final piece of the attention puzzle is rotary embeddings. You see them in the purple box in figure 14.7 above. They deserve their own chapter!