Update comments and KNN example
parent
083f6b679f
commit
13fa08cf6f
|
@ -5,22 +5,27 @@ from __future__ import absolute_import
|
|||
from packtml.clustering import KNNClassifier
|
||||
from packtml.utils.plotting import add_decision_boundary_to_axis
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.datasets import load_iris
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.colors import ListedColormap
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
# #############################################################################
|
||||
# Create a classification sub-dataset using iris
|
||||
iris = load_iris()
|
||||
X = iris.data[:, :2]
|
||||
X = iris.data[:, :2] # just use the first two dimensions
|
||||
y = iris.target
|
||||
|
||||
# split data
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
|
||||
|
||||
# scale the data
|
||||
scaler = StandardScaler()
|
||||
X_train = scaler.fit_transform(X_train)
|
||||
X_test = scaler.transform(X_test)
|
||||
|
||||
# #############################################################################
|
||||
# Fit a k-nearest neighbor model and get predictions
|
||||
k=10
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 21 KiB |
|
@ -80,7 +80,8 @@ class KNNClassifier(BaseSimpleEstimator):
|
|||
# Compute the pairwise distances between each observation in
|
||||
# the dataset and the training data. This can be relatively expensive
|
||||
# for very large datasets!!
|
||||
dists = euclidean_distances(X, self.X)
|
||||
train = self.X
|
||||
dists = euclidean_distances(X, train)
|
||||
|
||||
# Arg sort to find the shortest distance for each row. This sorts
|
||||
# elements in each row (independent of other rows) to determine the
|
||||
|
@ -93,7 +94,13 @@ class KNNClassifier(BaseSimpleEstimator):
|
|||
nearest = np.argsort(dists, axis=1)
|
||||
|
||||
# We only care about the top K, really, so get sorted and then truncate
|
||||
# I.e:
|
||||
# array([[1, 2, 1],
|
||||
# ...
|
||||
# [0, 0, 0]])
|
||||
predicted_labels = self.y[nearest][:, :self.k]
|
||||
|
||||
# We want the most common along the rows as the predictions
|
||||
# I.e:
|
||||
# array([1, ..., 0])
|
||||
return mode(predicted_labels, axis=-1)[0].ravel()
|
||||
|
|
Loading…
Reference in New Issue