最近邻分类器项目-分类小麦种子

   数据给出的是小麦的特征数据以及类型。每一个样本由7个特征属性组成,即可以看做7维空间的一个点。我们通过计算两个样本的距离来度量样品间的相似度。在分类时,采用一个简单的规则:对于一个新的样本,我们在数据集中找到最接近它的点,然后将该样本归为和它最近点的同一标签。并采用10折交叉验证。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# __author__ : '小糖果'

import numpy as np
import matplotlib.pyplot as plt

class KnnRecomender(object):

	def __init__(self,fname,k = 1):
		data = []
		labels = []
		with open(fname) as f:
			for line in f:
				d = line.split()
				data.append([float(x.strip()) for x in d[:-1]])
				labels.append(d[-1].strip())
		self.features = np.array(data)
		self.labels = np.array(labels)
		self.k = k
		self.acc = 0.0

	def plurality(self, results):
		counts = {}
		for v in results:
			counts.setdefault(v,0)
			counts[v] += 1
		maxc = max(counts.values())
		for k,v in counts.items():
			if maxc == v:
				return k

	def applyModel(self,testing_feats, model):
		training_feats,labels = model
		results = []
		for f in testing_feats:
			d  = []
			for t,label in zip(training_feats,labels):
				dis = np.linalg.norm(f-t)
				d.append((dis,label))
			d.sort()
			d = d[:self.k]
			results.append(self.plurality([label for dis,label in d]))
		return np.array(results)

	def accuracy(self,test_model,learn_model):
		preds = self.applyModel(test_model[0],learn_model)
		acc = np.mean(preds == test_model[1])
		return acc

	def crossValidata(self):
		self.acc = 0
		for fold in range(10):
			# 采用10折交叉验证
			training = np.ones(self.features.shape[0],bool)
			training[fold::10] = 0
			testing = ~training
			learn_model = (self.features[training].copy(),
					 self.labels[training].copy())
			test_model = (self.features[testing].copy(),
					 self.labels[testing].copy())
			self.acc += self.accuracy(test_model,learn_model)
		self.acc /= 10

	def standard(self):
		m = self.features.mean(axis = 0)
		s = self.features.std(axis = 0)
		self.features = (self.features - m)/s

def test():
	fpath = r'C:UsersTDDesktopdataMachine Learning1400OS_02_Codesdataseeds.tsv'
	instance = KnnRecomender(fpath)
	instance.crossValidata()
	print "the accuracy is {:.2f}%".format(instance.acc * 100)

	# 将数据标准化后再测试
	instance.standard()
	instance.crossValidata()
	print "the accuracy is {:.2f}%".format(instance.acc*100)

if __name__ == '__main__':
	test()

 结果得到:

the accuracy is 89.52%     (没有标准化)
the accuracy is 94.29%     (标准化后)

原文地址:https://www.cnblogs.com/td15980891505/p/6019851.html