1 KD-Tree
При реализации алгоритма kNN самый простой способ реализовать его — это линейное сканирование, как мы представили в предыдущей главе ->Алгоритм K-ближайшего соседа, необходимо рассчитать расстояние между входным экземпляром и каждой обучающей выборкой. Когда тренировочный набор большой, это может занять очень много времени.
Чтобы повысить эффективность поиска kNN, вы можете рассмотреть возможность использования специальной структуры для хранения обучающих данных, чтобы уменьшить количество вычислений расстояния, и KD-Tree является одним из них.
Дерево kd представляет собой бинарную древовидную структуру, которая эквивалентна непрерывному делению k-мерного пространства вертикальными линиями для формирования ряда k-мерных гиперпрямоугольных областей.
2 Как построить KD-дерево
2.1 Алгоритм KD-Tree выглядит следующим образом:
K-мерный набор пространственных данныхв
-
построить корневой узел выберитеявляется осью координат, и все экземпляры в T представленыКоординаты срединные, вертикальныеОсь разрезается на два прямоугольника, и корневой узел порождает два левых и правых дочерних узла с глубиной 1: координаты левого дочернего узла меньше точки разделения, а координаты правого дочернего узла больше координаты точки разделения.
-
Повторите: для узлов с глубиной j выберитеось деления,, используйте этот узел, чтобы снова разделить прямоугольную область на две подобласти.
-
Он останавливается до тех пор, пока два субрегиона не ослабеют, образуя региональное подразделение KD-Tree.
2.2 Пример построения KD-дерева
Сначала случайным образом сгенерируйте 13 точек в наборе данных в качестве нашего набора данных.
Сначала делим по x-координате, выбираем середину x-координаты, и получаем координаты самого корневого узла
И пространство делится по координате x точки, все данные, координата x которых меньше 6,27, используются для построения левой ветви, а точки, координата x которых больше 6,27, используются для построения левой ветви. правая ветвь.
на следующем шаге, В соответствии с осью y левая и правая стороны делятся в соответствии с порядком оси y, а срединная точка записывается в узлах левой и правой ветвей. Получите следующее дерево, x слева означает, что узлы этого слоя разделены по оси x.
Разделение пространства происходит следующим образом.
на следующем шаге, соответствующие оси x, поэтому следующие элементы будут отсортированы и разделены в соответствии с координатой x.
В конце в каждой части остается только одна точка, запишите их в нижнем узле. Поскольку незарегистрированных точек больше нет, дальнейшая сегментация не выполняется.
На этом построение дерева kd завершено.
2.3 Построение кода
class Node:
def __init__(self, data, depth=0, lchild=None, rchild=None):
self.data = data # 此结点
self.depth = depth # 树的深度
self.lchild = lchild # 左子结点
self.rchild = rchild # 右子节点
class KdTree:
def __init__(self):
self.KdTree = None
self.n = 0
self.nearest = None
def create(self, dataSet, depth=0):
"""KD-Tree创建过程"""
if len(dataSet) > 0:
m, n = np.shape(dataSet)
self.n = n - 1
# 按照哪个维度进行分割,比如0:x轴,1:y轴
axis = depth % self.n
# 中位数
mid = int(m / 2)
# 按照第几个维度(列)进行排序
dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
# KD结点为中位数的结点,树深度为depth
node = Node(dataSetcopy[mid], depth)
if depth == 0:
self.KdTree = node
# 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度
node.lchild = self.create(dataSetcopy[:mid], depth+1)
node.rchild = self.create(dataSetcopy[mid+1:], depth+1)
return node
return None
3 Поиск KD-дерева
Вход: построенное дерево kd, целевая точка x Выход: набор из k ближайших соседей x ближайших[ ]
3.1 Поиск ближайшего соседа KD-дерева
-
Начиная с корневого узла, рекурсивно посещайте KD-дерево вниз.Если текущий размер целевой точки x меньше координаты точки разделения, перейдите к левому дочернему узлу, в противном случае - к правому дочернему узлу, пока дочерний узел является листовым узлом.
-
Вставьте этот листовой узел в качестве ближайшей соседней точки и вставьте его в ближайший[ ]
-
Рекурсивно создайте резервную копию, выполните следующие операции на этом узле:
- a Если узел находится ближе, чем точка в ближайшем [ ], замените точку с наибольшим расстоянием в ближайшем [ ].
- b Расстояние по вертикали между целевой точкой и разделительной линией этого узла равно d. Считается, что точка с наибольшим расстоянием в ближайшем [ ] сравнивается с d. Если она больше d, это означает, что существует вероятность того, что площадь по другую сторону от d больше, чем у ближайшего [ ].Расстояние мало, поэтому вам нужно посмотреть на расстояние между левым и правым дочерними узлами d. Если точка с наибольшим расстоянием в ближайшем [ ] меньше d, это означает, что расстояние между точками на другой стороне и целевой точкой больше d, поэтому искать не нужно, а продолжать идти назад вверх.
- При возврате к корневому узлу поиск завершается, и k точек в последнем ближайшем [ ] являются ближайшими соседями x.
3.2 Временная сложность
Средняя временная сложность KD-Tree составляет, а N — количество обучающих выборок.
KD-Tree используется для поиска k ближайших соседей, когда количество обучающих выборок намного превышает контрольную размерность. Когда пространственное измерение приближается к количеству обучающих выборок, его эффективность быстро падает, почти приближаясь к линейной развертке.
3.3 Пример описания
Предполагая, что точка, которую мы хотим запросить, равна p=(−1,−5), а функция расстояния является обычным расстоянием, мы хотим найти k=3 точек, ближайших к проблемной точке. следующее:
Сначала мы следуем построенному KD-дереву, начиная с корневого узла.
Сравните с осью x этого узла,
Ось x p меньше. Итак, ищем левую ветку:
На этот раз нам нужно сравнить ось Y
Значение y p меньше, поэтому поиск выполняется в левой ветви:
Этот узел имеет только одну подветвь, поэтому нет необходимости в сравнении. Это находит листовой узел (-4,6,-10,55).
Синие точки на 2D-карте
На этом этапе нам нужно выполнить второй шаг, вставить текущий узел в ближайший[ ] и записать L=[(−4,6,−10,55)]. Посещенные узлы отображаются зачеркнутыми на двоичном дереве.
Затем выполните третий шаг, а не самый верхний узел. Я отступаю. Выше (−6,88, −5,4).
Выполните 3a, потому что мы записали только одну точку, которая меньше k=3, поэтому текущий узел также записывается и вставляется в ближайший[ ] набор, есть L=[(−4.6,−10.55),( −6,88, − 5,4)]. Т.к. левая ветвь текущего узла пуста, то она сразу пропускается.Как дальше идти назад, а не наверх, и лезть вверх по другому участку.
Поскольку трех точек по-прежнему недостаточно, текущая точка также вставляется в ближайшее[ ] множество с L=[(−4,6,−10,55),(−6,88,−5,4),(1,24,−2,86)]. Конечно, текущий узел становится посещаемым.
В это время обнаруживается, что текущий узел имеет другие ветви, выполните 3b и вычислите, что расстояние между точкой p и тремя точками в L равно 6,62, 5,89, 3,10, но расстояние между p и разделительной линией текущий узел всего 2,14, Меньше максимального расстояния от L:
Поэтому на другом конце разделительной линии могут быть более близкие точки. Итак, мы выполняем шаг 1 с самого начала на другой ветке текущего узла. Хорошо, вот мы и у красной линии:
В этой точке ось x разделена, поэтому используйте p для сравнения x-координаты этого узла:
Координата x p больше, поэтому исследуйте правую ветвь (1.75, 12.26) и обнаружите, что правая ветвь уже является самым нижним узлом, выполните шаги 2 и 3a.
Расстояние между (1,75,12,26) и p рассчитано равным 17,48, что больше, чем расстояние между p и L, поэтому мы не указываем его в записи.
Затем вернитесь назад, определите, что это не верхний узел, и поднимитесь наверх.
Выполните 3а Расстояние между этим узлом и p равно 4,91, что меньше максимального расстояния между p и L, равного 6,62.
Поэтому мы заменяем тот в L, который находится дальше всего от p (−4,6, −10,55), этим новым узлом.
Затем 3b мы сравниваем расстояние между p и разделительной линией текущего узла
Это расстояние меньше максимального расстояния между L и p, поэтому мы хотим перейти на другую ветвь текущего узла и выполнить шаг 1. Конечно, эта ветвь имеет только одну точку.
Вычисление расстояния показывает, что эта точка находится дальше от p, чем L, поэтому замена не производится.
Потом возвращаемся, не корневой узел, лезем наверх
Этот уже был посещен, так что лезьте снова
подняться снова
это вершина. Так все кончено? Конечно нет, но и реализовать 3б. Теперь очередь шага 1.
Мы выполняем вычислительное сравнение и обнаруживаем, что верхний узел находится дальше от p, чем L, поэтому обновление не выполняется.
Затем вычислите расстояние между p и разделительной линией и найдите, что оно также дальше.
Так что нет необходимости проверять еще одну ветку.
Считается, что текущий узел является вершиной, поэтому расчет завершен! Три выборки, которые выводят самое близкое расстояние к p: L=[(−6,88, −5,4), (1,24, −2,86), (−2,96, −2,5)].
3.3 Код
def search(self, x, count=1):
"""KD-Tree的搜索"""
nearest = [] # 记录近邻点的集合
for i in range(count):
nearest.append([-1, None])
self.nearest = np.array(nearest)
def recurve(node):
"""内方法,负责查找count个近邻点"""
if node is not None:
# 步骤1:怎么找叶子节点
# 在哪个维度的分割线,0,1,0,1表示x,y,x,y
axis = node.depth % self.n
# 判断往左走or右走,递归,找到叶子结点
daxis = x[axis] - node.data[axis]
if daxis < 0:
recurve(node.lchild)
else:
recurve(node.rchild)
# 步骤2:满足的就插入到近邻点集合中
# 求test点与此点的距离
dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data)))
# 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的
for i, d in enumerate(self.nearest):
if d[0] < 0 or dist < d[0]:
self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
self.nearest = self.nearest[:-1]
break
# 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧
n = list(self.nearest[:, 0]).count(-1)
# -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离)
# 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点
if self.nearest[-n-1, 0] > abs(daxis):
# 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点
if daxis < 0:
recurve(node.rchild)
else:
recurve(node.lchild)
recurve(self.KdTree) # 调用根节点,开始查找
knn = self.nearest[:, 1] # knn为k个近邻结点
belong = [] # 记录k个近邻结点的分类
for i in knn:
belong.append(i.data[-1])
b = max(set(belong), key=belong.count) # 找到测试点所属的分类
return self.nearest, b
4 Общий код
import numpy as np
from math import sqrt
import pandas as pd
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
class Node:
def __init__(self, data, depth=0, lchild=None, rchild=None):
self.data = data # 此结点
self.depth = depth # 树的深度
self.lchild = lchild # 左子结点
self.rchild = rchild # 右子节点
class KdTree:
def __init__(self):
self.KdTree = None
self.n = 0
self.nearest = None
def create(self, dataSet, depth=0):
"""KD-Tree创建过程"""
if len(dataSet) > 0:
m, n = np.shape(dataSet)
self.n = n - 1
# 按照哪个维度进行分割,比如0:x轴,1:y轴
axis = depth % self.n
# 中位数
mid = int(m / 2)
# 按照第几个维度(列)进行排序
dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
# KD结点为中位数的结点,树深度为depth
node = Node(dataSetcopy[mid], depth)
if depth == 0:
self.KdTree = node
# 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度
node.lchild = self.create(dataSetcopy[:mid], depth+1)
node.rchild = self.create(dataSetcopy[mid+1:], depth+1)
return node
return None
def preOrder(self, node):
"""遍历KD-Tree"""
if node is not None:
print(node.depth, node.data)
self.preOrder(node.lchild)
self.preOrder(node.rchild)
def search(self, x, count=1):
"""KD-Tree的搜索"""
nearest = [] # 记录近邻点的集合
for i in range(count):
nearest.append([-1, None])
self.nearest = np.array(nearest)
def recurve(node):
"""内方法,负责查找count个近邻点"""
if node is not None:
# 步骤1:怎么找叶子节点
# 在哪个维度的分割线,0,1,0,1表示x,y,x,y
axis = node.depth % self.n
# 判断往左走or右走,递归,找到叶子结点
daxis = x[axis] - node.data[axis]
if daxis < 0:
recurve(node.lchild)
else:
recurve(node.rchild)
# 步骤2:满足的就插入到近邻点集合中
# 求test点与此点的距离
dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data)))
# 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的
for i, d in enumerate(self.nearest):
if d[0] < 0 or dist < d[0]:
self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
self.nearest = self.nearest[:-1]
break
# 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧
n = list(self.nearest[:, 0]).count(-1)
# -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离)
# 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点
if self.nearest[-n-1, 0] > abs(daxis):
# 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点
if daxis < 0:
recurve(node.rchild)
else:
recurve(node.lchild)
recurve(self.KdTree) # 调用根节点,开始查找
knn = self.nearest[:, 1] # knn为k个近邻结点
belong = [] # 记录k个近邻结点的分类
for i in knn:
belong.append(i.data[-1])
b = max(set(belong), key=belong.count) # 找到测试点所属的分类
return self.nearest, b
def show_train():
plt.scatter(x0[:, 0], x0[:, 1], c='pink', label='[0]')
plt.scatter(x1[:, 0], x1[:, 1], c='orange', label='[1]')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
if __name__ == "__main__":
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
data = np.array(df.iloc[:100, [0, 1, -1]])
train, test = train_test_split(data, test_size=0.1)
x0 = np.array([x0 for i, x0 in enumerate(train) if train[i][-1] == 0])
x1 = np.array([x1 for i, x1 in enumerate(train) if train[i][-1] == 1])
kdt = KdTree()
kdt.create(train)
kdt.preOrder(kdt.KdTree)
score = 0
for x in test:
show_train()
plt.scatter(x[0], x[1], c='red', marker='x') # 测试点
near, belong = kdt.search(x[:-1], 5) # 设置临近点的个数
if belong == x[-1]:
score += 1
print(x, "predict:", belong)
print("nearest:")
for n in near:
print(n[1].data, "dist:", n[0])
plt.scatter(n[1].data[0], n[1].data[1], c='green', marker='+') # k个最近邻点
plt.legend()
plt.show()
score /= len(test)
print("score:", score)
Отказ от ответственности: эта статья является моими учебными заметками, см.:zhuanlan.zhihu.com/p/23966698
Если вы сочтете это полезным, пожалуйста, обратите внимание на мой официальный аккаунт. Время от времени я буду публиковать свои собственные учебные заметки, материалы по ИИ и идеи. Пожалуйста, оставьте сообщение и исследуйте путь ИИ вместе с вами.