13
Transformer block
We’re ready to get back into the transformer. As a reminder of where we are and the big picture, our tokenizer (not shown in the diagram) converted text, say “The capital of France is,” to a sequence of tokens: 449, 3422, 281, 4192, 309. The embedding module converted each of those into an embedding of 1280 numbers filled with meaning. We’re now going to feed this 5×1280 tensor into our first transformer block. We’ll then take the output, which will also be 5×1280, and feed it to the second block. We’ll repeat this over and over through a total of 20 blocks. We aim for that final output, once projected onto the vocabulary by the final linear layer, to be an appropriate prediction for the next token for each input token. After training, I sincerely hope that the top next token prediction after “is” will be “Paris!”
Let’s look inside the transformer block.
Take a look at the diagram. Where do you think the learning is happening? The norm and adding operation are critical but contain no parameters of their own. Therefore all the parameters we need to learn must be somewhere in the causal self attention and/or MLP modules. Spoiler alert—they are in both. You will see.
To keep this concrete, here’s example output from the embed module that we can use as input to the first transformer block:
The high level idea is that causal self attention is where one position brings in information from earlier positions. For example, the fifth (“is”) might pull in information from “capital” and “France” so that eventually the model can predict “Paris” as the next token. Think of this as each position paying a certain amount of attention to each prior position. We’ll get into the details below.
MLP is where learning within a position happens. We’ll look at it first because you’ve already seen an MLP, I just avoided the term. The MLP operates only within a position. To emphasize what I mean, it will perform an identical operation on each row of its input, but it will never mix information between rows.
Keep in mind that the example tensor in figure 13.3 is only the input to the first block. For each successive block, the input will be the output of the prior block. It may be useful to think of each position as containing a token embedding, as shown above. However, after the first block this is no longer true. Instead, think of each position as containing a vector of size 1280 that somehow represents the sequence up to and including that position. In a sense, the successive layers take us from encoding the meaning of individual words (tokens) to encoding the meaning of the entire length of text up to a given position. (Casually, though, you will hear people talking about the embeddings out of say layer ten and that’s fine.)
MLP stands for multilayer perceptron. I’ve avoided saying that up to now because the term makes it sound more complicated than it is. Hedgehog model #4 shown in figure 10.18, where we made a linear-ReLU-linear sandwich, was an MLP. Here’s the MLP inside our transformer:
In hedgehog model #4, the initial linear transformation took us from one number (length of hedgehog) to 10 numbers. In the MLP in our transformer, the linear transformation goes from D to 4 times D which in our example is 1280 to 5120. We then do ReLU, square that, and then the second linear transformation takes us back to 1280. We don’t use any bias terms in our linear layers, so sticking with our 1280 example, there will be 1280 × 5120 parameters in each of these two linear transforms. That’s over 13 million parameters together. Multiply that times the 20 layers and you’ll see that over 262 million parameters are in the MLPs. This is where a lot of learning will take place.
And I want to emphasize again, this is not only, or even mostly, learning about how to use the embedding of a single token/word to predict the next token. If that’s all that was happening, all of this high-dimensional (1280 in this case) information would be useless and we could just use the n-gram approach discussed in chapter 2. The vectors of size 1280 at each position represent and contain information from the entire sequence up to and including that position. You’ll see how soon.
In case things are getting confusing, let me work through a tiny example of the calculation where I assume D = 2 and T = 3.
You can’t verify the calculations because you don’t know the weights inside the linear transformations. I more want you to see how the column dimension expands from two to eight and then contracts back to two. Notice also that in the input the first position and third position are both [1.00, 2.00] and therefore they also match in the output [0.35, -0.78].
Now, finally, it’s time to get into the heart of the transformer—the head-spinning, beautiful concept that makes a transformer a transformer and represented by the blue box above: causal self-attention.