在對模型優化時,希望通過梯度下降法使得模型的損失函式降低。目前主要的梯度下降法有sgd、momentum、adagrad、rmsprop、adam幾種,接下來將詳細討論這幾種方法以及他們的優缺點。
隨機選取乙個樣本的損失來近似整體樣本的平均損失,sgd在進行引數更新時的計算方式為
θ t=
θt−1
−αgt
\theta_t = \theta_ - \alpha g_t
θt=θt
−1−
αgt
其優點在於:
收斂速度快。
其缺陷在於:
容易收斂到區域性最優,或被困在鞍點。
對初始學習率的選擇依賴度較高,因該演算法引數的更新幅度固定,無法主動隨著迭代次數更新。
各方向學習率相同,但在實際情形中這種方式並不合理。
對於鞍點或是區域性最優點,因各方向在該點的梯度均為0,因此sgd演算法沒有能力從該點逃離。為了解決這一缺陷,momentum演算法被提出。
momentum演算法模擬物理學中的動量這一概念,它模擬的是物體運動的慣性。即引數在更新時,在一定程度上保留之前更新的方向,對當前的更新方向進行微調。
記 g
tg_t
gt 為當前時刻的梯度,momentum在進行引數更新時的計算方式為
v t=
γvt−
1+αg
tv_t = \gamma v_ + \alpha g_t
vt=γv
t−1
+αgt
θ t=
θt−1
−v
t\theta_t = \theta_ - v_t
θt=θt
−1−
vt
對比sgd演算法,其優點在於:
引數的更新方向與上一時刻一致時,能增大引數的更新幅度,模型能學習的更快。
同時,該方法具有一定擺脫區域性最優的能力。
對初始學習率的選擇要求沒那麼高。
其缺陷在於:
沒解決各方向學習率相同的問題。
adagrad演算法是針對sgd在各方向學習率相同的缺點進行的改進,該演算法在進行引數更新時的計算方式為
θ t,
i=θt
−1,i
−αgt
,i+ϵ
gt,i
\theta_ = \theta_ - \frac+\epsilon}}g_
θt,i=
θt−1
,i−
gt,i
+ϵ
αgt
,i
g t,
i=∑k
=1tg
k,i2
g_ = \sum_^t^2}
gt,i=
k=1∑
tgk
,i2
其中 該演算法的缺陷在於:
缺乏擺脫區域性困境的能力。
隨著 g
tg_t
gt 的累積,訓練中後期分母將越來越大、從而梯度趨近於0,使得訓練提前結束。
rmsprop演算法是針對adagrad梯度下降過快的缺陷進行的改進,該演算法在進行引數更新時的計算方式為
θ t=
θt−1
−αgt
+ϵgt
,i
\theta_ = \theta_ - \frac+\epsilon}}g_
θt=θt
−1−
gt+
ϵα
gt,i
g t=
0.9et−
1[g2
]+
0.1gt2
g_t = 0.9e_[g^2] + 0.1g_t^2
gt=0.
9et−
1[g
2]+0
.1gt
2 et−
1[g2
]=1t
−1∑i
=1t−
1gi2
e_[g^2] = \frac \sum_^
et−1[
g2]=
t−11
i=1
∑t−1
gi2
adam將momentum和rmsprop兩種方式進行結合,使得引數更新時既有一定慣性沿著之前的方向,同時更新時可在各方向有不同的更新幅度。
該演算法在進行引數更新時的計算方式為
m t=
β1mt
−1+(
1−β1
)g
tm_t = \beta_1m_ + (1-\beta_1)g_t
mt=β1
mt−
1+(
1−β1
)gt
m ^t
=mt1
−β1=
β11−
β1mt
+g
t\hat_t = \frac = \fracm_t+g_t
m^t=1
−β1
mt
=1−β
1β1
mt
+gt
v t=
β2vt
−1+(
1−β2
)gt2
v_t = \beta_2v_ + (1-\beta_2)g_t^2
vt=β2
vt−
1+(
1−β2
)gt
2 v^t
=vt1
−β2=
β21−
β2vt
+gt2
\hat_t = \frac = \fracv_t+g_t^2
v^t=1
−β2
vt
=1−β
2β2
vt
+gt
2 θt=
θt−1
−αv^
t+ϵm
^t
\theta_t = \theta_ - \frac_t+\epsilon}} \hat_t
θt=θt
−1−
v^t
+ϵα
m^t
機器學習(三) 梯度下降法
本部落格大部分參考了這篇博文 在微積分裡面,對多元函式的引數求 偏導數,把求得的各個引數的偏導數以向量的形式寫出來,就是梯度。比如函式f x,y 分別對x,y求偏導數,求得的梯度向量就是 f x f y 簡稱gr adf x,y 或者 f x,y 如果是3個引數的向量梯度,就是 f x f y,f ...
機器學習(二) 梯度下降法
前言 在上篇博文機器學習 一 中,最後我們提到,通過計算代價函式j 是否收斂於最小值來確定假設函式的引數 進而訓練出機器學習中的線性回歸演算法,那麼如何來找到使得j 最小話的引數 呢,本篇博文將介紹一種常用的方法,梯度下降法來確定引數 值。一 對於單特徵線性回歸,梯度下降法的演算法如下 repeat...
機器學習一(梯度下降法)
最近偶觸python,感ctrl c和ctrl v無比順暢,故越發膨脹。怒拾起python資料分析一pdf讀之,不到百頁,內心惶恐,嘆 臥槽,這都tm是啥,甚是迷茫。遂感基礎知識薄弱,隨意搜了機器學習教程,小看一翻。此文給出課件中幾個演算法,自己都不知道對不對,感覺還可以吧。本文以線性回歸為例,在給...