数据给出的是小麦的特征数据以及类型。每一个样本由7个特征属性组成,即可以看做7维空间的一个点。我们通过计算两个样本的距离来度量样品间的相似度。在分类时,采用一个简单的规则:对于一个新的样本,我们在数据集中找到最接近它的点,然后将该样本归为和它最近点的同一标签。并采用10折交叉验证。
#!/usr/bin/env python # -*- coding: utf-8 -*- # __author__ : '小糖果' import numpy as np import matplotlib.pyplot as plt class KnnRecomender(object): def __init__(self,fname,k = 1): data = [] labels = [] with open(fname) as f: for line in f: d = line.split() data.append([float(x.strip()) for x in d[:-1]]) labels.append(d[-1].strip()) self.features = np.array(data) self.labels = np.array(labels) self.k = k self.acc = 0.0 def plurality(self, results): counts = {} for v in results: counts.setdefault(v,0) counts[v] += 1 maxc = max(counts.values()) for k,v in counts.items(): if maxc == v: return k def applyModel(self,testing_feats, model): training_feats,labels = model results = [] for f in testing_feats: d = [] for t,label in zip(training_feats,labels): dis = np.linalg.norm(f-t) d.append((dis,label)) d.sort() d = d[:self.k] results.append(self.plurality([label for dis,label in d])) return np.array(results) def accuracy(self,test_model,learn_model): preds = self.applyModel(test_model[0],learn_model) acc = np.mean(preds == test_model[1]) return acc def crossValidata(self): self.acc = 0 for fold in range(10): # 采用10折交叉验证 training = np.ones(self.features.shape[0],bool) training[fold::10] = 0 testing = ~training learn_model = (self.features[training].copy(), self.labels[training].copy()) test_model = (self.features[testing].copy(), self.labels[testing].copy()) self.acc += self.accuracy(test_model,learn_model) self.acc /= 10 def standard(self): m = self.features.mean(axis = 0) s = self.features.std(axis = 0) self.features = (self.features - m)/s def test(): fpath = r'C:UsersTDDesktopdataMachine Learning1400OS_02_Codesdataseeds.tsv' instance = KnnRecomender(fpath) instance.crossValidata() print "the accuracy is {:.2f}%".format(instance.acc * 100) # 将数据标准化后再测试 instance.standard() instance.crossValidata() print "the accuracy is {:.2f}%".format(instance.acc*100) if __name__ == '__main__': test()
结果得到:
the accuracy is 89.52% (没有标准化)
the accuracy is 94.29% (标准化后)
。