pytorch 載入自己的資料集,需要寫乙個繼承自torch.utils.data中dataset類,並修改其中的__init__方法、__getitem__方法、__len__方法。預設載入的都是,__init__的目的是得到乙個包含資料和標籤的list,每個元素能找到位置和其對應標籤。然後用__getitem__方法得到每個元素的影象畫素矩陣和標籤,返回img和label。
以載入乙個影象放在某個資料夾下,並在當前目錄下生成了乙個.txt的檔案,大致如下train、test資料夾下放,test.txt和train.txt以如下方式存放路徑和標籤:
# 建構函式帶有預設引數
def __init__(self, txt, transform=none, target_transform=none, loader=default_loader):
fh = open(txt, 'r')
imgs =
for line in fh:
# 移除字串首尾的換行符
# 刪除末尾空
# 以空格為分隔符 將字串分成
line = line.strip('\n')
line = line.rstrip()
words = line.split()
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
#呼叫定義的loader方法
img = self.loader(fn)
if self.transform is not none:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
train_data = mydataset(txt=root + 'train.txt', transform=transforms.totensor())
test_data = mydataset(txt=root + 'test.txt', transform=transforms.totensor())
#train_data 和test_data包含多有的訓練與測試資料,呼叫dataloader批量載入
train_loader = dataloader(dataset=train_data, batch_size=64, shuffle=true)
test_loader = dataloader(dataset=test_data, batch_size=64)
pytorch 載入資料集
2 tensor 的 構造方式 import torch import numpy as np data np.array 1,2,3 print torch.tensor data 副本 print torch.tensor data 副本 print torch.as tensor data 檢...
PyTorch學習 載入資料集
需要定義diabetesdataset做為載入資料集diabetes的類,繼承自dataset,dataset是抽象類,需要實現其中的三個方法,init,getitem,len import torch from torch.utils.data import dataset 抽象類 from to...
封裝自己的pytorch資料集
1.pytorch 對於資料的標籤要求是長整形,因此要對標籤進行轉換 train label train label.long 2.對於資料的特徵部分,要轉換為tensor 形式,可以通過torch.tensor 將資料從numpy 轉為tensor train fea torch.tensor t...