diff --git a/examples/clustering/example_knn_classifier.py b/examples/clustering/example_knn_classifier.py index 239ff6c..1e81691 100644 --- a/examples/clustering/example_knn_classifier.py +++ b/examples/clustering/example_knn_classifier.py @@ -5,22 +5,27 @@ from __future__ import absolute_import from packtml.clustering import KNNClassifier from packtml.utils.plotting import add_decision_boundary_to_axis from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler from sklearn.metrics import accuracy_score from sklearn.datasets import load_iris from matplotlib import pyplot as plt from matplotlib.colors import ListedColormap -import numpy as np import sys # ############################################################################# # Create a classification sub-dataset using iris iris = load_iris() -X = iris.data[:, :2] +X = iris.data[:, :2] # just use the first two dimensions y = iris.target # split data X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) +# scale the data +scaler = StandardScaler() +X_train = scaler.fit_transform(X_train) +X_test = scaler.transform(X_test) + # ############################################################################# # Fit a k-nearest neighbor model and get predictions k=10 diff --git a/img/clustering/example_knn_classifier.png b/img/clustering/example_knn_classifier.png index 6ce10cb..d59ba2f 100644 Binary files a/img/clustering/example_knn_classifier.png and b/img/clustering/example_knn_classifier.png differ diff --git a/packtml/clustering/knn.py b/packtml/clustering/knn.py index aa1801c..34e0719 100644 --- a/packtml/clustering/knn.py +++ b/packtml/clustering/knn.py @@ -80,7 +80,8 @@ class KNNClassifier(BaseSimpleEstimator): # Compute the pairwise distances between each observation in # the dataset and the training data. This can be relatively expensive # for very large datasets!! - dists = euclidean_distances(X, self.X) + train = self.X + dists = euclidean_distances(X, train) # Arg sort to find the shortest distance for each row. This sorts # elements in each row (independent of other rows) to determine the @@ -93,7 +94,13 @@ class KNNClassifier(BaseSimpleEstimator): nearest = np.argsort(dists, axis=1) # We only care about the top K, really, so get sorted and then truncate + # I.e: + # array([[1, 2, 1], + # ... + # [0, 0, 0]]) predicted_labels = self.y[nearest][:, :self.k] # We want the most common along the rows as the predictions + # I.e: + # array([1, ..., 0]) return mode(predicted_labels, axis=-1)[0].ravel()