Transformer取代者登場!微軟、清華剛推出RetNet:成本低、速度快、性能強(2)
Retentive 網(wǎng)絡(luò)
RetNet 由 L 個相同的塊堆疊而成,其布局與 Transformer 類似(即殘差連接和 pre-LayerNorm)。每個 RetNet 塊包含兩個模塊:多尺度retention(MSR)和前饋網(wǎng)絡(luò)(FFN)。
給定輸入序列,RetNet 以自回歸方式對序列進行編碼。輸入向量
首先被封裝為
,其中
是隱藏維度。然后,計算上下文向量表征
。
Retention
RetNet 具有循環(huán)和并行雙重形式的 retention 機制,因此能夠并行地訓(xùn)練模型,同時循環(huán)地進行推理。
給定輸入,將其投影為一維函數(shù) v (n) = X_n - w_V。考慮一個序列建模問題,通過狀態(tài) s_n 映射 v (n) → o (n)。
為簡單起見,讓 v_n, o_n 表示 v (n),o (n)。此處以循環(huán)的方式對映射進行表述:
其中,將 v_n 映射到狀態(tài)向量 s_n,然后實現(xiàn)線性變換,對序列信息進行循環(huán)編碼。
接下來,使投影 Q_n, K_n 具有內(nèi)容感知能力:
其中是可學(xué)習(xí)矩陣。
將矩陣對角化,其中
。然后得到
。通過將 Λ 吸收到 W_Q 和 W_K 中,可以將方程(1)重寫為
其中,稱為 xPos,即為 Transformer 提出的相對位置嵌入。進一步將 γ 簡化為標(biāo)量,公式(3)則變?yōu)?/span>
其中?為共軛轉(zhuǎn)置。該公式很容易在訓(xùn)練實例中并行化。
總之,從公式 (1) 所示的循環(huán)建模開始,然后推導(dǎo)出公式 (4) 中的并行公式。將原始映射 v (n) →o (n) 視為向量,得到如下的 retention 機制:
1)Retention 的并行表征
如圖 3a 所示,Retention 層定義為:
與自注意力類似,并行表征使得能夠使用 GPU 高效地訓(xùn)練模型。
2)Retention 的循環(huán)表征
如圖 3b 所示,所提出機制也可以寫成循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN),這有利于推理。對于第 n 個時間步,循環(huán)得到的輸出為
這里的 Q, K, V, γ 和公式 5 相同。
3)Retention 分塊循環(huán)表征
并行表征和循環(huán)表征的混合形式可以加速訓(xùn)練,特別是對于長序列。此處將輸入序列劃分為若干小塊。在每個塊內(nèi),按照并行表征(公式(5))進行計算。相反,跨塊信息則按照循環(huán)表征(公式(6))進行傳遞。具體來說,讓 B 表示塊長度。通過以下方式計算第 i 個分塊的 retention 輸出:
其中 [i] 表示第 i 個數(shù)據(jù)塊,例如。
門控多尺度 Retention
在每個層中,研究者使用 h = d_model/d 個 retention 頭,其中 d 是頭的維度。這些頭使用不同的參數(shù)矩陣 W_Q、W_K、W_V ∈ R^(d×d)。此外,多尺度 retention(MSR)為每個頭分配不同的 γ。為了簡化,研究者將 γ 設(shè)置為在不同層之間相同并保持固定。另外,他們添加了一個 swish 門 [RZL17] 來增加層的非線性性。形式上,給定輸入 X,研究者將該層定義為:
其中,為可學(xué)習(xí)參數(shù),GroupNorm [WH18] 對每個頭的輸出進行歸一化,遵循 [SPP^+19] 中提出的 SubLN。注意,這些頭使用多個 γ 尺度,這會帶來不同的方差統(tǒng)計結(jié)果。所以研究者分別對頭的輸出進行歸一化。
retention 的偽代碼如圖 4 所示。
Retention Score 歸一化
研究者利用 GroupNorm 的尺度不變性來提高 retention 層的數(shù)值精度。具體而言,在 GroupNorm 中乘以一個標(biāo)量值不會影響輸出和反向梯度,即 GroupNorm (α ? head_i) = GroupNorm (head_i)。研究者在公式(5)中實現(xiàn)了三個歸一化因子。首先,他們將 QK^? 歸一化為 QK^? / √ d。其次,他們將 D 替換為。第三,他們用 R 表示 retention scores R = QK^? ⊙ D,將其歸一化為
。然后,retention 輸出變?yōu)?nbsp;
。由于尺度不變的特性,上述技巧不會影響最終的結(jié)果,同時穩(wěn)定了正向和反向傳遞的數(shù)值流動。
Retention 網(wǎng)絡(luò)總體結(jié)構(gòu)
對于一個 L 層的 retention 網(wǎng)絡(luò),研究者堆疊多尺度 retention (MSR) 和前饋網(wǎng)絡(luò)(FFN)來構(gòu)建模型。形式上,輸入序列通過一個詞嵌入層被轉(zhuǎn)換為向量。研究者使用打包后的嵌入
作為輸入,并計算模型的輸出 X^L:
其中,LN (?) 為 LayerNorm [BKH16]。FFN 部分計算為 FFN (X) = gelu (XW_1) W_2,其中 W_1、W_2 為參數(shù)矩陣。
訓(xùn)練:研究者在訓(xùn)練過程中使用了并行(公式 5)表示和塊循環(huán)(公式 7)表示。序列或塊內(nèi)的并行有效地利用了 GPU 來加速計算。更有利的是,塊循環(huán)對于長序列訓(xùn)練特別有用,這在 FLOPs 和內(nèi)存消耗方面都是有效的。
推理:在推理過程中,研究者采用了循環(huán)表示(公式 6),這非常適合自回歸解碼。O (1) 的復(fù)雜度減少了內(nèi)存占用和推理延遲,同時實現(xiàn)了相當(dāng)?shù)慕Y(jié)果。
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。