matlab這幾年在人工智慧這塊兒也越做越好了,最近為了熟悉matlab如何搭建神經網路,自己做了乙個手寫體識別實驗,記錄一下。
實驗任務非常簡單,網路搭的也非常隨意,不合理的地方也懶得改,旨在走通matlab搭建神經網路的流程。
首先,資料集為mnist資料集
我已經把資料按類別分好,分為train和test,底下又都有十個子資料夾存放手寫體影象。
網路訓練**如下:
clear;close all;clc;
%% 資料讀取、增強
%讀取訓練集
path_train =
'd:\work\過期檔案\手寫體識別\mnist\train'
;%訓練集路徑
folders_train =
fullfile
(path_train,);
%讀取子目錄
%讀取所有影象路徑
[imdstrain,imdsvalidation]
=spliteachlabel
(imds_train,
0.9,
0.1)
;%拆分出驗證集
%讀取測試集
path_test =
'd:\work\過期檔案\手寫體識別\mnist\test'
%影象增強
pixelrange =[-
22];
%平移範圍
scalerange =
[0.9
1.1]
;%縮放範圍
imageaugmenter =
imagedataaugmenter(.
..'randxtranslation'
,pixelrange,..
.'randytranslation'
,pixelrange,..
.'randxscale'
,scalerange,..
.'randyscale'
,scalerange)
;%定義影象增強器
augimdstrain =
augmentedimagedatastore([
28,28]
,imds_train,..
.'dataaugmentation'
,imageaugmenter)
;%影象增強
%% 設計(或者讀取)網路
layers =
[imageinputlayer([
28281]
,"name"
,"imageinput"
)convolution2dlayer([
55],
32,"name"
,"conv_1"
,"padding"
,"same"
,"stride",[
22])
relulayer
("name"
,"relu_1"
)batchnormalizationlayer
("name"
,"batchnorm_1"
)convolution2dlayer([
33],
32,"name"
,"conv_2"
,"padding"
,"same"
)relulayer
("name"
,"relu_2"
)fullyconnectedlayer
(512
,"name"
,"fc_1"
)batchnormalizationlayer
("name"
,"batchnorm_2"
)relulayer
("name"
,"relu_3"
)fullyconnectedlayer(10
,"name"
,"fc_2"
)softmaxlayer
("name"
,"softmax"
)classificationlayer
("name"
,"classoutput")]
;%analyzenetwork
(layers)
%分析網路
%% 訓練網路
options =
trainingoptions
('sgdm',.
..'minibatchsize'
,512,.
..'maxepochs',1
,...
'initiallearnrate'
,1e-2,.
..'shuffle'
,'every-epoch',.
..'validationdata'
,imdsvalidation,..
.'validationfrequency',3
,...
'verbose',1
,...
'plots'
,'training-progress');
%設定訓練策略
trainednet =
trainnetwork
(augimdstrain,layers,options)
;%訓練
%% 測試模型
[ypred,probs]
=classify
(trainednet,imds_test)
; accuracy =
mean
(ypred == imds_test.labels)
這裡面,用到了一些函式,一些重要的用法我都寫在其他部落格裡了,這兒只大致說一下有什麼用
訓練結果:
如果需要處理好的資料集,可以留下郵箱~
最後在說明一下,網路是隨便搭的,不要用!!只是學習matlab用的
以上這些希望會對你有所幫助
深度學習之手寫數字識別
mnist是乙個入門級的計算機視覺資料集,它包含各種手寫數字 它也包含每一張對應的標籤,告訴我們這個是數字幾。比如,上面這四張的標籤分別是5,0,4,1。mnist資料集的官網是 yann lecun s website 這份 然後用下面的 匯入到你的專案裡面,也可以直接複製貼上到你的 檔案裡面。i...
Python基礎學習之手寫識別演算法
k 近鄰演算法 from numpy import python裡的計算包numpy import operator 運算子模組 import os 資料準備所需的函式 def createdataset group array 1.0,1.1 1.0,1.0 0,0 0,0.1 labels a ...
深度學習實踐
選擇合適的損失函式 mini batch 選擇不同的啟用函式 改變學習速度 momentum early stopping 正則化 dropout 改變網路架構 選擇合適的損失函式 mini batch當資料集很大時,訓練演算法是非常慢的,和 batch 梯度下降相比,使用 mini batch 梯...