PyTorch訓練集的讀取

2021-09-27 04:36:41 字數 3300 閱讀 2045

pytorch讀取訓練資料是非常便捷的,只需要使用2個類:

(1)torch.utils.data.dataset

(2)torch.utils.data.dataloader

常用資料集的讀取

1、torchvision.datasets的使用

對於常用資料集,可以使用torchvision.datasets直接進行讀取。torchvision.dataset是torch.utils.data.dataset的實現,該包提供了以下資料集的讀取:

以cifar10為例,

import torch

import torchvision

from pil import image

cifarset=torchvision.datasets.cifar10(root="../data/cifar",train=true,download=true)

print(cifarset[0])

img,label=cifarset[0]

print(img)

print(label)

print(img.format,img.size,img.mode)

img.show()

2、例項化torch.utils.data.dataloader

mytransform=transforms.compose([

transforms.totensor()

])#torch.utils.data.dataloader

cifarset=torchvision.datasets.cifar10(root="../data/cifar/",train=true,download=true,transform=mytransform)

cifarloader=torch.utils.data.dataloader(cifarset,batch_size=10,shuffle=false,num_workers=2)

下面就可以進行讀取資料的顯示,以進行簡單測試是否讀取成功

for i, data in enumerate(cifarloader,0):

print(data[i][0])

#pil

img=transforms.topilimage()(data[i][0])

img.show()

break

自定義標籤資料集的讀取

1、實現torch.utils.data.dataset

假設我們有乙個標籤test_image.txt,內容如下:

對應的影象位於images目錄下,首先要繼承torch.utils.data.dataset類,完成影象及標籤的讀取。

import os

import torch

import torch.utils.data as data

from pil import image

def default_loader(path):

return image.open(path).convert('rgb')

class myimagefolder(data.dataset):

def __init__(self,root,label,transform=none,target_transform=none,loader=default_loader):

fh=open(label)

c=0imags=

class_names=

for line in fh.readlines():

if c==0:

class_names=[n.strip() for n in line.rstrip().split(' ')]

else:

cls=line.split()

fn=cls.pop(0)

if os.path.isfile(os.path.join(root,fn)):

c=c+1

self.root=root

self.imgs=imgs

self.classes=class_names

self.transform=transform

self.target_transform=target_transform

self.loader=loader

def __getitem__(self,index):

fn,label=self.imgs[index]

img=self.loader(os.path.join(self.root,fn))

if self.transform is not none:

img=self.transform(img)

return img, torch.tensor(label)

def __len__(self):

return len(self.imgs)

def getname(self):

return self.classes

2、例項化torch.utils.data.dataloader

mytransform=transforms.compose([

transforms.totensor()

])#torch.utils.data.dataloader

imgloader=torch.utils.data.dataloader(

myfloder.myimagefolder(root="../data/testimages/images",label="../data/test_images.txt",transform=mytransform), batch_size=2,shuffle=false,num_workers=2)

for i, data in enumerate(imgloader,0):

print(data[i][0])

#opencv

img2=data[i][0].numpy()*255

img2=img2.astype('uint8')

img2=np.transpose(img2,(1,2,0))

img2=img2[:,:,::-1]#rgb->bgr

cv2.imshow('img2',img2)

cv2.waitkey()

break

pytorch讀取coco資料集

yolov3 an incremental improvement 原理在該篇部落格就寫的很詳細了,這裡就不贅述了 bin bash credit clone coco api git clone cd coco mkdir images cd images download images wget...

Pytorch 讀取大資料集

記錄一下pytorch讀取大型資料集的要點 pytorch 讀取大資料集的一般方法 class mydataset data.dataset def init self,root filepath self.root root init 中讀取檔案路徑而非檔案本體 self.imgs list se...

pytorch訓練MNIST資料集1

本文採用全連線網路對mnist資料集進行訓練,訓練模型主要由五個線性單元和relu啟用函式組成 import torch from torchvision import transforms from torchvision import datasets from torch.utils.data...