From the course: Data Science for Java Developers

K-nearest neighbor basics - Java Tutorial

From the course: Data Science for Java Developers

K-nearest neighbor basics

- [Instructor] Okay, so we've spent quite a bit of time building up to this point and now we're actually going to start building some models for our datasets. So the first kind of model, the first classifier we're going to look at is called the K-Nearest-Neighbor algorithm. Now what the K-Nearest-Neighbor algorithm is is it's a way of classifying unknown points based on the other points that we know about that they're closest to. So in the graphic over here to the left, you can see an example where we might have two categories, category A and category B, and we want to classify an unknown data point. This is a data point we don't know what category it's in, exactly. So what the K-Nearest-Neighbor algorithm asks then is which K points, K being just some constant, like one or three, or five, or seven, they're usually odd and I'll talk about that in a second, which K points is a new point closest to? So in other words, if we have a new data point like we do in the picture here, what we're going to do is figure out what the closest points to that point are, what are the closest points that we know about, and based on what those closest points are, we're going to label that new point as part of that group. So it's a fairly straightforward algorithm and we're going to implement it in JAVA in just a minute, here. So first, just to formalize things a little bit, let's talk about these steps involved in performing a KNN classification. And now actually this is pretty simple because there's only two steps. The first step is for us to calculate the distance between whatever new data point we're looking at and the rest of the points in the dataset. So basically what we're going to do is calculate how far away the new point is from the rest of the points we know about. And step two is based on that information, we're going to pick the nearest points and let them sort of vote on whatever the classification for that new point should be. So if we're trying to label a new point as either red or green, just to come up with some fake labels that we might be working with, if we were working with K equal to three, basically what we'd do is we would take the three closest points, we'd see what the overall majority of those points thinks, right? So if two of them were green and one of them was red, we would label it as green. If all three of them were red, we would label it as red. And the same if we were to do it with a different number for K. So basically again, K is just the number of closes neighbors we're taking into account in the final classification. So that's simple enough. Just a few things to keep in mind about KNN before we get started. The first is that there's very little actual up-front learning involved. We talked earlier about building models for data science, and those of you who are familiar with things like neural networks will know that some data science models involve quite a bit of up-front learning, right? So it takes a lot of computational power in order to actually build the model. With KNN, there is very little, if any, up-front learning because all of the computation takes place at the moment when we want to classify a new point because that new point is going to determine what the closest points are and therefore, what the ultimate classification is going to be. Another thing to keep in mind about K-Nearest-Neighbor is that K is usually an odd number. This is more often than not just to help break ties, although now that I mention it, it doesn't necessarily have to be an odd number. But what we want is an odd number that doesn't divide evenly into the number of different classifications, right? So if we have three groups, we generally won't want to pick K as a multiple of three because what would happen is the three nearest neighbors could, in theory, be one from the first group, one from the second group, and one from the third group and then we'd really have no idea which one to pick. Now that being said, there are some variations on KNN that can help with situations like what I just mentioned. And one of those is something called weighted KNN which is basically where, instead of just saying okay, this point is one of the K closest points to our new point that we're trying to classify, instead of doing that and having each of the points have an equal vote, what you do is you actually have the distance, itself, factored into that point's weight. And this is a good way to break ties if you do end up with a tie in the nearest points. Let's say that there's two possible classifications and one of them is much farther away than the other, then obviously, you'll probably want to pick the closest point and give that point's value more of a say. All right, so that's weighted KNN. There's obviously a lot of different variations on this one. And one more thing to keep in mind about KNN is that although there's very little up-front learning involved, what this means is that KNN can actually be very slow when working with very large datasets in production because again, KNN pushes all of its computation to the actual classification step. So basically with neural networks, for example, once you train the neural network, it goes a lot faster once you're actually trying to classify a given item. Whereas with KNN, all of that computation again, takes place in the actual classification step. So if you have several gigabytes of data or more, and you have to go through all of that data and calculate the distance for each data point you're trying to calculate, that's obviously not ideal. So anyway, that's something to keep in mind when working with KNN. And the last thing, we're going to be building a model here in just a second, and something to keep in mind while we're building that is that we're going to be building this model based on something called the Iris dataset. Now, I'll go into a little more detail on what that is exactly, but just as a general summary, we're going to be building a classifier that is able to classify different species of flowers, depending on measurements of certain parts of that flower. So just to give you a graphical representation of what this is going to look like, the different points we have here, the red ones, the yellow ones, and the green ones, are the different species of flowers visualized on a graph, on a scatter-plot. We've seen how to do that, before. So the little blue X in between the red and yellow groups might be some flower that we're trying to classify. So what we would do with KNN and what you're going to see when we build our classifier in just a minute is we're going to see which points are closest to that point and let those closest points sort of vote on what the final classification for that new point should be. And know that also that this graph here only displays two dimensions of the actual dataset. The actual dataset is four dimensional meaning that basically, each data point has four different variables that can change, so this isn't 100% accurate as a visualization, but I think it gets the point across. Basically, what we're trying to do is see which data points are closest to a given new point, which, in our case, is represented by that blue X. So without further ado, let's get started building our classifier.

Contents