反向傳播的全矩陣方法

2021-10-14 13:10:50 字數 2104 閱讀 3252

之前在神經網路隨機梯度下降計算梯度時,在反向傳播時每個樣本單獨計算梯度,然後再求小批量資料的梯度平均值;而現在全矩陣方法是將整個小批量作為乙個矩陣(乙個樣本作為一列)輸入整體利用矩陣運算一次計算梯度平均值,用計算出的梯度平均值去更新權重和偏置。結果表明,全矩陣方法能夠提公升效率平均5倍左右,由開始的平均10秒到2秒

廢話不多說,直接上**:

# ⼩批量資料上的反向傳播的全矩陣⽅法,並且最後更新權重

defbackprop_matrix

(self, x, y, m, eta)

:"""

⼩批量資料上的反向傳播的全矩陣⽅法

:param x: 小批量資料的輸入矩陣,一列代表乙個樣本

:param y: 期望輸出矩陣

:param m: 資料的規模

:param eta: 學習率

:return:

"""# 根據權重矩陣和偏置列向量的形狀生成梯度矩陣

nabla_b =

[np.zeros(b.shape)

for b in self.biases]

nabla_w =

[np.zeros(w.shape)

for w in self.weights]

# 第一步,設定輸入啟用值矩陣

activation = x

activations =

[x]# 儲存各層的啟用值矩陣

zs =

# 儲存各層的帶權輸入矩陣

# 第二步,前向傳播,計算各層的帶權輸入和啟用值

for w, b in

zip(self.weights, self.biases)

: z = np.dot(w, activation)

+ b activation = sigmoid(z)

# 第三步,計算輸出層誤差矩陣

delta = cost_derivative(activations[-1

], y)

* sigmoid_prime(zs[-1

])# 計算輸出層的偏置和權重的梯度

nabla_b[-1

]= np.array(

[np.mean(delta, axis=1)

]).transpose(

)# self.biases[-1] = self.biases[-1] - eta * nabla_b

nabla_w[-1

]= np.dot(delta, activations[-2

].transpose())

/ m # self.weights[-1] = self.weights[-1] - eta * nabla_w

# 第四步,反向傳播誤差,並且用誤差計算梯度

for l in

range(2

, self.num_layers)

: delta = np.dot(self.weights[

-l+1

].transpose(

), delta)

* sigmoid_prime(zs[

-l])

nabla_b[

-l]= np.array(

[np.mean(delta, axis=1)

]).transpose(

) nabla_w[

-l]= np.dot(delta, activations[

-l-1

].transpose())

/ m # 第五步,梯度下降更新引數

for l in

range(1

, self.num_layers)

: self.biases[

-l]= self.biases[

-l]- eta * nabla_b[

-l] self.weights[

-l]= self.weights[

-l]- eta * nabla_w[

-l]

矩陣求導與反向傳播個人理解

矩陣求導參考 看圖中畫紅圈的欄 反向傳播參考b站 向量x與向量y相乘 乙個數 xy 數 也就是x與y有乙個是行向量,乙個是列向量,習慣上一般我們認為y是列向量 則數對x的偏導 y 格式與x一致 數對y的偏導 x的轉置 格式與y一致 口訣 前不轉 後轉 口訣 鏈式法則向前乘 g x 導數 g 1 g ...

反向傳播演算法 反向傳播演算法的工作原理(2)

推薦圖書 資料準備和特徵工程 在第一部分 反向傳播演算法的工作原理 1 已經得到了如下結論,本文將在前述基礎上,做進一步的證明和解釋。其中 是乙個方陣,其對角線是 非對角線項為零。請注意,此矩陣通過矩陣乘法作用於 有上面糧食,可得 對於熟悉矩陣乘法的讀者來說,這個方程可能比 bp1 和 bp2 更容...

全連線層及卷積層的反向傳播公式

這裡主要是記錄一下全連線層及卷積層的兩組反向傳播 bp 公式,只要看懂和記住這兩個公式,反向傳播的理解和實現應該都是不難的了。z jl kw jkla kl 1 bjl 1.1 z l j sum k w l a k b l j tag zjl k wj kl a kl 1 bj l 1 1 l a...