搭建網路的步驟大致為以下:
1.準備資料
2. 定義網路結構model
3. 定義損失函式
4. 定義優化演算法 optimizer
5. 訓練
5.1 準備好tensor形式的輸入資料和標籤(可選)
5.2 前向傳播計算網路輸出output和計算損失函式loss
5.3 反向傳播更新引數
以下三句話一句也不能少:
5.3.1 optimizer.zero_grad() 將上次迭代計算的梯度值清0
5.3.2 loss.backward() 反向傳播,計算梯度值
5.3.3 optimizer.step() 更新權值引數
5.4 儲存訓練集上的loss和驗證集上的loss以及準確率以及列印訓練資訊。(可選
6. 圖示訓練過程中loss和accuracy的變化情況(可選)
7. 在測試集上測試
**注釋都寫的很詳細
1import
torch
2import
torch.nn.functional as f
3import
matplotlib.pyplot as plt45
#1.準備資料 generate data
6 x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
7print
(x.shape)
8 y=x*x+0.2*torch.rand(x.size())9#
顯示資料散點圖
10plt.scatter(x.data.numpy(),y.data.numpy())
1112
#2.定義網路結構 build net
13class
net(torch.nn.module):14#
n_feature:輸入特徵個數 n_hidden:隱藏層個數 n_output:輸出層個數
15def
__init__
(self,n_feature,n_hidden,n_output):16#
super表示繼承net的父類,並同時初始化父類的引數
17 super(net,self).__init__
()18
#nn.linear代表線性層 代表y=w*x+b 其中w的shape為[n_hidden,n_feature] b的shape為[n_hidden]19#
y=w^t*x+b 這裡w的維度是轉置前的維度 所以是反的
20 self.hidden =torch.nn.linear(n_feature,n_hidden)
21 self.predict =torch.nn.linear(n_hidden,n_output)
22print
(self.hidden.weight)
23print
(self.predict.weight)24#
定義乙個前向傳播過程函式
25def
forward(self, x):26#
n_feature n_hidden n_output27#
舉例(2,5,1) 2 5 128#
- ** -29#
** - - - ** - -30#
- ** - - - **31#
** - - - ** - -32#
- ** -33#
輸入層 隱藏層 輸出層
34 x=f.relu(self.hidden(x))
35 x=self.predict(x)
36returnx37
#例項化乙個網路為net
38 net = net(n_feature=1,n_hidden=10,n_output=1)
39print
(net)40#
3.定義損失函式 這裡使用均方誤差(mean square error)
41 loss_func=torch.nn.mseloss()42#
4.定義優化器 這裡使用隨機梯度下降
43 optimizer=torch.optim.sgd(net.parameters(),lr=0.2)44#
定義300遍更新 每10遍顯示一次
45plt.ion()46#
5.訓練
47for t in range(100):
48 prediction = net(x) #
input x and predict based on x
49 loss = loss_func(prediction, y) #
must be (1. nn output, 2. target)50#
5.3反向傳播三步不可少
51 optimizer.zero_grad() #
clear gradients for next train
52 loss.backward() #
backpropagation, compute gradients
53 optimizer.step() #
5455
if t % 10 ==0:56#
plot and show learning process
57plt.cla()
58plt.scatter(x.data.numpy(), y.data.numpy())
59 plt.plot(x.data.numpy(), prediction.data.numpy(), '
r-', lw=5)
60 plt.text(0.5, 0, '
loss=%.4f
' % loss.data.numpy(), fontdict=)
61plt.show()
62 plt.pause(0.1)
6364 plt.ioff()
參考:莫煩python
pytorch 搭建網路步驟
準備資料 定義網路結構model 定義損失函式 定義優化演算法 optimizer 有是還要定義更新學習率 scheduler steplr 訓練5.1 準備好tensor形式的輸入資料和標籤 可選 5.2 前向傳播計算網路輸出output和 計算損失函式loss 5.3 反向傳播更新引數 以下三句...
R seau Donn e 搭建網路
reseu donnnee這門基本處於學一回忘一回的階段,這次,趁還沒忘利索之前,趕緊寫下來,為以後用著的時候存著。網路的組成 client1 communateur routeur1 routeaur2 communateur client2 配置ip 1.sudo ifconfig eth0 1...
搭建網路源
搭建本地源 1.mount o loop home centos 7 x86 64 everything 1708.iso mnt sr0 掛載檔案到mnt下的sr0,如果沒有sr0可以自己建乙個 2.lsblk可以檢視到掛載的資訊 3.vi etc yum.repos.d centos base....