Self-Attention GAN 中的 self-attention 机制
作者丨尹相楠
學(xué)校丨里昂中央理工博士在讀
研究方向丨人臉識(shí)別、對(duì)抗生成網(wǎng)絡(luò)
Self Attention GAN 用到了很多新的技術(shù)。最大的亮點(diǎn)當(dāng)然是 self-attention 機(jī)制,該機(jī)制是 Non-local Neural Networks?[1] 這篇文章提出的。其作用是能夠更好地學(xué)習(xí)到全局特征之間的依賴(lài)關(guān)系。因?yàn)閭鹘y(tǒng)的 GAN 模型很容易學(xué)習(xí)到紋理特征:如皮毛,天空,草地等,不容易學(xué)習(xí)到特定的結(jié)構(gòu)和幾何特征,例如狗有四條腿,既不能多也不能少。?
除此之外,文章還用到了 Spectral Normalization for GANs [2]?提出的譜歸一化。譜歸一化的解釋見(jiàn)本人這篇文章:詳解GAN的譜歸一化(Spectral Normalization)。
但是,該文代碼中的譜歸一化和原始的譜歸一化運(yùn)用方式略有差別:?
1. 原始的譜歸一化基于 W-GAN 的理論,只用在 Discriminator 中,用以約束 Discriminator 函數(shù)為 1-Lipschitz 連續(xù)。而在 Self-Attention GAN 中,Spectral Normalization 同時(shí)出現(xiàn)在了 Discriminator 和 Generator 中,用于使梯度更穩(wěn)定。除了生成器和判別器的最后一層外,每個(gè)卷積/反卷積單元都會(huì)上一個(gè) SpectralNorm。?
2. 當(dāng)把譜歸一化用在 Generator 上時(shí),同時(shí)還保留了 BatchNorm。Discriminator 上則沒(méi)有 BatchNorm,只有 SpectralNorm。?
3. 譜歸一化用在 Discriminator 上時(shí)最后一層不加 Spectral Norm。?
最后,self-attention GAN 還用到了 cGANs With Projection Discriminator 提出的 conditional normalization 和 projection in the discriminator。這兩個(gè)技術(shù)我還沒(méi)有來(lái)得及看,而且 PyTorch 版本的 self-attention GAN 代碼中也沒(méi)有實(shí)現(xiàn),就先不管它們了。
本文主要說(shuō)的是 self-attention 這部分內(nèi)容。
▲?圖1.?Self-Attention
Self-Attention
在卷積神經(jīng)網(wǎng)絡(luò)中,每個(gè)卷積核的尺寸都是很有限的(基本上不會(huì)大于 5),因此每次卷積操作只能覆蓋像素點(diǎn)周?chē)苄∫粔K鄰域。
對(duì)于距離較遠(yuǎn)的特征,例如狗有四條腿這類(lèi)特征,就不容易捕獲到了(也不是完全捕獲不到,因?yàn)槎鄬拥木矸e、池化操作會(huì)把 feature map 的高和寬變得越來(lái)越小,越靠后的層,其卷積核覆蓋的區(qū)域映射回原圖對(duì)應(yīng)的面積越大。但總而言之,畢竟還得需要經(jīng)過(guò)多層映射,不夠直接)。
Self-Attention 通過(guò)直接計(jì)算圖像中任意兩個(gè)像素點(diǎn)之間的關(guān)系,一步到位地獲取圖像的全局幾何特征。?
論文中的公式不夠直觀,我們直接看文章的 PyTorch 的代碼,核心部分為 sagan_models.py:
????"""?Self?attention?Layer"""
????def?__init__(self,in_dim,activation):
????????super(Self_Attn,self).__init__()
????????self.chanel_in?=?in_dim
????????self.activation?=?activation
????????self.query_conv?=?nn.Conv2d(in_channels?=?in_dim?,?out_channels?=?in_dim//8?,?kernel_size=?1)
????????self.key_conv?=?nn.Conv2d(in_channels?=?in_dim?,?out_channels?=?in_dim//8?,?kernel_size=?1)
????????self.value_conv?=?nn.Conv2d(in_channels?=?in_dim?,?out_channels?=?in_dim?,?kernel_size=?1)
????????self.gamma?=?nn.Parameter(torch.zeros(1))
????????self.softmax??=?nn.Softmax(dim=-1)?#
????def?forward(self,x):
????????"""
????????????inputs?:
????????????????x?:?input?feature?maps(?B?X?C?X?W?X?H)
????????????returns?:
????????????????out?:?self?attention?value?+?input?feature?
????????????????attention:?B?X?N?X?N?(N?is?Width*Height)
????????"""
????????m_batchsize,C,width?,height?=?x.size()
????????proj_query??=?self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1)?#?B?X?CX(N)
????????proj_key?=??self.key_conv(x).view(m_batchsize,-1,width*height)?#?B?X?C?x?(*W*H)
????????energy?=??torch.bmm(proj_query,proj_key)?#?transpose?check
????????attention?=?self.softmax(energy)?#?BX?(N)?X?(N)?
????????proj_value?=?self.value_conv(x).view(m_batchsize,-1,width*height)?#?B?X?C?X?N
????????out?=?torch.bmm(proj_value,attention.permute(0,2,1)?)
????????out?=?out.view(m_batchsize,C,width,height)
????????out?=?self.gamma*out?+?x
????????return?out,attention
構(gòu)造函數(shù)中定義了三個(gè) 1?× 1 的卷積核,分別被命名為 query_conv , key_conv 和 value_conv 。
為啥命名為這三個(gè)名字呢?這和作者給它們賦予的含義有關(guān)。query 意為查詢(xún),我們希望輸入一個(gè)像素點(diǎn),查詢(xún)(計(jì)算)到 feature map 上所有像素點(diǎn)對(duì)這一點(diǎn)的影響。而 key 代表字典中的鍵,相當(dāng)于所查詢(xún)的數(shù)據(jù)庫(kù)。query 和 key 都是輸入的 feature map,可以看成把 feature map 復(fù)制了兩份,一份作為 query 一份作為 key。?
需要用一個(gè)什么樣的函數(shù),才能針對(duì) query 的 feature map 中的某一個(gè)位置,計(jì)算出 key 的 feature map 中所有位置對(duì)它的影響呢?作者認(rèn)為這個(gè)函數(shù)應(yīng)該是可以通過(guò)“學(xué)習(xí)”得到的。那么,自然而然就想到要對(duì)這兩個(gè) feature map 分別做卷積核為 1?× 1 的卷積了,因?yàn)榫矸e核的權(quán)重是可以學(xué)習(xí)得到的。?
至于 value_conv ,可以看成對(duì)原 feature map 多加了一層卷積映射,這樣可以學(xué)習(xí)到的參數(shù)就更多了,否則 query_conv 和 key_conv 的參數(shù)太少,按代碼中只有 in_dims × in_dims//8 個(gè)。?
接下來(lái)逐行研究 forward 函數(shù):
這行代碼先對(duì)輸入的 feature map 卷積了一次,相當(dāng)于對(duì) query feature map 做了一次投影,所以叫做 proj_query。由于是 1?× 1 的卷積,所以不改變 feature map 的長(zhǎng)和寬。feature map 的每個(gè)通道為如 (1) 所示的矩陣,矩陣共有 N 個(gè)元素(像素)。
然后重新改變了輸出的維度,變成:
?(m_batchsize,-1,width*height)?
batch size 保持不變,width 和 height 融合到一起,把如 (1) 所示二維的 feature map 每個(gè) channel 拉成一個(gè)長(zhǎng)度為 N 的向量。
因此,如果 m_batchsize 取 1,即單獨(dú)觀察一個(gè)樣本,該操作的結(jié)果是得到一個(gè)矩陣,矩陣的的行數(shù)為 query_conv 卷積輸出的 channel 的數(shù)目 C( in_dim//8 ),列數(shù)為 feature map 像素?cái)?shù) N。
然后作者又通過(guò) .permute(0, 2, 1) 轉(zhuǎn)置了矩陣,矩陣的行數(shù)變成了 feature map 的像素?cái)?shù) N,列數(shù)變成了通道數(shù) C。因此矩陣維度為 N?× C 。該矩陣每行代表一個(gè)像素位置上所有通道的值,每列代表某個(gè)通道中所有的像素值。
▲?圖2.?proj_query 的維度
這行代碼和上一行類(lèi)似,只不過(guò)取消了轉(zhuǎn)置操作。得到的矩陣行數(shù)為通道數(shù) C,列數(shù)為像素?cái)?shù) N,即矩陣維度為 C?× N。該矩陣每行代表一個(gè)通道中所有的像素值,每列代表一個(gè)像素位置上所有通道的值。
▲?圖3. proj_key的維度
這行代碼中, torch.bmm 的意思是 batch matrix multiplication。就是說(shuō)把相同 batch size 的兩組 matrix 一一對(duì)應(yīng)地做矩陣乘法,最后得到同樣 batchsize 的新矩陣。
若 batch size=1,就是普通的矩陣乘法。已知 proj_query 維度是 N?× C, proj_key 的維度是 C?×?N,因此 energy 的維度是 N?× N:
▲?圖4. energy的維度
energy 是 attention 的核心,其中第 i 行 j 列的元素,是由 proj_query 第 i 行,和 proj_key 第 j 列通過(guò)向量點(diǎn)乘得到的。而 proj_query 第 i 行表示的是 feature map 上第 i 個(gè)像素位置上所有通道的值,也就是第 i 個(gè)像素位置的所有信息,而 proj_key 第 j 列表示的是 feature map 上第 j 個(gè)像素位置上的所有通道值,也就是第 j 個(gè)像素位置的所有信息。
這倆相乘,可以看成是第 j 個(gè)像素對(duì)第 i 個(gè)像素的影響。即,energy 中第 i 行 j 列的元素值,表示第 j 個(gè)像素點(diǎn)對(duì)第 i 個(gè)像素點(diǎn)的影響。
這里 sofmax 是構(gòu)造函數(shù)中定義的,為按“行”歸一化。這個(gè)操作之后的矩陣,各行元素之和為 1。這也比較好理解,因?yàn)?energy 中第 i 行元素,代表 feature map 中所有位置的像素對(duì)第 i 個(gè)像素的影響,而這個(gè)影響被解釋為權(quán)重,故加起來(lái)應(yīng)該是 1,故應(yīng)對(duì)其按行歸一化。attention 的維度也是 N?× N。
上面的代碼中,先對(duì)原 feature map 作一次卷積映射,然后把得到的新 feature map 改變形狀,維度變?yōu)?C?×?N ,其中 C 為通道數(shù)(注意和上面計(jì)算 proj_query???proj_key 的 C 不同,上面的 C 為 feature map 通道數(shù)的 1/8,這里的 C 與 feature map 通道數(shù)相同),N 為 feature map 的像素?cái)?shù)。
▲?圖5.?proj_value的維度
out?=?out.view(m_batchsize,C,width,height)
然后,再把 proj_value (C?× N)矩陣同? attention 矩陣的轉(zhuǎn)置(N?× N)相乘,得到 out (C?× N)。之所以轉(zhuǎn)置,是因?yàn)?/span> attention 中每行的和為 1,其意義是權(quán)重,需要轉(zhuǎn)置后變?yōu)槊苛械暮蜑?1,施加于 proj_value 的行上,作為該行的加權(quán)平均。 proj_value 第 i 行代表第 i 個(gè)通道所有的像素值, attention 第 j 列,代表所有像素施加到第 j 個(gè)像素的影響。
因此, out 中第 i 行包含了輸出的第 i 個(gè)通道中的所有像素,第 j 列表示所有像素中的第 j 個(gè)像素,合起來(lái)也就是: out 中的第 i 行第 j 列的元素,表示被 attention 加權(quán)之后的 feature map 的第 i 個(gè)通道的第 j 個(gè)像素的像素值。再改變一下形狀, out 就恢復(fù)了 channel×width×height 的結(jié)構(gòu)。
▲?圖6.?out的維度
最后一行代碼,借鑒了殘差神經(jīng)網(wǎng)絡(luò)(residual neural networks)的操作, gamma 是一個(gè)參數(shù),表示整體施加了 attention 之后的 feature map 的權(quán)重,需要通過(guò)反向傳播更新。而 x 就是輸入的 feature map。
在初始階段, gamma 為 0,該 attention 模塊直接返回輸入的 feature map,之后隨著學(xué)習(xí),該 attention 模塊逐漸學(xué)習(xí)到了將 attention 加權(quán)過(guò)的 feature map 加在原始的 feature map 上,從而強(qiáng)調(diào)了需要施加注意力的部分 feature map。
總結(jié)
可以把 self attention 看成是 feature map 和它自身的轉(zhuǎn)置相乘,讓任意兩個(gè)位置的像素直接發(fā)生關(guān)系,這樣就可以學(xué)習(xí)到任意兩個(gè)像素之間的依賴(lài)關(guān)系,從而得到全局特征了。看論文時(shí)會(huì)被它復(fù)雜的符號(hào)迷惑,但是一看代碼就發(fā)現(xiàn)其實(shí)是很 naive 的操作。
參考文獻(xiàn)
[1] Xiaolong Wang, Ross Girshick, Abhinav Gupta, Kaiming He, Non-local Neural Networks, CVPR 2018.
[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida, Spectral Normalization for Generative Adversarial Networks, ICLR 2018.
點(diǎn)擊以下標(biāo)題查看更多往期內(nèi)容:?
Airbnb實(shí)時(shí)搜索排序中的Embedding技巧
圖神經(jīng)網(wǎng)絡(luò)綜述:模型與應(yīng)用
近期值得讀的10篇GAN進(jìn)展論文
自然語(yǔ)言處理中的語(yǔ)言模型預(yù)訓(xùn)練方法
從傅里葉分析角度解讀深度學(xué)習(xí)的泛化能力
深度思考 | 從BERT看大規(guī)模數(shù)據(jù)的無(wú)監(jiān)督利用
AI Challenger 2018 機(jī)器翻譯參賽總結(jié)
小米拍照黑科技:基于NAS的圖像超分辨率算法
異構(gòu)信息網(wǎng)絡(luò)表示學(xué)習(xí)論文解讀
不懂Photoshop如何P圖?交給深度學(xué)習(xí)吧
#投 稿 通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢??答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類(lèi)優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)習(xí)心得或技術(shù)干貨。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來(lái)。
??來(lái)稿標(biāo)準(zhǔn):
? 稿件確系個(gè)人原創(chuàng)作品,來(lái)稿需注明作者個(gè)人信息(姓名+學(xué)校/工作單位+學(xué)歷/職位+研究方向)?
? 如果文章并非首發(fā),請(qǐng)?jiān)谕陡鍟r(shí)提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認(rèn)每篇文章都是首發(fā),均會(huì)添加“原創(chuàng)”標(biāo)志
? 投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請(qǐng)單獨(dú)在附件中發(fā)送?
? 請(qǐng)留下即時(shí)聯(lián)系方式(微信或手機(jī)),以便我們?cè)诰庉嫲l(fā)布時(shí)和作者溝通
?
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁(yè)搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專(zhuān)欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個(gè)推薦、解讀、討論、報(bào)道人工智能前沿論文成果的學(xué)術(shù)平臺(tái)。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號(hào)后臺(tái)點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點(diǎn)擊 |?閱讀原文?| 獲取最新論文推薦
總結(jié)
以上是生活随笔為你收集整理的Self-Attention GAN 中的 self-attention 机制的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 巧断梯度:单个loss实现GAN模型(附
- 下一篇: 你不是一个人在战斗!有人将吴恩达的视频教