pytorch 模型部分引數的載入
pytorch中,只匯入部分模型引數的做法
correct way to freeze layers
pytorch自由載入部分模型引數並凍結
pytorch凍結部分引數訓練另一部分
pytorch更新部分網路,其他不更新
pytorch固定部分引數(只訓練部分層)
如果載入現有模型的所有引數,我們常使用的是**如下:
torch.load(model.state_dict())
在訓練過程中,我們常常會使用預訓練模型,有時我們是在自己的模型中加入別人的某些模組,或者對別人的模型進行區域性修改,這個時候再使用torch.load(model.state_dict())
,就會出現類似這些的錯誤:runtimeerror: error(s) in loading state_dict for net:missing key(s) in state_dict:***
。出現這個錯誤就是某些引數缺失或者不匹配。
現有模型中引入的那部分網路結構的網路層的名稱和結構保持不變,這時候載入引數的**很簡單。
# 載入引入的網路模型
model_path = "***"
checkpoint = torch.load(os.path.join(model_path, map_location=torch.device('cpu'))
pretrained_dict = checkpoint['net']
# 獲取現有模型的引數字典
model_dict = model.state_dict()
# 獲取兩個模型相同網路層的引數字典
state_dict =
# update必不可少,實現相同key的value同步
model_dict.update(state_dict)
# 載入模型部分引數
model.load_state_dict(model_dict)
這個時候再直接使用上面的載入方法,會導致部分key的value無法實現更新。
我就曾在這個位置犯過很嚴重的錯誤。首先我定義了attentionresnet
,這是乙個unet來實現影象分割,然後在另乙個模型中我使用了這個模型self.attention_map = attentionresnet(***)
。因為我在引用的過程中並沒有對attentionresnet
那部分**進行修改,所以本能的覺得這部分網路層的名稱是相同的,所以載入這部分引數時,我直接使用了上面的方法。這個錯誤隱藏了差不多乙個星期。直到我開始凍結這部分引數進行訓練時,發現情況不對。因為我在輸出attention_map
的特徵圖時,我發現它是一張全黑圖(畫素全為0),這表示載入的引數不對,然後我嘗試輸出pretrained_dict
時,它是乙個空字典。然後繼續輸出pretrained_dict.keys()
(未修改之前的pretrained_dict
)和model_dict.keys()
發現預期相同的那部分key中都多了一部分attention_map.
。問題主要出在self.attention_map = attentionresnet(***)
這一句,它使原有的網路層名稱都加了個字首attention_map.
,知道了錯誤,修改起來很簡單。
# 載入引入的網路模型
model_path = "***"
checkpoint = torch.load(os.path.join(model_path, map_location=torch.device('cpu'))
pretrained_dict = checkpoint['net']
# 獲取現有模型的引數字典
model_dict = model.state_dict()
# 獲取兩個模型相同網路層的引數字典
state_dict =
# update必不可少,實現相同key的value同步
model_dict.update(state_dict)
# 載入模型部分引數
model.load_state_dict(model_dict)
其實我這個位置的修改有點投機,更加常規的方法是:
引用自pytorch自由載入部分模型引數並凍結
我們看出只要構建乙個字典,使得字典的keys和我們自己建立的網路相同,我們在從各種預訓練網路把想要的引數對著新的keys填進去就可以有乙個新的state_dict了,這樣我們就可以load這個新的state_dict,這是最普適的方法適用於所有的網路變化。先輸出兩個模型的引數字典,觀察需要載入的那部分引數所處的位置,然後利用for迴圈構建新的字典。
將需要固定的那部分引數的requires_grad
置為false.
在優化器中加入filter根據requires_grad
進行過濾.
ps: 解決attributeerror: 『nonetype』 object has no attribute 『data』
問題的一種思路就是凍結引數,參考部落格
**如下:
# requires_grad置為false
for p in net.***.parameters():
p.requires_grad = false
# filter
optimizer.sgd(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
當需要凍結的那部分引數的網路層名稱不太明確時,可以採用pytorch凍結部分引數訓練另一部分的思路,列印出所有網路層,通過引數名稱進行凍結。 pytorch 模型部分引數的載入
如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new conv1,將分類層fc的1000類輸出改為2類輸出的new fc,注意 要改一下名字與原來的不同。匯入模型 mynet resnet 然後就載...
pytorch 模型部分引數的載入
如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new conv1,將分類層fc的1000類輸出改為2類輸出的new fc,注意 要改一下名字與原來的不同。匯入模型 mynet resnet18 然後...
pytorch 更新部分引數(凍結引數)注意事項
實驗的pytorch版本1.2.0 在訓練過程中可能需要固定一部分模型的引數,只更新另一部分引數。有兩種思路實現這個目標,乙個是設定不要更新引數的網路層為false,另乙個就是在定義優化器時只傳入要更新的引數。當然最優的做法是,優化器中只傳入requires grad true的引數,這樣占用的記憶...