二、使用函式統計模型參數量:
安裝:pip install torchstat
torchstat github 原始碼頁面
例子:
from torchstat import stat
model = model(
)stat(model,(3
,1280
,1280
))
輸出:會輸出模型各層網路的資訊,最後進行總結統計。
安裝:pip install ptflops
ptflops githhub原始碼頁面
例子:
import torchvision.models as models
import torch
from ptflops import get_model_complexity_info
with torch.cuda.device(0)
: net = models.densenet161(
) macs, params = get_model_complexity_info(net,(3
,224
,224
), as_strings=
true
, print_per_layer_stat=
true
, verbose=
true
)print
(' '
.format
('computational complexity: '
, macs)
)print
(' '
.format
('number of parameters: '
, params)
)
輸出:同樣會輸出模型各層的資訊,最後總結統計 參數量 和 flops。
注意:使用第三方工具時, 網路中有些層可能會不支援計算。
計算模型參數量 與 可訓練參數量:
def
get_parameter_number
(model)
: total_num =
sum(p.numel(
)for p in model.parameters())
trainable_num =
sum(p.numel(
)for p in model.parameters(
)if p.requires_grad)
return
result = get_parameter_number(model)
print
(result[
'total'
],result[
'trainable'])
#列印參數量
檢視模型各層引數(Pytorch
這個實驗用到的資料集是mnist資料集,維度是1 28 28 import torch.nn as nn class cnn nn.module def init self super cnn,self init 卷積層 self.conv1 nn.sequential in channels 1,...
pytorch 模型部分引數的載入
如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new conv1,將分類層fc的1000類輸出改為2類輸出的new fc,注意 要改一下名字與原來的不同。匯入模型 mynet resnet 然後就載...
pytorch 模型部分引數的載入
如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new conv1,將分類層fc的1000類輸出改為2類輸出的new fc,注意 要改一下名字與原來的不同。匯入模型 mynet resnet18 然後...