import torch
import matplotlib as plt
n_data=torch.ones(100,2)
#代表形成了乙個100行2列為1的資料
x0=torch.normal(2*n_data,1)
#表示均值為2,標準差為為1
y0=torch.zeros(100)
#定義資料型別為0
x1=torch.normal(-2*n_data,1)
#表示均值為-2,標準差為1
y1=torch.ones(100)
#定義資料型別為1
x=torch.cat((x0,x1),0).type(torch.floattensor)
y=torch.cat((y0,y1),).type(torch.longtensor)
# 注意 x, y 資料的資料形式是一定要像下面一樣 (torch.cat 是在合併資料),type後面接的是資料型別
建立神經網路
import torch
import torch.nn.functional as f # 激勵函式都在這
class net(torch.nn.module): # 繼承 torch 的 module
def __init__(self, n_feature, n_hidden, n_output):
super(net, self).__init__() # 繼承 __init__ 功能
self.hidden = torch.nn.linear(n_feature, n_hidden) # 隱藏層線性輸出
self.out = torch.nn.linear(n_hidden, n_output) # 輸出層線性輸出
def forward(self, x):
# 正向傳播輸入值, 神經網路分析出輸出值
x = f.relu(self.hidden(x)) # 激勵函式(隱藏層的線性值)
x = self.out(x) # 輸出值, 但是這個不是**值, **值還需要再另外計算
return x
net = net(n_feature=2, n_hidden=10, n_output=2) # 幾個類別就幾個 output
(2)訓練網路
# optimizer 是訓練的工具
optimizer = torch.optim.sgd(net.parameters(), lr=0.02) # 傳入 net 的所有引數, 學習率
# 算誤差的時候, 注意真實值!不是! one-hot 形式的, 而是1d tensor, (batch,)
# 但是**值是2d tensor (batch, n_classes)
loss_func = torch.nn.crossentropyloss()
for t in range(100):
out = net(x) # 餵給 net 訓練資料 x, 輸出分析值
loss = loss_func(out, y) # 計算兩者的誤差
optimizer.zero_grad() # 清空上一步的殘餘更新引數值
loss.backward() # 誤差反向傳播, 計算引數更新值
optimizer.step() # 將引數更新值施加到 net 的 parameters 上
(四)視覺化訓練過程
import matplotlib.pyplot as plt
plt.ion() # 畫圖
plt.show()
for t in range(100):
loss.backward()
optimizer.step()
# 接著上面來
if t % 2 == 0:
plt.cla()
# 過了一道 softmax 的激勵函式後的最大概率才是**值
prediction = torch.max(f.softmax(out), 1)[1]
pred_y = prediction.data.numpy().squeeze()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='rdylgn')
accuracy = sum(pred_y == target_y)/200. # **中有多少和真實值一樣
plt.text(1.5, -4, 'accuracy=%.2f' % accuracy, fontdict=)
plt.pause(0.1)
plt.ioff() # 停止畫圖
plt.show()
【參考文獻】、
PyTorch分類神經網路
這次我們也是用最簡單的途徑來看看神經網路是怎麼進行事物的分類.我們建立一些假資料來模擬真實的情況.比如兩個二次分布的資料,不過他們的均值都不一樣.import torch import matplotlib.pyplot as plt 假資料 n data torch.ones 100,2 資料的基...
RNN 迴圈神經網路 分類 pytorch
import torch from torch import nn import torchvision.datasets as dsets import torchvision.transforms as transforms import matplotlib.pyplot as plt imp...
pytorch(八) RNN迴圈神經網路 分類
import torch import torch.nn as nn import torchvision.transforms as transforms from torch.autograd import variable import matplotlib.pyplot as plt imp...