在线观看www成人影院-在线观看www日本免费网站-在线观看www视频-在线观看操-欧美18在线-欧美1级

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫(xiě)文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

如何計(jì)算transformer模型的參數(shù)量

jf_pmFSk4VX ? 來(lái)源:GiantPandaCV ? 2023-07-10 09:13 ? 次閱讀

1. 前言

最近,OpenAI推出的ChatGPT展現(xiàn)出了卓越的性能,引發(fā)了大規(guī)模語(yǔ)言模型(Large Language Model,LLM)的研究熱潮。大規(guī)模語(yǔ)言模型的“大”體現(xiàn)在兩個(gè)方面:模型參數(shù)規(guī)模大,訓(xùn)練數(shù)據(jù)規(guī)模大。以GPT3為例,GPT3的參數(shù)量為1750億,訓(xùn)練數(shù)據(jù)量達(dá)到了570GB。進(jìn)而,訓(xùn)練大規(guī)模語(yǔ)言模型面臨兩個(gè)主要挑戰(zhàn):顯存效率和計(jì)算效率。

現(xiàn)在業(yè)界的大語(yǔ)言模型都是基于transformer模型的,模型結(jié)構(gòu)主要有兩大類(lèi):encoder-decoder(代表模型是T5)和decoder-only,具體的,decoder-only結(jié)構(gòu)又可以分為Causal LM(代表模型是GPT系列)和PrefixLM(代表模型是GLM)。歸因于GPT系列取得的巨大成功,大多數(shù)的主流大語(yǔ)言模型都采用Causal LM結(jié)構(gòu)。因此,針對(duì)decoder-only框架,為了更好地理解訓(xùn)練訓(xùn)練大語(yǔ)言模型的顯存效率和計(jì)算效率,本文分析采用decoder-only框架transformer模型的模型參數(shù)量、計(jì)算量、中間激活值、KV cache。

853e25d8-1d86-11ee-962d-dac502259ad0.jpg

為了方便分析,先定義好一些數(shù)學(xué)符號(hào)。記transformer模型的層數(shù)為8568a3ee-1d86-11ee-962d-dac502259ad0.png?,隱藏層維度為85808590-1d86-11ee-962d-dac502259ad0.png?,注意力頭數(shù)為8597e2c6-1d86-11ee-962d-dac502259ad0.png?。詞表大小為85af3034-1d86-11ee-962d-dac502259ad0.png?,訓(xùn)練數(shù)據(jù)的批次大小為85c04e64-1d86-11ee-962d-dac502259ad0.png?,序列長(zhǎng)度為85cf59e0-1d86-11ee-962d-dac502259ad0.png?。

2. 模型參數(shù)量

transformer模型由8568a3ee-1d86-11ee-962d-dac502259ad0.png個(gè)相同的層組成,每個(gè)層分為兩部分:self-attention塊和MLP塊。

self-attention塊的模型參數(shù)有85f2dcbc-1d86-11ee-962d-dac502259ad0.png?的權(quán)重矩陣860b6cb4-1d86-11ee-962d-dac502259ad0.png和偏置,輸出權(quán)重矩陣?861f975c-1d86-11ee-962d-dac502259ad0.png?和偏置,4個(gè)權(quán)重矩陣的形狀為86324bae-1d86-11ee-962d-dac502259ad0.png?,4個(gè)偏置的形狀為8645eff6-1d86-11ee-962d-dac502259ad0.png?。self- attention塊的參數(shù)量為8657ee68-1d86-11ee-962d-dac502259ad0.png?。

MLP塊由2個(gè)線性層組成,一般地,第一個(gè)線性層是先將維度從85808590-1d86-11ee-962d-dac502259ad0.png?映射到867afe44-1d86-11ee-962d-dac502259ad0.png,第二個(gè)線性層再將維度從867afe44-1d86-11ee-962d-dac502259ad0.png映射到85808590-1d86-11ee-962d-dac502259ad0.png。第一個(gè)線性層的權(quán)重矩陣86ac99f4-1d86-11ee-962d-dac502259ad0.png?的形狀為86c0c348-1d86-11ee-962d-dac502259ad0.png?,偏置的形狀為86d6bf18-1d86-11ee-962d-dac502259ad0.png?。第二個(gè)線性層權(quán)重矩陣86e7af30-1d86-11ee-962d-dac502259ad0.png?的形狀為86fae0c8-1d86-11ee-962d-dac502259ad0.png?,偏置形狀為8645eff6-1d86-11ee-962d-dac502259ad0.png?。MLP塊的參數(shù)量為872bf654-1d86-11ee-962d-dac502259ad0.png?。

self-attention塊和MLP塊各有一個(gè)layer normalization,包含了2個(gè)可訓(xùn)練模型參數(shù):縮放參數(shù)873cc984-1d86-11ee-962d-dac502259ad0.png?和平移參數(shù)8753cc42-1d86-11ee-962d-dac502259ad0.png?,形狀都是8645eff6-1d86-11ee-962d-dac502259ad0.png?。2個(gè)layernormalization的參數(shù)量為?867afe44-1d86-11ee-962d-dac502259ad0.png?。

87817ce6-1d86-11ee-962d-dac502259ad0.jpg

總的,每個(gè)transformer層的參數(shù)量879c4148-1d86-11ee-962d-dac502259ad0.png?。

除此之外,詞嵌入矩陣的參數(shù)量也較多,詞向量維度通常等于隱藏層維度85808590-1d86-11ee-962d-dac502259ad0.png,詞嵌入矩陣的參數(shù)量為?87c03350-1d86-11ee-962d-dac502259ad0.png。最后的輸出層的權(quán)重矩陣通常與詞嵌入矩陣是參數(shù)共享的。

關(guān)于位置編碼,如果采用可訓(xùn)練式的位置編碼,會(huì)有一些可訓(xùn)練模型參數(shù),數(shù)量比較少。如果采用相對(duì)位置編碼,例如RoPE和ALiBi,則不包含可訓(xùn)練的模型參數(shù)。我們忽略這部分參數(shù)。

綜上,8568a3ee-1d86-11ee-962d-dac502259ad0.png層transformer模型的可訓(xùn)練模型參數(shù)量為87d9da76-1d86-11ee-962d-dac502259ad0.png。當(dāng)隱藏維度?85808590-1d86-11ee-962d-dac502259ad0.png?較大時(shí),可以忽略一次項(xiàng),?模型參數(shù)量近似為8803fd6a-1d86-11ee-962d-dac502259ad0.png?。

接下來(lái),我們估計(jì)不同版本LLaMA模型的參數(shù)量。

實(shí)際參數(shù)量 隱藏維度h 層數(shù)l 12lh^2
6.7B 4096 32 6,442,450,944
13.0B 5120 40 12,582,912,000
32.5B 6656 60 31,897,681,920
65.2B 8192 80 64,424,509,440

2.1 訓(xùn)練過(guò)程中的顯存占用分析

在訓(xùn)練神經(jīng)網(wǎng)絡(luò)的過(guò)程中,占用顯存的大頭主要分為四部分:模型參數(shù)、前向計(jì)算過(guò)程中產(chǎn)生的中間激活、后向傳遞計(jì)算得到的梯度、優(yōu)化器狀態(tài)。這里著重分析參數(shù)、梯度和優(yōu)化器狀態(tài)的顯存占用,中間激活的顯存占用后面會(huì)詳細(xì)介紹。訓(xùn)練大模型時(shí)通常會(huì)采用AdamW優(yōu)化器,并用混合精度訓(xùn)練來(lái)加速訓(xùn)練,基于這個(gè)前提分析顯存占用。

在一次訓(xùn)練迭代中,每個(gè)可訓(xùn)練模型參數(shù)都會(huì)對(duì)應(yīng)1個(gè)梯度,并對(duì)應(yīng)2個(gè)優(yōu)化器狀態(tài)(Adam優(yōu)化器梯度的一階動(dòng)量和二階動(dòng)量)。設(shè)模型參數(shù)量為881c3236-1d86-11ee-962d-dac502259ad0.png?,那么梯度的元素?cái)?shù)量為881c3236-1d86-11ee-962d-dac502259ad0.png?,AdamW優(yōu)化器的元素?cái)?shù)量為8841de5a-1d86-11ee-962d-dac502259ad0.png。float16數(shù)據(jù)類(lèi)型的元素占2個(gè)bytes,float32數(shù)據(jù)類(lèi)型的元素占4個(gè)bytes。在混合精度訓(xùn)練中,會(huì)使用float16的模型參數(shù)進(jìn)行前向傳遞和后向傳遞,計(jì)算得到float16的梯度;在優(yōu)化器更新模型參數(shù)時(shí),會(huì)使用float32的優(yōu)化器狀態(tài)、float32的梯度、float32的模型參數(shù)來(lái)更新模型參數(shù)。因此,對(duì)于每個(gè)可訓(xùn)練模型參數(shù),占用了88581f9e-1d86-11ee-962d-dac502259ad0.png。使用AdamW優(yōu)化器和混合精度訓(xùn)練來(lái)訓(xùn)練參數(shù)量為?881c3236-1d86-11ee-962d-dac502259ad0.png的大模型,?模型參數(shù)、梯度和優(yōu)化器狀態(tài)占用的顯存大小為887f5154-1d86-11ee-962d-dac502259ad0.png?。

8892c3c4-1d86-11ee-962d-dac502259ad0.jpg

2.2 推理過(guò)程中的顯存占用分析

在神經(jīng)網(wǎng)絡(luò)的推理階段,沒(méi)有優(yōu)化器狀態(tài)和梯度,也不需要保存中間激活。少了梯度、優(yōu)化器狀態(tài)、中間激活,模型推理階段占用的顯存要遠(yuǎn)小于訓(xùn)練階段。模型推理階段,占用顯存的大頭主要是模型參數(shù),如果使用float16來(lái)進(jìn)行推理,推理階段模型參數(shù)占用的顯存大概是88b124fe-1d86-11ee-962d-dac502259ad0.png?。如果使用KVcache來(lái)加速推理過(guò)程,?KV cache也需要占用顯存,KVcache占用的顯存下文會(huì)詳細(xì)介紹。此外,輸入數(shù)據(jù)也需要放到GPU上,還有一些中間結(jié)果(推理過(guò)程中的中間結(jié)果用完會(huì)盡快釋放掉),不過(guò)這部分占用的顯存是很小的,可以忽略。

3. 計(jì)算量FLOPs估計(jì)

FLOPs,floating point operations,表示浮點(diǎn)數(shù)運(yùn)算次數(shù),衡量了計(jì)算量的大小。

如何計(jì)算矩陣乘法的FLOPs呢?

對(duì)于88c2f724-1d86-11ee-962d-dac502259ad0.png?,計(jì)算?88d9729c-1d86-11ee-962d-dac502259ad0.png?需要進(jìn)行?88f02cc6-1d86-11ee-962d-dac502259ad0.png?次乘法運(yùn)算和?88f02cc6-1d86-11ee-962d-dac502259ad0.png?次加法運(yùn)算,共計(jì)?8913769a-1d86-11ee-962d-dac502259ad0.png?次浮點(diǎn)數(shù)運(yùn)算,需要?8913769a-1d86-11ee-962d-dac502259ad0.png?的FLOPs。對(duì)于?893c3c60-1d86-11ee-962d-dac502259ad0.png?,計(jì)算?88d9729c-1d86-11ee-962d-dac502259ad0.png?需要的浮點(diǎn)數(shù)運(yùn)算次數(shù)為?8962e2fc-1d86-11ee-962d-dac502259ad0.png?。

在一次訓(xùn)練迭代中,假設(shè)輸入數(shù)據(jù)的形狀為897c4b84-1d86-11ee-962d-dac502259ad0.png?。我們?先分析self-attention塊的計(jì)算,計(jì)算公式如下:

89962dd8-1d86-11ee-962d-dac502259ad0.png89a87cb8-1d86-11ee-962d-dac502259ad0.png

1. 計(jì)算89bccee8-1d86-11ee-962d-dac502259ad0.png?:矩陣乘法的輸入和輸出形狀為89d2b0f0-1d86-11ee-962d-dac502259ad0.png。計(jì)算量為89e69084-1d86-11ee-962d-dac502259ad0.png

2.89fdc10a-1d86-11ee-962d-dac502259ad0.png?矩陣乘法的輸入和輸出形狀為

8a0cb43a-1d86-11ee-962d-dac502259ad0.png。計(jì)算量為?8a280b4a-1d86-11ee-962d-dac502259ad0.png?。

3. 計(jì)算在85af3034-1d86-11ee-962d-dac502259ad0.png?上的加權(quán)?8a4c4500-1d86-11ee-962d-dac502259ad0.png?,矩陣乘法的輸入和輸出形狀為8a619a9a-1d86-11ee-962d-dac502259ad0.png。計(jì)算量為?8a280b4a-1d86-11ee-962d-dac502259ad0.png?。

4. attention后的線性映射,矩陣乘法的輸入和輸出形狀為89d2b0f0-1d86-11ee-962d-dac502259ad0.png。計(jì)算量為?8a93189a-1d86-11ee-962d-dac502259ad0.png?。

接下來(lái)分析MLP塊的計(jì)算,計(jì)算公式如下

8aaa25ee-1d86-11ee-962d-dac502259ad0.png

1. 第一個(gè)線性層,矩陣乘法的輸入和輸出形狀為8ac3882c-1d86-11ee-962d-dac502259ad0.png。計(jì)算量為?8adb754a-1d86-11ee-962d-dac502259ad0.png?。

2. 第二個(gè)線性層,矩陣乘法的輸入和輸出形狀為8af27f60-1d86-11ee-962d-dac502259ad0.png。計(jì)算量為?8adb754a-1d86-11ee-962d-dac502259ad0.png?。

將上述計(jì)算量相加,得到每個(gè)transformer層的計(jì)算量大約為8b1dfc44-1d86-11ee-962d-dac502259ad0.png?。

此外,另一個(gè)計(jì)算量的大頭是logits的計(jì)算,將隱藏向量映射為詞表大小。矩陣乘法的輸入和輸出形狀為8b35ff06-1d86-11ee-962d-dac502259ad0.png,計(jì)算量為?8b4c2218-1d86-11ee-962d-dac502259ad0.png?。

因此,對(duì)于一個(gè)8568a3ee-1d86-11ee-962d-dac502259ad0.png?層的transformer模型,輸入數(shù)據(jù)形狀為897c4b84-1d86-11ee-962d-dac502259ad0.png?的情況下,一次訓(xùn)練迭代的計(jì)算量為8b7f76ea-1d86-11ee-962d-dac502259ad0.png

3.1 計(jì)算量與參數(shù)量的關(guān)聯(lián)

當(dāng)隱藏維度85808590-1d86-11ee-962d-dac502259ad0.png?比較大,且遠(yuǎn)大于序列長(zhǎng)度85cf59e0-1d86-11ee-962d-dac502259ad0.png?時(shí),我們可以忽略一次項(xiàng),計(jì)算量可以近似為8bb25fce-1d86-11ee-962d-dac502259ad0.png?。前面提到當(dāng)模型參數(shù)量為8803fd6a-1d86-11ee-962d-dac502259ad0.png?,輸入的tokens數(shù)為8bd8c614-1d86-11ee-962d-dac502259ad0.png?,存在等式8bef6874-1d86-11ee-962d-dac502259ad0.png。我們可以近似認(rèn)為:?在一次前向傳遞中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),需要進(jìn)行2次浮點(diǎn)數(shù)運(yùn)算,即一次乘法法運(yùn)算和一次加法運(yùn)算。

一次訓(xùn)練迭代包含了前向傳遞和后向傳遞,后向傳遞的計(jì)算量是前向傳遞的2倍。因此,前向傳遞 + 后向傳遞的系數(shù)8c064c1a-1d86-11ee-962d-dac502259ad0.png。一次訓(xùn)練迭代中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),需要進(jìn)行8c185e50-1d86-11ee-962d-dac502259ad0.png?次浮點(diǎn)數(shù)運(yùn)算。

接下來(lái),我們可以估計(jì)訓(xùn)練GPT3-175B所需要的計(jì)算量。對(duì)于GPT3,每個(gè)token,每個(gè)參數(shù)進(jìn)行了6次浮點(diǎn)數(shù)運(yùn)算,再乘以參數(shù)量和總tokens數(shù)就得到了總的計(jì)算量。GPT3的模型參數(shù)量為8c29c1b8-1d86-11ee-962d-dac502259ad0.png?,訓(xùn)練數(shù)據(jù)量為?8c3c5f3a-1d86-11ee-962d-dac502259ad0.png?tokens。

8c4efc26-1d86-11ee-962d-dac502259ad0.png

8c661cb2-1d86-11ee-962d-dac502259ad0.jpg

3.2 訓(xùn)練時(shí)間估計(jì)

模型參數(shù)量和訓(xùn)練總tokens數(shù)決定了訓(xùn)練transformer模型需要的計(jì)算量。給定硬件GPU類(lèi)型的情況下,可以估計(jì)所需要的訓(xùn)練時(shí)間。給定計(jì)算量,訓(xùn)練時(shí)間(也就是GPU算完這么多flops的計(jì)算時(shí)間)不僅跟GPU類(lèi)型有關(guān),還與GPU利用率有關(guān)。計(jì)算端到端訓(xùn)練的GPU利用率時(shí),不僅要考慮前向傳遞和后向傳遞的計(jì)算時(shí)間,還要**考慮CPU加載數(shù)據(jù)、優(yōu)化器更新、多卡通信和記錄日志的時(shí)間。一般來(lái)講,GPU利用率一般在8c8a6fd6-1d86-11ee-962d-dac502259ad0.png之間

上文講到一次前向傳遞中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),進(jìn)行2次浮點(diǎn)數(shù)計(jì)算。使用激活重計(jì)算技術(shù)來(lái)減少中間激活顯存(下文會(huì)詳細(xì)介紹)需要進(jìn)行一次額外的前向傳遞,因此前向傳遞+ 后向傳遞 + 激活重計(jì)算的系數(shù)=1+2+1=4。使用激活重計(jì)算的一次訓(xùn)練迭代中,對(duì)于每個(gè)token,每個(gè)模型參數(shù),需要進(jìn)行8c9e60b8-1d86-11ee-962d-dac502259ad0.png?次浮點(diǎn)數(shù)運(yùn)算。在給定訓(xùn)練tokens數(shù)、硬件環(huán)境配置的情況下,訓(xùn)練transformer模型的計(jì)算時(shí)間為

8cb12194-1d86-11ee-962d-dac502259ad0.png

8cc7905a-1d86-11ee-962d-dac502259ad0.jpg

以GPT3-175B為例,在1024張40GB顯存的A100上,在300Btokens的數(shù)據(jù)上訓(xùn)練175B參數(shù)量的GPT3。40GB顯存A100的峰值性能為312TFLOPS,設(shè)GPU利用率為0.45,則所需要的訓(xùn)練時(shí)間為34天,這與[7]中的訓(xùn)練時(shí)間是對(duì)得上的

8cee6784-1d86-11ee-962d-dac502259ad0.png

以LLaMA-65B為例,在2048張80GB顯存的A100上,在1.4TBtokens的數(shù)據(jù)上訓(xùn)練了65B參數(shù)量的模型。80GB顯存A100的峰值性能為624TFLOPS,設(shè)GPU利用率為0.3,則所需要的訓(xùn)練時(shí)間為21天,這與[4]中的實(shí)際訓(xùn)練時(shí)間是對(duì)得上的

8d05f390-1d86-11ee-962d-dac502259ad0.png

4. 中間激活值分析

除了模型參數(shù)、梯度、優(yōu)化器狀態(tài)外,占用顯存的大頭就是前向傳遞過(guò)程中計(jì)算得到的中間激活值了,需要保存中間激活以便在后向傳遞計(jì)算梯度時(shí)使用。這里的激活(activations)指的是:前向傳遞過(guò)程中計(jì)算得到的,并在后向傳遞過(guò)程中需要用到的所有張量。這里的激活不包含模型參數(shù)和優(yōu)化器狀態(tài),但包含了dropout操作需要用到的mask矩陣。

在分析中間激活的顯存占用時(shí),只考慮激活占用顯存的大頭,忽略掉一些小的buffers。比如,對(duì)于layernormalization,計(jì)算梯度時(shí)需要用到層的輸入、輸入的均值8d33dd50-1d86-11ee-962d-dac502259ad0.png?和方差8d45df46-1d86-11ee-962d-dac502259ad0.png?。輸入包含了8d5b7ad6-1d86-11ee-962d-dac502259ad0.png?個(gè)元素,而輸入的均值和方差分別包含了8bd8c614-1d86-11ee-962d-dac502259ad0.png?個(gè)元素。由于85808590-1d86-11ee-962d-dac502259ad0.png?通常是比較大的(千數(shù)量級(jí)),有?8d91f19c-1d86-11ee-962d-dac502259ad0.png?。因此,對(duì)于layernormalization,中間激活近似估計(jì)為?8d5b7ad6-1d86-11ee-962d-dac502259ad0.png?,而不是8db7052c-1d86-11ee-962d-dac502259ad0.png?。

大模型在訓(xùn)練過(guò)程中通常采用混合精度訓(xùn)練,中間激活值一般是float16或者bfloat16數(shù)據(jù)類(lèi)型的。在分析中間激活的顯存占用時(shí),假設(shè)中間激活值是以float16或bfloat16數(shù)據(jù)格式來(lái)保存的,每個(gè)元素占了2個(gè)bytes。唯一例外的是,dropout操作的mask矩陣,每個(gè)元素只占1個(gè)bytes。在下面的分析中,單位是bytes,而不是元素個(gè)數(shù)。

每個(gè)transformer層包含了一個(gè)self-attention塊和MLP塊,并分別對(duì)應(yīng)了一個(gè)layer normalization連接。

先分析self-attention塊的中間激活。self-attention塊的計(jì)算公式如下:

89962dd8-1d86-11ee-962d-dac502259ad0.png

89a87cb8-1d86-11ee-962d-dac502259ad0.png

1. 對(duì)于89bccee8-1d86-11ee-962d-dac502259ad0.png?,需要保存它們共同的輸入8df04ff8-1d86-11ee-962d-dac502259ad0.png?,這就是中間激活。輸入8df04ff8-1d86-11ee-962d-dac502259ad0.png?的形狀為8e13c38e-1d86-11ee-962d-dac502259ad0.png?,元素個(gè)數(shù)為8d5b7ad6-1d86-11ee-962d-dac502259ad0.png?,占用顯存大小為8e323ab2-1d86-11ee-962d-dac502259ad0.png?。

2. 對(duì)于89fdc10a-1d86-11ee-962d-dac502259ad0.png?矩陣乘法,需要保存中間激活8e548176-1d86-11ee-962d-dac502259ad0.png?,兩個(gè)張量的形狀都是8e13c38e-1d86-11ee-962d-dac502259ad0.png?,占用顯存大小合計(jì)為8e754262-1d86-11ee-962d-dac502259ad0.png?。

3. 對(duì)于8e8d8cc8-1d86-11ee-962d-dac502259ad0.png函數(shù),需要保存函數(shù)的輸入?89fdc10a-1d86-11ee-962d-dac502259ad0.png?,占用顯存大小為8eaf0420-1d86-11ee-962d-dac502259ad0.png?,這里的8597e2c6-1d86-11ee-962d-dac502259ad0.png?表示注意力頭數(shù)。

8ed2bb04-1d86-11ee-962d-dac502259ad0.png

8ee428ee-1d86-11ee-962d-dac502259ad0.png?的形狀為:?8ef662e8-1d86-11ee-962d-dac502259ad0.png

8f0ab716-1d86-11ee-962d-dac502259ad0.png?的形狀為:8f2396a0-1d86-11ee-962d-dac502259ad0.png

89fdc10a-1d86-11ee-962d-dac502259ad0.png?的形狀為:8f445c96-1d86-11ee-962d-dac502259ad0.png,元素個(gè)數(shù)為?8f5a3d36-1d86-11ee-962d-dac502259ad0.png?,占用顯存大小為8eaf0420-1d86-11ee-962d-dac502259ad0.png?。

4. 計(jì)算完8e8d8cc8-1d86-11ee-962d-dac502259ad0.png函數(shù)后,會(huì)進(jìn)行dropout操作。需要保存一個(gè)mask矩陣,mask矩陣的形狀與89fdc10a-1d86-11ee-962d-dac502259ad0.png?相同,占用顯存大小為8f5a3d36-1d86-11ee-962d-dac502259ad0.png?。

5. 計(jì)算在85af3034-1d86-11ee-962d-dac502259ad0.png?上的attention,即?8a4c4500-1d86-11ee-962d-dac502259ad0.png?,需要保存8fed2d6c-1d86-11ee-962d-dac502259ad0.png?,大小為8eaf0420-1d86-11ee-962d-dac502259ad0.png?;以及85af3034-1d86-11ee-962d-dac502259ad0.png?,大小為90275a96-1d86-11ee-962d-dac502259ad0.png?。二者占用顯存大小合計(jì)為90482d70-1d86-11ee-962d-dac502259ad0.png?。

6. 計(jì)算輸出映射以及一個(gè)dropout操作。輸入映射需要保存其輸入,大小為90275a96-1d86-11ee-962d-dac502259ad0.png?;dropout需要保存mask矩陣,大小為8d5b7ad6-1d86-11ee-962d-dac502259ad0.png?。二者占用顯存大小合計(jì)為90a0671a-1d86-11ee-962d-dac502259ad0.png?。

因此,將上述中間激活相加得到,self-attention塊的中間激活占用顯存大小為90b20fce-1d86-11ee-962d-dac502259ad0.png?。

接下來(lái)看MLP塊的中間激活。MLP塊的計(jì)算公式如下

8aaa25ee-1d86-11ee-962d-dac502259ad0.png

1. 第一個(gè)線性層需要保存其輸入,占用顯存大小為90275a96-1d86-11ee-962d-dac502259ad0.png?。

2. 激活函數(shù)需要保存其輸入,占用顯存大小為90dd5c56-1d86-11ee-962d-dac502259ad0.png?。

3. 第二個(gè)線性層需要保存其輸入,占用顯存大小為90dd5c56-1d86-11ee-962d-dac502259ad0.png?。

4. 最后有一個(gè)dropout操作,需要保存mask矩陣,占用顯存大小為8d5b7ad6-1d86-11ee-962d-dac502259ad0.png?。

對(duì)于MLP塊,需要保存的中間激活值為910fbe12-1d86-11ee-962d-dac502259ad0.png?。

另外,self-attention塊和MLP塊分別對(duì)應(yīng)了一個(gè)layer normalization。每個(gè)layer norm需要保存其輸入,大小為90275a96-1d86-11ee-962d-dac502259ad0.png?。2個(gè)layer norm需要保存的中間激活為912fd2e2-1d86-11ee-962d-dac502259ad0.png?。

綜上,每個(gè)transformer層需要保存的中間激活占用顯存大小為91429bf2-1d86-11ee-962d-dac502259ad0.png?。對(duì)于8568a3ee-1d86-11ee-962d-dac502259ad0.png層transformer模型,還有embedding層、最后的輸出層。embedding層不需要中間激活。總的而言,當(dāng)隱藏維度85808590-1d86-11ee-962d-dac502259ad0.png?比較大,層數(shù)8568a3ee-1d86-11ee-962d-dac502259ad0.png?較深時(shí),這部分的中間激活是很少的,可以忽略。因此,對(duì)于8568a3ee-1d86-11ee-962d-dac502259ad0.png?層transformer模型,中間激活占用的顯存大小可以近似為918e2cd4-1d86-11ee-962d-dac502259ad0.png

4.1 對(duì)比中間激活與模型參數(shù)的顯存大小

在一次訓(xùn)練迭代中,模型參數(shù)(或梯度)占用的顯存大小只與模型參數(shù)量和參數(shù)數(shù)據(jù)類(lèi)型有關(guān),與輸入數(shù)據(jù)的大小是沒(méi)有關(guān)系的。優(yōu)化器狀態(tài)占用的顯存大小也是一樣,與優(yōu)化器類(lèi)型有關(guān),與模型參數(shù)量有關(guān),但與輸入數(shù)據(jù)的大小無(wú)關(guān)。而中間激活值與輸入數(shù)據(jù)的大小(批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png?和序列長(zhǎng)度85cf59e0-1d86-11ee-962d-dac502259ad0.png?)是成正相關(guān)的,隨著批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png?和序列長(zhǎng)度85cf59e0-1d86-11ee-962d-dac502259ad0.png的增大,中間激活占用的顯存會(huì)同步增大。當(dāng)我們訓(xùn)練神經(jīng)網(wǎng)絡(luò)遇到顯存不足OOM(Out OfMemory)問(wèn)題時(shí),通常會(huì)嘗試減小批次大小來(lái)避免顯存不足的問(wèn)題,這種方式減少的其實(shí)是中間激活占用的顯存,而不是模型參數(shù)、梯度和優(yōu)化器的顯存。

以GPT3-175B為例,我們來(lái)直觀地對(duì)比下模型參數(shù)與中間激活的顯存大小。GPT3的模型配置如下。我們假設(shè)采用混合精度訓(xùn)練,模型參數(shù)和中間激活都采用float16數(shù)據(jù)類(lèi)型,每個(gè)元素占2個(gè)bytes。

模型名 參數(shù)量 層數(shù) 隱藏維度 注意力頭數(shù)
GPT3 175B 96 12288 96

GPT3的模型參數(shù)量為175B,占用的顯存大小為91e94efc-1d86-11ee-962d-dac502259ad0.png。GPT3模型需要占用350GB的顯存。

GPT3的序列長(zhǎng)度85cf59e0-1d86-11ee-962d-dac502259ad0.png?為920b2784-1d86-11ee-962d-dac502259ad0.png?。對(duì)比不同的批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png?占用的中間激活:

當(dāng)922b29ee-1d86-11ee-962d-dac502259ad0.png?時(shí),中間激活占用顯存為92448182-1d86-11ee-962d-dac502259ad0.png,大約是模型參數(shù)顯存的0.79倍。

當(dāng)925af32c-1d86-11ee-962d-dac502259ad0.png?時(shí),中間激活占用顯存為9271fd88-1d86-11ee-962d-dac502259ad0.png,大約是模型參數(shù)顯存的50倍。

當(dāng)928c62ea-1d86-11ee-962d-dac502259ad0.png?時(shí),中間激活占用顯存為

929f7ba0-1d86-11ee-962d-dac502259ad0.png,大約是模型參數(shù)顯存的101倍。

可以看到隨著批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png的增大,中間激活占用的顯存遠(yuǎn)遠(yuǎn)超過(guò)了模型參數(shù)顯存。通常會(huì)采用?激活重計(jì)算技術(shù)來(lái)減少中間激活,理論上可以將中間激活顯存從92c0fb7c-1d86-11ee-962d-dac502259ad0.png?減少到92d78c98-1d86-11ee-962d-dac502259ad0.png,代價(jià)是增加了一次額外前向計(jì)算的時(shí)間,本質(zhì)上是“時(shí)間換空間”。

5. KV cache

在推斷階段,transformer模型加速推斷的一個(gè)常用策略就是使用 KV cache。一個(gè)典型的大模型生成式推斷包含了兩個(gè)階段:

1.預(yù)填充階段:輸入一個(gè)prompt序列,為每個(gè)transformer層生成 key cache和value cache(KV cache)。

2.解碼階段:使用并更新KV cache,一個(gè)接一個(gè)地生成詞,當(dāng)前生成的詞依賴(lài)于之前已經(jīng)生成的詞。

92ed9d6c-1d86-11ee-962d-dac502259ad0.png?個(gè)transformer層的權(quán)重矩陣為9304c852-1d86-11ee-962d-dac502259ad0.png。其中,self-attention塊的4個(gè)權(quán)重矩陣?9319b14a-1d86-11ee-962d-dac502259ad0.png,并且MLP塊的2個(gè)權(quán)重矩陣?93302484-1d86-11ee-962d-dac502259ad0.png

預(yù)填充階段

假設(shè)第92ed9d6c-1d86-11ee-962d-dac502259ad0.png?個(gè)transformer層的輸入為93515a78-1d86-11ee-962d-dac502259ad0.png?,self-attention塊的key、value、query和output表示為93684648-1d86-11ee-962d-dac502259ad0.png,其中,?93822d4c-1d86-11ee-962d-dac502259ad0.png

key cache和value cache的計(jì)算過(guò)程為:

9398dd94-1d86-11ee-962d-dac502259ad0.png93af854e-1d86-11ee-962d-dac502259ad0.png

92ed9d6c-1d86-11ee-962d-dac502259ad0.png?個(gè)transformer層剩余的計(jì)算過(guò)程為:

93d233fa-1d86-11ee-962d-dac502259ad0.png93e27e72-1d86-11ee-962d-dac502259ad0.png93f8cbaa-1d86-11ee-962d-dac502259ad0.png

解碼階段

給定當(dāng)前生成詞在第92ed9d6c-1d86-11ee-962d-dac502259ad0.png?個(gè)transformer層的向量表示為9418a024-1d86-11ee-962d-dac502259ad0.png。推斷計(jì)算分兩部分:更新KV cache和計(jì)算第?92ed9d6c-1d86-11ee-962d-dac502259ad0.png個(gè)transformer層的輸出。

更新key cache和value cache的計(jì)算過(guò)程如下:

943ab0a6-1d86-11ee-962d-dac502259ad0.png

94514e74-1d86-11ee-962d-dac502259ad0.png

92ed9d6c-1d86-11ee-962d-dac502259ad0.png?個(gè)transformer層剩余的計(jì)算過(guò)程為:

946e96e6-1d86-11ee-962d-dac502259ad0.png

9480832e-1d86-11ee-962d-dac502259ad0.png9492f5e0-1d86-11ee-962d-dac502259ad0.png

5.1 KV cache的顯存占用分析

假設(shè)輸入序列的長(zhǎng)度為85cf59e0-1d86-11ee-962d-dac502259ad0.png?,輸出序列的長(zhǎng)度為88f02cc6-1d86-11ee-962d-dac502259ad0.png?,以float16來(lái)保存KV cache,那么?KVcache的峰值顯存占用大小為94c74c82-1d86-11ee-962d-dac502259ad0.png。這里第一個(gè)2表示K/V cache,第二個(gè)2表示float16占2個(gè)bytes。

以GPT3為例,對(duì)比KV cache與模型參數(shù)占用顯存的大小。GPT3模型占用顯存大小為350GB。假設(shè)批次大小925af32c-1d86-11ee-962d-dac502259ad0.png?,輸入序列長(zhǎng)度94ee1858-1d86-11ee-962d-dac502259ad0.png?,輸出序列長(zhǎng)度95050914-1d86-11ee-962d-dac502259ad0.png?,則KV cache占用顯存為951b99e0-1d86-11ee-962d-dac502259ad0.png,大約是模型參數(shù)顯存的0.5倍。

6. 總結(jié)

本文首先介紹了如何計(jì)算transformer模型的參數(shù)量,基于參數(shù)量可以進(jìn)一步估計(jì)模型參數(shù)、梯度和優(yōu)化器狀態(tài)占用的顯存大小。接著,本文估計(jì)了訓(xùn)練迭代中,在給定訓(xùn)練tokens數(shù)的情況下transformer模型的計(jì)算量,給予計(jì)算量和顯卡性能可以進(jìn)一步估計(jì)訓(xùn)練迭代的計(jì)算耗時(shí)。然后,本文分析了transformer模型前向計(jì)算過(guò)程中產(chǎn)生的中間激活值的顯存大小,中間激活的顯存大小與輸入數(shù)據(jù)大小正相關(guān),甚至?xí)h(yuǎn)超過(guò)模型參數(shù)占用的顯存。最后,本文介紹了transformer模型推理過(guò)程常用的加速策略:使用KVcache。總的來(lái)說(shuō),分析transformer模型的參數(shù)量、計(jì)算量、中間激活和KV cache,有助于理解大模型訓(xùn)練和推斷過(guò)程中的顯存效率和計(jì)算效率。

聲明:本文內(nèi)容及配圖由入駐作者撰寫(xiě)或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 模型
    +關(guān)注

    關(guān)注

    1

    文章

    3243

    瀏覽量

    48840
  • Transformer
    +關(guān)注

    關(guān)注

    0

    文章

    143

    瀏覽量

    6006
  • ChatGPT
    +關(guān)注

    關(guān)注

    29

    文章

    1561

    瀏覽量

    7671

原文標(biāo)題:分析transformer模型的參數(shù)量、計(jì)算量、中間激活、KV cache

文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    基于卷積的基礎(chǔ)模型InternImage網(wǎng)絡(luò)技術(shù)分析

    近年來(lái)大規(guī)模視覺(jué) Transformer 的蓬勃發(fā)展推動(dòng)了計(jì)算機(jī)視覺(jué)領(lǐng)域的性能邊界。視覺(jué) Transformer 模型通過(guò)擴(kuò)大模型
    發(fā)表于 11-18 10:49 ?702次閱讀
    基于卷積的基礎(chǔ)<b class='flag-5'>模型</b>InternImage網(wǎng)絡(luò)技術(shù)分析

    基于Transformer做大模型預(yù)訓(xùn)練基本的并行范式

    并行(TP)。 它的基本思想就是把模型參數(shù)縱向切開(kāi),放到不同的GPU上進(jìn)行獨(dú)立計(jì)算,然后再做聚合。 在寫(xiě)這篇文章的過(guò)程中,我發(fā)現(xiàn)要理解Megatron的大框架不難,但是涉及到細(xì)節(jié),特別是混合并行部分,要考慮的就很多了。 所以我
    的頭像 發(fā)表于 05-31 14:38 ?2692次閱讀
    基于<b class='flag-5'>Transformer</b>做大<b class='flag-5'>模型</b>預(yù)訓(xùn)練基本的并行范式

    大語(yǔ)言模型背后的Transformer,與CNN和RNN有何不同

    ? 電子發(fā)燒友網(wǎng)報(bào)道(文/李彎彎)近年來(lái),隨著大語(yǔ)言模型的不斷出圈,Transformer這一概念也走進(jìn)了大眾視野。Transformer是一種非常流行的深度學(xué)習(xí)模型,最早于2017年
    的頭像 發(fā)表于 12-25 08:36 ?4091次閱讀
    大語(yǔ)言<b class='flag-5'>模型</b>背后的<b class='flag-5'>Transformer</b>,與CNN和RNN有何不同

    你了解在單GPU上就可以運(yùn)行的Transformer模型

    上一步也跑不了,因?yàn)樗鼈兊膬?nèi)存需求太大了。例如,完整的GPT-2模型大約包含1.5B參數(shù)。最大配置的參數(shù)數(shù)量超過(guò)每層0.5B,而層數(shù)有64 層。圖2:標(biāo)準(zhǔn)Transformer
    發(fā)表于 11-02 15:19

    Google科學(xué)家設(shè)計(jì)簡(jiǎn)化稀疏架構(gòu)Switch Transformer,語(yǔ)言模型參數(shù)量可擴(kuò)展至 1.6 萬(wàn)億

    剛剛,Google Brain 高級(jí)研究科學(xué)家 Barret Zoph 發(fā)帖表示,他們?cè)O(shè)計(jì)了一個(gè)名叫「Switch Transformer」的簡(jiǎn)化稀疏架構(gòu),可以將語(yǔ)言模型參數(shù)量擴(kuò)展至 1.6 萬(wàn)億
    的頭像 發(fā)表于 01-13 16:50 ?2994次閱讀

    一個(gè)GPU訓(xùn)練一個(gè)130億參數(shù)模型

    現(xiàn)在的模型動(dòng)輒數(shù)百、數(shù)千億參數(shù),普通人訓(xùn)不動(dòng)怎么辦? 前不久,谷歌發(fā)布了參數(shù)量為 1.6 萬(wàn)億的語(yǔ)言模型Swith Transformer
    的頭像 發(fā)表于 02-11 09:04 ?2437次閱讀
    一個(gè)GPU訓(xùn)練一個(gè)130億<b class='flag-5'>參數(shù)</b>的<b class='flag-5'>模型</b>

    超大Transformer語(yǔ)言模型的分布式訓(xùn)練框架

    模型的預(yù)訓(xùn)練計(jì)算。 大模型是大勢(shì)所趨 近年來(lái),NLP 模型的發(fā)展十分迅速,模型的大小每年以1-2個(gè)數(shù)量
    的頭像 發(fā)表于 10-11 16:46 ?2690次閱讀
    超大<b class='flag-5'>Transformer</b>語(yǔ)言<b class='flag-5'>模型</b>的分布式訓(xùn)練框架

    Microsoft使用NVIDIA Triton加速AI Transformer模型應(yīng)用

    Microsoft 的目標(biāo)是,通過(guò)結(jié)合使用 Azure 與 NVIDIA GPU 和 Triton 推理軟件,率先將一系列強(qiáng)大的 AI Transformer 模型投入生產(chǎn)用途。
    的頭像 發(fā)表于 04-02 13:04 ?1767次閱讀

    在X3派上玩轉(zhuǎn)一億參數(shù)量超大Transformer,DIY專(zhuān)屬你的離線語(yǔ)音識(shí)別

    Transformer模型在自然語(yǔ)言領(lǐng)域被提出后,目前已經(jīng)擴(kuò)展到了計(jì)算機(jī)視覺(jué)、語(yǔ)音等諸多領(lǐng)域。然而,雖然Transformer模型在語(yǔ)音識(shí)別
    的頭像 發(fā)表于 02-21 16:08 ?829次閱讀
    在X3派上玩轉(zhuǎn)一億<b class='flag-5'>參數(shù)量</b>超大<b class='flag-5'>Transformer</b>,DIY專(zhuān)屬你的離線語(yǔ)音識(shí)別

    基于Transformer的大型語(yǔ)言模型(LLM)的內(nèi)部機(jī)制

    本文旨在更好地理解基于 Transformer 的大型語(yǔ)言模型(LLM)的內(nèi)部機(jī)制,以提高它們的可靠性和可解釋性。 隨著大型語(yǔ)言模型(LLM)在使用和部署方面的不斷增加,打開(kāi)黑箱并了解它們的內(nèi)部
    的頭像 發(fā)表于 06-25 15:08 ?1481次閱讀
    基于<b class='flag-5'>Transformer</b>的大型語(yǔ)言<b class='flag-5'>模型</b>(LLM)的內(nèi)部機(jī)制

    transformer模型詳解:Transformer 模型的壓縮方法

    ?動(dòng)機(jī)&背景 Transformer 模型在各種自然語(yǔ)言任務(wù)中取得了顯著的成果,但內(nèi)存和計(jì)算資源的瓶頸阻礙了其實(shí)用化部署。低秩近似和結(jié)構(gòu)化剪枝是緩解這一瓶頸的主流方法。然而,作者通過(guò)分析發(fā)現(xiàn),結(jié)構(gòu)化
    的頭像 發(fā)表于 07-17 10:50 ?2124次閱讀
    <b class='flag-5'>transformer</b><b class='flag-5'>模型</b>詳解:<b class='flag-5'>Transformer</b> <b class='flag-5'>模型</b>的壓縮方法

    盤(pán)古大模型參數(shù)量有多少

    盤(pán)古大模型參數(shù)量有多少 盤(pán)古大模型(PanGu-α)是由中國(guó)科學(xué)院計(jì)算技術(shù)研究所提供的一種語(yǔ)言生成預(yù)訓(xùn)練模型。該
    的頭像 發(fā)表于 08-17 11:28 ?2973次閱讀

    基于Transformer模型的壓縮方法

    基于Transformer架構(gòu)的大型模型在人工智能領(lǐng)域中發(fā)揮著日益重要的作用,特別是在自然語(yǔ)言處理(NLP)和計(jì)算機(jī)視覺(jué)(CV)領(lǐng)域。
    的頭像 發(fā)表于 02-22 16:27 ?656次閱讀
    基于<b class='flag-5'>Transformer</b><b class='flag-5'>模型</b>的壓縮方法

    使用PyTorch搭建Transformer模型

    Transformer模型自其問(wèn)世以來(lái),在自然語(yǔ)言處理(NLP)領(lǐng)域取得了巨大的成功,并成為了許多先進(jìn)模型(如BERT、GPT等)的基礎(chǔ)。本文將深入解讀如何使用PyTorch框架搭建Trans
    的頭像 發(fā)表于 07-02 11:41 ?1631次閱讀

    Transformer語(yǔ)言模型簡(jiǎn)介與實(shí)現(xiàn)過(guò)程

    在自然語(yǔ)言處理(NLP)領(lǐng)域,Transformer模型以其卓越的性能和廣泛的應(yīng)用前景,成為了近年來(lái)最引人注目的技術(shù)之一。Transformer模型由谷歌在2017年提出,并首次應(yīng)用于
    的頭像 發(fā)表于 07-10 11:48 ?1713次閱讀
    主站蜘蛛池模板: 国模娜娜扒开嫩木耳| 日本黄色大片在线观看| 欧美一级日韩在线观看| www.99色.com| 亚洲啪啪网站| 小说老卫陈红张敏陈法蓉| 久久久五月| 男女同床爽爽视频免费| 美女三级黄| 尻老逼| 欧美黑粗| 伊人网综合在线| 黄色18网站| 大黄网站色多多| 唯美久草| 12306影院午夜入口| 在线免费看黄视频| 欧美一区a| 日本一本在线视频| lsj老司机精品视频在线观看| 免费视频淫片aa毛片| 欧美adc影院| 一卡二卡卡四卡无人区中文 | 色天使色婷婷丁香久久综合| video另类蛇交| 曰本a| 蕾丝视频在线播放| 91精品福利久久久| 色香蕉网站| 失禁h啪肉尿出来高h受| www.色涩| 777国产精品永久免费观看| 性免费网站| 好看的一级毛片| 成人啪啪免费视频| 免费激情网站| 91成人免费福利网站在线| 人人射人人澡| 成人精品福利| 日本精品高清一区二区2021| 国产成人精品日本亚洲网站|