08.手写KNN算法测试

导入库

import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt

导入数据

iris = datasets.load_iris()

数据准备

X = iris.data
y = iris.target
X.shape, y.shape
((150, 4), (150,))

数据分割(28开)

# 因为训练集矩阵和标签向量是分割的,不能单独对某一个进行乱序
# 需要将其合并整体乱序再分割

X_join_y = np.hstack([X, y.reshape(-1,1)])

# 随机,导致每次数据分割结果都会改变
# 如果有debug需求,需要保证每次运行的分割结果一致
# 则需要对random进行seed设置

np.random.seed(1)
np.random.shuffle(X_join_y)
train,test = np.vsplit(X_join_y, [int(0.8*len(X_join_y))])
train.shape,test.shape
((120, 5), (30, 5))

准备data和target

# X_train, y_train, X_test, y_test 成功拿到了训练集(数据+标签)和测试集(数据+标签)

X_train = train[:,0:4]
y_train = train[:,-1]
X_test = test[:,0:4]
y_test = test[:,-1]

KNN手写算法

import numpy as np
from math import sqrt
from collections import Counter
class KNNClassifier: def __init__(self, k): # 初始化KNN分类器 self.k = k self._X_train = None self._y_train = None def fit(self, X_train, y_train): # 根据训练集X_train, Y_train训练分类器 self._X_train = X_train self._y_train = y_train return self def predict(self, X_predict): # 给定待遇测的数据集X_predict,返回表示X_predict的结果向量 y_predict = [self._predict(x) for x in X_predict] return np.array(y_predict) def _predict(self, x): # 给定单个待遇测数据x,返回x的预测结果值 distances = [sqrt(np.sum((x_train - x) ** 2)) for x_train in self._X_train] nearest = np.argsort(distances) topK_y = [self._y_train[i] for i in nearest[:self.k]] votes = Counter(topK_y) return votes.most_common(1)[0][0] def __repr__(self): return "KNN=(%d)" % self.k
from sklearn.model_selection import train_test_split


result = train_test_split(X, y)
result
[array([[7.2, 3. , 5.8, 1.6],
        [5.4, 3.9, 1.3, 0.4],
        [6.5, 3.2, 5.1, 2. ],
        [6.1, 3. , 4.6, 1.4],
        [4.6, 3.2, 1.4, 0.2],
        [6.9, 3.2, 5.7, 2.3],
        [6.1, 2.8, 4. , 1.3],
        [5.7, 3. , 4.2, 1.2],
        [5.8, 2.7, 4.1, 1. ],
        [5.5, 2.5, 4. , 1.3],
        [5.7, 2.5, 5. , 2. ],
        [4.6, 3.4, 1.4, 0.3],
        [5.9, 3.2, 4.8, 1.8],
        [6.3, 2.9, 5.6, 1.8],
        [6.8, 3. , 5.5, 2.1],
        [6.4, 2.7, 5.3, 1.9],
        [6. , 2.9, 4.5, 1.5],
        [6. , 2.2, 4. , 1. ],
        [4.8, 3. , 1.4, 0.1],
        [5.6, 2.5, 3.9, 1.1],
        [7.1, 3. , 5.9, 2.1],
        [6.7, 3.3, 5.7, 2.1],
        [5.5, 2.6, 4.4, 1.2],
        [6.3, 3.3, 4.7, 1.6],
        [6.7, 3.1, 4.7, 1.5],
        [4.3, 3. , 1.1, 0.1],
        [4.8, 3.4, 1.9, 0.2],
        [6.7, 3.3, 5.7, 2.5],
        [6. , 2.7, 5.1, 1.6],
        [6.5, 3. , 5.5, 1.8],
        [4.9, 2.5, 4.5, 1.7],
        [5. , 3.5, 1.3, 0.3],
        [5.9, 3. , 4.2, 1.5],
        [5.5, 2.4, 3.8, 1.1],
        [6.2, 2.2, 4.5, 1.5],
        [6.3, 2.7, 4.9, 1.8],
        [4.4, 3. , 1.3, 0.2],
        [7.7, 3. , 6.1, 2.3],
        [7. , 3.2, 4.7, 1.4],
        [6.4, 2.8, 5.6, 2.2],
        [5.7, 2.8, 4.5, 1.3],
        [6.4, 2.9, 4.3, 1.3],
        [5.6, 3. , 4.1, 1.3],
        [6.3, 2.8, 5.1, 1.5],
        [4.9, 3.6, 1.4, 0.1],
        [6. , 3.4, 4.5, 1.6],
        [5.7, 4.4, 1.5, 0.4],
        [4.8, 3. , 1.4, 0.3],
        [5.4, 3.7, 1.5, 0.2],
        [5.4, 3.4, 1.5, 0.4],
        [5. , 2.3, 3.3, 1. ],
        [6.9, 3.1, 4.9, 1.5],
        [5.1, 3.8, 1.9, 0.4],
        [6.4, 2.8, 5.6, 2.1],
        [5.1, 3.8, 1.5, 0.3],
        [5. , 3.4, 1.5, 0.2],
        [5.1, 3.3, 1.7, 0.5],
        [5.2, 2.7, 3.9, 1.4],
        [6.1, 2.6, 5.6, 1.4],
        [7.7, 2.8, 6.7, 2. ],
        [5.8, 2.7, 5.1, 1.9],
        [6.8, 2.8, 4.8, 1.4],
        [4.4, 3.2, 1.3, 0.2],
        [5.3, 3.7, 1.5, 0.2],
        [6.9, 3.1, 5.4, 2.1],
        [5.1, 2.5, 3. , 1.1],
        [5.7, 2.8, 4.1, 1.3],
        [6.4, 3.1, 5.5, 1.8],
        [6.2, 3.4, 5.4, 2.3],
        [5.8, 2.7, 5.1, 1.9],
        [6.3, 2.5, 4.9, 1.5],
        [5.8, 2.6, 4. , 1.2],
        [4.6, 3.1, 1.5, 0.2],
        [4.9, 3.1, 1.5, 0.2],
        [5.6, 2.9, 3.6, 1.3],
        [5.1, 3.7, 1.5, 0.4],
        [5. , 3.2, 1.2, 0.2],
        [6.5, 3. , 5.8, 2.2],
        [7.3, 2.9, 6.3, 1.8],
        [5.2, 3.4, 1.4, 0.2],
        [4.5, 2.3, 1.3, 0.3],
        [5.5, 2.3, 4. , 1.3],
        [6.5, 3. , 5.2, 2. ],
        [5.5, 2.4, 3.7, 1. ],
        [7.6, 3. , 6.6, 2.1],
        [5. , 3.6, 1.4, 0.2],
        [5.9, 3. , 5.1, 1.8],
        [6.3, 2.5, 5. , 1.9],
        [6.1, 3. , 4.9, 1.8],
        [4.9, 3. , 1.4, 0.2],
        [6.7, 3. , 5.2, 2.3],
        [5.1, 3.5, 1.4, 0.3],
        [6.3, 2.3, 4.4, 1.3],
        [4.4, 2.9, 1.4, 0.2],
        [6.8, 3.2, 5.9, 2.3],
        [5.1, 3.8, 1.6, 0.2],
        [7.2, 3.6, 6.1, 2.5],
        [5.7, 3.8, 1.7, 0.3],
        [5. , 2. , 3.5, 1. ],
        [5. , 3. , 1.6, 0.2],
        [4.8, 3.4, 1.6, 0.2],
        [4.8, 3.1, 1.6, 0.2],
        [6.7, 3.1, 5.6, 2.4],
        [5.8, 2.8, 5.1, 2.4],
        [5.8, 4. , 1.2, 0.2],
        [6.1, 2.8, 4.7, 1.2],
        [5.4, 3.9, 1.7, 0.4],
        [6.5, 2.8, 4.6, 1.5],
        [4.9, 3.1, 1.5, 0.1],
        [5.4, 3.4, 1.7, 0.2],
        [4.9, 2.4, 3.3, 1. ],
        [5.1, 3.4, 1.5, 0.2]]),
 array([[6.2, 2.9, 4.3, 1.3],
        [6.7, 3. , 5. , 1.7],
        [5.2, 4.1, 1.5, 0.1],
        [5.7, 2.6, 3.5, 1. ],
        [7.4, 2.8, 6.1, 1.9],
        [5.6, 3. , 4.5, 1.5],
        [6.9, 3.1, 5.1, 2.3],
        [6. , 2.2, 5. , 1.5],
        [5.5, 3.5, 1.3, 0.2],
        [6.7, 2.5, 5.8, 1.8],
        [7.2, 3.2, 6. , 1.8],
        [6. , 3. , 4.8, 1.8],
        [5.2, 3.5, 1.5, 0.2],
        [5.1, 3.5, 1.4, 0.2],
        [5. , 3.3, 1.4, 0.2],
        [5.6, 2.8, 4.9, 2. ],
        [5.6, 2.7, 4.2, 1.3],
        [5. , 3.5, 1.6, 0.6],
        [7.9, 3.8, 6.4, 2. ],
        [6.3, 3.4, 5.6, 2.4],
        [5. , 3.4, 1.6, 0.4],
        [6.2, 2.8, 4.8, 1.8],
        [5.4, 3. , 4.5, 1.5],
        [5.5, 4.2, 1.4, 0.2],
        [4.6, 3.6, 1. , 0.2],
        [6.1, 2.9, 4.7, 1.4],
        [6.4, 3.2, 5.3, 2.3],
        [5.7, 2.9, 4.2, 1.3],
        [7.7, 2.6, 6.9, 2.3],
        [7.7, 3.8, 6.7, 2.2],
        [6.3, 3.3, 6. , 2.5],
        [5.8, 2.7, 3.9, 1.2],
        [6.6, 2.9, 4.6, 1.3],
        [4.7, 3.2, 1.6, 0.2],
        [6.7, 3.1, 4.4, 1.4],
        [6.4, 3.2, 4.5, 1.5],
        [4.7, 3.2, 1.3, 0.2],
        [6.6, 3. , 4.4, 1.4]]),
 array([2, 0, 2, 1, 0, 2, 1, 1, 1, 1, 2, 0, 1, 2, 2, 2, 1, 1, 0, 1, 2, 2,
        1, 1, 1, 0, 0, 2, 1, 2, 2, 0, 1, 1, 1, 2, 0, 2, 1, 2, 1, 1, 1, 2,
        0, 1, 0, 0, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2, 2, 1, 0, 0, 2, 1,
        1, 2, 2, 2, 1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 0, 1, 2, 1, 2, 0, 2, 2,
        2, 0, 2, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0, 0, 2, 2, 0, 1, 0, 1, 0, 0,
        1, 0]),
 array([1, 1, 0, 1, 2, 1, 2, 2, 0, 2, 2, 2, 0, 0, 0, 2, 1, 0, 2, 2, 0, 2,
        1, 0, 0, 1, 2, 1, 2, 2, 2, 1, 1, 0, 1, 1, 0, 1])]
my_knn_clf = KNNClassifier(k=3)
my_knn_clf.fit(result[0], result[2])
KNN=(3)

y_predict = my_knn_clf.predict(result[1])
sum(y_predict == result[3])
sum(y_predict == result[3])/len(result[3])
原文地址:https://www.cnblogs.com/waterr/p/14039173.html