本文採用全連線網路對mnist資料集進行訓練,訓練模型主要由五個線性單元和relu啟用函式組成
結果: 經過7論訓練測試集可以達到97%import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import dataloader
import torch.nn.functional as f
import torch.optim as optim
import os
import sys
batch_size = 64
transform = transforms.compose(
[transforms.totensor(), #將0-255變成0-1
transforms.normalize((0.1307,),(0.3081,)) #正則化
])train_dataset =datasets.mnist(root='../dataset/mnist',
train = true,
download = true,
transform = transform)
train_loader = dataloader(train_dataset,
shuffle=true,
batch_size=batch_size)
test_dataset =datasets.mnist(root='../dataset/mnist',
train = false,
download = true,
transform = transform)
test_loader = dataloader(test_dataset,
shuffle = false,
batch_size=batch_size)
class net(torch.nn.module):
def __init__(self):
super(net,self).__init__()
self.f1 = torch.nn.linear(784,512)
self.f2 = torch.nn.linear(512,256)
self.f3 = torch.nn.linear(256,128)
self.f4 = torch.nn.linear(128,64)
self.f5 = torch.nn.linear(64,10)
def forward(self,x):
#這裡將
x = x.view(-1,784) #展成1*784
x = f.relu(self.f1(x))
x = f.relu(self.f2(x))
x = f.relu(self.f3(x))
x = f.relu(self.f4(x))
return self.f5(x)
model = net()
#loss--交叉熵
criterion = torch.nn.crossentropyloss()
#帶衝量
optimzer = optim.sgd(model.parameters(),lr=0.01,momentum = 0.5)
#訓練def train(epoch):
running_loss =0.0
for batch_idx,data in enumerate(train_loader):
inputs,target = data
optimzer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs,target)
loss.backward()
optimzer.step()
running_loss += loss.item()
if batch_idx%300==299:
print('[%d,%5d] loss:%.3f' % (epoch+1,batch_idx+1,running_loss/300))
running_loss = 0.0
def test():
correct = 0
total = 0
with torch.no_grad():
for batch_idx,data in enumerate(test_loader):
images,labels = data
outputs = model(images)
_,predicted = torch.max(outputs.data,dim=1)
total += labels.size(0)
correct += (predicted==labels).sum().item()
print('accuracy on test set:%d %%' % (100*correct/total))
if __name__== '__main__':
for epoch in range(7):
train(epoch)
test()
#儲存網路引數

pytorch使用GPU訓練MNIST資料集
參考莫凡部落格進行mnist資料集的訓練,臨時記錄所使用的 import torch import torch.nn as nn import torch.utils.data as data import torchvision import matplotlib.pyplot as plt to...
使用matlab訓練mnist模型
前面的博文是通過命令進行mnist模型訓練與測試的,由於實驗需要,想要通過matlab語句來實現mnist模型的訓練,從而把這種方式用於其他問題模型的訓練與測試。1 準備資料與引數 因為matlab程式檔案是在matlab demo下,為了方便,直接把需要的檔案拷貝到demo下 mnist data...
TensorFlow 訓練 MNIST 資料(二)
輸入層 卷積層 卷積層 密集連線層 輸出層。其中每乙個卷積層中還有max pooling,用來進行降維,輸出層中是乙個softmax層。首先這次構建的神經網路相較上篇的神經網路來說,上次的權重矩陣和偏置矩陣直接設定為0,但是存在乙個問題就是容易導致神經元輸出恒為零的情況出現,由於是對稱的容易導致0梯...