K 近邻法(K-Nearest Neighbor, K-NN)
文章目錄
- 1. k近鄰算法
- 2. k近鄰模型
- 2.1 模型
- 2.2 距離度量
- 2.2.1 距離計(jì)算代碼 Python
- 2.3 kkk 值的選擇
- 2.4 分類決策規(guī)則
- 3. 實(shí)現(xiàn)方法, kd樹(shù)
- 3.1 構(gòu)造 kdkdkd 樹(shù)
- Python 代碼
- 3.2 搜索 kdkdkd 樹(shù)
- Python 代碼
- 4. 鳶尾花KNN分類
- 4.1 KNN實(shí)現(xiàn)
- 4.2 sklearn KNN
- 5. 文章完整代碼
k近鄰法(k-nearest neighbor,k-NN)是一種基本分類與回歸方法。
- 輸入:實(shí)例的特征向量,對(duì)應(yīng)于特征空間的點(diǎn)
- 輸出:實(shí)例的類別,可以取多類
- 假設(shè):給定一個(gè)訓(xùn)練數(shù)據(jù)集,其中的實(shí)例類別已定。
- 分類:對(duì)新的實(shí)例,根據(jù)其k個(gè)最近鄰的訓(xùn)練實(shí)例的類別,通過(guò)多數(shù)表決等方式進(jìn)行預(yù)測(cè)。因此,k近鄰法不具有顯式的學(xué)習(xí)過(guò)程。
- k近鄰法實(shí)際上利用訓(xùn)練數(shù)據(jù)集對(duì)特征向量空間進(jìn)行劃分,并作為其分類的“模型”。
k近鄰法1968年由Cover和Hart提出。
1. k近鄰算法
輸入:一組訓(xùn)練數(shù)據(jù)集,特征向量 xix_ixi?,及其類別 yiy_iyi?,給定實(shí)例特征向量 xxx
輸出:實(shí)例 xxx 所屬的類 yyy
y=arg?max?cj∑xi∈Nk(x)I(yi=cj),i=1,2,...,N,j=1,2,...,Ky = \argmax\limits_{c_j} \sum\limits_{x_i \in N_k(x) } I(y_i = c_j),\quad i=1,2,...,N, j = 1,2,...,Ky=cj?argmax?xi?∈Nk?(x)∑?I(yi?=cj?),i=1,2,...,N,j=1,2,...,K
III 為指示函數(shù),表示當(dāng) yi=cjy_i=c_jyi?=cj? 時(shí) III 為 1, 否則 III 為 0
當(dāng) k=1k=1k=1 時(shí),特殊情況,稱為最近鄰算法,跟它距離最近的點(diǎn)作為其分類
2. k近鄰模型
三要素:k值的選擇、距離度量、分類決策規(guī)則
2.1 模型
- kkk 近鄰模型,三要素確定后,對(duì)于任何一個(gè)新的輸入實(shí)例,它的類唯一確定。
- 這相當(dāng)于根據(jù)上述要素將特征空間劃分為一些子空間,確定子空間里的每個(gè)點(diǎn)所屬的類。這一事實(shí)從最近鄰算法中可以看得很清楚。
2.2 距離度量
空間中兩個(gè)點(diǎn)的距離是兩個(gè)實(shí)例相似程度的反映。
- LpL_pLp? 距離:
設(shè)特征 xix_ixi? 是 nnn 維的,Lp(xi,xj)=(∑l=1n∣xi(l)?xj(l)∣p)1pL_p(x_i,x_j) = \bigg(\sum\limits_{l=1}^n |x_i^{(l)}-x_j^{(l)}|^p \bigg)^{\frac{1}{p}}Lp?(xi?,xj?)=(l=1∑n?∣xi(l)??xj(l)?∣p)p1? - 歐氏距離:上面 p=2p=2p=2 時(shí),L2(xi,xj)=(∑l=1n∣xi(l)?xj(l)∣2)12L_2(x_i,x_j) = \bigg(\sum\limits_{l=1}^n |x_i^{(l)}-x_j^{(l)}|^2 \bigg)^{\frac{1}{2}}L2?(xi?,xj?)=(l=1∑n?∣xi(l)??xj(l)?∣2)21?
- 曼哈頓距離:上面 p=1p=1p=1 時(shí),L1(xi,xj)=∑l=1n∣xi(l)?xj(l)∣L_1(x_i,x_j) = \sum\limits_{l=1}^n |x_i^{(l)}-x_j^{(l)}|L1?(xi?,xj?)=l=1∑n?∣xi(l)??xj(l)?∣
- 切比雪夫距離:當(dāng) p=∞p=\inftyp=∞ 時(shí),它是坐標(biāo)距離的最大值:L∞(xi,xj)=max?l∣xi(l)?xj(l)∣L_\infty(x_i,x_j) = \max\limits_l |x_i^{(l)}-x_j^{(l)}|L∞?(xi?,xj?)=lmax?∣xi(l)??xj(l)?∣
2.2.1 距離計(jì)算代碼 Python
import mathdef L_p(xi, xj, p=2):if len(xi) == len(xj) and len(xi) > 0:sum = 0for i in range(len(xi)):sum += math.pow(abs(xi[i] - xj[i]), p)return math.pow(sum, 1 / p)else:return 0 x1 = [1, 1] x2 = [5, 1] x3 = [4, 4] X = [x1, x2, x3] for i in range(len(X)):for j in range(i + 1, len(X)):for p in range(1, 5):print("x%d,x%d的L%d距離是:%.2f" % (i + 1, j + 1, p, L_p(X[i], X[j], p))) x1,x2的L1距離是:4.00 x1,x2的L2距離是:4.00 x1,x2的L3距離是:4.00 x1,x2的L4距離是:4.00 x1,x3的L1距離是:6.00 x1,x3的L2距離是:4.24 x1,x3的L3距離是:3.78 x1,x3的L4距離是:3.57 x2,x3的L1距離是:4.00 x2,x3的L2距離是:3.16 x2,x3的L3距離是:3.04 x2,x3的L4距離是:3.012.3 kkk 值的選擇
-
k值的選擇會(huì)對(duì)k近鄰法的結(jié)果產(chǎn)生重大影響。
-
選較小的 k 值,相當(dāng)于用較小的鄰域中的訓(xùn)練實(shí)例進(jìn)行預(yù)測(cè),“學(xué)習(xí)”的近似誤差(approximation error)會(huì)減小,只有與輸入實(shí)例較近的(相似的)訓(xùn)練實(shí)例才會(huì)對(duì)預(yù)測(cè)結(jié)果起作用。但缺點(diǎn)是“學(xué)習(xí)”的估計(jì)誤差(estimation error)會(huì)增大,預(yù)測(cè)結(jié)果會(huì)對(duì)近鄰的實(shí)例點(diǎn)非常敏感。
-
如果鄰近的實(shí)例點(diǎn)恰巧是噪聲,預(yù)測(cè)就會(huì)出錯(cuò)。換句話說(shuō),k值的減小就意味著整體模型變得復(fù)雜,容易發(fā)生過(guò)擬合。
-
選較大的 k 值,相當(dāng)于用較大鄰域中的訓(xùn)練實(shí)例進(jìn)行預(yù)測(cè)。優(yōu)點(diǎn)是可以減少學(xué)習(xí)的估計(jì)誤差,但缺點(diǎn)是學(xué)習(xí)的近似誤差會(huì)增大。這時(shí)與輸入實(shí)例較遠(yuǎn)的(不相似的)訓(xùn)練實(shí)例也會(huì)對(duì)預(yù)測(cè)起作用,使預(yù)測(cè)發(fā)生錯(cuò)誤。
-
k值的增大就意味著整體的模型變得簡(jiǎn)單。
-
如果 k=N,無(wú)論輸入實(shí)例是什么,都將簡(jiǎn)單地預(yù)測(cè)它屬于在訓(xùn)練實(shí)例中最多的類。模型過(guò)于簡(jiǎn)單,完全忽略大量有用信息,不可取。
-
應(yīng)用中,k 值一般取一個(gè)比較小的數(shù)值。通常采用交叉驗(yàn)證法來(lái)選取最優(yōu)的 k 值。
2.4 分類決策規(guī)則
- 多數(shù)表決(majority voting rule)
假設(shè)損失函數(shù)為0-1損失,對(duì)于 xix_ixi? 的近鄰域 Nk(x)N_k(x)Nk?(x) 的分類是 cjc_jcj?,那么誤分類率是:
1k∑xi∈Nk(x)I(yi≠cj)=1?1k∑xi∈Nk(x)I(yi=cj)\frac{1}{k} \sum\limits_{x_i \in N_k(x) }I(y_i \neq c_j) = 1- \frac{1}{k}\sum\limits_{x_i \in N_k(x) } I(y_i = c_j)k1?xi?∈Nk?(x)∑?I(yi??=cj?)=1?k1?xi?∈Nk?(x)∑?I(yi?=cj?)
要使誤分類率最小,那么就讓 ∑xi∈Nk(x)I(yi=cj)\sum\limits_{x_i \in N_k(x) } I(y_i = c_j)xi?∈Nk?(x)∑?I(yi?=cj?) 最大,所以選多數(shù)的那個(gè)類(經(jīng)驗(yàn)風(fēng)險(xiǎn)最小化)
3. 實(shí)現(xiàn)方法, kd樹(shù)
-
算法實(shí)現(xiàn)時(shí),需要對(duì)大量的點(diǎn)進(jìn)行距離計(jì)算,復(fù)雜度是 O(n2)O(n^2)O(n2),訓(xùn)練集很大時(shí),效率低,不可取
-
考慮特殊的結(jié)構(gòu)存儲(chǔ)訓(xùn)練數(shù)據(jù),以減少計(jì)算距離次數(shù),如 kdkdkd 樹(shù)
3.1 構(gòu)造 kdkdkd 樹(shù)
kdkdkd 樹(shù)是一種對(duì) k 維空間中的實(shí)例點(diǎn)進(jìn)行存儲(chǔ)以便對(duì)其進(jìn)行快速檢索的樹(shù)形數(shù)據(jù)結(jié)構(gòu)。
- kdkdkd 樹(shù)是二叉樹(shù),表示對(duì)k維空間的一個(gè)劃分(partition)。
- 構(gòu)造 kdkdkd 樹(shù)相當(dāng)于不斷地用垂直于坐標(biāo)軸的超平面將 k 維空間切分,構(gòu)成一系列的k維超矩形區(qū)域。
- kdkdkd 樹(shù)的每個(gè)結(jié)點(diǎn)對(duì)應(yīng)于一個(gè) k 維超矩形區(qū)域。
構(gòu)造 kdkdkd 樹(shù)的方法:
- 根結(jié)點(diǎn):使根結(jié)點(diǎn)對(duì)應(yīng)于k維空間中包含所有實(shí)例點(diǎn)的超矩形區(qū)域;通過(guò)遞歸方法,不斷地對(duì) k 維空間進(jìn)行切分,生成子結(jié)點(diǎn)
- 在超矩形區(qū)域(結(jié)點(diǎn))上選擇一個(gè)坐標(biāo)軸和在此坐標(biāo)軸上的一個(gè)切分點(diǎn),確定一個(gè)超平面,將當(dāng)前超矩形區(qū)域切分為左右兩個(gè)子區(qū)域(子結(jié)點(diǎn))
- 實(shí)例被分到兩個(gè)子區(qū)域。這個(gè)過(guò)程直到子區(qū)域內(nèi)沒(méi)有實(shí)例時(shí)終止(終止時(shí)的結(jié)點(diǎn)為葉結(jié)點(diǎn))。在此過(guò)程中,將實(shí)例保存在相應(yīng)的結(jié)點(diǎn)上。
Python 代碼
class KdNode():def __init__(self, dom_elt, split, left, right):self.dom_elt = dom_elt # k維向量節(jié)點(diǎn)(k維空間中的一個(gè)樣本點(diǎn))self.split = split # 整數(shù)(進(jìn)行分割維度的序號(hào))self.left = left # 該結(jié)點(diǎn)分割超平面左子空間構(gòu)成的kd-treeself.right = right # 該結(jié)點(diǎn)分割超平面右子空間構(gòu)成的kd-treeclass KdTree():def __init__(self, data):k = len(data[0]) # 實(shí)例的向量維度def CreatNode(split, data_set):if not data_set:return Nonedata_set.sort(key=lambda x: x[split])split_pos = len(data_set) // 2 # 整除median = data_set[split_pos]split_next = (split + 1) % kreturn KdNode(median, split,CreatNode(split_next, data_set[:split_pos]),CreatNode(split_next, data_set[split_pos + 1:]))self.root = CreatNode(0, data)def preorder(self, root):if root:print(root.dom_elt)if root.left:self.preorder(root.left)if root.right:self.preorder(root.right) data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]] kd = KdTree(data) kd.preorder(kd.root)運(yùn)行結(jié)果:
[7, 2] [5, 4] [2, 3] [4, 7] [9, 6] [8, 1]3.2 搜索 kdkdkd 樹(shù)
給定目標(biāo)點(diǎn),搜索其最近鄰。
- 先找到包含目標(biāo)點(diǎn)的葉結(jié)點(diǎn)
- 從該葉結(jié)點(diǎn)出發(fā),依次回退到父結(jié)點(diǎn);不斷查找與目標(biāo)點(diǎn)最鄰近的結(jié)點(diǎn)
- 當(dāng)確定不可能存在更近的結(jié)點(diǎn)時(shí)終止。
- 這樣搜索就被限制在空間的局部區(qū)域上,效率大為提高。
- 目標(biāo)點(diǎn)的最近鄰一定在以目標(biāo)點(diǎn)為中心并通過(guò)當(dāng)前最近點(diǎn)的超球體的內(nèi)部。
- 然后返回當(dāng)前結(jié)點(diǎn)的父結(jié)點(diǎn),如果父結(jié)點(diǎn)的另一子結(jié)點(diǎn)的超矩形區(qū)域與超球體相交,那么在相交的區(qū)域內(nèi)尋找與目標(biāo)點(diǎn)更近的實(shí)例點(diǎn)。
- 如果存在這樣的點(diǎn),將此點(diǎn)作為新的當(dāng)前最近點(diǎn)。算法轉(zhuǎn)到更上一級(jí)的父結(jié)點(diǎn),繼續(xù)上述過(guò)程。
- 如果父結(jié)點(diǎn)的另一子結(jié)點(diǎn)的超矩形區(qū)域與超球體不相交,或不存在比當(dāng)前最近點(diǎn)更近的點(diǎn),則停止搜索。
Python 代碼
from collections import namedtuple# 定義一個(gè)namedtuple,分別存放最近坐標(biāo)點(diǎn)、最近距離和訪問(wèn)過(guò)的節(jié)點(diǎn)數(shù) result = namedtuple("Result_tuple","nearest_point nearest_dist nodes_visited")def find_nearest(tree, point):k = len(point) # 數(shù)據(jù)維度def travel(kd_node, target, max_dist):if kd_node is None:return result([0] * k, float("inf"), 0)# python中用float("inf")和float("-inf")表示正負(fù)無(wú)窮nodes_visited = 1s = kd_node.split # 進(jìn)行分割的維度pivot = kd_node.dom_elt # 進(jìn)行分割的“軸”if target[s] <= pivot[s]: # 如果目標(biāo)點(diǎn)第s維小于分割軸的對(duì)應(yīng)值(目標(biāo)離左子樹(shù)更近)nearer_node = kd_node.left # 下一個(gè)訪問(wèn)節(jié)點(diǎn)為左子樹(shù)根節(jié)點(diǎn)further_node = kd_node.right # 同時(shí)記錄下右子樹(shù)else: # 目標(biāo)離右子樹(shù)更近nearer_node = kd_node.right # 下一個(gè)訪問(wèn)節(jié)點(diǎn)為右子樹(shù)根節(jié)點(diǎn)further_node = kd_node.lefttemp1 = travel(nearer_node, target, max_dist) # 進(jìn)行遍歷找到包含目標(biāo)點(diǎn)的區(qū)域nearest = temp1.nearest_point # 以此葉結(jié)點(diǎn)作為“當(dāng)前最近點(diǎn)”dist = temp1.nearest_dist # 更新最近距離nodes_visited += temp1.nodes_visitedif dist < max_dist:max_dist = dist # 最近點(diǎn)將在以目標(biāo)點(diǎn)為球心,max_dist為半徑的超球體內(nèi)temp_dist = abs(pivot[s] - target[s]) # 第s維上目標(biāo)點(diǎn)與分割超平面的距離if max_dist < temp_dist: # 判斷超球體是否與超平面相交return result(nearest, dist, nodes_visited) # 不相交則可以直接返回,不用繼續(xù)判斷# ----------------------------------------------------------------------# 計(jì)算目標(biāo)點(diǎn)與分割點(diǎn)的歐氏距離p = np.array(pivot)t = np.array(target)temp_dist = np.linalg.norm(p-t)if temp_dist < dist: # 如果“更近”nearest = pivot # 更新最近點(diǎn)dist = temp_dist # 更新最近距離max_dist = dist # 更新超球體半徑# 檢查另一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否有更近的點(diǎn)temp2 = travel(further_node, target, max_dist)nodes_visited += temp2.nodes_visitedif temp2.nearest_dist < dist: # 如果另一個(gè)子結(jié)點(diǎn)內(nèi)存在更近距離nearest = temp2.nearest_point # 更新最近點(diǎn)dist = temp2.nearest_dist # 更新最近距離return result(nearest, dist, nodes_visited)return travel(tree.root, point, float("inf")) # 從根節(jié)點(diǎn)開(kāi)始遞歸 from time import time from random import randomdef random_point(k):return [random() for _ in range(k)]def random_points(k, n):return [random_point(k) for _ in range(n)]ret = find_nearest(kd, [3, 4.5]) print(ret)N = 400000 t0 = time() kd2 = KdTree(random_points(3, N))#40萬(wàn)個(gè)3維點(diǎn)(坐標(biāo)值0-1之間) ret2 = find_nearest(kd2, [0.1, 0.5, 0.8]) t1 = time() print("time: ", t1 - t0, " s") print(ret2)運(yùn)行結(jié)果:40萬(wàn)個(gè)點(diǎn),只用了4s就搜索完畢,找到最近鄰點(diǎn)
Result_tuple(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4) time: 4.314465284347534 s Result_tuple(nearest_point=[0.10186986970329936, 0.5007753108096316, 0.7998708312483109], nearest_dist=0.002028350099282986, nodes_visited=49)4. 鳶尾花KNN分類
4.1 KNN實(shí)現(xiàn)
# -*- coding:utf-8 -*- # @Python Version: 3.7 # @Time: 2020/3/2 22:44 # @Author: Michael Ming # @Website: https://michael.blog.csdn.net/ # @File: 3.KNearestNeighbors.py # @Reference: https://github.com/fengdu78/lihang-code import math import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from collections import Counterclass KNearNeighbors():def __init__(self, X_train, y_train, neighbors=3, p=2):self.n = neighborsself.p = pself.X_train = X_trainself.y_train = y_traindef predict(self, X):knn_list = []# 先在訓(xùn)練集中取n個(gè)點(diǎn)出來(lái),計(jì)算距離for i in range(self.n):dist = np.linalg.norm(X - self.X_train[i], ord=self.p)knn_list.append((dist, self.y_train[i]))# 再在剩余的訓(xùn)練集中取出剩余的,計(jì)算距離,有距離更近的,替換knn_list里最大的for i in range(self.n, len(self.X_train)):max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))dist = np.linalg.norm(X - self.X_train[i], ord=self.p)if knn_list[max_index][0] > dist:knn_list[max_index] = (dist, self.y_train[i])# 取出所有的n個(gè)最近鄰點(diǎn)的標(biāo)簽knn = [k[-1] for k in knn_list]count_pairs = Counter(knn)# 次數(shù)最多的標(biāo)簽,排序后最后一個(gè) 標(biāo)簽:出現(xiàn)次數(shù)max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]return max_countdef score(self, X_test, y_test):right_count = 0for X, y in zip(X_test, y_test): # zip 同時(shí)遍歷多個(gè)對(duì)象label = self.predict(X)if math.isclose(label, y, rel_tol=1e-5): # 浮點(diǎn)型相等判斷right_count += 1print("準(zhǔn)確率:%.4f" % (right_count / len(X_test)))return right_count / len(X_test)if __name__ == '__main__':# ---------鳶尾花K近鄰----------------iris = load_iris()df = pd.DataFrame(iris.data, columns=iris.feature_names)df['label'] = iris.targetplt.scatter(df[:50][iris.feature_names[0]], df[:50][iris.feature_names[1]], label=iris.target_names[0])plt.scatter(df[50:100][iris.feature_names[0]], df[50:100][iris.feature_names[1]], label=iris.target_names[1])plt.xlabel(iris.feature_names[0])plt.ylabel(iris.feature_names[1])data = np.array(df.iloc[:100, [0, 1, -1]]) # 取前2種花,前兩個(gè)特征X, y = data[:, :-1], data[:, -1]# 切分?jǐn)?shù)據(jù)集,留20%做測(cè)試數(shù)據(jù)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)# KNN算法,近鄰選擇20個(gè),距離度量L2距離clf = KNearNeighbors(X_train, y_train, 20, 2)# 預(yù)測(cè)測(cè)試點(diǎn),統(tǒng)計(jì)正確率clf.score(X_test, y_test)# 隨意給一個(gè)點(diǎn),用KNN預(yù)測(cè)其分類test_point = [4.75, 2.75]test_point_flower = '測(cè)試點(diǎn)' + iris.target_names[int(clf.predict(test_point))]print("測(cè)試點(diǎn)的類別是:%s" % test_point_flower)plt.plot(test_point[0], test_point[1], 'bx', label=test_point_flower)plt.rcParams['font.sans-serif'] = 'SimHei' # 消除中文亂碼plt.rcParams['axes.unicode_minus'] = False # 正常顯示負(fù)號(hào)plt.legend()plt.show() 準(zhǔn)確率:1.0000 測(cè)試點(diǎn)的類別是:測(cè)試點(diǎn)setosa4.2 sklearn KNN
sklearn.neighbors.KNeighborsClassifier
class sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, p=2, metric='minkowski', metric_params=None, n_jobs=None, **kwargs)- n_neighbors: 臨近點(diǎn)個(gè)數(shù)
- p: 距離度量
- algorithm: 近鄰算法,可選{‘a(chǎn)uto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}
- weights: 確定近鄰的權(quán)重
5. 文章完整代碼
# -*- coding:utf-8 -*- # @Python Version: 3.7 # @Time: 2020/3/2 22:44 # @Author: Michael Ming # @Website: https://michael.blog.csdn.net/ # @File: 3.KNearestNeighbors.py # @Reference: https://github.com/fengdu78/lihang-code import math import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from collections import Counter import timedef L_p(xi, xj, p=2):if len(xi) == len(xj) and len(xi) > 0:sum = 0for i in range(len(xi)):sum += math.pow(abs(xi[i] - xj[i]), p)return math.pow(sum, 1 / p)else:return 0class KNearNeighbors():def __init__(self, X_train, y_train, neighbors=3, p=2):self.n = neighborsself.p = pself.X_train = X_trainself.y_train = y_traindef predict(self, X):knn_list = []# 先在訓(xùn)練集中取n個(gè)點(diǎn)出來(lái),計(jì)算距離for i in range(self.n):dist = np.linalg.norm(X - self.X_train[i], ord=self.p)knn_list.append((dist, self.y_train[i]))# 再在剩余的訓(xùn)練集中取出剩余的,計(jì)算距離,有距離更近的,替換knn_list里最大的for i in range(self.n, len(self.X_train)):max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))dist = np.linalg.norm(X - self.X_train[i], ord=self.p)if knn_list[max_index][0] > dist:knn_list[max_index] = (dist, self.y_train[i])# 取出所有的n個(gè)最近鄰點(diǎn)的標(biāo)簽knn = [k[-1] for k in knn_list]count_pairs = Counter(knn)# 次數(shù)最多的標(biāo)簽,排序后最后一個(gè) 標(biāo)簽:出現(xiàn)次數(shù)max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]return max_countdef score(self, X_test, y_test):right_count = 0for X, y in zip(X_test, y_test): # zip 同時(shí)遍歷多個(gè)對(duì)象label = self.predict(X)if math.isclose(label, y, rel_tol=1e-5): # 浮點(diǎn)型相等判斷right_count += 1print("準(zhǔn)確率:%.4f" % (right_count / len(X_test)))return right_count / len(X_test)class KdNode():def __init__(self, dom_elt, split, left, right):self.dom_elt = dom_elt # k維向量節(jié)點(diǎn)(k維空間中的一個(gè)樣本點(diǎn))self.split = split # 整數(shù)(進(jìn)行分割維度的序號(hào))self.left = left # 該結(jié)點(diǎn)分割超平面左子空間構(gòu)成的kd-treeself.right = right # 該結(jié)點(diǎn)分割超平面右子空間構(gòu)成的kd-treeclass KdTree():def __init__(self, data):k = len(data[0]) # 實(shí)例的向量維度def CreatNode(split, data_set):if not data_set:return Nonedata_set.sort(key=lambda x: x[split])split_pos = len(data_set) // 2 # 整除median = data_set[split_pos]split_next = (split + 1) % kreturn KdNode(median, split,CreatNode(split_next, data_set[:split_pos]),CreatNode(split_next, data_set[split_pos + 1:]))self.root = CreatNode(0, data)def preorder(self, root):if root:print(root.dom_elt)if root.left:self.preorder(root.left)if root.right:self.preorder(root.right)from collections import namedtuple# 定義一個(gè)namedtuple,分別存放最近坐標(biāo)點(diǎn)、最近距離和訪問(wèn)過(guò)的節(jié)點(diǎn)數(shù) result = namedtuple("Result_tuple","nearest_point nearest_dist nodes_visited")def find_nearest(tree, point):k = len(point) # 數(shù)據(jù)維度def travel(kd_node, target, max_dist):if kd_node is None:return result([0] * k, float("inf"), 0)# python中用float("inf")和float("-inf")表示正負(fù)無(wú)窮nodes_visited = 1s = kd_node.split # 進(jìn)行分割的維度pivot = kd_node.dom_elt # 進(jìn)行分割的“軸”if target[s] <= pivot[s]: # 如果目標(biāo)點(diǎn)第s維小于分割軸的對(duì)應(yīng)值(目標(biāo)離左子樹(shù)更近)nearer_node = kd_node.left # 下一個(gè)訪問(wèn)節(jié)點(diǎn)為左子樹(shù)根節(jié)點(diǎn)further_node = kd_node.right # 同時(shí)記錄下右子樹(shù)else: # 目標(biāo)離右子樹(shù)更近nearer_node = kd_node.right # 下一個(gè)訪問(wèn)節(jié)點(diǎn)為右子樹(shù)根節(jié)點(diǎn)further_node = kd_node.lefttemp1 = travel(nearer_node, target, max_dist) # 進(jìn)行遍歷找到包含目標(biāo)點(diǎn)的區(qū)域nearest = temp1.nearest_point # 以此葉結(jié)點(diǎn)作為“當(dāng)前最近點(diǎn)”dist = temp1.nearest_dist # 更新最近距離nodes_visited += temp1.nodes_visitedif dist < max_dist:max_dist = dist # 最近點(diǎn)將在以目標(biāo)點(diǎn)為球心,max_dist為半徑的超球體內(nèi)temp_dist = abs(pivot[s] - target[s]) # 第s維上目標(biāo)點(diǎn)與分割超平面的距離if max_dist < temp_dist: # 判斷超球體是否與超平面相交return result(nearest, dist, nodes_visited) # 不相交則可以直接返回,不用繼續(xù)判斷# ----------------------------------------------------------------------# 計(jì)算目標(biāo)點(diǎn)與分割點(diǎn)的歐氏距離p = np.array(pivot)t = np.array(target)temp_dist = np.linalg.norm(p - t)if temp_dist < dist: # 如果“更近”nearest = pivot # 更新最近點(diǎn)dist = temp_dist # 更新最近距離max_dist = dist # 更新超球體半徑# 檢查另一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否有更近的點(diǎn)temp2 = travel(further_node, target, max_dist)nodes_visited += temp2.nodes_visitedif temp2.nearest_dist < dist: # 如果另一個(gè)子結(jié)點(diǎn)內(nèi)存在更近距離nearest = temp2.nearest_point # 更新最近點(diǎn)dist = temp2.nearest_dist # 更新最近距離return result(nearest, dist, nodes_visited)return travel(tree.root, point, float("inf")) # 從根節(jié)點(diǎn)開(kāi)始遞歸if __name__ == '__main__':# ---------計(jì)算距離----------------x1 = [1, 1]x2 = [5, 1]x3 = [4, 4]X = [x1, x2, x3]for i in range(len(X)):for j in range(i + 1, len(X)):for p in range(1, 5):print("x%d,x%d的L%d距離是:%.2f" % (i + 1, j + 1, p, L_p(X[i], X[j], p)))# ---------鳶尾花K近鄰----------------iris = load_iris()df = pd.DataFrame(iris.data, columns=iris.feature_names)df['label'] = iris.targetplt.scatter(df[:50][iris.feature_names[0]], df[:50][iris.feature_names[1]], label=iris.target_names[0])plt.scatter(df[50:100][iris.feature_names[0]], df[50:100][iris.feature_names[1]], label=iris.target_names[1])plt.xlabel(iris.feature_names[0])plt.ylabel(iris.feature_names[1])data = np.array(df.iloc[:100, [0, 1, -1]]) # 取前2種花,前兩個(gè)特征X, y = data[:, :-1], data[:, -1]# 切分?jǐn)?shù)據(jù)集,留20%做測(cè)試數(shù)據(jù)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)# KNN算法,近鄰選擇20個(gè),距離度量L2距離clf = KNearNeighbors(X_train, y_train, 20, 2)# 預(yù)測(cè)測(cè)試點(diǎn),統(tǒng)計(jì)正確率clf.score(X_test, y_test)# 隨意給一個(gè)點(diǎn),用KNN預(yù)測(cè)其分類test_point = [4.75, 2.75]test_point_flower = '測(cè)試點(diǎn)' + iris.target_names[int(clf.predict(test_point))]print("測(cè)試點(diǎn)的類別是:%s" % test_point_flower)plt.plot(test_point[0], test_point[1], 'bx', label=test_point_flower)plt.rcParams['font.sans-serif'] = 'SimHei' # 消除中文亂碼plt.rcParams['axes.unicode_minus'] = False # 正常顯示負(fù)號(hào)plt.legend()plt.show()# ---------sklearn KNN----------from sklearn.neighbors import KNeighborsClassifierclf_skl = KNeighborsClassifier(n_neighbors=50, p=4, algorithm='kd_tree')start = time.time()sum = 0for i in range(100):clf_skl.fit(X_train, y_train)sum += clf_skl.score(X_test, y_test)end = time.time()print("平均準(zhǔn)確率:%.4f" % (sum / 100))print("花費(fèi)時(shí)間:%0.4f ms" % (1000 * (end - start) / 100))# ------build KD Tree--------------data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]kd = KdTree(data)kd.preorder(kd.root)# ------search in KD Tree-----------from time import timefrom random import randomdef random_point(k):return [random() for _ in range(k)]def random_points(k, n):return [random_point(k) for _ in range(n)]ret = find_nearest(kd, [3, 4.5])print(ret)N = 400000t0 = time()kd2 = KdTree(random_points(3, N))ret2 = find_nearest(kd2, [0.1, 0.5, 0.8])t1 = time()print("time: ", t1 - t0, " s")print(ret2)總結(jié)
以上是生活随笔為你收集整理的K 近邻法(K-Nearest Neighbor, K-NN)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 程序员面试金典 - 面试题 17.09.
- 下一篇: LeetCode677. 键值映射(Tr