本文章以reid的資料載入為例。
from torch.utils.data import dataset, dataloader
from torchvision import transforms
一、建立自定義資料處理方法類:如隨機擦除,隨機裁剪等
**:class randomerasing(object):
def __init__(self,probability=0.5)
def __call__(self, img)
...return img
二、建立資料預處理組合類例項:如影象翻轉,歸一化,向量化,擦除等
**:train_transform = transforms.compose([
transforms.resize((384, 128), interpolation=3),
transforms.randomhorizontalflip(),
transforms.totensor(),
transforms.normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
randomerasing(probability=0.5, mean=[0.0, 0.0, 0.0])
])三、建立資料讀取類:從本地路徑進行資料載入,形成列表等
**:from torchvision.datasets.folder import default_loader //解釋見最後
class market(dataset.dataset):
def __init__(self, transform, dtype, data_path):
self.loader = default_loader
def __getitem__(self, index):
...//根據路徑生成影象與標籤列表
//載入影象
img = self.loader(path)
if self.transform is not none:
img = self.transform(img)
return img,target
def __len__(self):
return len(self.imgs)
四、生成torch資料流類例項:
self.train_loader = dataloader.dataloader(self.trainset,
sampler=randomsampler(self.trainset, batch_id=opt.batchid,
batch_image=opt.batchimage),
batch_size=opt.batchid * opt.batchimage, num_workers=8,
pin_memory=true)
self.test_loader = dataloader.dataloader(self.testset, batch_size=opt.batchtest, num_workers=8, pin_memory=true)
然後就可以在訓練階段使用迭代方法進行資料獲取了。
torchvision.datasets.folder中的default_loader函式:
該函式主要分兩種情況呼叫兩個函式,一般採用pil_loader函式。
def pil_loader(path):
with open(path, 'rb') as f:
with image.open(f) as img:
return img.convert('rgb')
def accimage_loader(path):
import accimage
try:
return accimage.image(path)
except ioerror:
# potentially a decoding problem, fall back to pil.image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else: #get_image_backend() == 'pil'
return pil_loader(path)
pytorch 載入預訓練模型
pytorch的torchvision中給出了很多經典的預訓練模型,模型的引數和權重都是在imagenet資料集上訓練好的 載入模型 方法一 直接使用預訓練模型中的引數 import torchvision.models as models model models.resnet18 pretrai...
pytorch載入資料
參考 pytorch深度學習快速入門教程 絕對通俗易懂!小土堆 可看到說明,dataset是乙個抽象類,我們重寫dataset時要繼承這個類,所有的子類都應該重寫 getitem 方法,這個方法作用是獲取資料及對應的labe。同時我們可以選擇性地去重寫 len 方法,其作用是獲取資料集長度。這裡我使...
pytorch載入預訓練模型後,訓練指定層
1 有了已經訓練好的模型引數,對這個模型的某些層做了改變,如何利用這些訓練好的模型引數繼續訓練 pretrained params torch.load pretrained model model the new model model.load state dict pretrained par...