練習pytorch,做個記錄。寫的有點亂
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data.dataloader import dataloader
from torchsummary import summary
device=torch.device('cuda')
batch_size=512
learning_rate=0.01
epochs=50
train_data=torchvision.datasets.cifar10(root='../pytorch_data',train=true,
transform=transforms.compose([transforms.totensor()]),
download=false
)train_loader=dataloader(dataset=train_data,batch_size=batch_size,shuffle=true)
test_data=torchvision.datasets.cifar10(root='../pytorch_data',train=false,
transform=transforms.compose([transforms.totensor()]),
download=false
)test_loader=dataloader(dataset=test_data,batch_size=batch_size,shuffle=true)
class mynet(nn.module):
def __init__(self):
super(mynet, self).__init__()
self.layer=nn.sequential(
nn.conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=0),
nn.batchnorm2d(64),
nn.relu(),
nn.conv2d(64,128,3,2),
nn.batchnorm2d(128),
nn.relu(),
nn.conv2d(128,128,3,2),
nn.batchnorm2d(128),
nn.relu()
)self.fc1=nn.linear(128*6*6,512)
self.fc2=nn.linear(512,10)
def forward(self,x):
x=self.layer(x)
x=x.view(x.size(0),-1)
x=self.fc1(x)
x=self.fc2(x)
return x
net=mynet().to(device)
summary(net,(3,32,32))
optimizer=torch.optim.sgd(net.parameters(),lr=learning_rate,momentum=0.9)
criterion=nn.crossentropyloss()
#訓練for epoch in range(epochs):
net.train()
for index,(x,label) in enumerate(train_loader):
x=x.to(device)
label=label.to(device)
output=net(x)
loss=criterion(output,label)
#梯度清零
optimizer.zero_grad()
#梯度計算
loss.backward()
#梯度更新
optimizer.step()
if index % 1024 ==0:
print('訓練的epoch為: {} [{}/{} (%)]\tloss: '.format(
epoch,index*len(x),len(train_loader.dataset),
100. * index / len(train_loader),loss.item()))
#測試print('開始測試')
net.eval()
#加上這個更加安全,不需要反向傳播
with torch.no_grad():
total_correct=0
total_num=0
for index,(x,label) in enumerate(test_loader):
x=x.to(device)
label=label.to(device)
#[batch_size,10]
output=net(x)
#返回的是最大的索引 [b]
pred=output.argmax(dim=1)
total_correct += torch.eq(pred,label).float().sum().item()
total_num += x.size(0)
acc=total_correct / total_num
print("準確率為: ",acc)
檢視資料的格式
for index,(x,label) in enumerate(train_loader):
x=x.to(device)
label=label.to(device)
print('x',x.data.size())
print('label',label.data.size())
break
50次訓練結果,不太好。改進地方還很多。
在資料集讀取中加入normalize
test_data=torchvision.datasets.cifar10(root='../pytorch_data',train=false,
transform=transforms.compose([transforms.totensor(),
transforms.normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])]),
有一點點提公升,看來還是要改變網路結構,增加訓練的次數。
利用pytorch對CIFAR 10資料集的分類
步驟如下 1.使用torchvision載入並預處理cifar 10資料集 2.定義網路 3.定義損失函式和優化器 4.訓練網路並更新網路引數 5.測試網路 執行環境 windows python3.6.3 pycharm pytorch0.3.0 import torchvision as tv ...
Pytorch 通過pytorch實現線性回歸
linear regression 線性回歸是分析乙個變數與另外乙個 多個 變數之間關係的方法 因變數 y 自變數 x 關係 線性 y wx b 分析 求解w,b 求解步驟 1.確定模型 2.選擇損失函式 3.求解梯度並更新w,b 此題 1.model y wx b 下為 實現 import tor...
Pytorch學習 1 pytorch簡介
pytorch簡介 1 pytorch簡介 1.1 pytorch的大概 pytorch不是簡單的封裝 lua torch 提供python介面,而是對當下tensor之上的模組進行重構,並增加了最先進的自動求導系統,成為當下最流行的動態圖框架。pytorch是乙個基於torch的python開源機...