資料集載入

2021-10-05 18:47:45 字數 3274 閱讀 3017

使用標準python包將資料載入成numpy陣列格式,然後轉換成torch.*tensor

自定義資料集製作與使用

可以載入資料並歸一化

#torchvision資料集的輸出是範圍在[0,1]之間的pilimage,我們將其轉化成歸一化範圍為[-1,1]之間的張量tensors.

import torch

import torchvision

import torchvision.transforms as transforms

transform = transforms.compose(

[transforms.totensor(),

#變成張量

transforms.normalize(

(0.5

,0.5

,0.5),

(0.5

,0.5

,0.5))

])#歸一化

trainset = torchvision.datasets.cifar10(root=

'./data'

, train=

true

,#匯入訓練集

download=

true

, transform=transform)

trainloader = torch.utils.data.dataloader(trainset, batch_size=4,

#dataloader封裝幾個為一批

shuffle=

true

, num_workers=2)

testset = torchvision.datasets.cifar10(root=

'./data'

, train=

false

, download=

true

, transform=transform)

testloader = torch.utils.data.dataloader(testset, batch_size=4,

shuffle=

false

, num_workers=2)

classes =

('plane'

,'car'

,'bird'

,'cat'

,'deer'

,'dog'

,'frog'

,'horse'

,'ship'

,'truck'

)

?載入資料——定義網路net—(見1、4)—定義loss和優化器——訓練分類器(迭代)?

for epoch in

range(2

):# 用訓練集對模型進行兩次訓練

running_loss =

0.0#初始化

for i, data in

enumerate

(trainloader,0)

:#對次數和樣本進行迭代 計算每個樣本的誤差

# get the inputs

inputs, labels = data #提取輸入及其標籤

# zero the parameter gradients

optimizer.zero_grad(

)# forward + backward + optimize

outputs = net(inputs)

loss = criterion(outputs, labels)

loss.backward(

) optimizer.step(

)# print statistics

running_loss += loss.item(

)if i %

2000

==1999

:# 每兩千個樣本 列印一次loss

print

('[%d, %5d] loss: %.3f'

%(epoch +

1, i +

1, running_loss /

2000))

running_loss =

0.0print

('finished training'

)

輸出

#[第幾次輸入訓練集,訓練了乙個訓練集中的多少個]  loss為多少[1

,2000

] loss:

2.187[1

,4000

] loss:

1.852[1

,6000

] loss:

1.672[1

,8000

] loss:

1.566[1

,10000

] loss:

1.490[1

,12000

] loss:

1.461[2

,2000

] loss:

1.389[2

,4000

] loss:

1.364[2

,6000

] loss:

1.343[2

,8000

] loss:

1.318[2

,10000

] loss:

1.282[2

,12000

] loss:

1.286

finished training

torchvision.transforms

資料集並不是同樣的尺寸。絕大多數神經網路都假定的尺寸相同。因此我們需要做一些預處理。

transforms = transforms.compose(

[transforms.resize(

[opt.img_size,opt.img_size]),

#尺寸#transforms.randomhorizontalflip(), #水平翻轉

#transforms.centercrop((size, size)), #以圖象中心為中心點裁剪

transforms.totensor(),

#做資料歸一化之前必須要把pil image轉成tensor

transforms.normalize(

[0.485

,0.456

,0.406],

[0.229

,0.224

,0.225])

])

pytorch 載入資料集

2 tensor 的 構造方式 import torch import numpy as np data np.array 1,2,3 print torch.tensor data 副本 print torch.tensor data 副本 print torch.as tensor data 檢...

載入 MNIST 資料集

使用 tensorflow 來讀取資料及標籤 from tensorflow.examples.tutorials.mnist import input data import tensorflow as tf 載入資料集 mnist input data.read data sets e soft...

PyTorch學習 載入資料集

需要定義diabetesdataset做為載入資料集diabetes的類,繼承自dataset,dataset是抽象類,需要實現其中的三個方法,init,getitem,len import torch from torch.utils.data import dataset 抽象類 from to...