pytorch 網路模型的設計 與繼承

2021-10-05 13:49:20 字數 1354 閱讀 8868

參考:來自 

深度特徵提取入門 

1. 網路模型 senet   參照 

#  model implement

# method1 從models庫直接匯出18層

from torchvision import models

resnet18 = models.resnet18( pretrained=0)

#print(resnet18)

# method2 # 按需設計層數

from torchvision.models import resnet

from resnet import basicblock

# resnet18 = resnet( basicblock, [2, 2, 2, 2], num_classes=10)

#print(resnet18)

# modify it as senet 按需設計結構

from senet.se_resnet import sebasicblock

from senet.se_module import selayer

senet9 = resnet(sebasicblock, [1, 1, 1, 1], num_classes=10)

print( senet9 )

# 重新設計 基礎塊 sebasicblock

在resnet網路結構的構建中有很多重複的子結構, resnet18和resnet34用的是基礎版block, 此時網路還不深,不太需要考慮模型的效率,而當網路加深到52,101,152層時則有必要引入bottleneck結構。

再進一進改進子塊,設計 sebasicblock, 整體結構採用 resnet, 只是塊嵌入 se , 因此,重設計基礎塊 sebasicblock,層selayer 。

1) 構建resnet:  繼承pytorch中網路的基類:torch.nn.module, 重寫初始化__init__和 forward方   法;

2)構建 基礎塊 sebasicblock:  也是繼承pytorch中網路的基類:torch.nn.module, 重寫初始化__init__和 forward方 法;

在初始化__init__中主要是定義一些層的引數。

forward方法中主要是定義資料在層之間的流動順序,也就是層的連線順序。

然後  生成網路  res18=resnet( basicblock, [2, 2, 2, 2], num_classes=10)

senet9 = resnet(sebasicblock, [1, 1, 1, 1], num_classes=10)

2. 網路模型 視覺化

pytorch設計模型

1.nn.modulelist使對於加入其中的子模組,不必在forward中依次呼叫 nn.sequentialt使對於加入其中的子模組在forward中可以通過迴圈實現呼叫 2.pytorch中nn.modulelist和nn.sequential的用法和區別 nn.sequential定義的網路...

pytorch 模型設計專題丨

複製層的初始化問題,使用deepcopyimport torch import torch.nn as nn import copy class mymodule nn.module def init self super mymodule,self init self.layer nn.linea...

pytorch載入模型與凍結

weights torch.load path with open a.pkl wb as f pickle.dump score dict,f weights pickle.load f 直接載入 model.load state dict weights 字典生成式載入 self.load st...