資料集:penn-fudan資料集
在學習pytorch官網教程時,作者對penn-fudan資料集進行了定義,並且在自定義的資料集上實現了對r-cnn模型的微調。
此篇筆記簡單總結一下pytorch如何實現定義自己的資料集
資料集必須繼承torch.utils.data.dataset類,並且實現__len__和__getitem__方法
其中__getitem__方法返回的是image和target(乙個包含影象相關資訊的字典型別)
資料集主要分為三個部分,其中的pngimages為行人的**的集合
pedmasks為的掩膜集合
通過掩膜,產生目標程式要實現的蒙版效果
官方例子,原圖:
mask處理後
#官方定義資料集**(含自己的注釋)
import os
import numpy as np
import torch
from pil import image
class pennfudandataset(object):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
#分別讀取pngimages和pedmasks資料夾下面的所有檔案,並組成乙個list
self.imgs = list(sorted(os.listdir(os.path.join(root, "pngimages"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "pedmasks"))))
def __getitem__(self, idx):
# 分別相應的載入每個list裡面的資訊
img_path = os.path.join(self.root, "pngimages", self.imgs[idx])
mask_path = os.path.join(self.root, "pedmasks", self.masks[idx])
img = image.open(img_path).convert("rgb")
#不需要convert("rgb")因為mask的背景全是0
mask = image.open(mask_path)
# 將mask的pil圖轉換為numpy陣列
mask = np.array(mask)
# 將mask簡化,此時的obj_ids為[0,1,2],有兩種型別的邊界框
obj_ids = np.unique(mask)
# first id is the background, so remove it
#0表示黑色的背景,進行去除
obj_ids = obj_ids[1:]
#split the color-encoded mask into a set of binary masks
#none就是newaxis,相當於多了乙個軸,維度
masks = mask == obj_ids[:, none, none]
#get bounding box coordinates for each mask
#定義邊界框的tensor
num_objs = len(obj_ids)
boxes =
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
#convert everything into a torch.tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# there is only one class
labels = torch.ones((num_objs,), dtype=torch.int64)
masks = torch.as_tensor(masks, dtype=torch.uint8)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# suppose all instances are not crowd
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
#返回的target字典賦予相應的值
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["masks"] = masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not none:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.imgs)
官方文件:torchvision物件檢測微調程式
定義自己的資料集及載入訓練
Pytorch 學習筆記
本渣的pytorch 逐步學習鞏固經歷,希望各位大佬指正,寫這個部落格也為了鞏固下記憶。a a b c 表示從 a 取到 b 步長為 c c 2 則表示每2個數取1個 若為a 1 1 1 即表示從倒數最後乙個到正數第 1 個 最後乙個 1 表示倒著取 如下 陣列的第乙個為 0 啊,第 0 個!彆扭 ...
Pytorch學習筆記
lesson 1.張量 tensor 的建立和常用方法 一 張量 tensor 的基本建立及其型別 import torch 匯入pytorch包 import numpy as np torch.version 檢視版本號1.張量 tensor 函式建立方法 張量 tensor 函式建立方法 t ...
Pytorch學習筆記
import torch import numpy as np 一維張量索引 t1 torch.arange 1,11 print t1 0 item 注 張量索引出來的結果還是零維張量,而不是單獨的數。要轉化成單獨的數,需要使用item 方法。t1 1 8 冒號分隔,表示對某個區域進行索引,也就是...