15
Rotary embeddings
Rotary embeddings fall into the category of giving the model a way to exploit information you suspect will make it perform better. Word position, word order, and the distance between words are essential to how we understand language. This is also true at a higher level. If you’re reading a paragraph, the information in the paragraph just prior is likely to provide more relevant context than a paragraph from four pages earlier.
It’s interesting that the importance of word order varies by language. In English it’s important and in Chinese it’s even more important. In English, the difference between “I love you” and “you love me” is word order and the pronoun “I” turning into “me” when it’s the object of the sentence. You can feel how important the word order is since “me love you” makes you think of an either wrong or cute way of saying “I love you,” not “you love me.” In Chinese, “I love you” is “我爱你” and “you love me” is “你爱我.” Look at the characters. They are identical but the characters for “I” and “you” have their positions flipped. The word order tells us the meaning. In Russian, however, we can put the words for “I,” “love,” and “you” in any order and the meaning will be an unambiguous “I love you” or “you love me” based on how the words are declined. Any permutation of the three words “Я тебя люблю” = “Ya tebya lyublyu” will mean “I love you.” Any permutation of the three words “Ты любишь меня” = “Ty lyubish' menya” will mean “You love me.” The pronouns “I” and “you” and the verb “love” all change forms based on the subject and object of the sentence.
Back to our model. Think again about the first transformer block. With the attention mechanism as I described it in chapter 14, a specific token like “France” will emit the same query and key no matter where it falls in the sequence. For example, while calculating the attention weights for position 100, a “France” in position 98 will look exactly the same as a “France” in position 5. The situation is perhaps slightly improved in subsequent layers because some information will have been rolled forward through position mixing, but the intuition is the model should have more position information to work with.
When Vaswani and his fellow researchers invented the transformer in 2017, they experimented with different approaches to injecting position information. They settled on one that required no learned parameters and worked for arbitrarily short and long sequences. Look back at figure 1 from their paper shown in chapter 6 and you’ll see a positional encoding being added just after the embedding layer. The reason the diagram shows what looks like a sine wave connected to a plus sign is because their approach generated vectors with numbers that cycled over the positions at various frequencies in the various dimensions and then added these to the embeddings. (Don’t worry if that sentence doesn’t make sense.) This might sound crude, but it meant that a “the” in position five was no longer identical to a “the” in every other position and therefore the model could learn to exploit this information. I can say this with confidence not because it’s obvious but because they ran experiments that showed the technique to be effective.
Our model, Nanochat, uses an even more effective technique. It was invented by researchers in Shenzhen, China in 2023. As you already can guess from figures 14.7 and 11.3, we do not crudely add positional information right after the initial embedding layer. Instead, we inject it into the matching of queries with keys inside each transformer block, without ever touching the values.
Before we get into the details, think about how much sense it makes. The purpose of matching a query at a specific position with keys from all prior positions is to decide how to blend together the information from all the positions. We want this process to take into account the distance between positions, which means the model can learn meaning from position.
And also before we get into the details, you can skip this chapter and it won’t take away from your understanding of everything else. I was initially puzzled and when I finally understood it I wanted to write about it. But the most important insight is the one I shared in the previous paragraph.
The easiest way to see how the technique works and why it’s called a rotary embedding is to look at an example. You may have encountered this piece of trivia at some point: “Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo” is a grammatically correct English sentence. Personally I can’t hear it as correct or make much sense of it, I think because I’m not familiar with “buffalo” as a verb meaning to bully.
I’m going to tokenize that sentence (not worrying about capital vs lowercase B) which gives me eight identical tokens. I’m then going to convert these tokens into an embedding, normalize, go through my Q and K linear transforms, and then break the resulting size D vectors into 10 heads. Put another way, I’m going to stick “buffalo buffalo…” into my full model and stop just before we reach the rotary embeds in the first transformer block. You can follow along in figures 11.3, 13.2, and 14.7 above but that will be tedious, so let me pull it all together into a single diagram:
Let’s look at a few of the 128 dimensions in the first head for the keys for the 8 positions (“buffalo, buffalo, …”). This is what you would see if you intercepted the tensor at the location I labeled d in the diagram. (That’s just a label. It has nothing to do with “dimension” or “depth.”)
The columns in this table 15.1 are similar to the key columns in table 14.5. There I pretended that a key had only two dimensions. Here we have the actual size from our 20-layer model: 128 dimensions. And let me say one more time where that 128 comes from because I know this is getting confusing. In my 20-layer model, D is 1280 which means each row at position b in figure 15.2 has size 1280. We split these vectors into ten heads, so the rows of keys in the single head shown in table 15.1 have size 128.
You’ll notice that all of the keys are identical. This is expected. We started with eight of the same token (“buffalo”), turned it into eight identical embeddings, and put each through an identical linear transformation. We could have a sequence of 5,000 “buffalo” and all the keys would still be identical. As you’ll see, once we go through the rotary embed, they no longer will be.
Finally, notice that I numbered the positions and dimensions starting with zero instead of one. This is how things actually work inside the computer and I wanted to stick to that convention so I wouldn’t get confused when making the plots below.
Let’s look at the queries for these eight positions. Do you expect them to also be identical to each other?
Let’s now pretend we skip rotary embedding and go straight to scaled dot product attention. Consider the query for the last position, position 7. We know from chapter 14 that we’ll be calculating a score saying how similar the query is to each of the eight keys. It would be nice to plot the query and keys so we can see how similar the query is to each key with our own eyes as we did in figure 14.1. We’re now dealing with 128 dimensions and unfortunately we humans can’t picture 128-dimensional data. Instead, let’s pick two dimensions to plot: 0 and 64. You’ll see later why I’m choosing those specific dimensions.
I plotted the keys in black. Read from the columns labeled 0 and 64 in table 15.2 and you can see that the keys are correctly plotted at (-24.6, -6.5). You can trust me that I plotted all eight keys and labeled them, but the lines and labels all fall on top of each other. I also plotted one query in red and you can compare with table 15.2 and see that it too is in the correct spot.
So you can see the problem. The scores (angles) between the query and every key are identical. (Now you might tell me, well, who cares because the values are also identical, so no matter what the attention weights are, we’re going to end up in the same place. This is true, but only because I’m using this contrived “buffalo buffalo…” example.)
You may also be surprised that the query and key vectors have different lengths because in all of the plots in chapter 14 they were of the same length. There are two reasons for this. The first is that we don’t normalize the lengths until after the rotary embed step as you can see in figure 14.7. The second is when we do normalize, it will make the vectors have the same length in 128 dimensions, not in the two dimensions I happen to be plotting in figure 15.3. (There is a subtle difference between my examples in this chapter and my examples in chapter 14. In chapter 14 I gave examples as if our vectors were of size two. In this chapter our vectors are of size 128 but I’m only plotting two dimensions at a time.)
Let’s look at the query for position 6:
No surprise, it’s the same. Now I’ll apply the rotary embeddings to both the keys and queries. We’ll get into exactly how this works later but know that it’s an efficient operation that does not involve any learned parameters.
Woah, that looks different! Compare with figure 15.3. Other than the key for position 0, each key and the query was rotated. Look at the clockwise rotation from key 0 to key 1. Now look at the rotation from key 1 to key 2. You can see that the angle is the same. So key 1 is that angle away from key 0, key 2 is two times that angle away from key 0, and so on all the way through key 7, by which point we’ve started to circle around again.
Notice also that the query for position 7 is close to the key for position 5. If we were matching the query to keys based only on these 2 of the 128 dimensions, position 5 would get the most weight followed by position 6. Think about that in relative terms for a moment. Position 7 matches most closely to the position two behind it. We would like this to remain the case even if there happened to be more tokens or fewer tokens in the sequence before we get to this position.
Here’s the same plot showing the query for position 6:
Good news. It’s doing what we want. Position 6 will now match most closely with position 4, the position two behind it.
Something that might be bothering you is how fast we cycle around. Sequences can be thousands of tokens long. Imagine the plot above with a thousand keys. It seems likely that the query would match most closely with keys for positions hundreds of positions away just because they would happen to have the smallest angle. We somehow need the model to have room to learn both about close relative positions and distant relative positions and have a way to, on average, attach more weight to close positions. If we always rotate by the angle shown in figures 15.5 and 15.6 you can see we’ll be in trouble.
Remember, though, that so far we’ve only looked at two of the 128 dimensions in the head. Let’s plot a different pair of dimensions: 10 and 74. First before rotation:
Let’s rotate:
The idea is the same as we saw for dimensions 0 and 64 in figure 15.5 but the angle of rotation is smaller. For example, compare the angle between the keys for position 0 and 1 in this plot compared with figure 15.5. Now think about the matching score between the query for position 7 and say the keys for positions 0 and 1. In this figure 15.8, the angle between the red query and the black position 0 key is not drastically different from the angle between the query and the position 1 key. In figure 15.5 for the pair of dimensions 0 and 64 the difference is much more pronounced.
Let’s look at the position 6 query:
Notice that the angle between the query for position 7 and key for the position two behind it (position 5) in figure 15.8 is the same as the angle between the query for position 6 and the key for position 4 in this figure 15.9. This is expected and desired because we only care about relative position (in this contrived case where the tokens in each position are identical.)
In higher dimension pairs the rotation angle becomes even smaller. Let’s look at dimension pair 20 and 84:
You can start to see how the angle of rotation for these keys is so small that among them, the amount of attention paid to each by position seven is going to be determined far more by meaning than by relative position. It’s just that in this case with our contrived “buffalo buffalo…” sequence, the keys all started out identical before rotation, so the only thing the match has to go on is position.
Lots of small angles do add up. The model can learn to make sense of relative distance over tens, hundreds, or thousands of tokens in a sequence. Let’s repeat the plot but pretend we had 50 buffalos:
The difference between the query-to-key angle with position 0 and position 49 is substantial, almost 180 degrees. However in the highest dimensions pair in the head, dimensions 63 and 127, the angle between position 0 and 49 is less than half a degree:
You can imagine how the model exploits this. As each layer is learning to emit queries and keys from its input, the model figures out that more or less emphasis on certain dimension pairs within a head results in more or less sensitivity to relative position.
Every plot I showed above had the keys and queries identical prior to applying the rotary embed. For the sake of visualizing how this works when that’s not the case, here are the keys and a query for “The capital of France” before rotation:
I chose the dimension 2 and 66 pair instead of dimensions 0 and 64 so that the rotation angle will be a little smaller. This will make it a little (but only a little!) easier to see what happens. Here’s the plot after rotation:
Notice that the key for position 0 is unchanged. The key for position 1 rotates clockwise by a certain angle. The key for position 2 rotates right by twice that angle. If this were the whole picture (which it isn’t because we have another 126 dimensions in the head), the closest match for the query for position 3 (“France”) would be position 2 (“of”).
Now let’s plot the rotated keys and query for “The current capital of France.” I inserted the word “current” so the sequence now has one additional position, but it still ends with “capital of France.”
The key for position 0 (“The”) is unchanged from figure 15.14. The query for position 4 (“France”) is in a new place, but notice that the angle between it and position 3 (“of”) is identical to the angle between the query for position 3 (“France”) and the key for position 2 (“of”) in figure 15.14. We inserted an extra word, but the relative relationship between the words after the extra word stays the same. So, again, if these were the only dimensions, “of” would get the highest match score with “France” just as in figure 15.14.
If you look carefully you can also see that the key for position 1 (“current”) is new. However, the key for position 2 (“capital”) is the same as the key for position 1 on the prior plot but rotated clockwise one more time.
So that’s the idea. Take pairs of dimensions within the head (0 and 64, 1 and 65, all the way up to 63 and 127), treat them as two-dimensional vectors, and rotate them clockwise a certain number of times. The number of times is the position in the sequence. The angle of rotation is bigger for the lower dimension pairs (e.g. 0 and 64) and lower for the higher dimension pairs (e.g. 63 and 127). We don’t have to worry about numbers exploding out of range no matter how long our sequence is since rotation is like a clock and we just keep going around. (We do, though, need to make sure the angle of rotation is small enough in the bigger dimension pairs so we can get meaningful results from this whole rotation embed operation.)
There’s nothing magical about pairing up 0 and 64, 1 and 65, etc. It’s just a convenient way to implement it. We could just as easily pair 0 and 1, 2, and 3, etc. Either way we’ll end up with the same size head before and after rotation.
I won’t show the exact calculation but it’s similar to what you might have imagined after seeing the sine wave in figure 15.1. We use sine and cosine waves that repeat more frequently along the sequence positions for the lower dimension pairs and less frequently for the higher dimension pairs. The rotation of a dimension pair at a particular position comes from the value of sine and cosine for that dimension pair at that position.
That’s it! You now can follow everything that goes on inside the model. (Technically I didn’t explain the Norm module yet. That will come in chapter 22.) In the next chapter I’ll put all the diagrams together so you can see it all at once.