對pytorch中優化器進行乙個簡單的例項進行比較說明:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as f
import matplotlib.pyplot as plt
# 超引數
lr =
0.1batch_size =
32epoch =
12
# 生成訓練資料
# torch.unsqueeze()的作用是將一維變成二維,torch只能處理二維的資料
x = torch.unsqueeze(torch.linspace(-1
,1,1000
), dim=1)
# 0.1 * torch.normal(torch.zeros(*x.size())為增加噪點
y = x.
pow(2)
+0.1
* torch.normal(torch.zeros(
*x.size())
)# tensordataset是將樣本和標籤打包成dataset
torch_dataset = data.tensordataset(x, y)
# 得到乙個大批量的生成器
# dataloader組合資料集和取樣器
loader = data.dataloader(dataset=torch_dataset, batch_size=batch_size, shuffle=
true
)
class
net(nn.module)
:# 初始化
def__init__
(self)
:super
(net, self)
.__init__(
) self.hidden = nn.linear(1,
20)self.predict = nn.linear(20,
1)# 前向傳播
defforward
(self, x)
: x = f.relu(self.hidden(x)
) x = self.predict(x)
return x
# 使用多種優化器
net_sgd = net(
)net_momentum = net(
)net_rmsprop = net(
)net_adam = net(
)# 裝進乙個列表裡
nets =
[net_sgd, net_momentum, net_rmsprop, net_adam]
opt_sgd = torch.optim.sgd(net_sgd.parameters(
), lr=lr)
opt_momentum = torch.optim.sgd(net_momentum.parameters(
), lr=lr, momentum=
0.9)
opt_rmsprop = torch.optim.rmsprop(net_rmsprop.parameters(
), lr=lr, alpha=
0.9)
opt_adam = torch.optim.adam(net_adam.parameters(
), lr=lr, betas=
(0.9
,0.99))
optimizers =
[opt_sgd, opt_momentum, opt_rmsprop, opt_adam]
# 訓練模型
# 呼叫均方損失函式
loss_func = torch.nn.mseloss(
)loss_his =[[
],,,
]for epoch in
range
(epoch)
:for step,
(batch_x, batch_y)
inenumerate
(loader)
:for net, opt, l_his in
zip(nets, optimizers, loss_his)
:# 從每乙個網路裡獲取輸出
output = net(batch_x)
# 計算每乙個網路的損失
loss = loss_func(output, batch_y)
# 梯度清零
opt.zero_grad(
)# 反向傳播
loss.backward(
)print
(loss)
# 更新引數
opt.step())
)labels =
["sgd"
,"momentum"
,"rmsprop"
,"adam"
]
# 視覺化結果
for i, l_his in
enumerate
(loss_his)
: plt.plot(l_his,label=labels[i]
)# print(l_his)
plt.legend(loc=
"best"
)plt.xlabel(
"steps"
)plt.ylabel(
"loss"
)plt.ylim((0
,0.2))
plt.show(
)
果然還是adam比較好。 PyTorch常見的優化器
用法pytorch學習率調整策略通過torch.optim.lr scheduler介面實現。torch.optim是乙個實現了各種優化演算法的庫。大部分常用的方法得到支援,並且介面具備足夠的通用性,使得未來能夠整合更加複雜的方法。參考連線 首先需要構建乙個optimizer物件。這個物件能夠保持當...
Pytorch中adam優化器的引數問題
之前用的adam優化器一直是這樣的 alpha optim torch.optim.adam model.alphas config.alpha lr,betas 0.5,0.999 weight decay config.alpha weight decay 沒有細想內部引數的問題,但是最近的工作...
PyTorch自定義優化器
簡單粗暴的方法直接更新引數 def myopt pre pre儲存當前梯度與歷史梯度方向是否一致的資訊 lr lr儲存各層各引數學習率 vdw vdw儲存各層各引數動量 y pred net x loss loss func y pred,y net.zero grad loss.backward ...