5
The matrix
I was having drinks with a friend who is a scientist at one of Boston’s biotech firms. I hadn’t seen him in a while and he mentioned that he started reading my blog post Tracing the Transformer in Diagrams but gave up when it got to matrix multiplication. That was a reminder to me—matrix multiplication is much more straightforward than the type of science he does. But the word matrix sounds complicated. Or people (including me) have bad memories from college math classes that moved aggressively from basic matrix operations to proofs.
Let’s take a step back. You’ve already seen, in gory (fowly?) detail, every variable, number, and calculation that went into our turkey feather prediction model. Our example had only two weights and three turkeys, each of which had a height and a length, and already it was getting unwieldy. Look at how messy figures 4.6 and 4.8 are. What if we had ten weights and twenty turkeys, and we also wanted to consider the wingspan of each turkey, and we had ideas for calculations beyond multiplying by weights and adding them together? We would get very confused very quickly. And even if you could keep it straight in your head, how would you ever communicate it to other people or to the computer?
Keep in mind that what I’m talking about here is a way to be more organized. I’m not looking for a magic bullet that will reduce the number of numbers to track or calculations to do. We really do need to do all the calculations we did in chapter 4. But like so much in life, if you can stay organized, and trust your organization, and give things appropriate names, then you can start working at a higher level of abstraction. Just like ChatGPT, our human minds can only keep track of so much at once, and so the secret is to learn (or come up with!) abstractions we trust.
Think first about the data from our turkeys. Turkey #1 has a height of 1 meter and a length of 1.5 meters. Turkey #2 has a height of 0.75 and a length of 1.25 Turkey #3 has a height of 1.25 and a length of 1. It would be nice if we could just stick all that information somewhere and give it a name, like “data.”
The people who figured out how to do modern AI were far from the first to hit the problem of keeping lots of numbers organized. The idea of sticking numbers into a matrix and specifying ways of then manipulating those numbers to do useful things dates back at least 2,000 years.
Let’s stick our turkey data into a matrix:
There is absolutely nothing special going on beyond what you see. Sure, there’s all sorts of sophisticated math and physics around what you can do with matrices, but for the most part we’re not going to do any of that. The matrix above really is just six numbers, two for each turkey. It looks like the height and length columns of table 3.1 because it is. If you work with spreadsheets, the matrix is a spreadsheet. If you’re a software developer, the matrix is a two-dimensional array.
Now why did I call the matrix X? Seems like I’m asking for trouble since so many things are called X or x and X can also get confused with the times symbol. I wanted to stick with the convention in AI. We’re creating a model that takes in data and makes a prediction. Just as in school you normally used “x” as the standard variable name when talking about the input to a function, like f(x) = x + 3, in AI X or x is the input to a model.
Let’s keep organizing our closet. We know the actual number of feathers for each of our three turkeys. Let’s stick those in a matrix too:
By convention this matrix is called Y. However, to try to avoid confusion between the actual number of feathers from our data, and the predicted number of feathers, I’m being extra verbose and calling it “Y_actual.”
What about the weights? It’s pretty annoying to keep talking about w1 and w2. Let’s stick those in a matrix too.
Using these weights and the height and length of our three turkeys, let’s predict the number of feathers for each turkey. We already know what calculation we want to do—the same as at the beginning of chapter 3 where we multiplied height by w1 and length by w2 and added. Lucky for us we don’t need to describe this in detail. There’s something called matrix multiplication that means exactly that. So our prediction, which we’ll call Y_prediction, is XW meaning X matrix-multiplied by W.
There are a few reasons you don’t really need to think about the details of matrix multiplication. First—you already know it: it’s what we did in chapter 3. Second—we’re trying to move our brains up a level of abstraction and not have to think about every detail. You’re happy to accept that 3 × 4 is 12 without thinking each time, okay, that’s 3 plus 3 plus 3 plus 3. Third—the computer will calculate everything for us.
But if that’s not satisfying, go read about matrix multiplication, or piece the procedure together from this: 1 × 1000 + 1.5 × 3000 = 5500. One other tidbit that may be helpful. We talk about a matrix as having size m × n where m is the number of rows and n is the number of columns. In the example above, X is 3×2 and W is 2×1 and the result, Y_prediction, is 3×1. This will always be true. An m × n matrix can be multiplied by an n × p matrix and the result will be m × p, meaning m rows and p columns.
Our closet is becoming nice and clean. Except we’re not done yet. We calculated our predictions, but not our loss. Let’s start by calculating the error and in the process see another way in which matrix operations are intuitive and keep things organized. Error = Y_prediction - Y_actual:
To do Y_prediction minus Y_actual, we subtract the first element from the first element, the second from the second, and the third from the third. This is an element-wise operation.
We’re getting close. Next we square the error. This again is an element-wise operation. There’s nothing crazy going on. We just square each element. There’s no interaction between elements.
The final step to computing our loss is to add the squared errors together. Up until this point we’ve kept our turkeys separate. In each matrix with three rows, the first was turkey #1, the second turkey #2, and the third turkey #3. Here we’re finally going to bring them together by adding the rows to get to a single loss number.
In this case it’s easy enough to wrap your mind around. But I do think it’s easy to get confused once we start mixing items of data. This is important, and is going to come up when we get back to working with text. To do a good job of predicting the next word in a sentence, we’ll want to take information into account from many words, and since we’ll be representing everything as matrices, we’ll need to do operations that mix information up among data items. Hold that thought for later.
Here we’ll sum the elements in the error2 matrix.
Now instead of writing out our loss equation as we did early in chapter 3 like this:
loss = (5000 - ((1)(w1) + (1.5)(w2)))2 +
(3500 - ((0.75)(w1) + (1.25)(w2)))2 +
(4500 - ((1.25)(w1) + (1)(w2)))2
We can write:
loss = sum((XW - Y_actual)2)
X is fixed (our turkey heights and weights), Y_actual is fixed (the actual number of feathers of our turkeys), and we have some initial values for W (our weights). Our goal is to adjust our weights to minimize the loss.
The diagram for the forward calculation is so much neater too:
I’m using “@” to mean matrix multiplication. I could write (X)(input) but that’s not very clear.
Before we come up with the diagram for the backward calculation, let me show the inputs to each step when we start with our old friends, w1=1000 and w2=3000.
Now backward:
Hooray! In a much more organized way we calculated the partial derivatives of loss with respect to w1 and w2. I want to emphasize again that there’s no magic. We still had to do all the same calculations as our messy diagram above. We’re just now doing it in a way that’s easier to write down, and more important, easier to communicate to the computer.
So far I’ve been using the term derivative or partial derivatives. Now we’ve put w1 and w2 into a matrix, and the partial derivative of w1 and the partial derivative of w2 into a corresponding matrix. This is called a gradient. The gradient of loss with respect to W tells us how making a small change to an element of W will affect the loss.
I’ll put the forward and backward calculation on a single simplified diagram to emphasize that we start with W, do the forward pass, do the backward pass, and end up with the gradient of W.
Here’s something that could be bothering you at this point. I keep showing W as the input to the forward calculation, and you could be thinking, wait, shouldn’t X, the heights and lengths of each turkey, be the input? Yes, they are both inputs and if you look at the detailed forward calculation you’ll see where X comes in. However, we’re not interested in changing X. The heights and lengths are what they are. We want to change W. So it’s more intuitive at this point to think of W as the input, even though X, W, and Y_actual are all inputs to the loss calculation.
What do we now do with W’s gradient? The exact same thing we did above—use it to nudge W in the desired direction, the one that will cause the loss to go down. In chapter 3 I showed how we do this by multiplying the gradient by a learning rate and subtracting it from the current weights. This calculation is another one where it’s easy to stay organized in matrix form. Let’s keep our learning rate at 0.01, the same as in chapter 3.
Take a look at table 3.8 above and you’ll see that after the first update we ended up with 981 for w1 and 2965 for w2. The results match because the calculation is the same.
I’ll be showing almost no code in this book. I will here because I want to convey the beauty of our modern tools and how suited they are to the mess of calculations we just walked through. Also, as we get to the GPT model, and how we calculate loss, and what’s going on inside it, I want you to have a sense for how these things are actually expressed to the computer.
You’ll see that we don’t need much code at all. All of the head-spinning stuff like working out the backward calculation is done for you. This is the advantage of working at the right level of abstraction and sitting on the shoulders of decades of software development. We specify the matrices and the forward calculation in a language that is concise and clear. Then layer upon layer upon layer of software and hardware causes the correct forward and backward calculations to be executed, ultimately by transistors on chips.
So here it is. Seven lines of code (or two based on what you count) that will compute the gradient we’ve painstakingly calculated by hand many times starting in chapter 3. I wrote the code in the programming language Python using the PyTorch library, the most popular top-level framework for writing AI models.
What does it print out?
tensor([[1875.],
[3500.]])
We get the gradient of W just as in figure 5.13. Nice! I annotated the code in case you’re curious what each line does.
Take a moment to appreciate how clean and powerful this approach is. We define a bunch of matrices and tell PyTorch we’re interested in the gradient with respect to one of them, W. Now for whatever calculations we do involving W, PyTorch keeps track of the computation graph and all necessary intermediate calculations along the way to later do the backward calculations. This is everything I showed in figure 5.13. We then call backward() on loss. This tells PyTorch we’re interested in the gradient of loss, it does the backward calculations, and voilà, a matrix containing W’s gradient gets attached to W. In this code we print it out. In a training loop we would use it to adjust the weights.
Speaking of adjusting the weights, let’s do the full training loop in code for 1000 steps like in table 3.11.
What do we get?
tensor([[2310.7886],
[1633.1046]])
This tells us we should use w1=2311 and w2=1633 in our turkey feather prediction model, just like we calculated in table 3.11.
If you’re curious about the code, here’s the idea. Do a thousand loops. In each, do the forward and backward calculation. Then update the weights by subtracting the gradient times the learning rate, zero out the gradient, and start the next loop.
If you’re a coder but not familiar with PyTorch you may wonder what the “no_grad()” thing is about. PyTorch is very nicely keeping track of every calculation involving W so it can compute W’s gradient (W.grad). However, we’re about to adjust the weights in W and clear the gradient. These operations have nothing to do with actually calculating the gradient and so we want to tell PyTorch to temporarily stop its tracking.
You may also wonder why we zero out the gradient. Gradients accumulate. You got a taste for that way up in figure 4.9 where the final derivative was the addition of the numbers coming in on each arrow. We need to set it back to zero before we compute loss and do backpropagation again.
That’s mostly it for code. But know that behind all of the models we look at and train going forward, there will be similar code, often at even higher levels of abstraction.
You’ve probably heard that all the magical AI stuff from the past decade is somehow connected to GPUs. NVIDIA is the most valuable company in the world, at least as of the moment I’m writing this sentence. Let’s get into why.
Over the past few sections of this chapter we moved up a few layers of abstraction so we could think in terms of matrices without getting bogged down in the details of matrix operations. I’ll go back into the weeds here to make a point.
We want to multiply two 3x3 matrices:
Let’s do this the only way we can as single-threaded humans:
- 1 × 10 = 10
- 2 × 13 = 26
- 3 × 16 = 48
- 10 + 26 = 36
- 36 + 48 = 84 ← put that in the top left position
- 1 × 11 = 11
- 2 × 14 = 28
- 3 × 17 = 51
- 11 + 28 = 39
- 39 + 51 = 90 ← put that in the top middle position
Here’s what we’ve done so far:
We’re 2/9th of the way done. This is going to get pretty tedious. Each number in the resulting matrix requires three multiplications and two additions. That’s a total of 27 multiplications and 18 additions. If we kept numbering the operations above, we would be done on the 45th operation.
Now notice something interesting about calculating the resulting matrix. Calculating 90 didn’t require anything from the calculation of 84. We had to wait only because as single-threaded humans our brains can only do so much math at once. Go a level deeper. None of the multiplications have any dependencies either. We calculated 1 × 10 and then 2 × 13, but we could just as easily have done it in the other order—nothing would change.
Matrix multiplication is sometimes called embarrassingly parallel. Yes, we may have an enormous amount of multiplying and adding to do (or in this small example 27 multiplications and 18 additions), but we can do much of it at the same time up to the constraints of our hardware. For example, the resulting matrix could be calculated in three parallel steps rather than 45 serial steps as shown here:
Here we needed 45 operations. Starting in chapter 24 we’ll be using NVIDIA H100 chips to train our 32-layer version of Nanochat. Through massive parallelism and speed, the H100, which came out in October 2022, can compute nearly 2000 trillion 16-bit floating point operations per second. (In chapter 30 I explain what a 16-bit floating point number is and why it’s important to specify the number of bits when measuring calculation speed.)
There’s so much talk about GPUs and AI these days that you may forget that the G in GPU stands for graphics. What’s the connection? Over most of their histories, NVIDIA and their competitors created chips for gaming and video rendering. Coloring each pixel of an image is more or less the same problem as doing lots of matrix operations. Say you want to render an image at 4K resolution. That’s a rectangle of 3840 × 2160 or over 8 million pixels. The good news is, just like matrix multiplication, you can mostly compute the value (the amount of red, green, and blue) of each pixel independently. The bad news is you want a fast refresh rate on your game, so you might need to do these calculations 30 or 60 times a second.
In the early 2000s, NVIDIA realized that their chips could be used for general purpose accelerated computing, and in 2007 officially released CUDA, a way for engineers and scientists to access the GPU without having to pretend they were manipulating graphics. A number of AI researchers began to experiment with training and running models on GPU in the late 2000s. The real wake up moment came in 2012 with AlexNet as you’ll read about in chapter 12. But first we’ll get into the transformer and why you would even need so much processing power for a language model.