Deep learning系列(八)引數初始化

2021-07-05 21:30:24 字數 1792 閱讀 6377

在主成分分析與白化一節中介紹了如何對輸入資料進行預處理,在這節中介紹與之類似的另乙個問題,引數初始化(weight initialization)。

在模型訓練之初,我們不知道引數的具體分布,然而如果資料經過了合理的歸一化(normalization)處理後,對於引數的合理猜測是其中一半是正的,另一半是負的。然後我們想是不是把引數都初始化為0會是比較好的初始化?這樣做其實會帶來乙個問題,經過正向傳播和反向傳播後,引數的不同維度之間經過相同的更新,迭代的結果是不同維度的引數是一樣的,嚴重地影響了模型的效能。

我們仍然想要引數接近於0,又不是絕對的0,一種可行的做法是將引數初始化為小的隨機數,這樣做可以打破對稱性(symmetry breaking)。python**如下:

nn_input_dim = 2

nn_hdim = 3

w = 0.001* np.random.randn(nn_input_dim,nn_hdim)

其中randn從均值為0,標準差是1的高斯分布中取樣,這樣,引數的每個維度來自乙個多維的高斯分布。需要注意的是引數初始值不能取得太小,因為小的引數在反向傳播時會導致小的梯度,對於深度網路來說,也會產生梯度瀰散問題,降低引數的收斂速度。

引數隨機初始化為乙個小的隨機數存在乙個問題:乙個神經元輸出的方差會隨著輸入神經元數量的增多而變大。對於有n個輸入單元的神經元來說,考慮χ2

分布,每個輸入的方差是1/

n 時,總的方差是1,因此,我們對每個輸入的標準差乘以1/

sqrt

(n) ,每個神經元的引數初始化**為:

w = np.random.randn(n) / sqrt(n)
其中n為這個神經元輸入的個數。這樣可以確保神經元的輸出有相同的分布,提高訓練的收斂速度。

將上面初始化方案推廣到網路的一層,對於神經網路的第一層可以這樣初始化:

nn_input_dim = 2

nn_hdim = 3

w = np.random.randn(nn_input_dim,nn_hdim) / sqrt(nn_input_dim)

understanding the difficulty of training deep feedforward neural networks文章中給出乙個類似的初始化方案:

nn_input_dim = 2

nn_hdim = 3

w = np.random.randn(nn_input_dim,nn_hdim) / sqrt(nn_input_dim+nn_hdim)

對於relu啟用神經元,delving deep into rectifiers: surpassing human-level performance on imagenet classification認為每個神經元的方差應該為:2/

n ,其初始化方案:

nn_input_dim = 2

nn_hdim = 3

w = np.random.randn(nn_input_dim,nn_hdim) / sqrt(2.0/nn_input_dim)

通常偏置項(bias)初始化為0:

nn_input_dim = 2

nn_hdim = 3

b1 = np.zeros((1, nn_hdim))

對於relu啟用神經元來說,可以將偏置項初始化為乙個小的常數,比如0.01,但不確定這樣做是否提高收斂的表現,在實際應用中,也常初始化為0。

參考內容:

1.

《DeepLearning》讀書筆記(八)

現代卷積神經網路動輒包含數百萬的神經單元,就像12.1節中討論的那樣,使用平行計算的高效實現是至關重要的。但是,有時候選用合適的演算法來加速計算也可以起到事半功倍的效果。卷積等價於使用傅利葉變換同時將輸入和核轉化為頻域 frequency domain 然後執行畫素層面的兩個訊號的乘法,最後使用反傅...

Deep learning系列(七)啟用函式

sigmoid將乙個實數輸入對映到 0,1 範圍內,如下圖 左 所示。使用sigmoid作為啟用函式存在以下幾個問題 因為上面兩個問題的存在,導致引數收斂速度很慢,嚴重影響了訓練的效率。因此在設計神經網路時,很少採用sigmoid啟用函式。tanh函式將乙個實數輸入對映到 1,1 範圍內,如上圖 右...

Deep learning系列(七)啟用函式

sigmoid將乙個實數輸入對映到 0,1 範圍內,如下圖 左 所示。使用sigmoid作為啟用函式存在以下幾個問題 因為上面兩個問題的存在,導致引數收斂速度很慢,嚴重影響了訓練的效率。因此在設計神經網路時,很少採用sigmoid啟用函式。tanh函式將乙個實數輸入對映到 1,1 範圍內,如上圖 右...