準備資料
定義網路結構model
定義損失函式
定義優化演算法 optimizer
(有是還要定義更新學習率:scheduler=steplr())
訓練5.1 準備好tensor形式的輸入資料和標籤(可選)
5.2 前向傳播計算網路輸出output和 計算損失函式loss
5.3 反向傳播更新引數
以下三句話一句也不能少:
5.3.1optimizer.zero_grad()將上次迭代計算的梯度值清0
5.3.2loss.backward()反向傳播,計算梯度值
5.3.3optimizer.step()更新權值引數
(schedule.step(episode)更新學習率)
5.4 儲存訓練集上的loss和驗證集上的loss以及準確率以及列印訓練資訊。(可選
圖示訓練過程中loss和accuracy的變化情況(可選)
在測試集上測試
示例**:
import torch
import torch.nn.functional as f
import matplotlib.pyplot as plt
# 1.準備資料 generate data
x=torch.unsqueeze(torch.linspace(-1
,1,100
),dim=1)
print
(x.shape)
y=x*x+
0.2*torch.rand(x.size())
#顯示資料散點圖
plt.scatter(x.data.numpy(
),y.data.numpy())
# 2.定義網路結構 build net
class
net(torch.nn.module)
:#n_feature:輸入特徵個數 n_hidden:隱藏層個數 n_output:輸出層個數
def__init__
(self,n_feature,n_hidden,n_output)
:# super表示繼承net的父類,並同時初始化父類的引數
super
(net,self)
.__init__(
)# nn.linear代表線性層 代表y=w*x+b 其中w的shape為[n_hidden,n_feature] b的shape為[n_hidden]
# y=w^t*x+b 這裡w的維度是轉置前的維度 所以是反的
self.hidden =torch.nn.linear(n_feature,n_hidden)
self.predict =torch.nn.linear(n_hidden,n_output)
print
(self.hidden.weight)
print
(self.predict.weight)
#定義乙個前向傳播過程函式
defforward
(self, x)
:# n_feature n_hidden n_output
#舉例(2,5,1) 2 5 1
# - ** -
# ** - - - ** - -
# - ** - - - **
# ** - - - ** - -
# - ** -
# 輸入層 隱藏層 輸出層
x=f.relu(self.hidden(x)
) x=self.predict(x)
return x
# 例項化乙個網路為net
net = net(n_feature=
1,n_hidden=
10,n_output=1)
print
(net)
# 3.定義損失函式 這裡使用均方誤差(mean square error)
loss_func=torch.nn.mseloss(
)# 4.定義優化器 這裡使用隨機梯度下降
optimizer=torch.optim.sgd(net.parameters(
),lr=
0.2)
#定義300遍更新 每10遍顯示一次
plt.ion(
)# 5.訓練
for t in
range
(100):
prediction = net(x)
# input x and predict based on x
loss = loss_func(prediction, y)
# must be (1. nn output, 2. target)
# 5.3反向傳播三步不可少
optimizer.zero_grad(
)# clear gradients for next train
loss.backward(
)# backpropagation, compute gradients
optimizer.step(
)if t %
10==0:
# plot and show learning process
plt.cla(
) plt.scatter(x.data.numpy(
), y.data.numpy())
plt.plot(x.data.numpy(
), prediction.data.numpy(),
'r-'
, lw=5)
plt.text(
0.5,0,
'loss=%.4f'
% loss.data.numpy(
), fontdict=
) plt.show(
) plt.pause(
0.1)
plt.ioff(
)
參考:pytorch基礎-搭建網路 pytorch基礎 搭建網路
搭建網路的步驟大致為以下 1.準備資料 2.定義網路結構model 3.定義損失函式 4.定義優化演算法 optimizer 5.訓練 5.1 準備好tensor形式的輸入資料和標籤 可選 5.2 前向傳播計算網路輸出output和計算損失函式loss 5.3 反向傳播更新引數 以下三句話一句也不能...
R seau Donn e 搭建網路
reseu donnnee這門基本處於學一回忘一回的階段,這次,趁還沒忘利索之前,趕緊寫下來,為以後用著的時候存著。網路的組成 client1 communateur routeur1 routeaur2 communateur client2 配置ip 1.sudo ifconfig eth0 1...
搭建網路源
搭建本地源 1.mount o loop home centos 7 x86 64 everything 1708.iso mnt sr0 掛載檔案到mnt下的sr0,如果沒有sr0可以自己建乙個 2.lsblk可以檢視到掛載的資訊 3.vi etc yum.repos.d centos base....