warmup 預熱學習率

2022-02-20 01:47:02 字數 2033 閱讀 1707

學習率是神經網路訓練中最重要的超引數之一,針對學習率的優化方式很多,warmup是其中的一種。

(一)、什麼是warmup?warmup是在resnet**中提到的一種學習率預熱的方法,它在訓練開始的時候先選擇使用乙個較小的學習率,訓練了一些epoches或者steps(比如4個epoches,10000steps),再修改為預先設定的學習率來進行訓練。

(二)、為什麼使用warmup?

由於剛開始訓練時,模型的權重(weights)是隨機初始化的,此時若選擇乙個較大的學習率,可能帶來模型的不穩定(振盪),選擇warmup預熱學習率的方式,可以使得開始訓練的幾個epoches或者一些steps內學習率較小,在預熱的小學習率下,模型可以慢慢趨於穩定,等模型相對穩定後再選擇預先設定的學習率進行訓練,使得模型收斂速度變得更快,模型效果更佳。

exampleexampleexample:resnet**中使用乙個110層的resnet在cifar10上訓練時,先用0.01的學習率訓練直到訓練誤差低於80%(大概訓練了400個steps),然後使用0.1的學習率進行訓練。

(三)、warmup的改進

(二)所述的warmup是constant warmup,它的不足之處在於從乙個很小的學習率一下變為比較大的學習率可能會導致訓練誤差突然增大。於是18年facebook提出了gradual warmup來解決這個問題,即從最初的小學習率開始,每個step增大一點點,直到達到最初設定的比較大的學習率時,採用最初設定的學習率進行訓練。

1.gradual warmup的實現模擬**如下:

"""

implements gradual warmup, if train_steps < warmup_steps, the

learning rate will be `train_steps/warmup_steps * init_lr`.

args:

warmup_steps:warmup步長閾值,即train_steps"""

import

numpy as np

warmup_steps = 2500init_lr = 0.1

#模擬訓練15000步

max_steps = 15000

for train_steps in

range(max_steps):

if warmup_steps and train_steps

warmup_percent_done = train_steps /warmup_steps

warmup_learning_rate = init_lr * warmup_percent_done #

gradual warmup_lr

learning_rate =warmup_learning_rate

else

:

#learning_rate = np.sin(learning_rate) #預熱學習率結束後,學習率呈sin衰減

learning_rate = learning_rate**1.0001 #

預熱學習率結束後,學習率呈指數衰減(近似模擬指數衰減)

if (train_steps+1) % 100 ==0:

print("

train_steps:%.3f--warmup_steps:%.3f--learning_rate:%.3f

" %(

train_steps+1,warmup_steps,learning_rate))

2.上述**實現的warmup預熱學習率以及學習率預熱完成後衰減(sin or exp decay)的曲線圖如下:

(四)總結

使用warmup預熱學習率的方式,即先用最初的小學習率訓練,然後每個step增大一點點,直到達到最初設定的比較大的學習率時(注:此時預熱學習率完成),採用最初設定的學習率進行訓練(注:預熱學習率完成後的訓練過程,學習率是衰減的),有助於使模型收斂速度變快,效果更佳。

warmup 預熱學習率

目錄 一 什麼是warmup?二 為什麼使用warmup?三 warmup的改進 四 總結 學習率是神經網路訓練中最重要的超引數之一,針對學習率的優化方式很多,warmup是其中的一種。warmup是在resnet 中提到的一種學習率預熱的方法,它在訓練開始的時候先選擇使用乙個較小的學習率,訓練了一...

BackboneJS入門學習 01 預熱

今天這篇blog是我自學backbonejs的第一篇,一些學習心得,一些學習筆記。如有不妥之處,盡請批評指教。儘管backbonejs出來的好久,但是它在web前端mvc上得到廣泛應用。尤其是在單頁面應用程式上。比如豆瓣之前的阿爾法城等。它可以徹底地將html從js中分離出來,上顯得簡潔。首先,ba...

學習率的作用,學習率衰減,函式

目錄 1.學習率的作用 2.學習率衰減常用引數有哪些 3.常見衰減函式 3.1分段常數衰減 3.2指數衰減 3.3自然指數衰減 3.4多項式衰減 3.5余弦衰減 梯度下降法是乙個廣泛被用來最小化模型誤差的引數優化演算法。梯度下降法通過多次迭代,並在每一步中最小化成本函式 cost 來估計模型的引數。...