pytorch 載入自己的資料集

2021-09-18 06:36:29 字數 1936 閱讀 6672

pytorch 載入自己的資料集,需要寫乙個繼承自torch.utils.data中dataset類,並修改其中的__init__方法、__getitem__方法、__len__方法。預設載入的都是,__init__的目的是得到乙個包含資料和標籤的list,每個元素能找到位置和其對應標籤。然後用__getitem__方法得到每個元素的影象畫素矩陣和標籤,返回img和label。

以載入乙個影象放在某個資料夾下,並在當前目錄下生成了乙個.txt的檔案,大致如下train、test資料夾下放,test.txt和train.txt以如下方式存放路徑和標籤:

# 建構函式帶有預設引數

def __init__(self, txt, transform=none, target_transform=none, loader=default_loader):

fh = open(txt, 'r')

imgs =

for line in fh:

# 移除字串首尾的換行符

# 刪除末尾空

# 以空格為分隔符 將字串分成

line = line.strip('\n')

line = line.rstrip()

words = line.split()

self.imgs = imgs

self.transform = transform

self.target_transform = target_transform

self.loader = loader

def __getitem__(self, index):

fn, label = self.imgs[index]

#呼叫定義的loader方法

img = self.loader(fn)

if self.transform is not none:

img = self.transform(img)

return img, label

def __len__(self):

return len(self.imgs)

train_data = mydataset(txt=root + 'train.txt', transform=transforms.totensor())

test_data = mydataset(txt=root + 'test.txt', transform=transforms.totensor())

#train_data 和test_data包含多有的訓練與測試資料,呼叫dataloader批量載入

train_loader = dataloader(dataset=train_data, batch_size=64, shuffle=true)

test_loader = dataloader(dataset=test_data, batch_size=64)

pytorch 載入資料集

2 tensor 的 構造方式 import torch import numpy as np data np.array 1,2,3 print torch.tensor data 副本 print torch.tensor data 副本 print torch.as tensor data 檢...

PyTorch學習 載入資料集

需要定義diabetesdataset做為載入資料集diabetes的類,繼承自dataset,dataset是抽象類,需要實現其中的三個方法,init,getitem,len import torch from torch.utils.data import dataset 抽象類 from to...

封裝自己的pytorch資料集

1.pytorch 對於資料的標籤要求是長整形,因此要對標籤進行轉換 train label train label.long 2.對於資料的特徵部分,要轉換為tensor 形式,可以通過torch.tensor 將資料從numpy 轉為tensor train fea torch.tensor t...