##此處x,y為資料集的tensor
torch_dataset = data.tensordataset(data_tensor=x, target_tensor=y)
loader = data.dataloader(
dataset=torch_dataset,
batch_size=batch_size,
shuffle=true, ##是否打亂次序
num_workers=2 ##執行緒
)for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
## train .....
這裡是cnn層的**,要注意的地方是,在__init__()中定義各個層的功能,在forward()中連線起各層,要注意把conv後的tensor展平,要注意保留batch的維度,但是只在forward中考慮batch的維度,在定義的時候不需考慮,view函式的用法相當於reshape,在指定為-1的地方表示該處值自動推斷
class cnn(nn.module):
def __init__(self):
super(cnn, self).__init__()
self.conv1 = nn.sequential(#(1, 28, 28)
nn.conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),#(16, 28, 28)
nn.relu(),
nn.maxpool2d(kernel_size=2),#(16, 14, 14)
)self.conv2 = nn.sequential(
nn.conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),#(32, 14, 14)
nn.relu(),
nn.maxpool2d(kernel_size=2),#(32, 7, 7)
)self.out = nn.linear(32*7*7, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x) #(batch, 32, 7, 7)
x = x.view(x.size(0), -1) #(batch, 32*7*7)
output = self.out(x)
return output
torch.max()返回的有兩個陣列,是最大值陣列,是最大值索引陣列,所以這裡用
要從tensor取資料出來,應該用tensor.data.numpy()
要把邏輯陣列先astype為int 1/0,然後除的時候要轉化為float
for epoch in range(epoch):
for step, (x, y) in enumerate(train_loader):
output = cnn(x)
loss = loss_func(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
test_output = cnn(test_x)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
print('epoch: ', epoch,
'| train loss: %.4f' % loss.data.numpy(),
'| test accuracy: %.4f' % accuracy)
莫煩pytorch學習筆記2
類似numpy,pytorch就是在神經網路領域代替numpy的模組 神經網路在做什麼?pytorch類似tensorflow使用tensor表示高維資訊 參考pytorch環境搭建 或者看pytorch官方文件 官網命令安裝了兩個東西 可以進行一些矩陣相關的運算 莫煩莫煩 激勵函式必須使可微分的,...
莫煩 pytorch筆記 variable是什麼
variable型別是什麼 variable tensor1 torch.floattensor 1,2 3,4 建立tensor variable variable tensor1,requires grad true 建立variable。其中requires grad是誤差反向傳播 計算梯度的...
莫煩pytorch批訓練
import torch import torch.utils.data as data 包裝資料類 tensordataset 包裝資料和目標張量的資料集,通過沿著第乙個維度索引兩個張量來 class torch.utils.data.tensordataset data tensor,targe...