最高加速9倍!字節(jié)跳動開源8比特混合精度Transformer引擎(2)
量化技術(shù)
int8 量化的加速收益主要來自如下幾個方面:
GEMM 精度從 fp16 降低到 int8 后,計算時間縮短;
自定義算子采用 int8 輸入輸出后,數(shù)據(jù)讀寫時間縮短;
梯度采用 int8 存儲后,多機之間通信時間縮短。
以 Transformer 模型為例,經(jīng)過 LightSeq fp16 引擎加速后,自定義算子時間大大縮短,而 GEMM 時間占比提升到了 90% 左右,因此優(yōu)化的重點轉(zhuǎn)移到了 GEMM 提速。將 fp16 GEMM 替換為 int8 GEMM 不僅可以縮短 GEMM 時間,還可以減小前后算子的輸入輸出位寬,從而減小讀寫數(shù)據(jù)的時間。最后多機訓練的瓶頸主要在梯度的通信,將梯度量化為 int8 精度可以大大加快分布式訓練的速度。
量化原理
為了彌補量化帶來的精度損失,通常需要用量化感知訓練來模擬量化過程。如上圖所示,量化感知訓練就是將 float GEMM 的兩個 float 輸入分別做一遍量化和反量化(稱之為偽量化結(jié)點),離散化成分段的浮點數(shù)輸入,然后進行 float GEMM 運算。得到結(jié)果后再次進行量化與反量化,得到最終的浮點數(shù)結(jié)果。而量化的過程是不可導的,因此需要用 STE 方法來估計量化參數(shù)的梯度。之所以量化感知訓練中需要插入偽量化結(jié)點,然后用 float GEMM 去模擬量化過程,是因為 TensorFlow 和 PyTorch 等訓練框架不支持 int8 GEMM。
而 LightSeq 量化訓練直接采用 int8 GEMM 來真實還原量化過程,因此相比傳統(tǒng)的實現(xiàn)要更快,且更加節(jié)省顯存。在推理的時候,同樣采用離散化后的整數(shù)進行 int8 GEMM 運算,最后再反量化回浮點數(shù)結(jié)果。量化推理過程和量化訓練完全一致,并且和傳統(tǒng)的量化感知訓練是完全等價的。
量化位置
整個量化 Transformer 的網(wǎng)絡結(jié)構(gòu)如上圖所示,紅色箭頭表示需要加上量化和反量化結(jié)點的位置。
首先所有 int8 GEMM 的輸入和輸出都需要進行量化。由于 int8 GEMM 的 shape 限制,部分 GEMM(例如注意力分數(shù)的計算)仍然采用 float GEMM。此外第二層 FFN 的 GEMM 采用的是 int32 的輸出,因為它的 GEMM 輸入是 ReLU 激活函數(shù)的輸出結(jié)果,只包含正數(shù),非對稱,因此如果采用 int8 輸出的 GEMM,將無法反量化為正確的浮點數(shù)結(jié)果。
然后所有的模型權(quán)重 weight 都需要存儲為 int8 類型,因此需要對 weight 做量化。而權(quán)重 bias 參數(shù)量較小,無需量化,保留 float 精度反而可以提升模型效果。
最后需要對 decoder 端的 cache 進行量化。因為在推理時,decoder 端的 cache 需要頻繁進行讀寫,因此將 cache 量化為 int8 可以大大加快解碼的速度。
量化策略
將一個浮點數(shù)矩陣量化為 int8 整數(shù)矩陣有很多方法,LightSeq 采用的是對稱量化,即將正負數(shù)范圍對稱的浮點數(shù)區(qū)間等比例地映射到整數(shù)區(qū)間 [-127, 127] 上。
而實際上浮點數(shù)矩陣的數(shù)值范圍通常并不對稱,存在極少的離群值。如果直接按照離群值的范圍來量化矩陣,會影響到量化后的精度,所以需要先對矩陣進行數(shù)值截斷。
LightSeq 采用 PACT 方法進行截斷[6],將截斷的范圍當作模型可學習的參數(shù),然后利用 STE 算法去估計參數(shù)的梯度,并進行反向傳播優(yōu)化。根據(jù)實踐經(jīng)驗,權(quán)重 weight 的初始截斷范圍設為[-1, 1],中間結(jié)果的初始截斷范圍設為[-16, 16],可以在大部分任務上達到最好的效果。最后經(jīng)過截斷范圍和其他模型參數(shù)的聯(lián)合優(yōu)化,量化模型的效果可以達到基本無損。
梯度通信量化
針對分布式訓練場景,LightSeq 推出了梯度量化壓縮技術(shù)。即對浮點精度的梯度進行 int8 量化,以減少梯度通信的時間消耗,從而加速訓練,這就是梯度通信量化(GCQ)。
如上圖所示,梯度通信量化的主要流程如下:
計算每張卡上各自梯度的截斷范圍;
對截斷范圍執(zhí)行 all-reduce max 操作;
每張卡使用統(tǒng)一的截斷范圍對各自梯度進行 int8 量化;
對 int8 梯度執(zhí)行 all-reduce sum 操作;
每張卡對 all-reduce 后的梯度進行反量化,還原為浮點數(shù)梯度,并進行參數(shù)更新。
為了解決 int8 梯度在 all-reduce 過程中溢出的問題,LightSeq 首先將每張卡上的浮點數(shù)梯度除以卡數(shù),再使用除之前的截斷范圍進行量化,最后進行 all-reduce 操作。這樣每張卡上量化后的 int8 整數(shù) all-reduce 完就不會溢出,但是單卡實際用于量化的比特數(shù)也因此而減少,所以目前方案在 2 機 8 卡效果幾乎無損,但隨著卡數(shù)的上漲,訓練效果會有所下降。以 en2de 和 en2fr 翻譯任務為例,在 4 機 8 卡上進行分布式量化訓練,BLEU 值分別會下降 0.4 和 1.5 左右。未來 LightSeq 將會持續(xù)探索更好的方法來解決這一問題。
通用技術(shù)
除了上一章節(jié)中提到的量化技術(shù)以外,此次更新 LightSeq 還提出了幾種通用的優(yōu)化技術(shù),不僅可以應用在量化模型中,也適用于其它所有精度模型的訓練與推理。
算子融合
上圖是 encoder 模塊量化訓練的計算圖,LightSeq 將兩次 GEMM 運算之間的所有操作融合成一個算子[7],減少了 kernel 調(diào)用的次數(shù),因此減少了總的計算時間。
圖中黃色矩形表示 int8 GEMM,綠色矩形表示 float GEMM。這里采用 float GEMM 是由于 shape 的限制,不適合使用 int8 GEMM 加速。紅色箭頭表示流動數(shù)據(jù)的類型是 int8,綠色箭頭表示第二層 FFN 的 GEMM 輸出是 int32 數(shù)據(jù)類型。int8 GEMM 輸入輸出的量化與反量化操作都被融合到了前后 kernel 里,這不僅可以減少數(shù)據(jù)搬運,還可以減小顯存占用。
在推理時,LightSeq 還針對 decoder 做了優(yōu)化。如上圖所示,在計算 self-attention 時,注意力得分的維度是(batch size, 1, sequence length)。因此在計算 value 乘積時,可以不采用 GEMM 運算,而直接手寫加權(quán)求和的算子,從而將圖中虛線框中的計算融合成一個 kernel。
自動顯存管理
模型量化引入了更復雜的張量類型和張量依賴關(guān)系,這給顯存管理帶來新的挑戰(zhàn)。為此,LightSeq 設計了新的顯存管理機制。如上圖所示,主要包括以下過程:
訓練啟動前,根據(jù)每個算子的拓撲依賴關(guān)系,自動計算每個張量的生命周期及顯存空間大小。其中,包含動態(tài)維度的張量按照此維度的最大量進行計算,例如機器翻譯任務中的最大句長和最大 batch 句子數(shù)量。這些最大量在訓練前已被指定;
張量確定生命周期和大小后,分析顯存復用關(guān)系。其中,無生命周期重合的張量可以共用一片顯存空間,所有顯存空間都是無數(shù)據(jù)類型的,可以被分配到任意數(shù)據(jù)類型的張量上;
根據(jù)張量顯存復用關(guān)系,申請多段顯存空間,為每個張量分配實際的顯存起止地址。
張量顯存復用的分析,LightSeq 借鑒了論文 [3] 中提出的 Greedy by Size for Offset Calculation 方法,做了三個改進:
支持了整個訓練過程的顯存復用(forward/backward);
不同數(shù)據(jù)類型能做到顯存復用(int8/fp16/fp32);
在多段顯存空間上容納所有張量,而非一段非常大的顯存空間,這樣能有效提升顯存利用率。
自動 GEMM 調(diào)優(yōu)
LightSeq 的 int8 GEMM 采用了 NVIDIA 的 cuBLASLt 庫,這也是目前 NVIDIA 顯卡上最為高效的矩陣運算庫。但是輸入數(shù)據(jù)的 shape 或者顯卡不同的話,GEMM 所采用的最優(yōu)配置(例如數(shù)據(jù)排布、GEMM 算法等等)也可能不同,因此需要進行自動選取。LightSeq 采取的自動調(diào)優(yōu)方案如下:
在多種型號顯卡上(例如 T4 和 A100)進行不同 shape 的 GEMM 最優(yōu)配置搜索,并將結(jié)果保存到配置文件中,用戶只需要下載即可;
模型初始化時,加載對應型號顯卡的配置文件,解析并保存到鍵值對為 (shape, 最優(yōu)配置) 的字典中。如果沒有對應型號顯卡的配置文件,或者沒有需要的 GEMM shape,那么用戶可以選擇自己搜索并保存,或者直接使用默認配置;
模型前向或后向計算時,根據(jù)輸入的 shape 在字典中尋找最優(yōu)配置,然后進行 GEMM 計算。如果沒有找到對應的 shape,那么直接采用默認的配置。
未來工作
未來 LightSeq 還將繼續(xù)探索移動端的低精度量化、反向傳播中梯度的量化、大模型量化等方向。
引用
[1] Wang, Xiaohui, et al. "LightSeq2: Accelerated training for transformer-based models on gpus." arXiv preprint arXiv:2110.05722 (2021).
[2] Micikevicius, Paulius, et al. "Mixed precision training." arXiv preprint arXiv:1710.03740 (2017).
[3] Pisarchyk, Yury, and Juhyun Lee. "Efficient memory management for deep neural net inference." arXiv preprint arXiv:2001.03288 (2020).
[4] Jacob, Benoit, et al. "Quantization and training of neural networks for efficient integer-arithmetic-only inference." Proceedings of the IEEE conference on computer vision and pattern recognition. 2018.
[5] Alistarh, Dan, et al. "QSGD: Communication-efficient SGD via gradient quantization and encoding." Advances in neural information processing systems 30 (2017).
[6] Choi, Jungwook, et al. "Pact: Parameterized clipping activation for quantized neural networks." arXiv preprint arXiv:1805.06085 (2018).
[7] Wang, Xiaohui, et al. "LightSeq: A high performance inference library for transformers." arXiv preprint arXiv:2010.13887 (2020).
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。
聲控燈相關(guān)文章:聲控燈原理