pytorch 模型參數量 FLOPs統計方法

2021-10-25 15:36:57 字數 1555 閱讀 6612

二、使用函式統計模型參數量:

安裝:pip install torchstat

torchstat github 原始碼頁面

例子:

from torchstat import stat

model = model(

)stat(model,(3

,1280

,1280

))

輸出:會輸出模型各層網路的資訊,最後進行總結統計。

安裝:pip install ptflops

ptflops githhub原始碼頁面

例子:

import torchvision.models as models

import torch

from ptflops import get_model_complexity_info

with torch.cuda.device(0)

: net = models.densenet161(

) macs, params = get_model_complexity_info(net,(3

,224

,224

), as_strings=

true

, print_per_layer_stat=

true

, verbose=

true

)print

(' '

.format

('computational complexity: '

, macs)

)print

(' '

.format

('number of parameters: '

, params)

)

輸出:同樣會輸出模型各層的資訊,最後總結統計 參數量 和 flops。

注意:使用第三方工具時, 網路中有些層可能會不支援計算。

計算模型參數量 與 可訓練參數量:

def

get_parameter_number

(model)

: total_num =

sum(p.numel(

)for p in model.parameters())

trainable_num =

sum(p.numel(

)for p in model.parameters(

)if p.requires_grad)

return

result = get_parameter_number(model)

print

(result[

'total'

],result[

'trainable'])

#列印參數量

檢視模型各層引數(Pytorch

這個實驗用到的資料集是mnist資料集,維度是1 28 28 import torch.nn as nn class cnn nn.module def init self super cnn,self init 卷積層 self.conv1 nn.sequential in channels 1,...

pytorch 模型部分引數的載入

如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new conv1,將分類層fc的1000類輸出改為2類輸出的new fc,注意 要改一下名字與原來的不同。匯入模型 mynet resnet 然後就載...

pytorch 模型部分引數的載入

如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new conv1,將分類層fc的1000類輸出改為2類輸出的new fc,注意 要改一下名字與原來的不同。匯入模型 mynet resnet18 然後...