简介
k近邻算法是一种非常简单、直观的算法:给定一个数据集,对于新的输入实例,在训练数据集中找到与该实例最近的k个实例,这k个实例的多数属于某个类,就把该新的输入实例分为这个类。
这里主要涉及到两个问题:一是实例间的距离度量,二是如何找到所有训练数据集中最靠近新的输入实例的k个实例。常用的距离度量有欧氏距离(Euclidean Distance)和Minkowski距离等,本文实现的k近邻算法就采用了欧氏距离。至于如何找到靠近输入实例最近的k个实例,一种比较朴素的实现方式是对整个训练数据集进行线性扫描,维护一个大小为k的优先队列,保存距离输入实例最近的k个实例点,也俗称暴力搜索,但是当训练集非常大时,计算非常耗时,这种方法就有点捉襟见肘了。为了提高搜索的效率,可以使用特殊的数据结构来存储训练数据,以减少搜索时间,常见的有ball-tree、kd-tree等,本文通过实现kd-tree来加快k近邻的搜索。
kd-tree的构建与搜索
假设训练数据中的样本属于k维空间,kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是二叉树,表示对k维空间的一个划分,构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分,构成一系列的k维超矩形区域。kd树的每一个节点对应于一个k维超矩形区域。
构造kd树的算法如下:
假设输入的训练集为T,训练集中每个实例的维度是n;
构造树节点:假设当前树节点的深度为depth, 以样本的第(depth % n)维为坐标轴,以T中所有实例的第(depth % n)维的中位数为切分点,将训练集分为t1和t2两部分。其中t1是切分点左边的数据集,t2是切分点右边的数据集。将切分点保存在当前树节点,用数据集t1递归构造当前节点的左节点,用数据集t2递归构造当前节点的右节点。当输入的数据集为空时递归终止。
kd树的k近邻搜索算法如下:
假设要搜索实例点x的k近邻,结果保存在L中;
- 在kd树中找到距离x最近的叶节点:从根节点出发,递归向下访问kd树,如果x当前维的坐标小于切分点的坐标,则向左子树搜索,反之向右子树搜索,直到到达叶节点为止;
以当前叶节点为起点,开始递归向上搜索,直到到达根节点。对每个节点进行以下操作:
(a) 如果L中不足k个实例点,或者当前树节点保存的实例与x的距离小于L中的最大距离,用当前树节点保存的实例点替换L中距离最大的点;
(b) 计算x与当前切分轴的距离,如果此距离小于L中的最大距离,则当前树节点的另一个子节点区域可能存在更近的点,对另一个子节点递归调用k近邻搜索算法。因为是向上回退搜索,如果上一步是从左节点退到父节点,就应该对右节点进行递归搜索,另一种情况同理。
代码实现
1 | class KNN: |
这里还是采用sklearn自带的digits数据集来测试我们的k近邻算法,分别用暴力搜索和kd树的方式来测试以上的实现,运行结果如下图。
可以看到kd树的实现方式略占优势,数据集更大的话,kd树应该能更快。完整代码在这里。
总结
k近邻算法虽然简单,但是在实现过程中还是学到了些知识和技巧,例如python中的heapq优先队列是最小堆,如果想要最大堆则将队列中元素的值乘以-1即可;再比如用cProfile来分析程序的性能,发现自己实现的欧式距离计算耗时很长,更换为numpy自带的实现后程序运行速度有了很大提升。