DL訓練中電腦記憶體問題

2021-08-15 13:42:32 字數 2315 閱讀 1046

本文主要譯介自graphcore在2023年1月的這篇部落格: why is so much memory needed for deep neural networks。介紹了深度學習中記憶體的開銷,以及降低記憶體需求的幾種解決方案。

考慮乙個單層線性網路,附帶乙個啟用函式: h=

w1x+

w2 y

=f(h

) 代價函式:

在訓練時,每乙個迭代要記錄以下資料:

- 當前模型引數w1

,w2- 前向運算各層響應:x,

h,y 這樣,可以在後向運算中用梯度下降更新引數:

很小,不做考量。

256*256的彩色影象:256*256*3*1 byte= 192kb

較大,和模型複雜度有關。

入門級的mnist識別網路有6.6 million引數,使用32-bit浮點精度,佔記憶體:6.6m * 32 bit = 25mb

50層的resnet有26 million引數,佔記憶體:26m * 32 bit = 99mb

當然,你可以設計精簡的網路來處理很複雜的問題。

較大,同樣和模型複雜度有關。

50層的resnet有16 million響應,佔記憶體:16m*32bit = 64mb

響應和模型引數的數量並沒有直接關係。卷積層可以有很大尺寸的響應,但只有很少的引數;啟用層甚至可以沒有引數。

– 這樣看起來也不大啊?幾百兆而已。

– 往下看。

為了有效利用gpu的simd機制,要把資料以mini-batch的形式輸入網路。

如果要用32 bit的浮點數填滿常見的1024 bit通路,需要32個樣本同時計算。

在使用mini-batch時,模型引數依然只儲存乙份,但各層響應需要按mini-batch大小翻倍。

50層的resnet,mini-batch=32,各層相應佔記憶體:64mb*32 = 2gb

設h×

w的輸入影象為x,

k×k的卷積核為

r,符合我們直覺的卷積是這樣計算的。

對每乙個輸出位置,計算小塊對位乘法結果之和。

其中,表示輸入影象中,以 h

,w為中心,尺寸為 k

×k的子影象。

但是,這種零碎運算很慢

在深度學習庫中,一般會採用lowering的方式,把卷積計算轉換成矩陣乘法

首先,把輸入影象分別平移不同距離,得到

的位移影象,串接成

。之後,把k×

k的卷積核按照同樣順序拉伸成

的矩陣r

卷積結果通過一次矩陣乘法獲得:

輸入輸出為多通道時,方法類似,詳情參見這篇部落格。

在計算此類卷積時,前層響應

x需要擴大

倍。

50層的resnet,考慮lowering效應時,各層響應佔記憶體7.5gb

為了有效利用simd,如果精度降低一倍,batch大小要擴大一倍。不能降低記憶體消耗。

不開闢新記憶體,直接重寫原有響應。

很多啟用函式都可以這樣操作。

複雜一些,通過分析整個網路圖,可以找出只需要用一次的響應,它可以和後續響應共享記憶體。例如mxnet的memory sharing機制。

綜合運用這種方法,mit在2023年的這篇**能夠把記憶體降低兩到三倍。

找出那些容易計算的響應結果(例如啟用函式層的輸出)不與儲存,在需要使用的時候臨時計算。

使用這種方法,mxnet的這個例子能夠把50層的resnet網路占用的記憶體減小四倍。

類似地,deepmind在2023年的這篇**用rnn處理長度為1000的序列,記憶體占用降低20倍,計算量增加30%。

當然,還有graphcore自家的ipu,也通過儲存和計算的平衡來節約資源。

graphcore本身是一家機器學習晶元初創公司,行文中難免夾帶私貨,請明辨。

本文**,

電腦記憶體常見問題處理方法

相信眾多朋友在使用電腦時,總會遇到這樣或那樣的各種問題。如啟動電腦卻無法正常啟動 無法進入作業系統或是執行應用軟體,無故經常宕機等故障時,這些問題的產生常會因為記憶體出現異常故障而導致操作失敗。這是因為記憶體做為電腦中三大件配件之一,主要擔負著資料的臨時訪問任務。而市場上記憶體條的質量又參差不齊,所...

電腦記憶體常見問題處理方法

相信眾多朋友在使用電腦時,總會遇到這樣或那樣的各種問題。如啟動電腦卻無法正常啟動 無法進入作業系統或是執行應用軟體,無故經常宕機等故障時,這些問題的產生常會因為記憶體出現異常故障而導致操作失敗。這是因為記憶體做為電腦中三大件配件之一,主要擔負著資料的臨時訪問任務。而市場上記憶體條的質量又參差不齊,所...

說說DLL中記憶體問題

今天除錯動態庫的時候,有個函式在返回的時候總是要報錯。在callstack視窗中看見是堆疊釋放出了問題。但是我一向也是堅持誰申請誰釋放不是就ok了嗎。空氣這函式裡面還沒有堆操作,全部是區域性變數啊,怎麼會出錯呢?經過仔細排查,最後把函式的呼叫方式全部改為 stdcall方式,當時還以為是不是引數出入...