Pytorch學習筆記(深度之眼)(4)之模型容器

2021-10-13 05:09:32 字數 3824 閱讀 9842

1)# 展開,形狀變換

x = self.classifier(x)

return x

在__init__()模組中,採用sequential()對卷積層和池化層進行包裝,把sequential類屬性賦予feature,然後對三個全連線層進行sequential包裝,賦值為classifier類屬性,這就完成了模型構建的第一步。foward構建了前向傳播過程,只有三行,非常簡潔。

我們用sequential構建lenet,lenet中有乙個features,型別為sequential,sequential中有六個網路層,以序號(0)-(5)命名;還有乙個classifier,一樣是sequential。這裡存在乙個問題,這裡的網路層是沒有名字的,是通過序號索引的,如果在乙個上千層的網路中,很難採用序號去進行索引每乙個網路層。這時候可以對網路層進行命名,這就是第二種sequential的方法,對sequential輸入乙個頭有序的字典,以這種方式構建網路,**如下所示。

class

lenetsequentialorderdict

(nn.module)

:def

__init__

(self, classes)

:super

(lenetsequentialorderdict, self)

.__init__(

) self.features = nn.sequential(ordereddict())

self.classifier = nn.sequential(ordereddict())

defforward

(self, x)

: x = self.features(x)

x = x.view(x.size()[

0],-

1)x = self.classifier(x)

return x

nn.sequential是nn.module的容器,用於按順序包裝一組網路層:

順序性:各網路層之間嚴格按照順序構建;

自帶forward():自帶的forward裡,通過for迴圈依次執行前向傳播運算;

)])# 列表生成式,採用for迴圈

defforward

(self, x)

:for i, linear in

enumerate

(self.linears)

: x = linear(x)

return x

net = modulelist(

)通過單步除錯,進入了nn.modulelist的init()函式當中,通過**可以發現,如果mpdules不是空的話,則會不斷進行疊加,得到linear。

可以看到,通過nn.modulelist可以簡便地建立乙個二十層的全連線網路模型。

)nn.sequential:順序性:各網路層之間嚴格按順序執行,常用於block構建;

nn.modulelist:迭代性,常用於大量重複網的構建,通過for迴圈實現重複構建;

nn.moduledict:索引性,常用於可選擇的網路層;

這句話一般出現在model類的forward函式中,具體位置一般都是在呼叫分類器之前。分類器是乙個簡單的nn.linear()結構,輸入輸出都是維度為一的值,x = x.view(x.size(0), -1) 這句話的出現就是為了將前面多維度的tensor展平成一維。

iew()函式的功能根reshape類似,用來轉換size大小。x = x.view(batchsize, -1)中batchsize指轉換後有幾行,而-1指在不告訴函式有多少列的情況下,根據原tensor資料和batchsize自動分配列數。

PyTorch 深度學習 筆記

方差 偏差 線性回歸來確定兩種或兩種以上變數間相互依賴的定量關係。線性回歸對於輸入x和輸出y有乙個對映 類似ax b 而我們是訓練a b這兩個引數。以下 使用pytorch建立乙個線性的模型來對其進行擬合,即訓練過程。def linear ex x np.random.rand 256 noise ...

深度學習 Pytorch學習筆記(一)

pytorch中需要自己定義網路模型,該模型需要封裝到乙個自定義的類中,該類只是乙個子類,其繼承的父類為torch.nn.module,可見如下樹形結構圖 module實際又是繼承了object類,關於為什麼要繼承object類有興趣的可以看這篇部落格mro演算法 也就是說,自定義的模型除了要有 i...

深度學習 Pytorch學習筆記(五)

pytorch實現卷積神經網路 執行效率相對於keras慢太多 import torch import warnings import torchvision from torchvision.datasets import mnist from torch.utils.data import da...