三次简化一张图:一招理解LSTM/GRU门控机制
機器之心專欄
作者:張皓
RNN 在處理時序數(shù)據(jù)時十分成功。但是,對 RNN 及其變種 LSTM 和 GRU 結(jié)構(gòu)的理解仍然是一個困難的任務。本文介紹一種理解 LSTM 和 GRU 的簡單通用的方法。通過對 LSTM 和 GRU 數(shù)學形式化的三次簡化,最后將數(shù)據(jù)流形式畫成一張圖,可以簡潔直觀地對其中的原理進行理解與分析。此外,本文介紹的三次簡化一張圖的分析方法具有普適性,可廣泛用于其他門控網(wǎng)絡的分析。
1. RNN、梯度爆炸與梯度消失
1.1 RNN
近些年,深度學習模型在處理有非常復雜內(nèi)部結(jié)構(gòu)的數(shù)據(jù)時十分有效。例如,圖像數(shù)據(jù)的像素之間的 2 維空間關(guān)系非常重要,CNN(convolution neural networks,卷積神經(jīng)網(wǎng)絡)處理這種空間關(guān)系十分有效。而時序數(shù)據(jù)(sequential data)的變長輸入序列之間時序關(guān)系非常重要,RNN(recurrent neural networks,循環(huán)神經(jīng)網(wǎng)絡,注意和 recursive neural networks,遞歸神經(jīng)網(wǎng)絡的區(qū)別)處理這種時序關(guān)系十分有效。
我們使用下標 t 表示輸入時序序列的不同位置,用 h_t 表示在時刻 t 的系統(tǒng)隱層狀態(tài)向量,用 x_t 表示時刻 t 的輸入。t 時刻的隱層狀態(tài)向量 h_t 依賴于當前詞 x_t 和前一時刻的隱層狀態(tài)向量 h_(t-1):
其中 f 是一個非線性映射函數(shù)。一種通常的做法是計算 x_t 和 h_(t-1) 的線性變換后經(jīng)過一個非線性激活函數(shù),例如
其中 W_(xh) 和 W_(hh) 是可學習的參數(shù)矩陣,激活函數(shù) tanh 獨立地應用到其輸入的每個元素。
為了對 RNN 的計算過程做一個可視化,我們可以畫出下圖:
圖中左邊是輸入 x_t 和 h_(t-1)、右邊是輸出 h_t。計算從左向右進行,整個運算包括三步:輸入 x_t 和 h_(t-1) 分別乘以 W_(xh) 和 W_(hh) 、相加、經(jīng)過 tanh 非線性變換。
我們可以認為 h_t 儲存了網(wǎng)絡中的記憶(memory),RNN 學習的目標是使得 h_t 記錄了在 t 時刻之前(含)的輸入信息 x_1, x_2,..., x_t。在新詞 x_t 輸入到網(wǎng)絡之后,之前的隱狀態(tài)向量 h_(t-1) 就轉(zhuǎn)換為和當前輸入 x_t 有關(guān)的 h_t。
1.2 梯度爆炸與梯度消失
雖然理論上 RNN 可以捕獲長距離依賴,但實際應用中,RNN 將會面臨兩個挑戰(zhàn):梯度爆炸(gradient explosion)和梯度消失(vanishing gradient)。
我們考慮一種簡單情況,即激活函數(shù)是恒等(identity)變換,此時
在進行誤差反向傳播(error backpropagation)時,當我們已知損失函數(shù)對 t 時刻隱狀態(tài)向量 h_t 的偏導數(shù)時,利用鏈式法則,我們計算損失函數(shù) 對 t 時刻隱狀態(tài)向量 h_0 的偏導數(shù)
我們可以利用 RNN 的依賴關(guān)系,沿時間維度展開,來計算
也就是說,在誤差反向傳播時我們需要反復乘以參數(shù)矩陣 W_(hh)。我們對矩陣 ?W_(hh) 進行奇異值分解(SVD)
其中 r 是矩陣 W_(hh) 的秩(rank)。因此,
那么我們最后要計算的目標
當 t 很大時,該偏導數(shù)取決于矩陣 W_(hh) 的最大的奇異值是大于 1 還是小于 1,要么結(jié)果太大,要么結(jié)果太小:
(1). 梯度爆炸。當 > 1,,那么
此時偏導數(shù)將會變得非常大,實際在訓練時將會遇到 NaN 錯誤,會影響訓練的收斂,甚至導致網(wǎng)絡不收斂。這好比要把本國的產(chǎn)品賣到別的國家,結(jié)果被加了層層關(guān)稅,等到了別國市場的時候,價格已經(jīng)變得非常高,老百姓根本買不起。在 RNN 中,梯度(偏導數(shù))就是價格,隨著向前推移,梯度越來越大。這種現(xiàn)象稱為梯度爆炸。
梯度爆炸相對比較好處理,可以用梯度裁剪(gradient clipping)來解決:
這好比是不管前面的關(guān)稅怎么加,設置一個最高市場價格,通過這個最高市場價格保證老百姓是買的起的。在 RNN 中,不管梯度回傳的時候大到什么程度,設置一個梯度的閾值,梯度最多是這么大。
(2). 梯度消失。當 < 1,,那么
此時偏導數(shù)將會變得十分接近 0,從而在梯度更新前后沒有什么區(qū)別,這會使得網(wǎng)絡捕獲長距離依賴(long-term dependency)的能力下降。這好比打仗的時候往前線送糧食,送糧食的隊伍自己也得吃糧食。當補給點離前線太遠時,還沒等送到,糧食在半路上就已經(jīng)被吃完了。在 RNN 中,梯度(偏導數(shù))就是糧食,隨著向前推移,梯度逐漸被消耗殆盡。這種現(xiàn)象稱為梯度消失。
梯度消失現(xiàn)象解決起來困難很多,如何緩解梯度消失是 RNN 及幾乎其他所有深度學習方法研究的關(guān)鍵所在。LSTM 和 GRU 通過門(gate)機制控制 RNN 中的信息流動,用來緩解梯度消失問題。其核心思想是有選擇性的處理輸入。比如我們在看到一個商品的評論時
Amazing! This box of cereal gave me a perfectly balanced breakfast, as all things should be. In only ate half of it but will definitely be buying again!
我們會重點關(guān)注其中的一些詞,對它們進行處理
Amazing! This box of cereal gave me a perfectly balanced breakfast, as all things should be. In only ate half of it but will definitely be buying again!
LSTM 和 GRU 的關(guān)鍵是會選擇性地忽略其中一些詞,不讓其參與到隱層狀態(tài)向量的
更新中,最后只保留相關(guān)的信息進行預測。
2. LSTM
2.1 LSTM 的數(shù)學形式
LSTM(Long Short-Term Memory)由 Hochreiter 和 Schmidhuber 提出,其數(shù)學上的形式化表示如下:
其中代表逐元素相乘,sigm 代表 sigmoid 函數(shù)
和 RNN 相比,LSTM 多了一個隱狀態(tài)變量 c_t,稱為細胞狀態(tài)(cell state),用來記錄信息。
這個公式看起來似乎十分復雜,為了更好的理解 LSTM 的機制,許多人用圖來描述 LSTM 的計算過程。比如下面這張圖:
似乎看完之后,對 LSTM 的理解仍然是一頭霧水?這是因為這些圖想把 LSTM 的所有細節(jié)一次性都展示出來,但是突然暴露這么多的細節(jié)會使你眼花繚亂,從而無處下手。
2.2 三次簡化一張圖
因此,本文提出的方法旨在簡化門控機制中不重要的部分,從而更關(guān)注在 LSTM 的核心思想。整個過程是三次簡化一張圖,具體流程如下:
(1). 第一次簡化:忽略門控單元 i_t 、f_t 、o_t 的來源。3 個門控單元的計算方法完全相同,都是由輸入經(jīng)過線性映射得到的,區(qū)別只是計算的參數(shù)不同:
使用相同計算方式的目的是它們都扮演了門控的角色,而使用不同參數(shù)的目的是為了誤差反向傳播時對三個門控單元獨立地進行更新。在理解 LSTM 運行機制的時候,為了對圖進行簡化,我們不在圖中標注三個門控單元的計算過程,并假定各門控單元是給定的。
(2). 第二次簡化:考慮一維門控單元 i_t 、 f_t 、 o_t。LSTM 中對各維是獨立進行門控的,所以為了表示和理解方便,我們只需要考慮一維情況,在理解 LSTM 原理之后,將一維推廣到多維是很直接的。經(jīng)過這兩次簡化,LSTM 的數(shù)學形式只有下面三行
由于門控單元變成了一維,所以向量和向量的逐元素相乘符號變成了數(shù)和向量相乘 · 。
(3). 第三次簡化:各門控單元二值輸出。門控單元 i_t 、f_t 、o_t 的由于經(jīng)過了 sigmoid 激活函數(shù),輸出是范圍是 [0, 1]。激活函數(shù)使用 sigmoid 的目的是為了近似 0/1 階躍函數(shù),這樣 sigmoid 實數(shù)值輸出單調(diào)可微,可以基于誤差反向傳播進行更新。
既然 sigmoid 激活函數(shù)是為了近似 0/1 階躍函數(shù),那么,在進行 LSTM 理解分析的時候,為了理解方便,我們認為各門控單元 {0, 1} 二值輸出,即門控單元扮演了電路中開關(guān)的角色,用于控制信息的通斷。
(4). 一張圖。將三次簡化的結(jié)果用電路圖表述出來,左邊是輸入,右邊是輸出。在 LSTM 中,有一點需要特別注意,LSTM 中的細胞狀態(tài) c_t 實質(zhì)上起到了 RNN 中隱層單元 h_t 的作用,這點在其他文獻資料中不常被提到,所以整個圖的輸入是 x_t 和 ?c_{t-1},而不是 x_t 和 h_(t-1)。為了方便畫圖,我們需要將公式做最后的調(diào)整
最終結(jié)果如下:
和 RNN 相同的是,網(wǎng)絡接受兩個輸入,得到一個輸出。其中使用了兩個參數(shù)矩陣 ?W_(xc) 和 W_(hc),以及 tanh 激活函數(shù)。不同之處在于,LSTM 中通過 3 個門控單元 i_t 、f_t 、o_t 來對的信息交互進行控制。當 i_t=1(開關(guān)閉合)、f_t=0(開關(guān)打開)、o_t=1(開關(guān)閉合)時,LSTM 退化為標準的 RNN。
2.3 LSTM 各單元作用分析
根據(jù)這張圖,我們可以對 LSTM 中各單元作用進行分析:
輸出門 o_(t-1):輸出門的目的是從細胞狀態(tài) c_(t-1) 產(chǎn)生隱層單元 h_(t-1)。并不是 c_(t-1) 中的全部信息都和隱層單元 h_(t-1) 有關(guān),c_(t-1)?可能包含了很多對 h_(t-1) 無用的信息。因此,o_t 的作用就是判斷?c_(t-1) 中哪些部分是對 h_(t-1) 有用的,哪些部分是無用的。
輸入門 i_t。i_t 控制當前詞 x_t 的信息融入細胞狀態(tài) c_t。在理解一句話時,當前詞 x_t 可能對整句話的意思很重要,也可能并不重要。輸入門的目的就是判斷當前詞 x_t 對全局的重要性。當 i_t 開關(guān)打開的時候,網(wǎng)絡將不考慮當前輸入 ?x_t。
遺忘門 f_t: f_t 控制上一時刻細胞狀態(tài) c_(t-1) 的信息融入細胞狀態(tài) c_t。在理解一句話時,當前詞 x_t 可能繼續(xù)延續(xù)上文的意思繼續(xù)描述,也可能從當前詞 x_t 開始描述新的內(nèi)容,與上文無關(guān)。和輸入門 i_t 相反,f_t 不對當前詞 x_t 的重要性作判斷,而判斷的是上一時刻的細胞狀態(tài)c_(t-1)對計算當前細胞狀態(tài) c_t 的重要性。當 f_t 開關(guān)打開的時候,網(wǎng)絡將不考慮上一時刻的細胞狀態(tài)?c_(t-1)。
細胞狀態(tài)?c_t :c_t 綜合了當前詞 x_t 和前一時刻細胞狀態(tài)?c_(t-1)?的信息。這和 ResNet 中的殘差逼近思想十分相似,通過從?c_(t-1)?到 c_t 的「短路連接」,梯度得已有效地反向傳播。當 f_t 處于閉合狀態(tài)時,c_t 的梯度可以直接沿著最下面這條短路線傳遞到c_(t-1),不受參數(shù) W_(xh) 和 W_(hh) 的影響,這是 LSTM 能有效地緩解梯度消失現(xiàn)象的關(guān)鍵所在。
3. GRU
3.1 GRU 的數(shù)學形式
GRU 是另一種十分主流的 RNN 衍生物。RNN 和 LSTM 都是在設計網(wǎng)絡結(jié)構(gòu)用于緩解梯度消失問題,只不過是網(wǎng)絡結(jié)構(gòu)有所不同。GRU 在數(shù)學上的形式化表示如下:
3.2 三次簡化一張圖
為了理解 GRU 的設計思想,我們再一次運用三次簡化一張圖的方法來進行分析:
(1). 第一次簡化:忽略門控單元 z_t 和 r_t 的來源。
(2). 考慮一維門控單元 z_t 和 r_t。經(jīng)過這兩次簡化,GRU 的數(shù)學形式是以下兩行
(3). 第三次簡化:各門控單元二值輸出。這里和 LSTM 略有不同的地方在于,當 z_t=1 時h_t = h_(t-1) ;而當 z_t = 0 時,h_t =。因此,z_t 扮演的角色是一個個單刀雙擲開關(guān)。
(4). 一張圖。將三次簡化的結(jié)果用電路圖表述出來,左邊是輸入,右邊是輸出。
與 LSTM 相比,GRU 將輸入門 i_t 和遺忘門 f_t 融合成單一的更新門 z_t,并且融合了細胞狀態(tài) c_t 和隱層單元 h_t。當 r_t=1(開關(guān)閉合)、 z_t=0(開關(guān)連通上面)GRU 退化為標準的 RNN。
3.3 GRU 各單元作用分析
根據(jù)這張圖, 我們可以對 GRU 的各單元作用進行分析:
重置門 r_t : r_t 用于控制前一時刻隱層單元 h_(t-1) 對當前詞 x_t 的影響。如果 h_(t-1)?對 x_t 不重要,即從當前詞 x_t 開始表述了新的意思,與上文無關(guān)。那么開關(guān) r_t 可以打開,使得 h_(t-1) 對 x_t 不產(chǎn)生影響。
更新門 z_t : z_t 用于決定是否忽略當前詞 x_t。類似于 LSTM 中的輸入門 i_t,z_t 可以判斷當前詞 x_t 對整體意思的表達是否重要。當 z_t 開關(guān)接通下面的支路時,我們將忽略當前詞 x_t,同時構(gòu)成了從 h_(t-1)?到 h_t 的短路連接,這使得梯度得已有效地反向傳播。和 LSTM 相同,這種短路機制有效地緩解了梯度消失現(xiàn)象,這個機制于 highway networks 十分相似。
4. 小結(jié)
盡管 RNN、LSTM、和 GRU 的網(wǎng)絡結(jié)構(gòu)差別很大,但是他們的基本計算單元是一致的,都是對 x_t 和 h_t 做一個線性映射加 tanh 激活函數(shù),見三個圖的紅色框部分。他們的區(qū)別在于如何設計額外的門控機制控制梯度信息傳播用以緩解梯度消失現(xiàn)象。LSTM 用了 3 個門、GRU 用了 2 個,那能不能再少呢?MGU(minimal gate unit)嘗試對這個問題做出回答,它只有一個門控單元。最后留個小練習,參考 LSTM 和 GRU 的例子,你能不能用三次簡化一張圖的方法來分析一下 MGU 呢?
參考文獻
Yoshua Bengio, Patrice Y. Simard, and Paolo Frasconi. Learning long-term dependencies with gradient descent is difficult. IEEE Transactions on Neural Networks 5(2): 157-166, 1994.
Kyunghyun Cho, Bart van Merrienboer, ?aglar Gül?ehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. Learning phrase representations using RNN encoder-decoder for statistical machine translation. In EMNLP, pages 1724-1734, 2014.
Junyoung Chung, ?aglar Gül?ehre, KyungHyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. In NIPS Workshop, pages 1-9, 2014.
Felix Gers. Long short-term memory in recurrent neural networks. PhD Dissertation, Ecole Polytechnique Fédérale de Lausanne, 2001.
Ian J. Goodfellow, Yoshua Bengio, and Aaron C. Courville. Deep learning. Adaptive Computation and Machine Learning, MIT Press, ISBN 978-0-262-03561-3, 2016.
Alex Graves. Supervised sequence labelling with recurrent neural networks. Studies in Computational Intelligence 385, Springer, ISBN 978-3-642-24796-5, 2012.
Klaus Greff, Rupesh Kumar Srivastava, Jan Koutník, Bas R. Steunebrink, and Jürgen Schmidhuber. LSTM: A search space odyssey. IEEE Transactions on Neural Networks and Learning Systems. 28(10): 2222-2232, 2017.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, pages 770-778, 2016.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Identity mappings in deep residual networks. In ECCV, pages 630-645, 2016.
Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural Computation 9(8): 1735-1780, 1997.
Rafal Józefowicz, Wojciech Zaremba, and Ilya Sutskever. An empirical exploration of recurrent network architectures. In ICML, pages 2342-2350, 2015.
Zachary Chase Lipton. A critical review of recurrent neural networks for sequence learning. CoRR abs/1506.00019, 2015.
Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. On the difficulty of training recurrent neural networks. In ICML, pages 1310-1318, 2013.
Rupesh Kumar Srivastava, Klaus Greff, and Jürgen Schmidhuber. Highway networks. In ICML Workshop, pages 1-6, 2015.
Guo-Bing Zhou, Jianxin Wu, Chen-Lin Zhang, and Zhi-Hua Zhou. Minimal gated unit for recurrent neural networks. International Journal of Automation and Computing, 13(3): 226-234, 2016.
本文為機器之心專欄,轉(zhuǎn)載請聯(lián)系本公眾號獲得授權(quán)。
?------------------------------------------------
加入機器之心(全職記者 / 實習生):hr@jiqizhixin.com
投稿或?qū)で髨蟮?#xff1a;content@jiqizhixin.com
廣告 & 商務合作:bd@jiqizhixin.com
總結(jié)
以上是生活随笔為你收集整理的三次简化一张图:一招理解LSTM/GRU门控机制的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 斯坦福统计学习理论笔记:Percy Li
- 下一篇: 好嗨哟~谷歌量子神经网络新进展揭秘