K近邻算法(k-nearest neighbor, kNN)

2023-02-16,,,

K近邻算法(K-nearest neighbor, KNN

KNN是一种分类和回归方法。

KNN简介
KNN模型3要素
KNN优缺点
KNN应用
参考文献


KNN简介

KNN思想

给定一个训练集

T={(x1,y1),(x2,y2),...,(xN,yN)}

T

=

{

(

x

1

,

y

1

)

,

(

x

2

,

y

2

)

,

.

.

.

,

(

x

N

,

y

N

)

}

,对新输入的实例

x

x

,在训练集中找到与实例 xx 最近的k个实例,根据k个实例中大多数类所属的类作为实例

x

x

<script type="math/tex" id="MathJax-Element-4">x</script> 所属的类。

KNN算法

KNN模型3要素

K值得选择、距离度量方法选择、分类决策规则选择 

K值得选择

应用中,一般选择较小的k值,交叉验证可以选择最优的k值。

k值减小,模型变复杂,容易过拟合(原因:选择较小k值时,近似误差减小,估计误差增大)。

近似误差:即对现有训练集的训练误差,更关注“训练”。

估计误差:即对测试集的测试误差,更关注“测试”。

距离度量方法选择

欧氏距离

曼哈顿距离

切比雪夫距离
等等

分类决策规则选择

最常用的是,大多数原则:即由输入实例的k个近邻样本中大多数的类别确定输入实例的类。

KNN优缺点

优点

简单、精度高

缺点

计算时间、空间复杂度高

KNN应用

使用knn算法识别手写数字数据集,链接:https://pan.baidu.com/s/1rgiGBLTMiybCCSUnzR1lYw 密码:yse7

# -*-coding:utf-8-*-

from numpy import *
import operator
from os import listdir def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0] # shape[0]读取矩阵第一维的长度
diffMat = tile(inX, (dataSetSize, 1)) - dataSet # numpy.tile(A,B)函数重复A, B次
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
#print(type(distances))
sortedDistIndicies = distances.argsort() classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0] def img2vector(filename):
returnVect = zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32*i + j] = int(lineStr[j])
return returnVect def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('digits/trainingDigits') # 加载训练集
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] # 提取文件名
classNumStr = int(fileStr.split('_')[0]) # 提取类别标签
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('digits/trainingDigits/%s' % fileNameStr)
testFileList = listdir('digits/testDigits') # 加载测试集
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0
print ("\nthe total number of errors is: %d" % errorCount)
print ("\nthe total error rate is: %f" % (errorCount/float(mTest))) if __name__ == '__main__':
handwritingClassTest()

程序运行结果:

参考文献

[1]李航. 统计学习方法[M]. 清华大学出版社, 2012.
[2]Peter Harrington. 机器学习实战[M]. 人民邮电出版社, 2013.

K近邻算法(k-nearest neighbor, kNN)的相关教程结束。

《K近邻算法(k-nearest neighbor, kNN).doc》

下载本文的Word格式文档,以方便收藏与打印。