實驗的pytorch版本1.2.0
在訓練過程中可能需要固定一部分模型的引數,只更新另一部分引數。有兩種思路實現這個目標,乙個是設定不要更新引數的網路層為false,另乙個就是在定義優化器時只傳入要更新的引數。當然最優的做法是,優化器中只傳入requires_grad=true的引數,這樣占用的記憶體會更小一點,效率也會更高。
import torch
import torch.nn as nn
import torch.optim as optim
# 定義乙個簡單的網路
class net(nn.module):
def __init__(self, num_class=10):
super(net, self).__init__()
self.fc1 = nn.linear(8, 4)
self.fc2 = nn.linear(4, num_class)
def forward(self, x):
return self.fc2(self.fc1(x))
model = net()
# 凍結fc1層的引數
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = false
loss_fn = nn.crossentropyloss()
optimizer = optim.sgd(model.parameters(), lr=1e-2) # 傳入的是所有的引數
由實驗的結果可以看出:只要設定requires_grad=false雖然傳入模型所有的引數,仍然只更新requires_grad=true的。
# 定義乙個簡單的網路
class net(nn.module):
def __init__(self, num_class=3):
super(net, self).__init__()
self.fc1 = nn.linear(8, 4)
self.fc2 = nn.linear(4, num_class)
def forward(self, x):
return self.fc2(self.fc1(x))
model = net()
# 凍結fc1層的引數
# for name, param in model.named_parameters():
# if "fc1" in name:
# param.requires_grad = false
loss_fn = nn.crossentropyloss()
optimizer = optim.sgd(model.fc2.parameters(), lr=1e-2) # 只傳入fc2的引數
可以看出:只會更新優化器傳入的引數,對於沒有傳入的引數雖然可以求導,但是仍然不會更新引數。
就是將上面兩種結合起來,不更新的引數設定為false同時不傳入。
# 定義乙個簡單的網路
class net(nn.module):
def __init__(self, num_class=3):
super(net, self).__init__()
self.fc1 = nn.linear(8, 4)
self.fc2 = nn.linear(4, num_class)
def forward(self, x):
return self.fc2(self.fc1(x))
model = net()
# 凍結fc1層的引數
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = false
loss_fn = nn.crossentropyloss()
optimizer = optim.sgd(model.fc2.parameters(), lr=1e-2)
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
for epoch in range(10):
x = torch.randn((3, 8))
label = torch.randint(0,3,[3]).long()
output = model(x)
loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print()
Pytorch載入部分引數並凍結
pytorch 模型部分引數的載入 pytorch中,只匯入部分模型引數的做法 correct way to freeze layers pytorch自由載入部分模型引數並凍結 pytorch凍結部分引數訓練另一部分 pytorch更新部分網路,其他不更新 pytorch固定部分引數 只訓練部分層...
pytorch凍結部分引數訓練另一部分
凍結引數僅需兩行 for param in net.parameters param.requires grad false另外乙個小技巧就是在nn.module裡凍結引數,這樣前面的引數就是false,而後面的不變。class net nn.module def init self super n...
pytorch載入模型與凍結
weights torch.load path with open a.pkl wb as f pickle.dump score dict,f weights pickle.load f 直接載入 model.load state dict weights 字典生成式載入 self.load st...