So, linear regression isn't actually all that simple. When I say it's "a moron's machine learning" I more mean that it's very frequently used by people and companies who are pretending that a much more complex type of machine learning is happening.
I actually learned about linear regression from this context as a joke made by colleagues of mine making fun of a dubious "artificial intelligence" solution in the cyber security industry. I decided to do a bit of a dive into linear regression to see how simple it actually was and whether or not it was something that I could learn easily enough.
After all, if I can learn it, it can't be that complicated...
This is another installation of my series on learning mathematics and machine learning. You can view the last article in the series here.
Before we get into it, this blog is based on a tutorial by NeuralNine on YouTube, who did a great breakdown of the mathematical theory of linear regression before writing the code that I will use for the rest of this article. Seriously, go check out his video, it rocks.
What are the goals of linear regression?
In the picture above, you'll see a fairly simple scatter plot for some random data. You'll intuitively notice that the graph trends upwards, meaning that on average the y-value of each data point is trending upward. We can generally think about this average as a line passing through all of the data points, as shown below.
Now, if we wanted to use this line for predictions or to describe the data, we would want something a bit better than a line just randomly drawn through the data. We also might want a way to reliably generate that line with any given dataset. In general, how would we construct that line most usefully?
Well, we would want to draw the line in such a way that its distance to as many points as possible is minimized as much as possible. There will obviously be outliers in our data where the distance between that outlier point and our line is relatively large, but in general we want each point y on the line at x to be as close to the actual data point y_actual at x as possible. We can call this minimizing error.
Now, there is a fairly simple formula for calculating the error between a given point on a line and corresponding data at that point.
A crash course in partial derivatives
First off, why do we care about the derivative of our mean squared error algorithm?
Well, our mean squared error function will calculate the average error with a given m and b value. If we were to try to calculate the mean squared error across a range of m and b values, the plot of the values of E, or our mean squared errors, would look like this as m or b increase. It would be a parabola starting at a large number at the lower end of our p and b values approaching zero and then increasing from zero to a much larger number at the other end of our p and b values. What we are interested in is the space where E is as close to zero as possible.
The derivative of a given function gives the value of its slope at any given point. What we are looking for is a place where the slope is zero, which would be the bottom of the curve where the mean squared error is zero.
The process of calculating a derivative of a single-variable function, something like y=7x+12, is fairly straightforward... The problem that we have in our use case here is that we have multiple variables, m and b. We will calculate partial derivatives of the function with respect to m and b separately, and use those functions to find our idea m and b values in the process described below.
The mathematics of linear regression
Alright, now let's talk even more math.
So now things have gotten a bit more complicated. We now have two separate equations to optimize on: the derivative of our mean squared error function with respect to m and the derivative of our mean squared error function with respect to b. We want to optimize on both to find the m and b pair that create the lowest, or at least an incredibly low, mean squared error with our input data.
Let's go back to our visualization of the mean squared error parabola.
We can go ahead and remove half of the parabola: after the slope of the tangent line starts to increase, we aren't interested in the m or b values anymore. We're only interested in the "left" half of the parabola, or the part that represents m and b increasing up until the point that the tangent line's slope is zero.
What linear regression seeks to do is slowly increase the value of m and b from some random starting point by a variable rate known as the learning rate.