元学习之《On First-Order Meta-Learning Algorithms》论文详细解读
元學習系列文章
文章目錄
- 引言
- On First-Order Meta-Learning Algorithms
- 偽算法
- 數學過程
- 訓練過程
- 實驗
- 核心代碼
- OpenAI Demo
- 幾點思考
- 參考資料
引言
上一篇博客對論文 MAML 做了詳細解讀,MAML 是元學習方向 optimization based 的開篇之作,還有一篇和 MAML 很像的論文 On First-Order Meta-Learning Algorithms,該論文是大名鼎鼎的 OpenAI 的杰作,OpenAI 對 MAML 做了簡化,但效果卻優于 MAML,具體做了什么簡化操作,請往下看😀。
On First-Order Meta-Learning Algorithms
這篇論文的標題就很針對 MAML,MAML 中有一個重要的特點,就是在求梯度時,為了加速放棄了二階求導,使用一階微分近似進行代替,雖然效果上相差不大,但總感覺少了點什么。這篇論文的標題上來就聲稱我們是一階的 metalearning 方法,而且剛好是在 MAML 發表的下一年(2018)發表在 ICML 會議的,從標題上也是賺慢了噱頭。
還有個有意思的事情,OpenAI 把論文中的算法稱之為 Reptile, 但是也沒有解釋為什么叫這個,論文中也沒看出來和 Reptile 有什么關聯,感興趣的讀者,可以去深究一下。
說了一堆廢話,下面開始進入正題。
偽算法
貼一張論文中的官方算法:
先來解釋一下:
1 首先初始化一個網絡模型的所有參數 ? \phi ?
2 迭代 N 次,進行訓練,每次迭代執行:
- 2.1 隨機抽樣一個任務 T,用網絡模型進行訓練,對應的loss 是 L t L_t Lt?,訓練結束后的參數是 ? ~ \widetilde{\phi} ? ?
- 2.2,在參數 ? \phi ?上使用 SGD 或 Adam 執行K次梯度下降更新,得到 ? ~ = U t k ( ? ) \widetilde{\phi}={U}^{k}_{t}(\phi) ? ?=Utk?(?)
- 2.3 用 ? ~ \widetilde{\phi} ? ?更新網絡模型模型參數, ? = ? + ? ( ? ~ ? ? ) \phi=\phi+\epsilon(\widetilde{\phi}-\phi) ?=?+?(? ???)
3 完成上述N次迭代訓練,則結束整個過程
從上面的算法中可以看出,Reptile 是在每個單獨的任務執行K次訓練后,就開始真正更新網絡模型的參數(Meta),更新方式不是梯度下降,但是和梯度下降公式長得很像,是用上一次的參數 ? \phi ?和K次后的參數 ? ~ \widetilde{\phi} ? ?的差來更新,更新的步長是 ? \epsilon ?。在這個過程中,只有一階求導的計算,就是在任務內部執行K次更新的過程中用到的隨機梯度下降,這也是為什么標題中叫 First-Order 的原因。
從這就可以看出和 MAML 算法的不同了:
這里說的meta參數,就是真正更新網絡模型參數的過程
數學過程
上面只是簡單介紹了 Reptile 的算法思想,下面從數學過程上來理解下它的更新過程,先來設定幾個符號:
? \phi ?代表網絡模型初始參數, ? , η \epsilon,\eta ?,η分別代表 meta 更新的學習率和 task 更新的學習率, N N N是meta訓練的 batch_size,即 meta 的一個bach有 N 個task,每個task內部執行K次訓練,N個任務都訓練完,再來更新meta參數。按照上面的算法過程,meta的一個batch訓練完之后,網絡模型的參數是:
? = ? + ? 1 N ∑ i = 1 N ( ? i ~ ? ? ) = ? + ? ( W ? ? ) \begin{aligned} \phi &= \phi +\epsilon \frac{1}{N}\sum_{i=1}^{N}\left ( \tilde{\phi_i } -\phi\right )\\ &= \phi +\epsilon \left ( W-\phi \right )\\ \end{aligned} ??=?+?N1?i=1∑N?(?i?~???)=?+?(W??)?
其中 W W W是每個任務最后參數的平均值,上述公式再進行展開就是這樣:
假設N=2,K=3,即meta每次訓練的一個batch 有2個task,每個task內部進行3此迭代,則 meta每次更新模型參數的公式為:
訓練過程
上面公式的最后一行,又變成了熟悉的梯度下降,只不過梯度方向是每個任務內部更新的幾次梯度方向的和。meta 模型的參數更新過程,在幾何上就是這樣的:
動圖看的更加清晰些,其中綠色代表第一個任務,三個綠色箭頭代表三次更新時的梯度方向,可以看到,Reptile的模型就是朝著每個任務的梯度和的方向上不斷地進行更新。
還記得 MAML 是怎樣更新的嗎?不記得的話,請翻看上一篇博客。還是同樣的設置,MAML 的更新過程如下:
即 MAML 是在每個任務最后一個梯度的方向上進行更新,而 Reptile 是在每個任務幾個梯度和的方向上進行更新。
實驗
實驗設置和 MAML 論文中的設置一樣,回歸任務以擬合正弦函數為例,分類任務以 MiniImagenet 數據和 omniglot 數據的圖片分類為例,詳細設置就不再贅述了,直接看實驗結果:
上半部分的圖是正弦函數的擬合結果,(b)是MAML的結果,C是Reptile的結果,橘黃色線是微調32次之后的樣子,綠色線是真實分布,可以看到 Reptile和MAML的結果相當,都能擬合到真實分布的樣子,硬要一較高下的話,那就是 Reptile稍好一些。
下半部分圖是在 MiniImagenet 分類數據上的結果,作者也對比了一階近似 MAML和二階MAML的結果,從圖中可以看出,Reptile的準確率至少要高出1個百分點。
在論文中作者還對比了一個有意思的實驗,Reptile 既然可以在 g 1 + g 2 + g 3 g_1+g_2+g_3 g1?+g2?+g3? 的梯度方向上更新,那么如果在其它梯度的組合方向上去更新,結果會怎樣呢?比如 g 1 + g 2 g_1+g_2 g1?+g2? 等方向,作者也針對不同梯度的組合進行了實驗,實驗結果如下:
橫軸是meta迭代次數,縱軸是準確率,不同顏色的曲線代表不同的梯度組合,可以明顯的看到最下面的藍色曲線準確率最低,藍色曲線代表在 g 1 g_1 g1? 第一個梯度方向上去更新,其實就是模型預訓練的過程,以所有訓練任務的 loss 為準進行更新。其他顏色的曲線都代表用若干次之后的 loss 來更新參數,最上面的那條曲線代表 Reptile,即用 g 1 + g 2 + g 3 + g 4 g_1+g_2+g_3+g_4 g1?+g2?+g3?+g4? 的梯度方向進行更新,只使用 g 4 g_4 g4? 的那條曲線代表 MAML。
核心代碼
Reptile 的論文代碼也是開源的,而且代碼很簡介規范,不愧是 OpenAI 出品。建議感興趣的讀者去看下論文源碼,不僅能更好的理解論文思想,對工程能力的提升也很有幫助,包括代碼風格、模塊化、組織架構、邏輯實現等都有很多值得借鑒的地方。關于源代碼有疑問的話,可以私信聯系我。這里只貼一點核心的訓練更新代碼,對應上面的數學過程:
代碼文件見 reptile.py
# 取出網絡模型的最新參數old_vars = self._model_state.export_variables()# 保存一個 meta batch 里,每個 task 更新 K 次后的參數new_vars = []for _ in range(meta_batch_size):# 抽樣出一個 taskmini_dataset = _sample_mini_dataset(dataset, num_classes, num_shots)for batch in _mini_batches(mini_dataset, inner_batch_size, inner_iters, replacement):# task 里面的訓練,更新 inner_iters 次,相當于公式中的Kinputs, labels = zip(*batch) # inner_iters 個 batch,每個 iter 使用一個 batch ,里面的一次訓練迭代if self._pre_step_op:self.session.run(self._pre_step_op)self.session.run(minimize_op, feed_dict={input_ph: inputs, label_ph: labels})# 一個 task 內部訓練完的參數new_vars.append(self._model_state.export_variables())self._model_state.import_variables(old_vars)# 對 meta_batch 個 task 的最終參數進行平均,相當于公式中的 Wnew_vars = average_vars(new_vars)# 所有的 meta_batch 個任務都訓練完, 更新一次 meta 參數,并且把更新后的參數更新到計算圖中,下次訓練從最新參數開始# 更新方式:old + scale*(new - old)self._model_state.import_variables(interpolate_vars(old_vars, new_vars, meta_step_size))OpenAI Demo
在 OpenAI 的官方博客 Reptile: A Scalable Meta-Learning Algorithm中,也有介紹這篇論文。該博客網頁中還有個有意思的 demo,大家可以試玩一下:
這個 demo 的意思是,openAI 已經用他們的 Reptile 算法訓練了一個用于少樣本場景的3分類網絡模型,并且嵌入到了網頁中,用戶可以通過 demo 中的交互制作一個新的三分類任務,并且這個任務只有三個訓練樣本,也就是每個類下只有一個樣本,學名叫3-Way 1-shot,讓他們的模型在這三個樣本上進行微調學習,然后在右邊畫一個新的三個類別下的測試樣本,Reptile 模型會自動給出它在三個類別下的概率。通過這個 demo 來證明他們的模型確實有奇效,在新任務的幾個樣本上微調一下,就可以在該任務的測試集上取得很好的準確率。
幾點思考
通過上面的 demo 可以得出一些結論:
參考資料
- https://arxiv.org/pdf/1803.02999.pdf
- https://github.com/openai/supervised-reptile
- https://www.bilibili.com/video/BV1Gb411n7dE?p=32
總結
以上是生活随笔為你收集整理的元学习之《On First-Order Meta-Learning Algorithms》论文详细解读的全部內容,希望文章能夠幫你解決所遇到的問題。