机器学习基石作业一中的PLA和POCKET_PLA实现
? ? ? ? ?前提:文中使用的數(shù)據(jù)是本人下載下來(lái)以后自己處理過(guò)的,就是把文件中的所有分隔符都換成了空格。所以load_data方法只能加載我自己的數(shù)據(jù),想要加載原生數(shù)據(jù)的話需要自己寫(xiě)load_data方法。
? ? ? ? ? ? ? ? ? ? ?兩個(gè)算法的關(guān)鍵地方都需要判斷當(dāng)前的w在某個(gè)樣本點(diǎn)x上是否犯錯(cuò),而對(duì)于犯錯(cuò)的判斷有兩個(gè)版本,第一個(gè)版本就是直接使用 wx*y <= 0 就表示犯錯(cuò);第二個(gè)版本是用 sign(wx) != y 就表示犯錯(cuò)。這兩個(gè)版本對(duì)于訓(xùn)練的結(jié)果基本沒(méi)有影響,個(gè)人看出來(lái)的 唯一區(qū)別在于初始化w為0的時(shí)候,第一個(gè)版本對(duì)于任意的樣本都是犯錯(cuò)的;第二個(gè)版本,根據(jù)林老師的sign版本,sign(0)是-1,所以僅僅是可能犯錯(cuò)誤。我選擇了第二個(gè)版本,因?yàn)閭€(gè)人認(rèn)為第二個(gè)版本更加符合林老師的sign定義。
? ? ? ?代碼部分:
util.py內(nèi)容
# -*- coding:utf-8 -*- # Author: Evan Mi import numpy as npdef load_data(file_name):x = []y = []with open(file_name, 'r+') as f:for line in f:line = line.rstrip("\n")temp = line.split(" ")temp.insert(0, '1')x_temp = [float(val) for val in temp[:-1]]y_tem = [int(val) for val in temp[-1:]]x.append(x_temp)y.append(y_tem)nx = np.array(x)ny = np.array(y)return nx, nydef sign(value):if value > 0:return 1else:return -1pla.py內(nèi)容(這里的main中只是作業(yè)要求的隨機(jī)打亂的測(cè)試)
# -*- coding:utf-8 -*- # Author: Evan Mi import numpy as np from pla_and_pocket_pla import utildef pla(nx, ny, rate=1):""":param nx: 屬性矩陣,格式是[[...],[...],[...]]:param ny: 值,格式是 [[.],[.],[.]]:param rate: 學(xué)習(xí)率,默認(rèn)是1:return: 迭代次數(shù)"""total_update_nums = 0total_train_example_nums = np.size(nx, 0)continue_right_nums = 0 # 連續(xù)不犯錯(cuò)的次數(shù),當(dāng)continue_right_nums==total_train_example_nums的時(shí)候,程序結(jié)束w = np.zeros((1, 5)) # 初始化參數(shù)為0loop_index = 0while True:this_x = nx[loop_index]result = util.sign(np.dot(this_x, w.T)[0])this_y = ny[loop_index, 0]if result == this_y:continue_right_nums += 1else:continue_right_nums = 0w = w + rate * (this_x * this_y)total_update_nums += 1loop_index = (loop_index + 1) % total_train_example_numsif continue_right_nums == total_train_example_nums:breakreturn total_update_numsif __name__ == '__main__':"""這里展示的是隨機(jī)打亂樣本,以0.5的學(xué)習(xí)率運(yùn)行1000次的結(jié)果"""out_nx, out_ny = util.load_data("data/data.txt")avg = 0for i in range(1000):shuffle_index = np.arange(0, np.size(out_nx, 0))np.random.shuffle(shuffle_index)shuffled_x = out_nx[shuffle_index]shuffled_y = out_ny[shuffle_index]result_out = pla(shuffled_x, shuffled_y, 0.5)print("第%d次的更新次數(shù)為:%d" % ((i + 1), result_out))avg = avg + (1.0 / (i + 1)) * (result_out - avg)print("平均迭代次數(shù)為:%d" % avg)pocket_pla.py內(nèi)容(同樣,這里的mian方法中也只有部分的題目要求的測(cè)試)
# -*- coding:utf-8 -*- # Author: Evan Mi import numpy as np from pla_and_pocket_pla import utildef error_counter(x, y, w):result = np.where(x.dot(w[0].T) > 0, 1, -1)compare_result = np.where(result == y.T[0], 0, 1)return (1.0 * np.sum(compare_result)) / np.size(y, 0)def pocket_pla(nx, ny, rate=1, max_iter=50):""":param nx:屬性矩陣,格式是[[...],[...],[...]]:param ny: 值,格式是 [[.],[.],[.]]:param rate: 學(xué)習(xí)率,默認(rèn)是1:param max_iter: 最大迭代次數(shù),默認(rèn)50:return: w_pocket和w"""total_update_nums = 0total_train_example_nums = np.size(nx, 0)w_pocket = np.zeros((1, 5)) # w_pocket 就是一個(gè)口袋里的桃子,觀察著w的變化,一旦比自己好,立馬把w放進(jìn)口袋里w = np.zeros((1, 5)) # 初始化參數(shù)為0while True:rand_index = np.random.randint(0, total_train_example_nums)this_x = nx[rand_index]result = util.sign(np.dot(this_x, w.T)[0])this_y = ny[rand_index, 0]if int(result) != int(this_y):w = w + rate * (this_x * this_y)total_update_nums += 1if error_counter(nx, ny, w) < error_counter(nx, ny, w_pocket):w_pocket = wif total_update_nums == max_iter:breakreturn w_pocket, wif __name__ == '__main__':x_train, y_train = util.load_data("data/train.txt")x_test, y_test = util.load_data("data/test.txt")avg_pocket = 0avg = 0for index in range(2000):w_out_pocket, w_out = pocket_pla(x_train, y_train, max_iter=100)error_out_pocket = error_counter(x_test, y_test, w_out_pocket)error_out = error_counter(x_test, y_test, w_out)avg_pocket = avg_pocket + (1.0 / (index + 1)) * (error_out_pocket - avg_pocket)avg = avg + (1.0 / (index + 1)) * (error_out - avg)print(avg_pocket)print(avg)詳細(xì)項(xiàng)目代碼及代碼使用的數(shù)據(jù)見(jiàn):PLA和POCKET_PLA?
?
總結(jié)
以上是生活随笔為你收集整理的机器学习基石作业一中的PLA和POCKET_PLA实现的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: JAVA:贪吃蛇源代码
- 下一篇: java中math类方法之数学运算(po