Update comments and KNN example

master
Taylor Smith 2018-07-19 11:48:06 -05:00
parent 083f6b679f
commit 13fa08cf6f
3 changed files with 15 additions and 3 deletions

View File

@ -5,22 +5,27 @@ from __future__ import absolute_import
from packtml.clustering import KNNClassifier
from packtml.utils.plotting import add_decision_boundary_to_axis
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_iris
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
import sys
# #############################################################################
# Create a classification sub-dataset using iris
iris = load_iris()
X = iris.data[:, :2]
X = iris.data[:, :2] # just use the first two dimensions
y = iris.target
# split data
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
# scale the data
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# #############################################################################
# Fit a k-nearest neighbor model and get predictions
k=10

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 21 KiB

View File

@ -80,7 +80,8 @@ class KNNClassifier(BaseSimpleEstimator):
# Compute the pairwise distances between each observation in
# the dataset and the training data. This can be relatively expensive
# for very large datasets!!
dists = euclidean_distances(X, self.X)
train = self.X
dists = euclidean_distances(X, train)
# Arg sort to find the shortest distance for each row. This sorts
# elements in each row (independent of other rows) to determine the
@ -93,7 +94,13 @@ class KNNClassifier(BaseSimpleEstimator):
nearest = np.argsort(dists, axis=1)
# We only care about the top K, really, so get sorted and then truncate
# I.e:
# array([[1, 2, 1],
# ...
# [0, 0, 0]])
predicted_labels = self.y[nearest][:, :self.k]
# We want the most common along the rows as the predictions
# I.e:
# array([1, ..., 0])
return mode(predicted_labels, axis=-1)[0].ravel()