需要繼承dataset類,並且實現兩個成員方法:
getitem_() 該方法定義用索引(0 到 len(self))獲取一條資料或乙個樣本
len_() 該方法返回資料集的總長度
eg: 例項化乙個物件ds_demo,通過ds_demo[index]方法得到index對應的資料值,通過len(ds_demo)獲取資料總長度
#引用
from torch.utils.data import dataset
import pandas as pd
#定義乙個資料集
class
bulldozerdataset
(dataset)
:""" 資料集演示 """
def__init__
(self, csv_file)
:"""實現初始化方法,在初始化的時候將資料讀載入"""
self.df=pd.read_csv(csv_file)
def__len__
(self)
:'''
返回df的長度
'''return
len(self.df)
def__getitem__
(self, idx)
:'''
根據 idx 返回一行資料
'''return self.df.iloc[idx]
.saleprice
csv_file為讀入的檔案,eg: 『median_benchmark.csv』。
#可以例項話乙個物件訪問他
ds_demo= bulldozerdataset(
'median_benchmark.csv'
)#實現了 __len__ 方法所以可以直接使用len獲取資料總數
len(ds_demo)
# output: 11573
#用索引可以直接訪問對應的資料,對應 __getitem__ 方法
ds_demo[0]
# output: 24000.0
常用引數有:batch_size(每個batch的大小)、 shuffle(是否進行shuffle操作)、 num_workers(載入資料的時候使用幾個子程序)。
dl = torch.utils.data.dataloader(ds_demo, batch_size=
10, shuffle=
true
, num_workers=0)
# 返回乙個可迭代物件dl
for i, data in
enumerate
(dl)
:print
(i,data)
# 為了節約空間,這裡只迴圈一遍
break
輸出:
0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
24000.], dtype=torch.float64)
1. torchvision.datasets
mnist
coco
captions
detection
lsun
imagefolder
imagenet-12
cifar
stl10
svhn
phototour 我們可以直接使用,示例如下:
import torchvision.datasets as datasets
trainset = datasets.mnist(root=
'./data'
,# 表示 mnist 資料的載入的目錄
train=
true
,# 表示是否載入資料庫的訓練集,false的時候載入測試集
download=
true
, transform=
none
)# 表示是否需要對資料進行預處理,none為不進行預處理
2. torchvision.models
torchvision不僅提供了常用資料集,還提供了訓練好的模型,可以載入之後,直接使用,或者在進行遷移學習 torchvision.models模組的子模組中包含以下模型結構。
alexnet
vggresnet
squeezenet
densenet
import torchvision.models as models
resnet18 = models.resnet18(pretrained=
true
)3.torchvision.transforms
transforms 模組提供了一般的影象轉換操作類,用作資料處理和資料增強,其實就是一些訓練模型之前,對資料的預處理
from torchvision import transforms as transforms
transform = transforms.compose(
[ transforms.randomcrop(
32, padding=4)
,#先四周填充0,在把影象隨機裁剪成32*32
transforms.randomhorizontalflip(),
#影象一半的概率翻轉,一半的概率不翻轉
transforms.randomrotation((-
45,45)
),#隨機旋轉
transforms.totensor(),
transforms.normalize(
(0.4914
,0.4822
,0.4465),
(0.229
,0.224
,0.225))
,#r,g,b每層的歸一化用到的均值和方差
])
注:本文參考 Pytorch的資料載入
pytorch將資料集的處理過程標準化,提供了dataset基本的資料 類,並在torchvision中提供了眾多資料變換函式,資料載入的具體過程 主要分為3步 1 繼承dataset類 對於資料集的處理,pytorch提供了torch.utils.data.dataset這個抽象 類,在使用時只需...
pytorch載入資料
參考 pytorch深度學習快速入門教程 絕對通俗易懂!小土堆 可看到說明,dataset是乙個抽象類,我們重寫dataset時要繼承這個類,所有的子類都應該重寫 getitem 方法,這個方法作用是獲取資料及對應的labe。同時我們可以選擇性地去重寫 len 方法,其作用是獲取資料集長度。這裡我使...
PyTorch 入門 自定義資料載入
之前學習tensorflow時也學習了它的資料載入,不過在網上看了很多教程後還是有很多小問題,不知道為什麼在別人電腦上可以執行但是我的就不行 把我頭搞暈了 很煩,這時想起之前聽導師說pytorch容易入門上手,所以果斷去學了pytorch,寫這篇博文的目的就是總結學到的,然後記錄下來,也希望以後學到...