torchvision中提供了很多訓練好的模型,這些模型是在1000類,224*224的imagenet中訓練得到的,很多時候不適合我們自己的資料,可以根據需要進行修改。
1、類別不同
#coding=utf-8
import
torchvision.models as models
#呼叫模型
model = models.resnet50(pretrained=true)
#提取fc層中固定的引數
fc_features =model.fc.in_features
#修改類別為9
model.fc = nn.linear(fc_features, 9)
2、新增層後,載入部分引數
model =...model_dict =model.state_dict()
#1. filter out unnecessary keys
pretrained_dict =
#2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
#3. load the new state dict
model.load_state_dict(model_dict)
參考:
pytorch 載入預訓練模型
pytorch的torchvision中給出了很多經典的預訓練模型,模型的引數和權重都是在imagenet資料集上訓練好的 載入模型 方法一 直接使用預訓練模型中的引數 import torchvision.models as models model models.resnet18 pretrai...
Pytorch 修改預訓練網路結構
我們以 inceptionv3 為例 pytorch裡我們如何使用設計好的網路結構,比如inceptionv3 import torchvision.models as models inception models.inception v3 pretrained true pytorch提供了個叫...
pytorch載入預訓練模型後,訓練指定層
1 有了已經訓練好的模型引數,對這個模型的某些層做了改變,如何利用這些訓練好的模型引數繼續訓練 pretrained params torch.load pretrained model model the new model model.load state dict pretrained par...