21


Muon

I’m not going to get into every detail of Muon, but I want to show the key idea and in the process get us thinking about more than three dimensions.

With our turkey model in chapter 3, we multiplied and added individual scalar numbers such as the height of turkey #1 with the model parameter weight_1. In chapter 5 I introduced matrices. I explained them purely as a different means to the same end: perform the same operations but with better organization, greater efficiency, and a streamlined way to express the desired calculations to us and to the computer. While all true, there are times where it’s helpful to think of a matrix as more than an organized bag of numbers.

Here’s a linear transform from a pair of numbers to a pair of numbers with no bias term:

Figure 21.1. The linear transformation we’ll be using in the examples below.

I put in the pair of numbers (-0.10, 0.80) and get out (-0.50, -0.78). Our weight matrix has four numbers as you would expect. I’m purposely calling it weight and not weights because from here to the end of this chapter I don’t want to think of it as a collection of four individual weights. It’s just a “thing” in four dimensional space the same way that the number “7” is a thing in one dimensional space.

We can’t directly visualize the weight matrix as a point in four dimensional space. However, if we think of the input as an x, y coordinate and the output as an x, y coordinate we can plot input and output points. Visualizing the effect of the linear transformation is one way to get a feel for what it does.

Figure 21.2. Input point -0.10, 0.80 goes to output point -0.50, -0.78.

It’s hard to visualize what the transform is doing with just one input point and one output point. Fortunately I made up 1200 points. Here are the first 10:

Table 21.1. The first ten of 1200 input points for our examples below.

Let me put all of them through the linear transformation and plot the original points and the transformed points.

Figure 21.3. The 1200 original points and 1200 transformed points.

Aha! The original is a chicken and the output is a rotated (about the origin) and stretched version of the chicken. This isn’t a surprise because multiplying each point in an image by a two-by-two matrix always rotates, scales, slants, or flips the image, or does a combination of those operations. In fact when you use drawing software like Canva and resize, rotate, or move objects around, that’s all being kept track of with matrices indicating how the original object is transformed. Now in our transformer blocks we might be dealing with a transformation from a point in 1280 dimensions to a point in 1280 dimensions so we can’t visualize that ourselves, but to a creature who could view 1280 dimensions, they might say, aha, that transform to come up the queries in the attention module is mostly doing a “rotation” and a little bit of a “stretch.”

The reason I called the orange chicken target rather than output in the plot above is because I want to now build a model that will also be a linear transform from 2D points to 2D points. It will start with a random weight matrix and we’ll train it to match the target.

Figure 21.4. Model with a random initial weight. Our goal is to train it to match the linear transform shown above.

Let’s look at what the model does before any training:

Figure 21.5. Effect of the model before any training.

The initial model is shrinking and slightly rotating the chicken. To train, we’ll need a loss function. We’ll use the same one we used for the turkey feathers model, the hedgehog quills model, and the chicken feathers model: mean squared error. The only difference is that we're dealing with pairs of numbers. I’ll illustrate the calculation for the first three points (which fall on the beak of the chicken):

Table 21.2. Calculating mean squared error with 2D points as our model output.

When the model output gets close to the target, the errors will come down and the loss will get close to zero.

I’ll now train for 500 steps with a learning rate of 0.2. I’ll use our classic way of updating the weight matrix: subtracting the gradient times the learning rate. Let me show you what the model does at three steps during the training: (For a fun preview of what’s coming, when we train our transformer we’ll periodically see how well it does at completing phrases like “The capital of France is _____”.)

Figure 21.6. Model after 10, 100, and 500 training steps.

It works! Let’s also look at the gradient at steps 0 (the start), 100, and 500. I’m talking about the gradient of the loss between the green and orange chickens with respect to the weight. This is the gradient that is used to update the weight in the step. Besides showing the gradient matrix with numbers, I’ll also transform an image with the gradient. This may not be that helpful but it provides another way to get a feel for the gradient. I’ll transform a friendly “G” centered at the origin.

Figure 21.7: Model and weight gradient at training steps 0, 100, and 500.

At every step, including the ones not shown, the gradient is reasonable. The weight gets pushed in the right direction to keep minimizing the loss. By step 500 the gradient is small and the model’s weight matrix is close to the one I made up in the original linear transform shown in figure 21.1.

In a bigger, deeper, more complicated model like a transformer we won’t be so lucky. Perhaps in some directions the loss surface declines slowly and in others it would be like stepping off a cliff. I’m going to change my loss function to get a feel for what it’s like to occupy this shakier ground. I’m going to do something horrible and raise one of the errors to the 20th power instead of squaring it. Here’s what I mean, calculating loss for just the first three points as in table 21.2 above to show the idea.

Table 21.3. Horrible loss function.

Think about the calculation for a moment. Even though it makes no sense to raise one of the errors to the 20th power, and even though it’s going to cause serious trouble, it is a valid loss function. The closer the model output gets to the target the lower the loss. If I instead raised the errors to the 21st power would it still be a valid loss function? Something to think about.

Here’s what happens when we try to train using the horrible loss function. My training is exactly the same as above in that I’m updating the weight by subtracting the gradient times the learning rate at each step.

Figure 21.8. Attempting to train via classic gradient descent with our horrible loss function.

It doesn’t work. “Nan” means not a number. The gradient is already out of control at the start, and within two steps it’s exploded beyond numbers the computer can store. This is similar to figure 19.6 above. There the explosion was due to too high a learning rate and here it’s due to the loss surface.

So, if taking leaps according to the gradient to explore a surface is hazardous, you might come up with this idea—I don’t trust the size of the gradient, I’ll just use the gradient to tell me which direction to move in and then I’ll take safe, baby steps. Think about that in the chicken feather model above where we’re dealing with a single weight. If the gradient is say positive 3, I know I need to make my weight smaller. Instead of subtracting say 0.3 (0.1 × 3), I’ll subtract 0.1. Or if the gradient is negative, I’ll add 0.1 to my weight. I’ll use the sign of the gradient but ignore the magnitude.

Could we do something like this with our 2×2 weight and 2×2 gradient? We want to use the “direction” of the gradient matrix so we can update the weight matrix in the opposite “direction.” This should prevent us from taking a huge leap off a cliff. Instead we’ll just take a tiny, cautious step into it. For a scalar gradient like 3 or -3, we know how to convert that to a unit with the correct sign—divide by the absolute value: 3 becomes 1 and -3 becomes -1.

Something analogous to this for a matrix is orthogonalizing it. If we think about a matrix as scaling, stretching, and rotating an image, the orthogonalized version will rotate and flip by the same amount but will not scale or stretch. Watch what happens when during training, we update our weight by subtracting a small multiple of the orthogonalized gradient (rather than the gradient itself). I’m using the same horrible loss function that threw us off within three steps above.

Figure 21.9: Training with our horrible loss function for 100 steps with weight updated by subtracting a small multiple of the orthogonalized gradient.

Every update is kept under control. By step 100 we’ve hit the target. It’s incredible to me that this works. It doesn’t look like there’s enough information in the orthogonalized gradient but clearly there is. The orthogonalized gradient is like someone guiding a truck to park in exactly the right place. “Back up, back some more, ease forward.” You can see between step 76 and step 77 the guide reverses, the weights correct, and in step 78 the guide reverses again.

I’m not claiming I understand it perfectly and I’m sure I’m not explaining it perfectly. Using orthogonalized matrices or approximations of orthogonalized matrices is a powerful technique that apparently is robust to all sorts of problems that Adam gets into, especially for the linear transformations in the transformer blocks. Muon, which combines a technique to efficiently calculate an approximate orthogonalization of the gradient with a few other techniques, has proven to be efficient and robust for transformer training in the last year or two.