博客專(zhuān)欄

EEPW首頁(yè) > 博客 > 節(jié)省顯存新思路,在PyTorch里使用2 bit激活壓縮訓(xùn)練神經(jīng)網(wǎng)絡(luò)

節(jié)省顯存新思路,在PyTorch里使用2 bit激活壓縮訓(xùn)練神經(jīng)網(wǎng)絡(luò)

發(fā)布人:CV研究院 時(shí)間:2021-08-03 來(lái)源:工程師 發(fā)布文章

隨著超大規(guī)模深度學(xué)習(xí)模型逐漸成為 AI 的趨勢(shì),如何在有限的 GPU 內(nèi)存下訓(xùn)練這些模型成為了一個(gè)難題。

本篇文章轉(zhuǎn)自于“機(jī)器之心”

本文將介紹來(lái)自加州伯克利大學(xué)的 ActNN,一個(gè)基于 PyTorch 的激活壓縮訓(xùn)練框架。在同樣的內(nèi)存限制下,ActNN 通過(guò)使用 2 bit 激活壓縮,可以將 batch size 擴(kuò)大 6-14 倍,將模型尺寸或者輸入圖片擴(kuò)大 6-10 倍。ActNN 相關(guān)論文已被 ICML 2021 接收為 Long Talk,代碼開(kāi)源于 github。

論文 https://arxiv.org/abs/2104.14129

代碼 https://github.com/ucbrise/actnn

AI 訓(xùn)練撞上「內(nèi)存墻」

從 AlexNet,ResNet 到 GPT-3,深度學(xué)習(xí)性能的突破都離不開(kāi)模型規(guī)模的瘋狂增長(zhǎng)。大模型有更好的性能已經(jīng)成為業(yè)界的共識(shí)。過(guò)去幾年,不僅訓(xùn)練一個(gè)最先進(jìn)模型需要的算力在指數(shù)增長(zhǎng),訓(xùn)練一個(gè)最先進(jìn)模型需要的內(nèi)存也在指數(shù)增長(zhǎng)。如下圖所示,大型 Transformer 模型的參數(shù)量以每?jī)赡攴?240 倍的速度指數(shù)增長(zhǎng)。但是,單個(gè) GPU 的內(nèi)存卻只以每?jī)赡攴?2 倍的速度在緩慢增長(zhǎng)。另外,在訓(xùn)練模型時(shí),不光要存儲(chǔ)模型參數(shù),還要存儲(chǔ)中間結(jié)果激活值和優(yōu)化器狀態(tài),所需要的內(nèi)存更多。如何在有限的 GPU 內(nèi)存下訓(xùn)練這些大規(guī)模模型成為了挑戰(zhàn)。 

1.jpg

source:Gholami A, Yao Z, Kim S, Mahoney MW, Keutzer K. AI and Memory Wall. RiseLab Medium Blog Post, University of California Berkeley

節(jié)省訓(xùn)練內(nèi)存的方法

目前,節(jié)省訓(xùn)練內(nèi)存的方法主要有三類(lèi):1. 重計(jì)算(Gradient checkpointing/Rematerialization)  2. 使用 CPU 內(nèi)存進(jìn)行交換 (swapping)  和 3. 使用分布式訓(xùn)練將 Tensor 分散存儲(chǔ)在多個(gè) GPU 上。這三類(lèi)方法互相不沖突,可以結(jié)合使用。大部分機(jī)器學(xué)習(xí)框架對(duì)這些方法都提供了一些支持,也有不少相關(guān)的論文。但是,想要高效、自動(dòng)化地實(shí)現(xiàn)這些策略并不容易。與已有方法不同,我們提出了 ActNN,一個(gè)新的基于壓縮的內(nèi)存節(jié)省框架。在提供理論證明的同時(shí),我們基于 PyTorch 提供了一個(gè)高效易用的實(shí)現(xiàn)。Table.1 比較了 ActNN 和已有的一些內(nèi)存節(jié)省系統(tǒng)。ActNN 支持 PyTorch 的動(dòng)態(tài)圖執(zhí)行模式,并且不需要預(yù)先進(jìn)行復(fù)雜的策略搜索。ActNN 作為一個(gè)獨(dú)立的 Python 庫(kù),使用時(shí) import 即可,不需要修改或重新編譯 PyTorch。與已有的工作相比,ActNN 靈活且易于使用。同時(shí),ActNN 在理論上也可以和已有的技術(shù)相互疊加。

2.jpg

ActNN:2 bit 激活壓縮訓(xùn)練

在訓(xùn)練一個(gè)多層神經(jīng)網(wǎng)絡(luò)時(shí),在前向傳播中,每一層的中間結(jié)果都要被存下來(lái)用于計(jì)算反向傳播的梯度。這些中間結(jié)果,又被叫做「激活值」(activation),實(shí)際上占據(jù)了大部分的內(nèi)存消耗,尤其是在 batch size 較大或者輸入圖片較大的時(shí)候。ActNN 的原理是就是壓縮這些激活值來(lái)節(jié)省內(nèi)存。如下圖所示,左圖表示的是普通的前向傳播和反向傳播,前向傳播時(shí)會(huì)存下所有層的 fp32 激活值用于反向傳播,內(nèi)存使用在計(jì)算 loss 的時(shí)候達(dá)到峰值。右圖表示的是 ActNN 的訓(xùn)練方法:在前向傳播時(shí),通過(guò)一個(gè)壓縮操作 Q 將激活值壓縮后再存儲(chǔ);反向傳播時(shí),通過(guò)解壓縮操作 Q^-1 將激活值解壓再計(jì)算梯度。

3.jpg

如果只是為了節(jié)省內(nèi)存,這里可以使用各種壓縮算法,但是大部分現(xiàn)有的壓縮算法并不能高效地運(yùn)行在 GPU 上,會(huì)引入較大的開(kāi)銷(xiāo)。ActNN 選擇了使用 2-bit 量化作為這里的壓縮算法。量化操作的代價(jià)較小,而且有一些好的數(shù)學(xué)性質(zhì)允許我們使用有損壓縮達(dá)到較大的壓縮比。

把 fp32 浮點(diǎn)數(shù)量化為 2-bit 整數(shù)是一個(gè)有損壓縮,會(huì)引入一些誤差。論文從理論上分析了量化引入的誤差是如何影響訓(xùn)練的收斂性的。

第一,存在一個(gè)隨機(jī)化的量化策略,使得使用有損量化壓縮后,估計(jì)出的有損梯度是原梯度的一個(gè)無(wú)偏估計(jì)。

4.jpg

在這一條件下,我們套用已有的隨機(jī)梯度下降收斂性定理,得出最后收斂時(shí)的誤差會(huì)被梯度的方差所限制。

第二,我們推導(dǎo)出了使用量化壓縮之后,隨機(jī)梯度下降計(jì)算出的梯度的方差。

5.jpg

等號(hào)右邊的第一項(xiàng)是隨機(jī)梯度下降在 minibatch 采樣時(shí)產(chǎn)生的方差,等號(hào)右邊的第二項(xiàng)是有損壓縮額外引入的方差。這條公式顯示地刻畫(huà)了有損壓縮帶來(lái)的影響。注意到,當(dāng)有損量化壓縮帶來(lái)的方差遠(yuǎn)小于原來(lái)隨機(jī)梯度下降自帶的方差時(shí),ActNN 引入的有損壓縮就不會(huì)影響訓(xùn)練的收斂性。更多關(guān)于公式的推導(dǎo)和可視化參見(jiàn)文末的論文鏈接。論文對(duì)不同的算子(conv2d,batch norm,linear等)都提供了詳細(xì)的分析。

由上述公式啟發(fā),我們提出了一些新的量化技巧用于降低有損壓縮引入的額外方差。我們引入了新的量化技巧 ( Per-group Quantization,F(xiàn)ine-Grained Mixed-Precision,Runtime Adaptation) 來(lái)利用梯度在不同樣本,不同緯度,不同層之間的異構(gòu)特性。最后的壓縮算法會(huì)分配更多的 bit 給更重要的激活值。平均每個(gè)浮點(diǎn)數(shù)分配到 2 bit。

在具體實(shí)現(xiàn)壓縮算法時(shí),還有很多可以調(diào)節(jié)的參數(shù)。這里產(chǎn)生了一個(gè)內(nèi)存節(jié)省和訓(xùn)練速度的取舍。一般來(lái)說(shuō),使用更復(fù)雜的壓縮算法可以節(jié)省更多的內(nèi)存,但是也會(huì)引入更多額外的開(kāi)銷(xiāo),使訓(xùn)練速度變慢。為了給用戶(hù)較大的靈活性,ActNN 提供了 5 個(gè)優(yōu)化等級(jí) L1-L5 供用戶(hù)選擇。低的優(yōu)化等級(jí)節(jié)省的內(nèi)存較少,但是運(yùn)行速度快。高的優(yōu)化等級(jí)節(jié)省的內(nèi)存多,但是運(yùn)行也更慢。在最高優(yōu)化等級(jí) L5 下,ActNN 會(huì)結(jié)合一個(gè)簡(jiǎn)單的內(nèi)存交換策略,將壓縮后的激活值移到 CPU 內(nèi)存上,進(jìn)一步節(jié)省內(nèi)存。

實(shí)現(xiàn)

要在 PyTorch 實(shí)現(xiàn) ActNN 算法非常簡(jiǎn)單。對(duì)于一個(gè) PyTorch nn Module,我們只需要在其 forward 函數(shù)里加入量化壓縮,在其 backward 函數(shù)里加入解壓縮操作。所有的計(jì)算還是在 fp32 下進(jìn)行,與原來(lái)一樣,偽代碼如下圖所示。

ActNN 為大部分常用的 PyTorch nn.Module 實(shí)現(xiàn)了使用量化壓縮的版本。用戶(hù)只需將模型里的所有 PyTorch nn.Module 替換成 ActNN 對(duì)應(yīng)的 Module (如把 nn.Conv2d 替換成 actnn.Conv2d),即可節(jié)省內(nèi)存,不需要更改其他代碼。ActNN 同時(shí)也提供了一個(gè) wrapper 實(shí)現(xiàn)一行代碼自動(dòng)替換。

6.jpg

實(shí)驗(yàn)結(jié)果

因?yàn)?ActNN 進(jìn)行的是有損壓縮,所以最重要的一點(diǎn)是先驗(yàn)證 ActNN 是否會(huì)影響模型的精度。下圖是使用 ActNN 在 ImageNet 上訓(xùn)練 ResNet-50 的結(jié)果。FP 代表普通的 fp32 訓(xùn)練, BLPA 是來(lái)自 NeurIPS 2019 的一個(gè)相關(guān)工作??梢钥吹剑?ActNN 的 2-bit 壓縮模式下,模型幾乎沒(méi)有損失精度。在更極限的 1.25 bit 的情況下,ActNN 也能收斂,只不過(guò)會(huì)損失一些精度。而之前的工作 BLPA 在小于 4 bit 的情況就下無(wú)法收斂。

7.jpg

我們還在圖像分割,物體檢測(cè),以及自監(jiān)督學(xué)習(xí)等多個(gè)任務(wù)上進(jìn)行了實(shí)驗(yàn)。ActNN 都能在 2-bit 壓縮模式下達(dá)到和普通 fp32 幾乎一樣的結(jié)果。在部分任務(wù)上,因?yàn)?ActNN 可以使用更大的 batch size,甚至可以取得更好的測(cè)試結(jié)果。詳細(xì)的實(shí)驗(yàn)結(jié)果和訓(xùn)練記錄參見(jiàn)文末的論文與 github 鏈接。

之后,我們對(duì)比了 ActNN 與普通 fp32 訓(xùn)練的實(shí)際內(nèi)存使用情況。如下表所示,ActNN 可以將激活值占用的內(nèi)存壓縮 12 倍,將訓(xùn)練使用的總內(nèi)存壓縮 4 - 7 倍。這一實(shí)際內(nèi)存壓縮效果符合理論推導(dǎo)。為什么激活值壓縮倍率是 12 而不是 32 bit / 2 bit = 16?主要是因?yàn)?ActNN 不能使用 inplace 的 ReLU,以及需要存儲(chǔ)少量額外的 min 和 scale 用于解壓縮。

8.jpg

最后,我們測(cè)試了 ActNN 的訓(xùn)練速度。因?yàn)?ActNN 在訓(xùn)練過(guò)程中進(jìn)行了壓縮,這些壓縮在節(jié)省內(nèi)存的同時(shí)也會(huì)引入額外的計(jì)算開(kāi)銷(xiāo)。一般來(lái)說(shuō),省得內(nèi)存越多,進(jìn)入的額外開(kāi)銷(xiāo)就越多,訓(xùn)練也就越慢。我們?cè)?NVIDIA T4 (16 GB 內(nèi)存) 上對(duì)比了 ActNN 和已有內(nèi)存節(jié)省系統(tǒng)的訓(xùn)練速度。如下圖所示,DTR (ICLR 2020),BLPA (NeurIPS 2019)和 swap 分別是基于重計(jì)算,壓縮和內(nèi)存交換的三種方法,紅叉代表 Out-of-memory。y 軸是訓(xùn)練吞吐量 (images per second),越高越好。綠色的曲線是綜合 ActNN 在不同優(yōu)化等級(jí)下的最優(yōu)結(jié)果。可以看到,ActNN 不僅能開(kāi)到最大的 batch size(即最省內(nèi)存),同時(shí)在所有 batch size 下都比 baseline 的訓(xùn)練速度更快。

9.jpg

我們還對(duì)更多的網(wǎng)絡(luò)進(jìn)行了測(cè)試。在同樣的內(nèi)存限制下,ActNN 可以將 batch size 擴(kuò)大 6-14 倍,將模型尺寸或者輸入圖片擴(kuò)大 6-10 倍。詳細(xì)的實(shí)驗(yàn)設(shè)置和結(jié)果參見(jiàn)文末的論文鏈接。

兩行代碼即可在 PyTorch 中使用

import actnn
model = actnn.QModel(model)

ActNN 提供了一個(gè)自動(dòng)模型轉(zhuǎn)換封裝。只需在訓(xùn)練腳本里插入兩行代碼,即可將普通的 PyTorch  模型轉(zhuǎn)換為使用 ActNN 的模型。同時(shí),ActNN 也提供了更高級(jí)的 API 支持定制化的使用場(chǎng)景。

更多的例子參見(jiàn) github 鏈接。我們提供了在圖像識(shí)別、圖像分割、物體檢測(cè),以及自監(jiān)督學(xué)習(xí)等多個(gè)任務(wù)上使用 actnn 的完整例子和訓(xùn)練記錄,歡迎試用!

*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。

dc相關(guān)文章:dc是什么


脈沖點(diǎn)火器相關(guān)文章:脈沖點(diǎn)火器原理


關(guān)鍵詞: 深度學(xué)習(xí)

相關(guān)推薦

技術(shù)專(zhuān)區(qū)

關(guān)閉