Life's monolog

用python实现k近邻、kd树搜索

Word count: 1,713 / Reading time: 7 min
2018/04/22 Share

简介

k近邻算法是一种非常简单、直观的算法:给定一个数据集,对于新的输入实例,在训练数据集中找到与该实例最近的k个实例,这k个实例的多数属于某个类,就把该新的输入实例分为这个类。

这里主要涉及到两个问题:一是实例间的距离度量,二是如何找到所有训练数据集中最靠近新的输入实例的k个实例。常用的距离度量有欧氏距离(Euclidean Distance)和Minkowski距离等,本文实现的k近邻算法就采用了欧氏距离。至于如何找到靠近输入实例最近的k个实例,一种比较朴素的实现方式是对整个训练数据集进行线性扫描,维护一个大小为k的优先队列,保存距离输入实例最近的k个实例点,也俗称暴力搜索,但是当训练集非常大时,计算非常耗时,这种方法就有点捉襟见肘了。为了提高搜索的效率,可以使用特殊的数据结构来存储训练数据,以减少搜索时间,常见的有ball-tree、kd-tree等,本文通过实现kd-tree来加快k近邻的搜索。

kd-tree的构建与搜索

假设训练数据中的样本属于k维空间,kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是二叉树,表示对k维空间的一个划分,构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分,构成一系列的k维超矩形区域。kd树的每一个节点对应于一个k维超矩形区域。

构造kd树的算法如下:

假设输入的训练集为T,训练集中每个实例的维度是n;

构造树节点:假设当前树节点的深度为depth, 以样本的第(depth % n)维为坐标轴,以T中所有实例的第(depth % n)维的中位数为切分点,将训练集分为t1和t2两部分。其中t1是切分点左边的数据集,t2是切分点右边的数据集。将切分点保存在当前树节点,用数据集t1递归构造当前节点的左节点,用数据集t2递归构造当前节点的右节点。当输入的数据集为空时递归终止。

kd树的k近邻搜索算法如下:

假设要搜索实例点x的k近邻,结果保存在L中;

  1. 在kd树中找到距离x最近的叶节点:从根节点出发,递归向下访问kd树,如果x当前维的坐标小于切分点的坐标,则向左子树搜索,反之向右子树搜索,直到到达叶节点为止;
  2. 以当前叶节点为起点,开始递归向上搜索,直到到达根节点。对每个节点进行以下操作:

    (a) 如果L中不足k个实例点,或者当前树节点保存的实例与x的距离小于L中的最大距离,用当前树节点保存的实例点替换L中距离最大的点;

    (b) 计算x与当前切分轴的距离,如果此距离小于L中的最大距离,则当前树节点的另一个子节点区域可能存在更近的点,对另一个子节点递归调用k近邻搜索算法。因为是向上回退搜索,如果上一步是从左节点退到父节点,就应该对右节点进行递归搜索,另一种情况同理。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class KNN:
def __init__(self, k=3):
self.k = k
self.neighbors = [] # 用于保存kd树搜索过程中k个最近邻居的标签

def _vote(self, neighbors):
"""投票算法,选取k个邻居中出现次数最多的类别"""
counts = np.bincount(neighbors.astype('int'))
return counts.argmax()

def _build_kd_tree(self, data, depth):
"""建立kd树
Args:
data: 需要训练的数据
depth: 当前建立的数的深度
Returns:
kd树的节点
"""
data = np.array(data)
n_samples = data.shape[0]
if n_samples == 0:
return None
else:
n_features = data.shape[1] - 1
current_node = dict()
current_node['split_axis'] = depth % n_features + 1

data = sorted(data, key=lambda x: x[current_node['split_axis']])
split_idx = n_samples // 2
current_node['data'] = data[split_idx]
current_node['left'] = self._build_kd_tree(data[:split_idx], depth + 1)
current_node['right'] = self._build_kd_tree(data[split_idx+1:], depth + 1)
return current_node

def _search_tree(self, root, x):
"""搜索kd树
Args:
root: 开始搜索的树节点
x: 需要判定类别的样本
"""
if root is None:
return

split_axis = root['split_axis']
if x[split_axis] < root['data'][split_axis]:
self._search_tree(root['left'], x)
else:
self._search_tree(root['right'], x)

heapq.heappush(self.neighbors, (-1 * euclidean_distance(x, root['data'][1:]), next(counter), root['data']))
if len(self.neighbors) > self.k:
heapq.heappop(self.neighbors)

split_dist = abs(x[split_axis] - root['data'][split_axis])
neighbor_max = -1 * heapq.nsmallest(1, self.neighbors)[0][0]
if split_dist > neighbor_max:
return

if x[split_axis] < root['data'][split_axis]:
self._search_tree(root['right'], x)
else:
self._search_tree(root['left'], x)

@run_time
def predict(self, X_train, y_train, X_test, kd_tree=False):
"""用训练集来预测测试集
Args:
X_train: 训练特征数据
y_train: 训练标签数据
X_test: 测试特征数据
kd_tree: True代表使用kd树搜索,False代表使用线性扫描
Returns
pred: 针对X_test的预测结果
"""
pred = np.empty(X_test.shape[0])
if kd_tree:
data = np.insert(X_train, 0, y_train, axis=1)
n_features = np.array(data).shape[1]
self.split_order = random.sample(range(1, n_features), n_features - 1)
root = self._build_kd_tree(data, 0)
for i, sample in enumerate(X_test):
# print(f'processing sample {i + 1} / {len(X_test)}')
self.neighbors.clear()
self._search_tree(root, sample)
neighbors = np.array([x[0] for d, c, x in self.neighbors])
pred[i] = self._vote(neighbors)
else:
for i, sample in enumerate(X_test):
# print(f'processing sample {i + 1} / {len(X_test)}')
idx = np.argsort([euclidean_distance(x, sample) for x in X_train])[:self.k]
neighbors = np.array([y_train[j] for j in idx])
pred[i] = self._vote(neighbors)
return pred

这里还是采用sklearn自带的digits数据集来测试我们的k近邻算法,分别用暴力搜索和kd树的方式来测试以上的实现,运行结果如下图。
kd_result

可以看到kd树的实现方式略占优势,数据集更大的话,kd树应该能更快。完整代码在这里

总结

k近邻算法虽然简单,但是在实现过程中还是学到了些知识和技巧,例如python中的heapq优先队列是最小堆,如果想要最大堆则将队列中元素的值乘以-1即可;再比如用cProfile来分析程序的性能,发现自己实现的欧式距离计算耗时很长,更换为numpy自带的实现后程序运行速度有了很大提升。

CATALOG
  1. 1. 简介
  2. 2. kd-tree的构建与搜索
  3. 3. 代码实现
  4. 4. 总结