pytorch載入預訓練模型後,訓練指定層

2021-09-29 21:15:22 字數 1976 閱讀 7029

1、有了已經訓練好的模型引數,對這個模型的某些層做了改變,如何利用這些訓練好的模型引數繼續訓練:

pretrained_params = torch.load('pretrained_model')

model = the_new_model(***)

model.load_state_dict(pretrained_params.state_dict(), strict=false)

strict=false 使得預訓練模型引數中和新模型對應上的引數會被載入,對應不上或沒有的引數被拋棄。

2、如果載入的這些引數中,有些引數不要求被更新,即固定不變,不參與訓練,需要手動設定這些引數的梯度屬性為fasle,並且在optimizer傳參時篩選掉這些引數:

# 載入預訓練模型引數後...

for name, value in model.named_parameters():

if name 滿足某些條件:

value.requires_grad = false

# setup optimizer

params = filter(lambda p: p.requires_grad, model.parameters())

optimizer = torch.optim.adam(params, lr=1e-4)

將滿足條件的引數的 requires_grad 屬性設定為false, 同時 filter 函式將模型中屬性 requires_grad = true 的引數帥選出來,傳到優化器(以adam為例)中,只有這些引數會被求導數和更新。

3、如果載入的這些引數中,所有引數都更新,但要求一些引數和另一些引數的更新速度(學習率learning rate)不一樣,最好知道這些引數的名稱都有什麼:

# 載入預訓練模型引數後...

for name, value in model.named_parameters():

print(name)

# 或print(model.state_dict().keys())

假設該模型中有encoder,viewer和decoder兩部分,引數名稱分別是:

'encoder.visual_emb.0.weight',

'encoder.visual_emb.0.bias',

'viewer.bd.wsi',

'viewer.bd.bias',

'decoder.core.layer_0.weight_ih',

'decoder.core.layer_0.weight_hh',

假設要求encode、viewer的學習率為1e-6, decoder的學習率為1e-4,那麼在將引數傳入優化器時:

ignored_params = list(map(id, model.decoder.parameters()))

base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())

optimizer = torch.optim.adam([,

],lr=1e-4, momentum=0.9)

**的結果是除decoder引數的learning_rate=1e-4 外,其他引數的額learning_rate=1e-6。 

在傳入optimizer時,和一般的傳參方法torch.optim.adam(model.parameters(), lr=***)不同,引數部分用了乙個list, list的每個元素有paramslr兩個鍵值。如果沒有lr則應用adam的lr屬性。adam的屬性除了lr, 其他都是引數所共有的(比如momentum)。

參考:pytorch官方文件

pytorch 載入預訓練模型

pytorch的torchvision中給出了很多經典的預訓練模型,模型的引數和權重都是在imagenet資料集上訓練好的 載入模型 方法一 直接使用預訓練模型中的引數 import torchvision.models as models model models.resnet18 pretrai...

pytorch載入預訓練模型後,訓練指定層

1 有了已經訓練好的模型引數,對這個模型的某些層做了改變,如何利用這些訓練好的模型引數繼續訓練 pretrained params torch.load pretrained model model the new model model.load state dict pretrained par...

pytorch載入預訓練模型後,訓練指定層

1 有了已經訓練好的模型引數,對這個模型的某些層做了改變,如何利用這些訓練好的模型引數繼續訓練 pretrained params torch.load pretrained model model the new model model.load state dict pretrained par...