K-近邻算法

🤔


K-近邻算法概述

  • 简单的说,K-近邻算法采用测量不同特征值之间的距离方法进行分类。
  • 优点:精度高,对异常值不敏感,无数据输入假定。
  • 缺点:计算复杂度高,空间复杂度高。
  • 适用范围:数值型和标称型。

K-近邻算法工作原理

存在一个样本数据集合,可以称之为训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最邻近数据的分类标签。一般来说我们只选择样本数据集中前k个最相似的数据,这就是K-近邻算法中k的由来,通常k是不大于20的整数。最后选择k个最相似数据中出现次数最多的分类,作为新数据的分类。


  • 下边是K-近邻算法的一个示例:
  1. 收集数据:提供文本文件。
  2. 准备数据:使用python解析文件。
  3. 分析数据:使用Matplotlib画二维扩散图
  4. 训练算法:
  5. 测试算法:以文件部分数据作为测试样本
  6. 使用算法:可以输入特征数据以判断是否正确
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
'''
在文件datingTest中存放着某约会网站的样本数据,每个样本占据一行,共1000行。
其中主要包括: 每年飞行里程数, 玩游戏所耗时间比, 每周消费的冰淇淋数。
'''

from numpy import *
import os
import operator
import matplotlib
import matplotlib.pyplot as plt
from os import listdir

def file2matrix(filename):
'''
处理输入格式
:param filename:
:return:
'''
fr = open(filename)
numberOfLines = len(fr.readlines()) #得到文件行数
returnMat = zeros((numberOfLines,3)) #创建返回的NumPy矩阵
classLabelVector = []
fr = open(filename)
index = 0
for line in fr.readlines(): #解析文件数据到列表
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector

def autoNorm(dataSet):
'''
将每列的最小值放在变量minVals中,最大值放在maxVals中,
其中dataSet.min(0)中的参数0使得函数可以从列中取最小值,而不是选取当前行最小值。
然后函数计算可能的取之范围,并创建新的矩阵返回
:param dataSet:
:return:
'''
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m, 1))
normDataSet = normDataSet / tile(ranges, (m, 1)) #特征值相除
return normDataSet, ranges, minVals


def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
#距离计算
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
#选择距离最小的k个点
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
#排序
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]


def datingClassTest():
'''
使用file2matrix和autoNorm函数从文件中读取数据并转换为归一化特征值
接着计算测试向量的数量,决定normMat向量中哪些数据用于测试,哪些用于训练样本
然后将两部分数据输入到原始分类起函数classify0,
最后计算错误率并返回
:return:
'''
hoRatio = 0.10
datingDataMat, datingLabels = file2matrix('datingTestSet.txt')
normMat, range, minvals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print("the classifier came back with : %d, the real answer is : %d" %(classifierResult, datingLabels[i]))
if(classifierResult != datingLabels[i]):
errorCount += 1.0
print("the total error rate is : %f" %(errorCount/float(numTestVecs)))


datingClassTest()
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2], 15.0*array(datingLabels), 15.0*array(datingLabels))
plt.show()
宇 wechat
扫描二维码,订阅微信公众号