由於pytorch中,訓練產生的中間變數會在訓練結束後被釋放掉,因此想要將這些變數儲存下來,需要用到hook函式,hook可以理解為乙個外掛程式函式,掛載在原有函式上.
這個用於儲存反向傳播時候的梯度
flag =
1if flag:
#定義網路
w = torch.tensor([1
.], requires_grad=
true
) x = torch.tensor([2
.], requires_grad=
true
) a = torch.add(w, x)
b = torch.add(w,1)
y = torch.mul(a, b)
#定義乙個空列表,用於儲存hook捕捉的梯度
a_grad =
list()
#定義hook函式,
defgrad_hook
(grad)
:#將hook捕捉的梯度儲存到a_grad中
grad *=
2#return為tensor型別時候,會將tensor資料賦給被掛載的變數;return為none的時候則不操作
return grad*
3#掛載hook函式到tensor變數a上
handle = a.register_hook(grad_hook)
#執行反向傳播,這時候在執行反向傳播的過程中會執行a的hook函式
y.backward(
)# 檢視hook儲存的梯度
print
("w.grad: "
, w.grad)
handle.remove(
)
共有三種:
forward_pre_hook:記錄網路前向傳播前的特徵圖
forward_hook:記錄前向傳播後的特徵圖
backward_hook:記錄反向傳播後的梯度資料
flag =
1if flag:
#定義網路
class
net(nn.module)
:def
__init__
(self)
:super
(net, self)
.__init__(
) self.conv1 = nn.conv2d(1,
2,3)
self.pool1 = nn.maxpool2d(2,
2)defforward
(self, x)
: x = self.conv1(x)
x = self.pool1(x)
return x
defforward_hook
(module, data_input, data_output)
:def
forward_pre_hook
(module, data_input)
:print
("forward_pre_hook input:{}"
.format
(data_input)
)def
backward_hook
(module, grad_input, grad_output)
:print
("backward hook input:{}"
.format
(grad_input)
)print
("backward hook output:{}"
.format
(grad_output)
)# 初始化網路
net = net(
) net.conv1.weight[0]
.detach(
).fill_(1)
net.conv1.weight[1]
.detach(
).fill_(2)
net.conv1.bias.data.detach(
).zero_(
)# 註冊hook
fmap_block =
list()
input_block =
list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)
# inference
fake_img = torch.ones((1
,1,4
,4))
# batch size * channel * h * w
output = net(fake_img)
loss_fnc = nn.l1loss(
) target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward(
)
在執行
output = net(fake_img)
的時候,實際上是執行了
#---------------------這一段是判斷是否有forward_pre_hook,並執行-----------------
def_call_impl
(self,
*input
,**kwargs)
:for hook in itertools.chain(
_global_forward_pre_hooks.values(),
self._forward_pre_hooks.values())
: result = hook(self,
input
)if result is
notnone:if
notisinstance
(result,
tuple):
result =
(result,
)input
= result
#---------------------這一段是真正執行forward-----------------
if torch._c._get_tracing_state():
result = self._slow_forward(
*input
,**kwargs)
else
: result = self.forward(
*input
,**kwargs)
#---------------------這一段是判斷是否有forward_hook,並執行-----------------
for hook in itertools.chain(
_global_forward_hooks.values(),
self._forward_hooks.values())
: hook_result = hook(self,
input
, result)
if hook_result is
notnone
: result = hook_result
#---------------------這一段是判斷是否有backward_hook,並執行----------------- if(
len(self._backward_hooks)
>0)
or(len(_global_backward_hooks)
>0)
: var = result
while
notisinstance
(var, torch.tensor):if
isinstance
(var,
dict):
var =
next
((v for v in var.values()if
isinstance
(v, torch.tensor)))
else
: var = var[0]
grad_fn = var.grad_fn
if grad_fn is
notnone
:for hook in itertools.chain(
_global_backward_hooks.values(),
self._backward_hooks.values())
:return result
js中的鉤子機制 hook
什麼是鉤子機制?使用鉤子機制有什麼好處?鉤子機制也叫hook機制,或者你可以把它理解成一種匹配機制,就是我們在 中設定一些鉤子,然後程式執行時自動去匹配這些鉤子 這樣做的好處就是提高了程式的執行效率,減少了if else 的使用同事優化 結構。由於js是單執行緒的程式語言,所以程式的執行效率在前端開...
pytorch中的廣播機制
pytorch中的廣播機制和numpy中的廣播機制一樣,因為都是陣列的廣播機制 兩個維度不同的tensor可以相乘,示例a torch.arange 0,6 reshape 6 tensor 0,1,2,3,4,5 shape torch.size 6 ndim 1 b torch.arange 0...
鉤子 HOOK 機制的使用
wh mouse,gethookinfo,hinstance,getcurrentthreadid mymousehook.callbackfun callbackf mymousehook.isrun not mymousehook.isrun end end procedure uninstal...