這些天看的東西,真的是比較多,相比以前來說,對我的學習方式起到顛覆性作用。我目前覺得,我們學到的東西,更多是孤立的,因此,在吸收一定知識後,需要在腦子裡形成知識體系。需要把自己以前學到的東西進行整理,形成乙個體系,這篇文章講解的是,深度學習中pytorch資料集的構造!!!
pytorch中有兩個自定義管理資料集的類,
torch.utils.data.dataset
torvchvision.datasets.imagefolder
這裡主要講解的第一種。
class dataset(object):
"""an abstract class representing a dataset.
all other datasets should subclass it. all subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""def __getitem__(self, index):
raise notimplementederror
def __len__(self):
raise notimplementederror
def __add__(self, other):
return concatdataset([self, other])
我們設計自己資料集類的時候, 只需要重寫__getitem__、__len__
兩個函式,分別的功能是,通過切片返回具樣例、返回樣本個數。
以下是voc2012資料集分割的例子:
return len(self.data_list)通過上面的操作,我們構建自己資料集類,接下來,構建乙個dataloader
類,這個作用是訓練過程中,返回 batch個樣例。
由於原始碼過於臃腫了,這裡知識摘出對應的建構函式:
def __init__(self, dataset, batch_size=1, shuffle=false, sampler=none,
batch_sampler=none, num_workers=0, collate_fn=default_collate,
pin_memory=false, drop_last=false, timeout=0,
worker_init_fn=none):
建構函式中,每個引數的意思就不一一介紹了,只著重的講解下,可呼叫函式collate_fn
。我們首先看乙個構建dataloader
的例項:
def build_dataset(cfg, transforms, is_train=true):
datasets = vocsegdataset(cfg, is_train, transforms)
return datasets
def make_data_loader(cfg, is_train=true):
if is_train:
batch_size = cfg.solver.ims_per_batch
shuffle = true
else:
batch_size = cfg.test.ims_per_batch
shuffle = false
transforms = build_transforms(cfg, is_train)
datasets = build_dataset(cfg, transforms, is_train)
num_workers = cfg.dataloader.num_workers
data_loader = data.dataloader(
datasets, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=true
)return data_loader
上面第乙個函式build_dataset
返回資料集例項,第二個函式返回dataloader
,關於dataloader,我們需要注意的是,有時我們需要根據dataset中的__getitem__
修改collate_fn
。
我們來看下原始碼:
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise stopiteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
我們在原始碼中發現,collate_fn
的輸入是乙個list,裡面的每個元素是__getitem__
的輸出,由此,我們估計,default_collate
的作用是將這個list,**變換格式為[batch,c,h,w]**的tensor,我們在來看下原始碼:
if.......
.........
elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))
由於原始碼均是對型別的判斷,因此,這裡我們知識摘出,與voc2012
分割相關的部分,這個語句的意思是, 對[(img1, label1), (img2, label2)],首先返回[img1,img2],[lable1,label2],在繼續返回兩個tensor,乙個是img,[batch,c,h,w],乙個是label:[batch,c,h,w]。
所以,通過上面分析,如果,我們__getitem__
不符合collat_fn
不符合預設函式的判斷時,需要修改該函式。
好了,先到這,接下來…慢慢聊程式,需要學的太多了
pytorch之建立資料集
import torch import torchvision from torchvision import datasets,transforms dataroot data celeba 資料集所在資料夾 建立資料集 dataset datasets.imagefolder root data...
pytorch 載入資料集
2 tensor 的 構造方式 import torch import numpy as np data np.array 1,2,3 print torch.tensor data 副本 print torch.tensor data 副本 print torch.as tensor data 檢...
pytorch批訓練資料構造
這是對莫凡python的學習筆記。1.建立資料 import torch import torch.utils.data as data batch size 8x torch.linspace 1,10,10 y torch.linspace 10,1,10 可以看到建立了兩個一維資料,x 1 1...