10
Deep and nonlinear models: hedgehogs and quills
I lied. We’re going to talk about hedgehogs before we crack open the transformer. I want to cover a few deep neural network concepts in a more straightforward model so there will be fewer new concepts inside the transformer.
In chapter 3 we created a model that took the height and length of a turkey and predicted the number of feathers. Suppose you land on an alien planet that happens to have a lot of hedgehogs. You go out and measure the length and count the number of quills on a thousand of them.
Here’s the data on the first 10 alien hedgehogs:
I had to make the data up because if we have human or robot teams on alien planets, they’ve prioritized other types of research, or kept the hedgehog stuff out of the public domain. Look at the table. Does a relationship between length and number of quills jump out at you? You would think the longer the animal the more quills, but even from a glance this doesn’t seem to be true.
We as humans are excellent at spotting patterns and trends when data is graphed in two dimensions. We’re also reasonable at seeing trends in three dimensions. Once we go beyond three we need to take special measures to view data, for example, making many plots each with different combinations of two dimensions, or using math to reduce the dimensions. Fortunately this example has only two dimensions: length and number of quills. This is one less than the turkeys where we dealt with height, length, and number of feathers.
Let’s plot all 1000 hedgehogs.
We can immediately see that there is a relationship between length and number of quills. It looks like hedgehogs around 1.3 meters long have over 5000 quills, ones that are 1.5 meters drop down to closer to 1000 quills, and 1.6 meter hedgehogs have around 4000 quills. It all seems pretty weird, but definitely not random. If you told me the length of a hedgehog not in this dataset, I bet I could make a pretty good prediction of how many quills it has.
That’s fine for a mental model. But how about a computer model? Our goal is:
One thing I didn’t worry about in the turkey example was holding back validation data. Here we’re going to be more careful. The idea is that if I use all of my data to train, I could end up with a model that seems incredible as evidenced by very low loss. However, maybe what’s really happening is the model has become overfit to the training data and won’t do a great job with new data. The whole point of the model is to predict things we don’t know.
The trick to not fooling yourself is separating your data into two sets: training and validation. We’ll use the training data during our training loop when we’re calculating loss, doing backprop, and updating weights. We’ll also calculate loss on validation data to tell us how we’re doing, but we’ll never use validation data to update weights.
So now say we’ve trained. Loss calculated on training data looks great, but loss calculated on validation data is much worse. This is a bad sign. It means our model has failed to capture some fundamental underlying pattern or truth. Instead it has memorized something about the specific training data that’s not important and possibly detrimental to learning the true pattern.
I mentioned above that we collected data on 1000 hedgehogs. I’m going to split this full dataset into 900 hedgehogs for training and 100 for validation.
How do we calculate loss? We’ll use squared errors just as in the turkey example, but I’m going to switch to taking the mean (average) rather than the sum. This doesn’t really change anything in terms of getting to a good model, but it does have a bunch of advantages.
The first is we can directly compare loss numbers across different sized batches of data. For example, we’ll want to compare loss between our training dataset with its 900 hedgehogs and our validation dataset with 100.
Second—it’s easier to have an intuition about the number. With our turkey model, the first loss we calculated was 1.3 million as shown in table 3.3. It’s hard to know what to make of that unless you know how many turkeys went into the sum. For our hedgehog model, let’s say the loss is 4. What does that tell you? Well, 4 is 2 squared, so a mean squared error loss of 4 makes me think about predictions that are over or under by 2000 quills. From figure 10.2 you can tell that being off by 2000 quills would be nothing to write home about. You might as well predict the average, which we’ll get to in a moment.
Finally, mean squared error or MSE is just more standard. It’s probably the most common loss function for models that predict continuous values like number of quills.
Time to build a model. Let’s start simple, even simpler than we started with the turkey example. The average number of quills of the hedgehogs in the training data is 3.03. Let’s just ignore the input and always predict 3.03.
We can calculate loss:
Since the square root of 1.050 is 1.025, we’ll be off by an average of around 1025 quills, judging by our validation set. Let’s try to beat this with a model like the one we used to predict the number of turkey feathers.
In a way this is more straightforward than the turkey model because we have a single input. In the turkey model we multiplied height by a weight, length by a different weight, and added those together. We started with random weights and adjusted them through training until they were close to optimal.
We’ll do the same here except with just a single weight. I do though want to add in what’s called a bias term. This is a number that gets added after the multiplication. So using x to mean our input, here’s what the model look like:
As an example, if w is 1 and b is 1, we would predict that a hedgehog of length 2 meters would have 3 (thousand) quills.
Before we train, I want to tighten up our terminology. In the turkey example we adjusted weights. In this model, I have a weight and a bias and I’ll be adjusting both of these. I might casually call everything a weight, but to avoid confusion, I’m going to use the term parameter. Our model has two parameters and we’ll be adjusting both of them during training to minimize loss.
I won’t show the training steps because they’re the same as the turkey example. Here’s the result:
Model #2 beat model #1! Validation loss is also lower than training loss so it doesn’t look like we overfit. Here are the parameters we learned:
And best of all, and unlike when we get into modeling all of human knowledge and reasoning with our transformer :), we can easily visualize exactly what our model is doing by plotting the predictions.
The good news is we’re making better predictions than before. The bad news is the model doesn’t seem all that great compared to what we humans could do. This is not a surprise because multiplying by a single weight and adding a bias can’t give us anything other than a line. In fact you might remember learning the line equation y = mx + b in middle school.
Before we go on, I want to redraw the model diagram. We talked in chapter 5 about how helpful it is to work at higher levels of abstraction, for example with matrices. But in figure 10.5 I went all the way back to scalar variables (x, w, and b). There’s a term for multiplying by weights and adding a bias: a linear transformation. Since that’s the only thing going on in model #2, I’m going to redraw my diagram using a linear module which I’ll show as a green box. You’ll be seeing many of these when we look inside the transformer in the next few chapters.
I wrote “1 → 1” in the linear module to indicate that we’re transforming from a single dimension (length) to a single dimension (number of quills). As a reminder from the turkey example, we can show this as a matrix. I know at this point it’s overkill.
Two thoughts pop into my mind as I think about how to build a better model. The first is that there’s this up and down relationship between length and number of quills, so no matter what, a single line isn’t going to be that great. I want to give the model more parameters to work with. The second is that I’ve heard of deep learning and deep neural networks and so I should try making my model deeper. This too might give it more room to learn the patterns in the data. Let’s try:
In this model I take a single input (like before) and feed it to my first linear transform. This first linear transform turns it into 10 numbers. Those 10 numbers go into the second linear transform which turns it back into one number. I drew 10 arrows between the layers here to emphasize the 10 numbers but I won’t do this going forward.
One nice thing about thinking at this higher level of abstraction is you don’t need to keep track of every single calculation. But since it’s our first time seeing a linear layer go from a small number to a bigger number, I’ll describe exactly what’s happening.
The first linear layer has 10 weights and 10 biases. Let’s call the input to this layer x, the length of a hedgehog. The 10 numbers coming out of the layer are w1(x) + b1, w2(x) + b2, etc. In other words, each number coming out is the input multiplied by a different weight and added to a different bias.
It’s easier to keep track of all this with matrices:
The second linear layer is similar to our turkey model because we take multiple numbers and reduce them to a single number. Think of this as mixing the inputs—an extreme version could be to multiply the first number by 1 and all the others by 0 indicating that we only want the information in the first number. If we call our inputs x1, x2, etc. then our single output number will be w1(x1) + w2(x2) + … + w10(x10) + b. The weights and bias here are just for this second layer and are not the same as the first layer.
Here’s how this looks with matrices. The output of the first linear layer is the input to the second linear layer.
Look how beautiful that is. Since you may still be getting used to working with matrices, let me also show the two linear layer calculations with actual numbers. I’ll use linear transformations of sizes 1 → 3 and 3 → 1 to keep things legible. I’ll also use random small integers (unrelated to our hedgehog data) for the weights and biases both to keep it legible and so you can repeat the calculations.
Now back to our hedgehog model #3, let’s count the total parameters. We have 10 weights and 10 biases for the first linear layer plus 10 weights and 1 bias for the second linear layer for a total of 31 parameters. There’s much more room for the model to learn. Let’s go beat model #2. Again, you can assume I trained the model properly and I’m going to skip right to showing the loss with the resulting model.
Well, that’s not good. Model #3 is very close to model #2 and actually slightly worse. Let’s look at the plot:
Another line! What’s going on? Let me illustrate by pretending our first linear layer is 1 → 2 and our second linear layer is 2 → 1 to keep the numbers small. The outputs of the first layer will be w1x + b1 and w2x + b2. These outputs are the inputs to the second layer, which means the output of the second layer will be w3(w1x + b1) + w4(w2x + b2) + b3. Multiply out and combine terms:
w3w1x + w3b1 + w4w2x + w4b2 + b3
(w3w1 + w4w2)x + w3b1 + w4b2 + b3
Think of w3w1 + w4w2 as a weight and w3b1 + w4b2 + b3 as a bias…and…oh no! We’re back to our line equation with a single weight and a single bias.
wx + b
Another way to see this is to look at figure 10.12. All those nice matrix calculations really just amount to multiplying the input by 85 and adding 87. Don’t believe me? Plug in a few different input numbers.
So all that fancy work with a two-layer network and 29 extra parameters was pointless. No matter how many linear transforms we do, and no matter how many parameters those linear transforms have, if all we do is feed the output of one transform into the input of the next, the whole thing can be simplified to a single linear layer.
We need something else. When you see what it is you’re going to think it’s too simple, almost a trick. Oddly it has a complicated name. To give you an intuition, suppose we wanted to create a model with two linear layers. The first layer goes from 1 → 2 and the second from 2 → 1. The purpose of the model is to predict y from x in this plot:
Even though it looks easy, you already know from the argument above that our model is going to be a line. This means at best we could get the right side correct or the left side correct, but not both. We need a way to tell our model to treat x < 0 differently than x > 0.
Let’s say we add one more operation that we can sandwich between the two linear layers: taking the max of 0 and our input. If our input is negative, it becomes 0. If it’s positive, no change. We’ll get to if it’s useful or not in a moment, but you can clearly see it’s something other than a linear transformation. There is no way to rearrange y = mx + b into meaning y is the maximum of x and 0. Try to think of values for m and b in y = mx + b that will always give you x or 0, whichever is higher. It’s impossible.
Here’s a diagram with the model I just described and variable names for the parameters in each linear layer. Before reading past the diagram, see if you can come up with parameter values that will make this model match the “data” shown in figure 10.14. Or put another way, come up with the parameters that will make this model match the function y = |x| (y = absolute value of x). This is a fun one to play with so don’t peek ahead.
One answer is to set w2 to -1 and all the other weights to 1, and set all three biases to 0. Let’s input 5 and -5 as examples:
This function to take the max of the value and zero has the fancy name Rectified Linear Unit or ReLU and I’ll be calling it that going forward. With this new tool, we give the model a way to introduce discontinuities. If it can match the “V” shape above, perhaps it can match more complicated data with lines of different slopes and intercepts in different regions. Let’s try!
Unlike in figure 10.9, I’m now only showing a single arrow connecting the output of one layer to the input of the next. I did add the dimensions of the tensors in purple as a reminder of the shape of the numbers that will flow through the model. I trained it and:
Our loss came down. That’s a good sign. Let’s look at the plot.
Yes! For the first time our model is not just a line, in other words, nonlinear. We’re on the right path, but we’re still not giving the model much room to learn the pattern. Let’s introduce one more layer and a lot more parameters.
Before reading on, count the number of parameters. The answer: 10 weights and 10 biases for the first linear layer, 100 weights and 10 biases for the second linear layer, and 10 weights and 1 bias for the third linear layer. That’s 20 + 110 + 11 = 141. (If you’re not sure why we need 100 weights and 10 biases for the second linear layer, think of it like this: each of the ten numbers in the output is its own linear combination of the 10 input numbers. Or you can think about the weight and bias matrices inside the layer. We multiply a 1×10 input by a 10×10 weight matrix to get a 1×10 result, and add that to a 1×10 bias matrix to get our final 1×10 output.)
Enough talking. Let’s train it.
Huge improvement! The square root of 0.153 is around 0.4 which means on average our predictions are off by around 400 quills. Here’s the plot.
What if we go from 1 to 100 instead of 1 to 10? This could give the model even more room to match all the turns in the data. We also could be entering the danger territory of it beginning to miss the forest through the trees and start memorizing every single data point. We’ll have 100×100 = 10,000 weights in the middle linear layer, far more than our 900 hedgehogs in the training dataset.
Here’s the loss:
And the plot:
It looks good. It appears to have picked up the trend. It’s modeling the big pattern without getting caught up in bits of what appear to be random variation or sloppy mistakes by our team of people in spacesuits measuring the hedgehogs with wooden meter sticks. One nice thing about our input being just a single number in a constrained range (hedgehog length) is we don’t need to limit ourselves to looking at predictions for the actual hedgehogs. We can look at predictions for all inputs in the range. This might be wise here to make it a little easier to spot if we’re overfitting.
It’s beautiful. We don’t know why there’s such an odd relationship between length and number of quills on these alien hedgehogs, but the computer can now make a good prediction. The hedgehog is 1.6 meters long? The model spits out 3.9.
With our hedgehogs you’ve now seen a few concepts that will come up again as we get into the transformer. A deep neural network is one with multiple layers. Usually the layers between the input and output layers are called hidden layers. They aren’t hidden in the way things in our brain are hidden—it’s easy enough to tap into them and see what parameters were learned and watch numbers flow through them. They’re only hidden in the sense that they’re inside the model, and if you view the model as a black box, you don’t need to think about them.
You’ve also seen two types of functions aka layers aka modules: linear transformations and ReLU. These, or other similar functions, are what neural networks are built of and what lets them be so good at modeling patterns in data. We’ll see both in the transformer.
I covered splitting data into a training set and a validation set and showed how and why we calculate loss on both.
Another thing, a little more subtle, is the idea of finding a model architecture suitable for the task. Model #1 would have worked fine for our turkey example but was not especially suited for the alien hedgehogs whereas models #5 and #6 fit well. The trick is to find an architecture that fits the problem and is trainable and then let backpropagation work its magic.
Finally, this is for the reader who says—wait, there are no alien hedgehogs, you generated that data, so how is model #6 compared to the ideal? You’re right. I made up a function and then added random noise. Let’s compare:
Pretty good! It’s not surprising that model #6 is worse in the tails. There aren’t many hedgehogs in our dataset that are less than 0.9 meters or more than 2.1 meters. There’s less data to learn from and being off in those regions does not affect the loss as much as in the regions where there’s lots of data.
Onward and upward. To the transformer!