pytorch 更新部分引數(凍結引數)注意事項

2021-10-09 07:38:47 字數 3611 閱讀 7975

實驗的pytorch版本1.2.0

在訓練過程中可能需要固定一部分模型的引數,只更新另一部分引數。有兩種思路實現這個目標,乙個是設定不要更新引數的網路層為false,另乙個就是在定義優化器時只傳入要更新的引數。當然最優的做法是,優化器中只傳入requires_grad=true的引數,這樣占用的記憶體會更小一點,效率也會更高。

import torch

import torch.nn as nn

import torch.optim as optim

# 定義乙個簡單的網路

class net(nn.module):

def __init__(self, num_class=10):

super(net, self).__init__()

self.fc1 = nn.linear(8, 4)

self.fc2 = nn.linear(4, num_class)

def forward(self, x):

return self.fc2(self.fc1(x))

model = net()

# 凍結fc1層的引數

for name, param in model.named_parameters():

if "fc1" in name:

param.requires_grad = false

loss_fn = nn.crossentropyloss()

optimizer = optim.sgd(model.parameters(), lr=1e-2) # 傳入的是所有的引數

由實驗的結果可以看出:只要設定requires_grad=false雖然傳入模型所有的引數,仍然只更新requires_grad=true的。

# 定義乙個簡單的網路

class net(nn.module):

def __init__(self, num_class=3):

super(net, self).__init__()

self.fc1 = nn.linear(8, 4)

self.fc2 = nn.linear(4, num_class)

def forward(self, x):

return self.fc2(self.fc1(x))

model = net()

# 凍結fc1層的引數

# for name, param in model.named_parameters():

# if "fc1" in name:

# param.requires_grad = false

loss_fn = nn.crossentropyloss()

optimizer = optim.sgd(model.fc2.parameters(), lr=1e-2) # 只傳入fc2的引數

可以看出:只會更新優化器傳入的引數,對於沒有傳入的引數雖然可以求導,但是仍然不會更新引數。

就是將上面兩種結合起來,不更新的引數設定為false同時不傳入。

# 定義乙個簡單的網路

class net(nn.module):

def __init__(self, num_class=3):

super(net, self).__init__()

self.fc1 = nn.linear(8, 4)

self.fc2 = nn.linear(4, num_class)

def forward(self, x):

return self.fc2(self.fc1(x))

model = net()

# 凍結fc1層的引數

for name, param in model.named_parameters():

if "fc1" in name:

param.requires_grad = false

loss_fn = nn.crossentropyloss()

optimizer = optim.sgd(model.fc2.parameters(), lr=1e-2)

print("model.fc1.weight", model.fc1.weight)

print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):

x = torch.randn((3, 8))

label = torch.randint(0,3,[3]).long()

output = model(x)

loss = loss_fn(output, label)

optimizer.zero_grad()

loss.backward()

optimizer.step()

print("model.fc1.weight", model.fc1.weight)

print("model.fc2.weight", model.fc2.weight)

print()

Pytorch載入部分引數並凍結

pytorch 模型部分引數的載入 pytorch中,只匯入部分模型引數的做法 correct way to freeze layers pytorch自由載入部分模型引數並凍結 pytorch凍結部分引數訓練另一部分 pytorch更新部分網路,其他不更新 pytorch固定部分引數 只訓練部分層...

pytorch凍結部分引數訓練另一部分

凍結引數僅需兩行 for param in net.parameters param.requires grad false另外乙個小技巧就是在nn.module裡凍結引數,這樣前面的引數就是false,而後面的不變。class net nn.module def init self super n...

pytorch載入模型與凍結

weights torch.load path with open a.pkl wb as f pickle.dump score dict,f weights pickle.load f 直接載入 model.load state dict weights 字典生成式載入 self.load st...