[Stove AI] Построение модели классификатора машинного обучения 030-KNN
(Библиотеки Python и номера версий, используемые в этой статье: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2)
KNN (K-ближайших соседей) — это алгоритм, который использует обучающий набор данных K ближайших соседей для поиска неизвестной классификации объектов. Его основная основная идея была представлена в моей последней статье.
1. Подготовьте набор данных
Моя подготовка набора данных здесь включает в себя загрузку данных и визуализацию данных.Эта часть относительно проста.Она использовалась много раз в предыдущих статьях, и я могу непосредственно посмотреть на диаграмму распределения данных.
2. Создайте модель классификатора KNN.
2.1 Построение и обучение модели классификатора KNN
Метод построения модели классификатора KNN аналогичен методу SVM и RandomForest, код выглядит следующим образом:
# 构建KNN分类模型
from sklearn.neighbors import KNeighborsClassifier
K=10 # 暂定10个最近样本
KNN=KNeighborsClassifier(K,weights='distance')
KNN.fit(dataset_X,dataset_y) # 使用该数据集训练模型
Модель KNN была обучена с использованием приведенного выше набора данных, но как мы узнаем эффект обучения модели? Классификационное влияние модели классификации на набор обучающих данных показано ниже, С граничной точки зрения классификатор четко различает этот набор данных.
2.1 Используйте обученный классификатор KNN для прогнозирования новых выборок
Перейдите непосредственно к коду:
# 用训练好的KNN模型预测新样本
new_sample=np.array([[4.5,3.6]])
predicted=KNN.predict(new_sample)[0]
print("KNN predicted:{}".format(predicted))
Полученный результат равен 2, что указывает на принадлежность нового образца к классу 2.
Ниже мы наносим этот новый образец на график, чтобы увидеть, где он находится на графике.
Чтобы отобразить положение нового образца и K образцов вокруг него, я изменил приведенную выше функцию plot_classifier следующим образом:
# 为了查看新样本在原数据集中的位置,也为了查看新样本周围最近的K个样本位置,
# 我修改了上面的plot_classifier函数,如下所示:
def plot_classifier2(KNN_classifier, X, y,new_sample,K):
x_min, x_max = min(X[:, 0]) - 1.0, max(X[:, 0]) + 1.0 # 计算图中坐标的范围
y_min, y_max = min(X[:, 1]) - 1.0, max(X[:, 1]) + 1.0
step_size = 0.01 # 设置step size
x_values, y_values = np.meshgrid(np.arange(x_min, x_max, step_size),
np.arange(y_min, y_max, step_size))
# 构建网格数据
mesh_output = KNN_classifier.predict(np.c_[x_values.ravel(), y_values.ravel()])
mesh_output = mesh_output.reshape(x_values.shape)
plt.figure()
plt.pcolormesh(x_values, y_values, mesh_output, cmap=plt.cm.gray)
plt.scatter(X[:, 0], X[:, 1], c=y, s=80, edgecolors='black',
linewidth=1, cmap=plt.cm.Paired)
# 绘制新样本所在的位置
plt.scatter(new_sample[:,0],new_sample[:,1],marker='*',color='red')
# 绘制新样本周围最近的K个样本,只适用于KNN
# Extract k nearest neighbors
dist, indices = KNN_classifier.kneighbors(new_sample)
plt.scatter(dataset_X[indices][0][:][:,0],dataset_X[indices][0][:][:,1],
marker='x',s=80,color='r')
# specify the boundaries of the figure
plt.xlim(x_values.min(), x_values.max())
plt.ylim(y_values.min(), y_values.max())
# specify the ticks on the X and Y axes
plt.xticks((np.arange(int(min(X[:, 0])), int(max(X[:, 0])), 1.0)))
plt.yticks((np.arange(int(min(X[:, 1])), int(max(X[:, 1])), 1.0)))
plt.show()
После подстановки непосредственно в операцию получается результирующий граф:
Как видно из рисунка, красная пятиконечная звезда — это наша новая выборка, а красный крест — ее ближайшие K соседей. Видно, что большинство этих соседей относятся ко второй категории, поэтому новая выборка также разбивается на вторую аналогию, и результат, полученный при прогнозировании, также равен 2.
########################резюме########################## ######
1. Создать и обучить классификатор KNN очень просто, просто импортируйте KNNClassifier с помощью sklearn, а затем используйте функцию fit().
2. Классификатор KNN хранит все доступные точки данных обучающего набора.Когда необходимо спрогнозировать новую точку данных, он сначала вычисляет сходство (то есть расстояние) между новой точкой данных и всеми точками данных, хранящимися внутри, и вычисляет Сортировка, получение K точек данных с ближайшим расстоянием, а затем оценка того, к какой категории относится большинство K точек данных, к какой категории принадлежит новая точка данных. Это также объясняет, почему K обычно принимает нечетное число.Если это четное число, количество точек данных в обеих категориях одинаково, что смущает.
3. Трудность классификатора KNN заключается в том, чтобы найти наиболее подходящее значение K. Это требует повторных попыток с перекрестной проверкой. K с наибольшей точностью или скоростью отзыва используется в качестве наилучшего значения K. Этот процесс также может быть выполнен с помощью GridSearch или RandomSearch.
#################################################################
Примечание. Эта часть кода была загружена в (мой гитхаб), добро пожаловать на скачивание.
Использованная литература:
1. Классические примеры машинного обучения Python, Пратик Джоши, перевод Тао Цзюньцзе и Чена Сяоли.