1import
torch
2import
os,glob
3import
random,csv45
from torch.utils.data import
dataset,dataloader67
from torchvision import
transforms
8from pil import
image910
class
pokemon(dataset):
11'''
12@param
13root:儲存的根路徑
1415
mode:train或者test模式
16'''
17def
__init__
(self,root,resize,mode):
18 super(pokemon,self).__init__
()19
20 self.root =root
21 self.resize =resize
2223
#字典型別key:name value:label
24 self.name2label ={}25#
listdir返回順序不固定,用sorted將它固定,因為排序一次之後就固定了
26for name in
sorted(os.listdir(os.path.join(root))):
27if
notos.path.isdir(os.path.join(root,name)):
28continue
2930 self.name2label[name] =len(self.name2label.keys())
3132
#print(self.name2label)
3334
#image_path + image_label
35 self.images,self.labels = self.load_csv('
images.csv')
3637
if mode == '
train
': #
60%38 self.images = self.images[:int(0.6*len(self.images))]
39 self.labels = self.labels[:int(0.6*len(self.labels))]
40elif mode == '
val': #
20%41 self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
42 self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
43elif mode == '
test
': #
20% = 80% ->100%
44 self.images = self.images[int(0.8*len(self.images)):]
45 self.labels = self.labels[int(0.8*len(self.labels)):]
4647
defload_csv(self,filename):
4849
#如果不存在再寫入,存在的話直接讀取就可以了
保證images和labels一一對應,長度相等
84assert len(images) ==len(labels)
85return
images,labels
8687
def__len__
(self):
8889
return
len(self.images)
9091
defdenormalize(self,x_hat):
9293 mean=[0.485,0.456,0.406]
94 std=[0.229,0.224,0.225]
9596
#x_hat = (x-mean)/std97#
x = x_hat*std+mean98#
x: [c,h,w]99#
mean: [3] --> [3,1,1]
100 mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
101 std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
102103 x = x_hat*std +mean
104105
return
x106
107108
def__getitem__
(self,idx):
109#
idx~[0~len(images)]
110#
self.images,self.labels
111#
img: pokemon\\bulbasaur\\00000000.png'
112#
label: 0
113 img,label =self.images[idx],self.labels[idx]
114115 tf =transforms.compose([
116lambda x:image.open(x).convert('
rgb'), #
string path --> image data
117 transforms.resize((int(self.resize*1.25),int(self.resize*1.25))),
118 transforms.randomrotation(15),
119transforms.centercrop(self.resize),
120transforms.totensor(),
121 transforms.normalize(mean=[0.485,0.456,0.406],
122 std=[0.229,0.224,0.225])
123])
124125 img =tf(img)
126 label =torch.tensor(label)
127128
return img,label
Pytorch 學習筆記
本渣的pytorch 逐步學習鞏固經歷,希望各位大佬指正,寫這個部落格也為了鞏固下記憶。a a b c 表示從 a 取到 b 步長為 c c 2 則表示每2個數取1個 若為a 1 1 1 即表示從倒數最後乙個到正數第 1 個 最後乙個 1 表示倒著取 如下 陣列的第乙個為 0 啊,第 0 個!彆扭 ...
Pytorch學習筆記
資料集 penn fudan資料集 在學習pytorch官網教程時,作者對penn fudan資料集進行了定義,並且在自定義的資料集上實現了對r cnn模型的微調。此篇筆記簡單總結一下pytorch如何實現定義自己的資料集 資料集必須繼承torch.utils.data.dataset類,並且實現 ...
Pytorch學習筆記
lesson 1.張量 tensor 的建立和常用方法 一 張量 tensor 的基本建立及其型別 import torch 匯入pytorch包 import numpy as np torch.version 檢視版本號1.張量 tensor 函式建立方法 張量 tensor 函式建立方法 t ...