05 pytorch 神經網路分類問題

2022-05-05 03:51:12 字數 2119 閱讀 5736

import torch 

from torch.autograd import variable

import torch.nn.functional as f

import numpy as np

n_data = torch.ones(100,2) # 列印[100,2]矩陣的1

# 第乙個資料集

x0 = torch.normal(2*n_data,1)

y0 = torch.zeros(100)

# 第二個資料集

x1 = torch.normal(-2*n_data,1)

y1 = torch.ones(100)

# 合併資料集 --> 合併 並改變格式

x = torch.cat((x0,x1),0).type(torch.floattensor) # 32位浮點數

y = torch.cat((y0,y1)).type(torch.longtensor) # 64 位整型

tensor([[-1.8586, -2.7746],

[-2.8297, -2.1551],

[-2.4832, -2.2842],

[-1.9556, -1.9917],

[-2.5398, -2.1877],

[-2.6220, -2.5604]])

定義乙個神經網路(用於分類)

class net(torch.nn.module):

def __init__(self,n_feature,n_hidden,n_output):

super(net,self).__init__()

self.hidden = torch.nn.linear(n_feature,n_hidden)

self.predict = torch.nn.linear(n_hidden,n_output)

pass

def forward(self,x):

x = f.relu(self.hidden(x))

x =self.predict(x)

return x

分類的時候使用 crossentropyloss() 概率誤差比較好

net = net(2,10,2)

print(net)

optimizer = torch.optim.sgd(net.parameters(),lr=0.1)

loss_func = torch.nn.crossentropyloss() # 標籤誤差

net(

(hidden): linear(in_features=2, out_features=10, bias=true)

(predict): linear(in_features=10, out_features=2, bias=true)

)

for i in range(100):

prediction = net(x)

loss = loss_func(prediction,y)

# 梯度歸零

optimizer.zero_grad()

# 計算梯度

loss.backward()

# 更新結點

optimizer.step()

if i % 20 == 0:

print(loss)

tensor(0.5676, grad_fn=)

tensor(0.0800, grad_fn=)

tensor(0.0339, grad_fn=)

tensor(0.0204, grad_fn=)

tensor(0.0143, grad_fn=)

x1 = torch.floattensor([2,2])

x1 = variable(x1)

# 這樣可以是實現**

np.argmax(net(x1).data.numpy)

PyTorch分類神經網路

這次我們也是用最簡單的途徑來看看神經網路是怎麼進行事物的分類.我們建立一些假資料來模擬真實的情況.比如兩個二次分布的資料,不過他們的均值都不一樣.import torch import matplotlib.pyplot as plt 假資料 n data torch.ones 100,2 資料的基...

pytorch動態神經網路(分類)

import torch import matplotlib as plt n data torch.ones 100,2 代表形成了乙個100行2列為1的資料 x0 torch.normal 2 n data,1 表示均值為2,標準差為為1 y0 torch.zeros 100 定義資料型別為0 ...

RNN 迴圈神經網路 分類 pytorch

import torch from torch import nn import torchvision.datasets as dsets import torchvision.transforms as transforms import matplotlib.pyplot as plt imp...