Pytorch實現CIFAR 10資料集

2021-10-03 07:33:39 字數 3422 閱讀 2198

練習pytorch,做個記錄。寫的有點亂

import torch

import torchvision

import torch.nn as nn

from torchvision import transforms

from torch.utils.data.dataloader import dataloader

from torchsummary import summary

device=torch.device('cuda')

batch_size=512

learning_rate=0.01

epochs=50

train_data=torchvision.datasets.cifar10(root='../pytorch_data',train=true,

transform=transforms.compose([transforms.totensor()]),

download=false

)train_loader=dataloader(dataset=train_data,batch_size=batch_size,shuffle=true)

test_data=torchvision.datasets.cifar10(root='../pytorch_data',train=false,

transform=transforms.compose([transforms.totensor()]),

download=false

)test_loader=dataloader(dataset=test_data,batch_size=batch_size,shuffle=true)

class mynet(nn.module):

def __init__(self):

super(mynet, self).__init__()

self.layer=nn.sequential(

nn.conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=0),

nn.batchnorm2d(64),

nn.relu(),

nn.conv2d(64,128,3,2),

nn.batchnorm2d(128),

nn.relu(),

nn.conv2d(128,128,3,2),

nn.batchnorm2d(128),

nn.relu()

)self.fc1=nn.linear(128*6*6,512)

self.fc2=nn.linear(512,10)

def forward(self,x):

x=self.layer(x)

x=x.view(x.size(0),-1)

x=self.fc1(x)

x=self.fc2(x)

return x

net=mynet().to(device)

summary(net,(3,32,32))

optimizer=torch.optim.sgd(net.parameters(),lr=learning_rate,momentum=0.9)

criterion=nn.crossentropyloss()

#訓練for epoch in range(epochs):

net.train()

for index,(x,label) in enumerate(train_loader):

x=x.to(device)

label=label.to(device)

output=net(x)

loss=criterion(output,label)

#梯度清零

optimizer.zero_grad()

#梯度計算

loss.backward()

#梯度更新

optimizer.step()

if index % 1024 ==0:

print('訓練的epoch為: {} [{}/{} (%)]\tloss: '.format(

epoch,index*len(x),len(train_loader.dataset),

100. * index / len(train_loader),loss.item()))

#測試print('開始測試')

net.eval()

#加上這個更加安全,不需要反向傳播

with torch.no_grad():

total_correct=0

total_num=0

for index,(x,label) in enumerate(test_loader):

x=x.to(device)

label=label.to(device)

#[batch_size,10]

output=net(x)

#返回的是最大的索引 [b]

pred=output.argmax(dim=1)

total_correct += torch.eq(pred,label).float().sum().item()

total_num += x.size(0)

acc=total_correct / total_num

print("準確率為: ",acc)

檢視資料的格式

for index,(x,label) in enumerate(train_loader):

x=x.to(device)

label=label.to(device)

print('x',x.data.size())

print('label',label.data.size())

break

50次訓練結果,不太好。改進地方還很多。 

在資料集讀取中加入normalize

test_data=torchvision.datasets.cifar10(root='../pytorch_data',train=false,

transform=transforms.compose([transforms.totensor(),

transforms.normalize(mean=[0.485,0.456,0.406],

std=[0.229,0.224,0.225])]),

有一點點提公升,看來還是要改變網路結構,增加訓練的次數。

利用pytorch對CIFAR 10資料集的分類

步驟如下 1.使用torchvision載入並預處理cifar 10資料集 2.定義網路 3.定義損失函式和優化器 4.訓練網路並更新網路引數 5.測試網路 執行環境 windows python3.6.3 pycharm pytorch0.3.0 import torchvision as tv ...

Pytorch 通過pytorch實現線性回歸

linear regression 線性回歸是分析乙個變數與另外乙個 多個 變數之間關係的方法 因變數 y 自變數 x 關係 線性 y wx b 分析 求解w,b 求解步驟 1.確定模型 2.選擇損失函式 3.求解梯度並更新w,b 此題 1.model y wx b 下為 實現 import tor...

Pytorch學習 1 pytorch簡介

pytorch簡介 1 pytorch簡介 1.1 pytorch的大概 pytorch不是簡單的封裝 lua torch 提供python介面,而是對當下tensor之上的模組進行重構,並增加了最先進的自動求導系統,成為當下最流行的動態圖框架。pytorch是乙個基於torch的python開源機...