动手做个DialoGPT:生成式多轮对话模型
文 | 蘇劍林
編 | 兔子醬
前段時(shí)間刷Arixv的時(shí)候,發(fā)現(xiàn)清華大學(xué)開源了一個(gè)大規(guī)模的中文閑聊語料庫LCCC,從開源的文件上來看,這可能是目前開源的數(shù)量最大、質(zhì)量最好的閑聊語料庫了,而且還包含了部分多輪對(duì)話聊天,總的來說可玩性還是蠻強(qiáng)的。筆者也被它吸引到了,嘗試著用它來訓(xùn)練了一個(gè)閑聊對(duì)話模型,結(jié)果看上去還是不錯(cuò)的,在此分享一下自己的經(jīng)驗(yàn)。
論文名稱:
《A Large-Scale Chinese Short-Text Conversation Dataset》
論文鏈接:
https://arxiv.org/abs/2008.03946
項(xiàng)目地址:
https://github.com/thu-coai/CDial-GPT
Arxiv訪問慢的小伙伴也可以在 【夕小瑤的賣萌屋】訂閱號(hào)后臺(tái)回復(fù)關(guān)鍵詞 【0917】 下載論文PDF~
語料簡(jiǎn)介
這里簡(jiǎn)單介紹一下LCCC這個(gè)數(shù)據(jù)集(Large-scale Cleaned Chinese Conversation),具體細(xì)節(jié)大家可以去Github上看,下載鏈接也在上面。LCCC分base和large兩個(gè)版本,base主要是來源于微博對(duì)話,large則是在base的基礎(chǔ)上融合了其他開源對(duì)話語料,按照作者的說法,LCCC經(jīng)過了嚴(yán)格的清洗過程,所以整體質(zhì)量看上去還是很不錯(cuò)的。
為了簡(jiǎn)化任務(wù),所有樣本都被處理成雙人對(duì)話。下面是一些樣本示例:
A: 等過年咱們回去買點(diǎn)兔頭好好吃頓火鍋
B: 太原就沒看見有好吃的兔頭
A: 我從虹橋給你帶個(gè)回去那天瞅到一正宗的
B: 最愛你了
A: 那是必須
A: 嗯嗯,我再等等!你現(xiàn)在在上海吧?上海風(fēng)好像比南京還大呢,少出門吧
B: 對(duì)啊,我在家,沒事兒。一定要小心啊!
A: 我去年也去轉(zhuǎn)了一圈,還碰見以前的體育老師了,合了個(gè)影
B: 哈哈我還去找高一時(shí)侯的英語老師沒找到她剛好有事情沒在學(xué)校~
A: 你也是真心找回憶了哦
B: 哈哈畢業(yè)了沒去過想去看看啊
模型設(shè)計(jì)
知道了數(shù)據(jù)長(zhǎng)什么樣之后,我們接下來就要去設(shè)計(jì)模型了。顯然,我們需要做的就是訓(xùn)練一個(gè)模型,預(yù)測(cè)下一個(gè)該回復(fù)什么。既然語料里包含了多輪對(duì)話,那么我們還要求這個(gè)模型支持多輪對(duì)話。考慮對(duì)話歷史的最簡(jiǎn)單的方式,就是把直到當(dāng)前句的所有歷史對(duì)話都拼接成單句文本,來作為模型的輸入信息。
給定一些輸入,預(yù)測(cè)一個(gè)輸出,從形式上來看我們應(yīng)該用Seq2Seq模型。直接用Seq2Seq其實(shí)問題也不大,但標(biāo)準(zhǔn)的Seq2Seq一般用于形式比較固定的輸入輸出,比如輸入的文本長(zhǎng)度應(yīng)該是集中在某個(gè)范圍內(nèi),不宜變化太大,但考慮多輪對(duì)話的話,理論上我們也不知道前面有多少輪對(duì)話,因此原則上輸入文本長(zhǎng)度是無限制的。用Seq2Seq的話,還有訓(xùn)練效率低的問題,就是我們每輪對(duì)話每次我們只能訓(xùn)練一句回復(fù),如果一個(gè)多輪對(duì)話有n句回復(fù),那么就要拆分為n個(gè)樣本來訓(xùn)練了。
因此,我們需要一個(gè)長(zhǎng)度能相當(dāng)自由地變化的、同時(shí)能預(yù)測(cè)整個(gè)多輪對(duì)話的模型,實(shí)現(xiàn)這個(gè)需求的比較適當(dāng)?shù)倪x擇就是單向語言模型(LM、GPT),做法如下圖:
如圖所示,我們選擇當(dāng)前主流的Transformer模型,按照BERT的常規(guī)輸入格式,將每句對(duì)話用[SEP]拼接起來,然后就訓(xùn)練一個(gè)從左往右的單向語言模型。為了區(qū)分不同的說話角色,我們對(duì)不同的說話者用不同的Segment Id區(qū)分。此外,考慮到BERT和GPT都是用了絕對(duì)位置編碼,可處理的文本長(zhǎng)度存在一個(gè)上限,而對(duì)話輪數(shù)理論上是無限的,所以這里我們采用了相對(duì)位置編碼的NEZHA作為基本結(jié)構(gòu),并使用NEZHA的預(yù)訓(xùn)練權(quán)重作為模型的初始化權(quán)重。
說白了,就是往NEZHA里邊加入了下三角形式的Attention Mask,使其變?yōu)橐粋€(gè)語言模型,相關(guān)介紹請(qǐng)參考《從語言模型到Seq2Seq:Transformer如戲,全靠Mask》[1]。
訓(xùn)練細(xì)節(jié)
首先,給出筆者的參考實(shí)現(xiàn)以及訓(xùn)練好的模型:
https://github.com/bojone/nezha_gpt_dialog代碼基于bert4keras實(shí)現(xiàn),基本結(jié)構(gòu)是NEZHA Base。為了使得訓(xùn)練更加“平民化”,筆者只用了一張卡(22G的RTX)配合梯度累積來訓(xùn)練,顯存不夠大的讀者可以縮小batch_size。開源的模型是用的學(xué)習(xí)率、0.01的權(quán)重衰減率的Adam優(yōu)化器進(jìn)行訓(xùn)練的,訓(xùn)練了1000萬步(20多天),單步batch_size為16,累積梯度步數(shù)為16,所以相當(dāng)于以256的batch_size訓(xùn)練了62.5萬步,最終的loss大概是2.83(平均每個(gè)token的交叉熵)。
我們來測(cè)試一下我們的訓(xùn)練結(jié)果。
>>> print(chatbot.response([u'我最近天天吃火鍋']))
火鍋好吃啊
>>> print(chatbot.response([u'我最近天天吃火鍋']))
我也是
>>> print(chatbot.response([u'我最近天天吃火鍋']))
哈哈哈哈哈哈我最近一周吃了兩頓火鍋
>>> print(chatbot.response([u'我最近天天吃火鍋']))
我也是,每次都吃的撐
>>> print(chatbot.response([u'我最近天天吃火鍋']))
吃火鍋不長(zhǎng)胖
>>> print(chatbot.response([u'我最近天天吃火鍋']))
火鍋好啊
>>> print(chatbot.response([u'我最近天天吃火鍋']))
你們那邊有這種火鍋?
>>> print(chatbot.response([u'我最近天天吃火鍋']))
火鍋好吃嗎?我也想吃
>>> print(chatbot.response([u'我最近天天吃火鍋']))
吃吧,吃了就開心了
>>> print(chatbot.response([u'我最近天天吃火鍋']))
我也想天天吃火鍋!
>>> print(chatbot.response([u'我最近天天吃火鍋']))
火鍋吃不了多久的。
>>> print(chatbot.response([u'你那天氣好嗎']))
還好,不熱,有點(diǎn)冷
>>> print(chatbot.response([u'你那天氣好嗎', u'還好,不熱,有點(diǎn)冷', u'那你得穿多厚的衣服']))
我穿了兩件
對(duì)比分析
CDial-GPT也開源了自己訓(xùn)練的預(yù)訓(xùn)練模型,筆者也將它轉(zhuǎn)換為bert4keras能加載的格式了,CDial-GPT-tf[2],讀者也可以測(cè)試比對(duì)一下。從訓(xùn)練上來看,CDial-GPT使用pytorch實(shí)現(xiàn)的模型,基本結(jié)構(gòu)是GPT Base,使用了4張2080Ti,總batch_size為32,累積梯度64步,論文說訓(xùn)練了30個(gè)epoch,總步數(shù)約2100萬步(筆者的兩倍),因此大概相當(dāng)于batch_size為2048訓(xùn)練了33萬步。
在輸入設(shè)計(jì)上,CDial-GPT也有所不同,如下圖:
如圖所示,CDial-GPT跟我們前述設(shè)計(jì)的主要不同是多輪對(duì)話之間的拼接方式,我們之前是直接用[SEP]連接,它是用[speaker1]、[speaker2](圖中簡(jiǎn)記為S1、S2)這樣的角色標(biāo)記來連接,最后才用一個(gè)[SEP]表示回復(fù)結(jié)束。這樣一來,由于預(yù)測(cè)部分的格式跟歷史的格式不一樣,因此每次只能訓(xùn)練一句回復(fù),多輪對(duì)話要拆分為多個(gè)樣本來訓(xùn)練,理論上是增加了訓(xùn)練復(fù)雜性的(要訓(xùn)練多步才能把一個(gè)多輪對(duì)話樣本訓(xùn)練完)。
至于效果上,個(gè)人測(cè)試的感覺是兩者沒什么明顯差別。有興趣的讀者也可以自行比較測(cè)試。
文章總結(jié)
本文主要分享了一次對(duì)話模型實(shí)踐,基于開源的LCCC閑聊語料庫,利用語言模型(GPT)對(duì)多輪對(duì)話進(jìn)行生成式建模,得到了一個(gè)相對(duì)通用的閑聊對(duì)話模型,最后將本文的思路與CDial-GPT本身開源的模型進(jìn)行了比較。
文末福利
后臺(tái)回復(fù)關(guān)鍵詞【入群】
加入賣萌屋NLP/IR/Rec與求職討論群
有頂會(huì)審稿人、大廠研究員、知乎大V和妹紙
等你來撩哦~
參考文獻(xiàn)
[1] 《從語言模型到Seq2Seq:Transformer如戲,全靠Mask》:
https://kexue.fm/archives/6933
[2] CDial-GPT-tf:
https://github.com/bojone/CDial-GPT-tf
總結(jié)
以上是生活随笔為你收集整理的动手做个DialoGPT:生成式多轮对话模型的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ICML2020 | 一行代码就能实现的
- 下一篇: 天天说常识推理,究竟常识是什么?