pytorch入門 資料的載入和預處理

2021-10-02 04:00:37 字數 3037 閱讀 4823

需要繼承dataset類,並且實現兩個成員方法:

getitem_() 該方法定義用索引(0 到 len(self))獲取一條資料或乙個樣本

len_() 該方法返回資料集的總長度

eg: 例項化乙個物件ds_demo,通過ds_demo[index]方法得到index對應的資料值,通過len(ds_demo)獲取資料總長度

#引用

from torch.utils.data import dataset

import pandas as pd

#定義乙個資料集

class

bulldozerdataset

(dataset)

:""" 資料集演示 """

def__init__

(self, csv_file)

:"""實現初始化方法,在初始化的時候將資料讀載入"""

self.df=pd.read_csv(csv_file)

def__len__

(self)

:'''

返回df的長度

'''return

len(self.df)

def__getitem__

(self, idx)

:'''

根據 idx 返回一行資料

'''return self.df.iloc[idx]

.saleprice

csv_file為讀入的檔案,eg: 『median_benchmark.csv』。

#可以例項話乙個物件訪問他

ds_demo= bulldozerdataset(

'median_benchmark.csv'

)#實現了 __len__ 方法所以可以直接使用len獲取資料總數

len(ds_demo)

# output: 11573

#用索引可以直接訪問對應的資料,對應 __getitem__ 方法

ds_demo[0]

# output: 24000.0

常用引數有:batch_size(每個batch的大小)、 shuffle(是否進行shuffle操作)、 num_workers(載入資料的時候使用幾個子程序)。

dl = torch.utils.data.dataloader(ds_demo, batch_size=

10, shuffle=

true

, num_workers=0)

# 返回乙個可迭代物件dl

for i, data in

enumerate

(dl)

:print

(i,data)

# 為了節約空間,這裡只迴圈一遍

break

輸出:

0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,

24000.], dtype=torch.float64)

1. torchvision.datasets

mnist

coco

captions

detection

lsun

imagefolder

imagenet-12

cifar

stl10

svhn

phototour 我們可以直接使用,示例如下:

import torchvision.datasets as datasets

trainset = datasets.mnist(root=

'./data'

,# 表示 mnist 資料的載入的目錄

train=

true

,# 表示是否載入資料庫的訓練集,false的時候載入測試集

download=

true

, transform=

none

)# 表示是否需要對資料進行預處理,none為不進行預處理

2. torchvision.models

torchvision不僅提供了常用資料集,還提供了訓練好的模型,可以載入之後,直接使用,或者在進行遷移學習 torchvision.models模組的子模組中包含以下模型結構。

alexnet

vggresnet

squeezenet

densenet

import torchvision.models as models

resnet18 = models.resnet18(pretrained=

true

)3.torchvision.transforms

transforms 模組提供了一般的影象轉換操作類,用作資料處理和資料增強,其實就是一些訓練模型之前,對資料的預處理

from torchvision import transforms as transforms

transform = transforms.compose(

[ transforms.randomcrop(

32, padding=4)

,#先四周填充0,在把影象隨機裁剪成32*32

transforms.randomhorizontalflip(),

#影象一半的概率翻轉,一半的概率不翻轉

transforms.randomrotation((-

45,45)

),#隨機旋轉

transforms.totensor(),

transforms.normalize(

(0.4914

,0.4822

,0.4465),

(0.229

,0.224

,0.225))

,#r,g,b每層的歸一化用到的均值和方差

])

注:本文參考

Pytorch的資料載入

pytorch將資料集的處理過程標準化,提供了dataset基本的資料 類,並在torchvision中提供了眾多資料變換函式,資料載入的具體過程 主要分為3步 1 繼承dataset類 對於資料集的處理,pytorch提供了torch.utils.data.dataset這個抽象 類,在使用時只需...

pytorch載入資料

參考 pytorch深度學習快速入門教程 絕對通俗易懂!小土堆 可看到說明,dataset是乙個抽象類,我們重寫dataset時要繼承這個類,所有的子類都應該重寫 getitem 方法,這個方法作用是獲取資料及對應的labe。同時我們可以選擇性地去重寫 len 方法,其作用是獲取資料集長度。這裡我使...

PyTorch 入門 自定義資料載入

之前學習tensorflow時也學習了它的資料載入,不過在網上看了很多教程後還是有很多小問題,不知道為什麼在別人電腦上可以執行但是我的就不行 把我頭搞暈了 很煩,這時想起之前聽導師說pytorch容易入門上手,所以果斷去學了pytorch,寫這篇博文的目的就是總結學到的,然後記錄下來,也希望以後學到...