关于知识蒸馏,你想知道的都在这里!
"蒸餾",一個化學用語,在不同的沸點下提取出不同的成分。知識蒸餾就是指一個很大很復雜的模型,有著非常好的效果和泛化能力,這是缺乏表達能力的小模型所不能擁有的。因此從大模型學到的知識用于指導小模型,使得小模型具有大模型的泛化能力,并且參數量顯著降低,壓縮了模型提升了性能,這就是知識蒸餾。<Distilling the Knowledge in a Neural Network>這篇論文首次提出了知識蒸餾的概念,核心思想就是訓練一個復雜模型,把這個復雜模型的輸出和有l(wèi)abel的數據一并喂給了小網絡,所以知識蒸餾一定會有個復雜的大模型(teacher model)和一個小模型(student model)。
為什么要蒸餾?
模型越來越深,網絡越來越大,參數越來越多,效果越來越好,但是計算復雜度呢?一并上升,蒸餾就是個特別好的方法,用于壓縮模型的大小。
- 提升模型準確率:如果你不滿意現有小模型的效果,可以訓練一個高度復雜效果極佳的大模型(teacher model),然后用它指導小模型達到你滿意的效果。
- 降低模型延遲,壓縮網絡參數:網絡延遲大?像是bert這種大模型,是否可以用一個一層,減少head數的簡單網絡去學習bert呢,這樣不僅提升了簡單網絡的準確率,也實現了延遲的降低。
- 遷移學習:比方說一個老師知道分辨貓狗,另一個老師知道分辨香蕉蘋果,那學生從這兩個老師學習就能同時分辨貓狗和香蕉蘋果。
順便回顧下之前探討過的模型壓縮5種方法:
- Model pruning
- Quantification
- Knowledge distillation
- Parameter sharing
- Parameter matrix approximation
理想情況下,我們是希望同樣一份訓練數據,無論是大模型還是小模型,他們收斂的空間重合度很高,但實際情況由于大模型搜索空間較大,小模型較小,他們收斂的重合度就較低,知識蒸餾能提升他們之間的重合度使得小模型有更好的泛化能力。
知識蒸餾最基礎的框架:
使用Teacher-Student model,用一個非常大而復雜的老師模型,輔助學生模型訓練。老師模型巨大復雜,因此不用于在線,學生模型部署在線上,靈活小巧易于部署。知識蒸餾可以簡單的分為兩個主要的方向:target-based蒸餾,feature-based蒸餾。
Target distillation-Logits method
上文提到的那篇最經典的論文就是該方法一個很好的例子。這篇論文解決的是一個分類問題,既然是分類問題模型就會有個softmax層,該層輸出值直接就是每個類別的概率,在知識蒸餾中,因為我們有個很好的老師模型,一個最直接的方法就是讓學生模型去擬合老師模型輸出的每個類別的概率,也就是我們常說的"Soft-target"。
Hard-target and Soft-target
模型要能訓練,必須定義loss函數,目標就是讓預測值更接近真實值,真實值就是Hard-target,loss函數會使得偏差越來越小。在知識蒸餾中,直接學習每個類別的概率(老師模型預估的)就是soft-target。
Hard-target:類似one-hot的label,比如二分類,正例是1,負例是0。
Soft-target:老師模型softmax層輸出的概率分布,概率最大的就是正類別。
知識蒸餾使得老師模型的soft-target去指導用hard-target學習的學生模型,為什么是有效的呢?因為老師模型輸出的softmax層攜帶的信息要遠多于hard-target,老師模型給學生模型不僅提供了正例的信息,也提供了負例的概率,所以學生模型可以學到更多hard-target學不到的東西。
知識蒸餾具體方法:
神經網絡用softmax層去計算各類的概率:
但是直接使用softmax的輸出作為soft-target會有其他問題,當softmax輸出的概率分布的熵相對較小時,負類別的label就接近0,對loss函數的共享就非常小,小到可以忽略。所以可以新增個變量"temperature",用下式去計算softmax函數:
當T是1,就是以前的softmax模型,當T非常大,那輸出的概率會變的非常平滑,會有很大的熵,模型就會更加關注負類別。
具體蒸餾流程如下:
1.訓練老師模型;
2.使用個較高的溫度去構建Soft-target;
3.同時使用較高溫度的Soft-target和T=1的Soft-target去訓練學生模型;
4.把T改為1在學生模型上做預估。
老師模型的訓練過程非常簡單。學生模型的目標函數可以同時使用兩個loss,一個是蒸餾loss,另一個是本身的loss,用權重控制,如下式所示:
老師和學生使用相同的溫度T,vi適合zi指softmax輸出的logits。L_hard用的就是溫度1。
L_hard的重要性不言而喻,老師也可能會教錯!使用L_hard能避免老師的錯誤傳遞給學生。L_soft和L_hard之前的權重也比較重要,實驗表明L_hard權重較小往往帶來更好的效果,因為L_soft的梯度貢獻大約是1/T^2,所以L_soft最好乘上一個T^2去確保兩個loss的梯度貢獻等同。
一種特殊形式的蒸餾方式:Direct Matching Logits
直接使用softmax層產出的logits作為soft-target,目標函數直接使用均方誤差,如下所示:
和傳統(tǒng)蒸餾方法相比,T趨向于無窮大時,直接擬合logits和擬合概率是等同的(證明略),所以這是一種特殊形式的蒸餾方式。
關于溫度:
一個較高的溫度,往往能蒸餾出更多知識,但是怎么去調節(jié)溫度呢?
- 最原始的softmax函數就是T=1,當T < 1,概率分布更"陡",當T->0,輸出值就變成了Hard-target,當T > 1,概率分布就會更平滑。
- 當T變大,概率分布熵會變大,當T趨于無窮,softamx結果就均勻分布了。
- 不管T是多少,Soft-target會攜帶更多具有傾向性的信息。
T的變化程度決定了學生模型有多少attention在負類別上,當溫度很低,模型就不太關注負類別,特別是那些小于均值的負類別,當溫度很高,模型就更多的關注負類別。事實上負類別攜帶更多信息,特別是大于均值的負類別。因此選對溫度很重要,需要更多實驗去選擇。T的選擇和學生模型的大小關系也很大,當學生模型相對較小,一個較小的T就足夠了,因為學生模型沒有能力學習老師模型全部的知識,一些負類別信息就可以忽略。
除此以外,還有很多特別的蒸餾思想,如intermediate based蒸餾,如下圖所示,蒸餾的不僅僅是softmax層,連中間層一并蒸餾。
關于"知識蒸餾",你想知道的都在這里!總結
以上是生活随笔為你收集整理的关于知识蒸馏,你想知道的都在这里!的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 时间序列里面最强特征之一
- 下一篇: 炼丹秘术:给Embedding插上翅膀