這裡我們以pytorch自帶的預訓練模型為例來講解:
# load the pretrained model
alexnet = models.alexnet(pretrained=true).cuda()
print(alexnet)
alexnet (
(features): sequential (
(0): conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): relu (inplace)
(2): maxpool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))
(3): conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): relu (inplace)
(5): maxpool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))
(6): conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): relu (inplace)
(8): conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): relu (inplace)
(10): conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): relu (inplace)
(12): maxpool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))
) (classifier): sequential (
(0): dropout (p = 0.5)
(1): linear (9216 -> 4096)
(2): relu (inplace)
(3): dropout (p = 0.5)
(4): linear (4096 -> 4096)
(5): relu (inplace)
(6): linear (4096 -> 1000)
))
如果我們想檢視倒數第二層全連線層的結果,可以有以下三種方式實現:
class
alexnet
(torch.nn.module):
def__init__
(self):
super(alexnet, self).__init__()
self.model_name = 'alexnet'
self.features = nn.sequential(
nn.conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.relu(inplace=true),
nn.maxpool2d(kernel_size=3, stride=2),
nn.conv2d(64, 192, kernel_size=5, padding=2),
nn.relu(inplace=true),
nn.maxpool2d(kernel_size=3, stride=2),
nn.conv2d(192, 384, kernel_size=3, padding=1),
nn.relu(inplace=true),
nn.conv2d(384, 256, kernel_size=3, padding=1),
nn.relu(inplace=true),
nn.conv2d(256, 256, kernel_size=3, padding=1),
nn.relu(inplace=true),
nn.maxpool2d(kernel_size=3, stride=2),
)self.classifier = nn.sequential(
nn.dropout(),
nn.linear(256 * 6 * 6, 4096),
nn.relu(inplace=true),
nn.dropout(),
nn.linear(4096, 4096),
# nn.relu(inplace=true),
# nn.linear(4096, num_classes),
)def
forward
(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
new_model = alexnet().cuda()
pretrained_dict = orig_model.state_dict()
model_dict = new_model.state_dict()
# remove the key in pretrained_dict that do not belong the model_dict
pretrained_dicted =
# update the model_dict
model_dict.update(pretrained_dicted)
new_model.load_state_dict(model_dict)
這種方法的優點是可以進行多處刪減,實現起來比較清晰。缺點就是實現的過程比較複雜。當模型較大時,效率比較低。而且模型的引數和名稱必須與原來的模型對應。
# remove last fully-connected layer
# the number -2 indicate how many layers should be removed
alexnet.classifier = nn.sequential(*list(alexnet.classifier.children())[:-2])
print(alexnet)
alexnet (
(features): sequential (
(0): conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): relu (inplace)
(2): maxpool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))
(3): conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): relu (inplace)
(5): maxpool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))
(6): conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): relu (inplace)
(8): conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): relu (inplace)
(10): conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): relu (inplace)
(12): maxpool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))
) (classifier): sequential (
(0): dropout (p = 0.5)
(1): linear (9216 -> 4096)
(2): relu (inplace)
(3): dropout (p = 0.5)
(4): linear (4096 -> 4096)
)
這種方法的優點是速度快,方便,缺點是要刪除掉某些層,不易實現多處修改。適用於剔除最後幾層。
def
get_features_hook
(self, input, output):
print("hook",output.data.cpu().numpy().shape)
handle=alexnet.classifier[4].register_forward_hook(get_features_hook)
這種方法的有點是不用改變原有模型的結構,可以實現任意地點的精準檢視。缺點是每次呼叫可能會帶來時間上的消耗。
handle.remove()
最後可以通過handle控制代碼刪除hook,這裡的handle可以是任意名稱。 pytorch技巧 一 檢視模型結構
第一步 安裝graphviz,網上教程很多,也可以點這裡。注意記得配置環境變數。第二步 安裝torchviz,開啟終端輸入pip install torchviz 第三步 使用 import torch from torchviz import make dot class mlp torch.nn...
Git Git指南一 檢視建立刪除標籤
列出現有標籤,使用如下命令 xiaosi yoona code learningnotes git tag r 000000 000000 cm.cm v1.0.0 v1.0.1 我們可以用特定的搜尋模式列出符合條件的標籤。如果只對1.0系列的版本感興趣,可以執行如下命令 xiaosi yoona ...
Git Git指南一 檢視建立刪除標籤
列出現有標籤,使用如下命令 xiaosi yoona code learningnotes git tag r 000000 000000 cm.cm v1.0.0 v1.0.1 我們可以用特定的搜尋模式列出符合條件的標籤。如果只對1.0系列的版本感興趣,可以執行如下命令 xiaosi yoona ...