搭建模型
定義計算步驟
輸出運算結果
本節主要針對mnist資料集的數字識別問題,寫出乙個解決回歸問題的方法。初步體會機器學習的工作流程
import torch
from torch import nn
from torch.nn import functional as f
from torch import optim
import torchvision
from matplotlib import pyplot as plt
#畫圖專用的檔案
from utils import plot_image, plot_curve, one_hot
batch_size =
512# step1. load dataset載入資料集
train_loader = torch.utils.data.dataloader(
torchvision.datasets.mnist(
'mnist_data'
, train=
true
, download=
true
, transform=torchvision.transforms.compose(
[ torchvision.transforms.totensor(),
torchvision.transforms.normalize(
(0.1307,)
,(0.3081,)
)]))
, batch_size=batch_size, shuffle=
true
)test_loader = torch.utils.data.dataloader(
torchvision.datasets.mnist(
'mnist_data/'
, train=
false
, download=
true
, transform=torchvision.transforms.compose(
[ torchvision.transforms.totensor(),
torchvision.transforms.normalize(
(0.1307,)
,(0.3081,)
)]))
, batch_size=batch_size, shuffle=
false
)
x, y =
next
(iter
(train_loader)
)print
(x.shape, y.shape, x.
min(
), x.
max())
plot_image(x, y,
'image sample'
)
class
net(nn.module)
:def
__init__
(self)
:super
(net, self)
.__init__(
)# xw+b
self.fc1 = nn.linear(28*
28,256)
self.fc2 = nn.linear(
256,64)
self.fc3 = nn.linear(64,
10)defforward
(self, x)
:# x: [b, 1, 28, 28]
# h1 = relu(xw1+b1) 公式
x = f.relu(self.fc1(x)
)# h2 = relu(h1w2+b2) 公式
x = f.relu(self.fc2(x)
)# h3 = h2w3+b3 公式
x = self.fc3(x)
return x
net = net(
)# [w1, b1, w2, b2, w3, b3]
#優化器
optimizer = optim.sgd(net.parameters(
), lr=
0.01
, momentum=
0.9)
#記錄loss
train_loss =
for epoch in
range(3
):for batch_idx,
(x, y)
inenumerate
(train_loader)
:# x: [b, 1, 28, 28], y: [512]
# [b, 1, 28, 28] => [b, 784] 從四維變換成二維
x = x.view(x.size(0)
,28*28
)# => [b, 10]
out = net(x)
# [b, 10]
y_onehot = one_hot(y)
# loss = mse(out, y_onehot)
loss = f.mse_loss(out, y_onehot)
# 清零梯度
optimizer.zero_grad(
) loss.backward(
)# w' = w - lr*grad 梯度更新
optimizer.step())
)# 輸出
if batch_idx %
10==0:
print
(epoch+
1, batch_idx, loss.item())
plot_curve(train_loss)
# we get optimal [w1, b1, w2, b2, w3, b3]
plot_curve(train_loss)
# we get optimal [w1, b1, w2, b2, w3, b3]
total_correct =
0for x,y in test_loader:
x = x.view(x.size(0)
,28*28
) out = net(x)
# out: [b, 10] => pred: [b]
pred = out.argmax(dim=1)
correct = pred.eq(y)
.sum()
.float()
.item(
) total_correct += correct
total_num =
len(test_loader.dataset)
acc = total_correct / total_num
print
('test acc:'
, acc)
pytorch學習筆記3 線性回歸
線性回歸線性回歸 是分析乙個變數與另外乙個 多 個變數之間關係的方法 因變數 y 自變數 x 關係 線性 y wx b 求解w,b 求解步驟 1.確定模型 module y wx b 2.選擇損失函式 mse 均方差等 3.求解梯度並更新w,b w w lr w.grad b b lr w.grad...
Pytorch 線性回歸問題
y 4 3x 高斯雜訊 我們利用線性回歸原理,假設y wx b,利用梯度下降法,去求解w,b。驗證w,b是否比較接近w 3,b 4 計算loss function函式 loss sum y w x b 2 定義loss functiondef computre error loss function...
pytorch碎碎念 回歸問題
根據b站莫煩python邊學邊打,只有自己打一遍才能發現容易發生好多錯誤啊 昨晚配置pytorch很順利!一遍就好了 環境 cuda10.0 python3.7 pytorch1.2.0 gpu 1660ti 現在已經有了pytorch1.4.0 似乎tensor和variable的用法有了改變 但...