pytorch訓練之資料載入步驟

2021-09-10 23:29:52 字數 2287 閱讀 2610

本文章以reid的資料載入為例。

from torch.utils.data import dataset, dataloader

from torchvision import transforms

一、建立自定義資料處理方法類:如隨機擦除,隨機裁剪等

**:class randomerasing(object):

def __init__(self,probability=0.5)

def __call__(self, img)

...return img

二、建立資料預處理組合類例項:如影象翻轉,歸一化,向量化,擦除等

**:train_transform = transforms.compose([

transforms.resize((384, 128), interpolation=3),

transforms.randomhorizontalflip(),

transforms.totensor(),

transforms.normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

randomerasing(probability=0.5, mean=[0.0, 0.0, 0.0])

])三、建立資料讀取類:從本地路徑進行資料載入,形成列表等

**:from torchvision.datasets.folder import default_loader  //解釋見最後

class market(dataset.dataset):

def __init__(self, transform, dtype, data_path):

self.loader = default_loader

def __getitem__(self, index):

...//根據路徑生成影象與標籤列表

//載入影象

img = self.loader(path)

if self.transform is not none:

img = self.transform(img)

return img,target

def __len__(self):

return len(self.imgs)

四、生成torch資料流類例項:

self.train_loader = dataloader.dataloader(self.trainset,

sampler=randomsampler(self.trainset, batch_id=opt.batchid,

batch_image=opt.batchimage),

batch_size=opt.batchid * opt.batchimage, num_workers=8,

pin_memory=true)

self.test_loader = dataloader.dataloader(self.testset, batch_size=opt.batchtest, num_workers=8, pin_memory=true)

然後就可以在訓練階段使用迭代方法進行資料獲取了。

torchvision.datasets.folder中的default_loader函式:

該函式主要分兩種情況呼叫兩個函式,一般採用pil_loader函式。

def pil_loader(path):

with open(path, 'rb') as f:

with image.open(f) as img:

return img.convert('rgb')

def accimage_loader(path):

import accimage

try:

return accimage.image(path)

except ioerror:

# potentially a decoding problem, fall back to pil.image

return pil_loader(path)

def default_loader(path):

from torchvision import get_image_backend

if get_image_backend() == 'accimage':

return accimage_loader(path)

else: #get_image_backend() == 'pil'

return pil_loader(path)

pytorch 載入預訓練模型

pytorch的torchvision中給出了很多經典的預訓練模型,模型的引數和權重都是在imagenet資料集上訓練好的 載入模型 方法一 直接使用預訓練模型中的引數 import torchvision.models as models model models.resnet18 pretrai...

pytorch載入資料

參考 pytorch深度學習快速入門教程 絕對通俗易懂!小土堆 可看到說明,dataset是乙個抽象類,我們重寫dataset時要繼承這個類,所有的子類都應該重寫 getitem 方法,這個方法作用是獲取資料及對應的labe。同時我們可以選擇性地去重寫 len 方法,其作用是獲取資料集長度。這裡我使...

pytorch載入預訓練模型後,訓練指定層

1 有了已經訓練好的模型引數,對這個模型的某些層做了改變,如何利用這些訓練好的模型引數繼續訓練 pretrained params torch.load pretrained model model the new model model.load state dict pretrained par...