資料擴增與字元識別

2021-10-06 12:52:54 字數 4228 閱讀 5461

資料擴增

在pytorch中資料是通過dataset進行封裝,並通過dataloder進行並行讀取

class

svhndataset

(dataset)

: def __init__

(self, img_path, img_label, transform=none)

: self.img_path = img_path

self.img_label = img_label

if transform is not none:

self.transform = transform

else

: self.transform = none

def __getitem__

(self, index)

: img = image.

open

(self.img_path[index]).

convert

('rgb'

)if self.transform is not none:

img = self.

transform

(img)

# 原始svhn中類別10為數字0

lbl = np.

array

(self.img_label[index]

, dtype=np.int)

lbl =

list

(lbl)+(

5-len(lbl))*

[10]return img, torch.

from_numpy

(np.

array

(lbl[:5

])) def __len__

(self)

:return

len(self.img_path)

```

train_loader = torch.utils.data.

dataloader

(svhndataset

(train_path, train_label,

transforms.

compose([

transforms.

resize((

64,128)),

transforms.

colorjitter

(0.3

,0.3

,0.2),

transforms.

randomrotation(5

),transforms.

totensor()

, transforms.

normalize([

0.485

,0.456

,0.406],

[0.229

,0.224

,0.225])

])),

batch_size=

10, # 每批樣本個數

shuffle=false, # 是否打亂順序

)

pytorch構建cnn模型
# 定義模型

class

svhn_model1

(nn.module)

: def __init__

(self)

:super

(svhn_model1, self)

.__init__()

# cnn提取特徵模組

self.cnn = nn.

sequential

( nn.

conv2d(3

,16, kernel_size=(3

,3), stride=(2

,2))

, nn.

relu()

, nn.

maxpool2d(2

),nn.

conv2d(16

,32, kernel_size=(3

,3), stride=(2

,2))

, nn.

relu()

, nn.

maxpool2d(2

),) #

self.fc1 = nn.

linear(32

*3*7

,11) self.fc2 = nn.

linear(32

*3*7

,11) self.fc3 = nn.

linear(32

*3*7

,11) self.fc4 = nn.

linear(32

*3*7

,11) self.fc5 = nn.

linear(32

*3*7

,11) self.fc6 = nn.

linear(32

*3*7

,11)

def forward

(self, img)

:

feat = self.

cnn(img)

feat = feat.

view

(feat.shape[0]

,-1)

c1 = self.

fc1(feat)

c2 = self.

fc2(feat)

c3 = self.

fc3(feat)

c4 = self.

fc4(feat)

c5 = self.

fc5(feat)

c6 = self.

fc6(feat)

return c1, c2, c3, c4, c5, c6

模型訓練

# 損失函式

criterion = nn.

crossentropyloss()

# 優化器

optimizer = torch.optim.

adam

(model.

parameters()

,0.005

)# 儲存loss和accuracy

loss_plot, c0_plot =

,[]# 迭代10個epoch

for epoch in

range(10

):for data in train_loader:

c0, c1, c2, c3, c4, c5 =

model

(data[0]

) loss =

criterion

(c0, data[1]

[:,0

])+ \ criterion

(c1, data[1]

[:,1

])+ \ criterion

(c2, data[1]

[:,2

])+ \ criterion

(c3, data[1]

[:,3

])+ \ criterion

(c4, data[1]

[:,4

])+ \ criterion

(c5, data[1]

[:,5

])loss /=

6 optimizer.

zero_grad()

loss.

backward()

optimizer.

step()

loss_plot.

(loss.

item()

) c0_plot.

((c0.

argmax(1

)== data[1]

[:,0

]).sum()

.item()

*1.0

/ c0.shape[0]

)print

(epoch)

模型訓練出現問題:模型根本沒有訓練,直接迭代完成???

街景字元識別 Task2 資料讀取與資料擴增(2)

資料擴增再增加了訓練集樣本的同時,也可以有效地緩解過擬合的情況,使模型具有更強的泛化能力。資料擴增的方法有很多,從顏色空間 尺度空間到樣本空間,根據不同任務,資料擴增都有區別。對於影象分類,資料擴增一般不會改變標籤 對於物體檢測,資料擴增會改變物體座標位置 對於影象分割,資料擴增會改變畫素標籤。方法...

Shape Context字元識別

關於shape context的文章,網上已經很多了。在這裡實現一下shape context描述子的視覺化。include hogimage.h using namespace std using namespace hog hogimage oimg struct color color col...

OCR字元識別

ocr optical character recognition 全稱光學字元識別技術,在halcon中,ocr常被用來分割區域及讀取識別影象中的字元含義。字元識別ocr原理及應用實現 ocr指電子裝置 掃瞄器 數位相機等 檢測在紙上列印的字元,通過檢測暗亮的模式確定其形狀,然後用字元識別方法將形...