首先我們對數學符號做一些約定。
我們首先考慮簡單的情況:前饋神經網路,如圖
??? 所示。我們先假設任意兩層之間沒有權值共享。方向傳播演算法本質上就是梯度下降演算法,所以我們要對損失函式關於每個引數求導。設單個輸入例項
x 損失函式為 j(
x), 那麼我們要求解∂j
∂wi 和∂j
∂bi , i=
1,2,
…,n . 對
j 關於bi
求導是容易的,直接使用鏈式法則
∂bi=
∂xi∂
bi⋅∂
j∂xi
=i∂j
∂xi=
∂j∂x
i(1)
我們可以證明 ∂j
∂wi=
∂j∂x
ix(i
−1)¯
¯¯¯¯
¯¯¯t
(2)
這是因為 ==
∂j∂w
ipq=
∂xi∂
wipq
∂j∂x
i∑j∂
xij∂
wipq
∂j∂x
ijxi
−1q¯
¯¯¯¯
¯∂j∂
xip
所以關鍵是求解 ∂j
∂xi . 由鏈式法則(注意這裡是使用的標量函式對矩陣或向量求到的鏈式法則,求導順序不可交換),有 ∂j
∂xi=
∂xn∂
xi==
==∂x
n∂xi
∂j∂x
n∂xn
−1∂x
i∂xn
∂xn−
1∂xn
−2∂x
i∂xn
−1∂x
n−2∂
xn∂x
n−1⋯
∂xi+
1∂xi
∂xi+
2∂xi
+1⋯∂
xn∂x
n−1(3)
(4)
所以,由公式 (1), (3) 和 (4) 得 ∂j
∂bi=
∂j∂x
i=∂x
i+1∂
xi∂x
i+2∂
xi+1
⋯∂xn
∂xn−
1⋅∂j
∂xn(5)
由公式 (2)–(4) 得 ∂j
∂wi=
∂xi+
1∂xi
∂xi+
2∂xi
+1⋯∂
xn∂x
n−1
ai
⋅∂j∂
xn
δ
n⋅x(
i−1)
¯¯¯¯
¯¯¯¯
t(6)
其中 ∂x
j+1∂
xj=diag(f
˙j(x
j))w
(j+1
)t公式 (6) 中δn
的計算是容易的,因為 xn
靠近網路的輸出端,一般而言
j 是 xn
的乙個在形式上比較簡單的函式。對給定的損失函式,我們可以直接寫出其表示式。
公式 (6) 中ai
是一系列的jaccobian矩陣的乘積。結合圖1,我們可以從網路的輸出端到輸入端的方向依次計算這些jaccobian矩陣,並累乘之得到aj
,j=n
−1,n
−2,…
,1. 但是這樣做的計算量太大,因為它涉及一列的矩陣與矩陣的乘積,我們不會顯示的計算矩陣ai
, 而是依次計算δj
,j=n
−1,n
−2,…
,1 δ
j=∂x
j+1∂
xjδj
+1=diag(f
˙j(x
j))w
(j+1
)tδj
+1最後得到 ∂j
∂wi=
∂j∂b
i=δi
x(i−
1)¯¯
¯¯¯¯
¯¯tδ
i 另一點值得注意是,通常我們不會每觀察到乙個例項就更新權值,而是對每
m>
1 個例項計算一次梯度,更新一次權值。例如,我們從訓練集中每抽取
m 個樣本
更新一次引數。對這
m 個樣本,損失函式為l=
1m∑m
i=1j
(ix)
. 於是 ∂l
∂bi=
∂l∂w
i=1m
∑j=1
mai∂
j(jx
)∂xn
1m∑j
=1ma
i∂j(
jx)∂
xn⋅j
x(i−
1)¯¯
¯¯¯¯
¯¯¯¯
t(7)
(8)
在實踐中,我們一般都是按照 (7) 和 (8) 上面兩式求引數的梯度,進而更新網路的權值。注意矩陣按不同的順序求值,時間複雜度是不同的,千萬不要顯示的計算 ai
. 如果定義 xi
=δi=
(1xi
,2xi
,…,m
xi)(
1δi,
2δi,
…,mδ
i)那麼 ∂j
∂wi=
∂j∂b
i=δi
x(i−
1)tδ
i 下面我們考慮有權值共享的情況。我們可以證明,當有權值共享的時候,網路可以如同沒有權值共享一樣地更新。如圖 2 所示,不失一般性,不妨假設除了圖 2 所示的
z 層和
y 層共享權值矩陣
w 之外,沒有其他權值共享;設所有的啟用函式都是
f .
====
===∂
j∂wp
q=∂z
∂wpq
⋅∂j∂
z∑i∂
zi∂w
pq⋅∂
j∂zi
∑i∂(
wih¯
¯¯+b
i)∂w
pq⋅∂
j∂zi
∑i∂w
ih¯¯
¯∂wp
q⋅∂j
∂zi∑
i(∂w
i∂wp
q⋅∂w
ih¯¯
¯∂wi
+∂h¯
¯¯∂w
pq⋅∂
wih¯
¯¯∂h
¯¯¯)
⋅∂j∂
zihq
¯¯¯¯
∂j∂z
p+∂h
¯¯¯∂
wpq∑
iwti
∂j∂z
ihq¯
¯¯¯∂
j∂zp
+∂h¯
¯¯∂w
pqwt
∂j∂z
hq¯¯
¯¯∂j
∂zp+
∂y∂w
pq∂h
∂y∂h
¯¯¯∂
h∂z∂
h¯¯¯
∂j∂z
hq¯¯
¯¯∂j
∂zp+
∂y∂w
pq∂j
∂yhq
¯¯¯¯
∂j∂z
p+xq
¯¯¯¯
∂j∂y
p所以 ∂j
∂w=∂
j∂zh
¯¯¯t
+∂j∂
yx¯¯
¯t同理,如果偏置共享的話,對偏置的導數也有類似的性質。 因為∂
j∂w 可以寫為沒有權值共享時,損失函式關於對應位置的權值矩陣的導數之和,所以對該網路更新權值,可以如同沒有權值共享一樣地更新。
未完待續……
反向傳播演算法
反向傳播演算法的工作機制為對 錯誤的神經元施以懲罰。從輸出層開始,向上層次查詢 錯誤的神經元,微調這些神經元輸入值的權重,以達到修復輸出錯誤的目的。神經元之所以給出錯誤的 原因在於它前面為其提供輸入的神經元,更確切地說是由兩個神經元之間的權重及輸入值決定的。我們可以嘗試對權重進行微調。每次調整的幅度...
反向傳播演算法
看了很多有關神經網路的資料,一直對於反向傳播演算法總是不理解,對於其過程也是覺得很複雜,讓人想放棄,寫一篇部落格來從頭到尾來擼一遍反向傳播,讓這個黑盒子變透明。主要涉及到的數學方法,就是求偏導數,鏈式法則,沒有其他的複雜的數學公式了 當然求偏導的原理,是利用梯度下降法 因為是要將誤差減小,那麼就需要...
反向傳播演算法
反向傳播演算法的原理很簡單,只涉及chain rule和求導,但是在實際程式設計中,需要考慮到向量化後,會涉及矩陣求導。矩陣的求導只是提供了理論支援,實際實現中又使用了額外的技巧。首先 李巨集毅老師的backpropagation課程,了解為什麼這個演算法取名為反向傳播,怎麼傳播的 掌握了反向傳播的...