pytorch讀取訓練資料是非常便捷的,只需要使用2個類:
(1)torch.utils.data.dataset
(2)torch.utils.data.dataloader
常用資料集的讀取
1、torchvision.datasets的使用
對於常用資料集,可以使用torchvision.datasets直接進行讀取。torchvision.dataset是torch.utils.data.dataset的實現,該包提供了以下資料集的讀取:
以cifar10為例,
import torch
import torchvision
from pil import image
cifarset=torchvision.datasets.cifar10(root="../data/cifar",train=true,download=true)
print(cifarset[0])
img,label=cifarset[0]
print(img)
print(label)
print(img.format,img.size,img.mode)
img.show()
2、例項化torch.utils.data.dataloader
mytransform=transforms.compose([
transforms.totensor()
])#torch.utils.data.dataloader
cifarset=torchvision.datasets.cifar10(root="../data/cifar/",train=true,download=true,transform=mytransform)
cifarloader=torch.utils.data.dataloader(cifarset,batch_size=10,shuffle=false,num_workers=2)
下面就可以進行讀取資料的顯示,以進行簡單測試是否讀取成功
for i, data in enumerate(cifarloader,0):
print(data[i][0])
#pil
img=transforms.topilimage()(data[i][0])
img.show()
break
自定義標籤資料集的讀取
1、實現torch.utils.data.dataset
假設我們有乙個標籤test_image.txt,內容如下:
對應的影象位於images目錄下,首先要繼承torch.utils.data.dataset類,完成影象及標籤的讀取。
import os
import torch
import torch.utils.data as data
from pil import image
def default_loader(path):
return image.open(path).convert('rgb')
class myimagefolder(data.dataset):
def __init__(self,root,label,transform=none,target_transform=none,loader=default_loader):
fh=open(label)
c=0imags=
class_names=
for line in fh.readlines():
if c==0:
class_names=[n.strip() for n in line.rstrip().split(' ')]
else:
cls=line.split()
fn=cls.pop(0)
if os.path.isfile(os.path.join(root,fn)):
c=c+1
self.root=root
self.imgs=imgs
self.classes=class_names
self.transform=transform
self.target_transform=target_transform
self.loader=loader
def __getitem__(self,index):
fn,label=self.imgs[index]
img=self.loader(os.path.join(self.root,fn))
if self.transform is not none:
img=self.transform(img)
return img, torch.tensor(label)
def __len__(self):
return len(self.imgs)
def getname(self):
return self.classes
2、例項化torch.utils.data.dataloader
mytransform=transforms.compose([
transforms.totensor()
])#torch.utils.data.dataloader
imgloader=torch.utils.data.dataloader(
myfloder.myimagefolder(root="../data/testimages/images",label="../data/test_images.txt",transform=mytransform), batch_size=2,shuffle=false,num_workers=2)
for i, data in enumerate(imgloader,0):
print(data[i][0])
#opencv
img2=data[i][0].numpy()*255
img2=img2.astype('uint8')
img2=np.transpose(img2,(1,2,0))
img2=img2[:,:,::-1]#rgb->bgr
cv2.imshow('img2',img2)
cv2.waitkey()
break
pytorch讀取coco資料集
yolov3 an incremental improvement 原理在該篇部落格就寫的很詳細了,這裡就不贅述了 bin bash credit clone coco api git clone cd coco mkdir images cd images download images wget...
Pytorch 讀取大資料集
記錄一下pytorch讀取大型資料集的要點 pytorch 讀取大資料集的一般方法 class mydataset data.dataset def init self,root filepath self.root root init 中讀取檔案路徑而非檔案本體 self.imgs list se...
pytorch訓練MNIST資料集1
本文採用全連線網路對mnist資料集進行訓練,訓練模型主要由五個線性單元和relu啟用函式組成 import torch from torchvision import transforms from torchvision import datasets from torch.utils.data...