PyTorch設定與更新可訓練引數

2021-10-14 02:59:57 字數 1236 閱讀 2609

1、設定可訓練引數

pytorch中可以使用torch.nn.parameter()來設定可訓練引數。parameter類是tensor類的子類,當它與module類一起使用時,也就是將乙個parameter物件作為module類的乙個屬性時,它們會自動新增到module的引數列表中,例如在該module類的parameters()迭代器中。

import torch

import torch.nn as nn

class mymodule(nn.module):

def __init__(self, num):

super(mymodule, self).__init__()

params = torch.ones(num, requires_grad=true)

self.params = nn.parameter(params)

def forward(self, x):

y = self.params * x

my_module = mymodule(10)

inputs = torch.ones(10)

outputs = my_module(inputs)

print(my_module.state_dict())

print(list(my_module.parameters()))

print(dict(my_module.named_parameters()))

此時params自動加入了my_module的parameters,列印結果:

ordereddict([('params', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]))])

[parameter containing:

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=true)]

2、更新自定義的可訓練引數

假如有乙個網路net,訓練時需要同時更新這個網路的引數和上面my_module的引數,可以這樣定義優化器:

optimizer = optim.sgd([, 

], lr=base_lr, momentum=momentum, weight_decay=weight_decay)

PyTorch 版本更新追蹤

backto pytorch index 功能公升級 主要的 torchvision,torchtext 和 torchaudio 庫,並推出將模型從 python api 轉換為 c api 等功能.工業級部署 facebook 還和 amazon 合作,推出了兩個重磅的工具 torchserve...

Pytorch學習率更新

自己在嘗試了官方的 後就想提高訓練的精度就想到了調整學習率,但固定的學習率肯定不適合訓練就嘗試了幾個更改學習率的方法,但沒想到居然更差!可能有幾個學習率沒怎麼嘗試吧 import torch import matplotlib.pyplot as plt matplotlib inline from...

Pytorch學習率更新

自己在嘗試了官方的 後就想提高訓練的精度就想到了調整學習率,但固定的學習率肯定不適合訓練就嘗試了幾個更改學習率的方法,但沒想到居然更差!可能有幾個學習率沒怎麼嘗試吧 import torch import matplotlib.pyplot as plt matplotlib inline from...