這篇部落格,以mnist資料集為例,對lstm的權重矩陣實現top-k剪枝(7,2),介紹了如何在pytorch框架下實現top-k剪枝。
可以使用如下**,檢視模型都含有哪些權重矩陣:
for name, _ in model.named_parameters():
print
(name)
矩陣每行含有28個引數,將其分為4組,每組7個元素,只保留最大的2個:
def
topk
(para, k)
:#對parameter,生成掩模矩陣,k表示保留前k個最大的
parameter = torch.
abs(para)
l =int(parameter.size()[
1]/7
) _, b = torch.topk(parameter[:,
0:7]
, k,
1, largest =
true
)for i in
range(1
,l):
_, b1 = torch.topk(parameter[
:,l*7:
(l+1)*
7], k,
1, largest =
true
)#該函式在dim=1上,保留前k個最大值,返回b1為前k個最大值的索引
b1 = b1 + i *
7 b = torch.cat(
(b,b1)
,dim =1)
c = torch.zeros(parameter.size()[
0], parameter.size()[
1],dtype = torch.
int)
#lstm權重矩陣為[4*28,28],所以這裡c也選這麼大
for i in
range
(c.size()[
0]):
for j in
range
(c.size()[
1]):
if j in b[i]
: c[i]
[j]=
1else
: c[i]
[j]=
0return c
c1,c2,c3,c4是根據四個權重矩陣生成的四個掩模矩陣(我定義的雙層lstm有四個權重矩陣),生成的掩模矩陣元素均為0或1
c1 = topk(rnn.lstm.weight_ih_l0.data,2)
c2 = topk(rnn.lstm.weight_hh_l0.data,2)
c3 = topk(rnn.lstm.weight_ih_l1.data,2)
c4 = topk(rnn.lstm.weight_hh_l1.data,
2)
生成的掩模矩陣如圖所示:
pytorch提供的自定義剪枝的模板,這裡分別將c1,c2,c3,c4作為掩模矩陣,這段**的意思就是,rnn模型中的lstm層的權重矩陣weight_ih_l0對應掩模矩陣c1, c1元素為1的位置,保留;c1為0的,weight_ih_l0對應的位置被剪枝掉,以此類推;
class
foobarpruningmethod1
(prune.basepruningmethod)
:"""prune every other entry in a tensor
"""pruning_type =
'unstructured'
defcompute_mask
(self, t, default_mask)
: mask = c1
return mask
class
foobarpruningmethod2
(prune.basepruningmethod)
:"""prune every other entry in a tensor
"""pruning_type =
'unstructured'
defcompute_mask
(self, t, default_mask)
: mask = c2
return mask
class
foobarpruningmethod3
(prune.basepruningmethod)
:"""prune every other entry in a tensor
"""pruning_type =
'unstructured'
defcompute_mask
(self, t, default_mask)
: mask = c3
return mask
class
foobarpruningmethod4
(prune.basepruningmethod)
:"""prune every other entry in a tensor
"""pruning_type =
'unstructured'
defcompute_mask
(self, t, default_mask)
: mask = c4
return mask
deffoobar_unstructured
(model)
: foobarpruningmethod1.
(model.lstm,
'weight_ih_l0'
) foobarpruningmethod2.
(model.lstm,
'weight_hh_l0'
) foobarpruningmethod3.
(model.lstm,
'weight_ih_l1'
) foobarpruningmethod3.
(model.lstm,
'weight_hh_l1'
)return model
rnn = foobar_unstructured(rnn)
#對預訓練完成的模型進行top-k剪枝
剪枝過後再訓練,會發現,剪枝後的訓練速度,明顯快於剪枝前。
剪枝後的矩陣如圖所示:
這篇部落格以mnist資料集為例,搭建了乙個含有雙層lstm,和fc層的模型,預訓練後對其進行top-k剪枝,詳細介紹了pytorch框架下的top-k剪枝過程;
pytorch 中的topk函式
1.函式介紹 最近在 中看到這兩個語句 maxk max topk pred output.topk maxk,1,true,true 這個函式是用來求output中的最大值或最小值,返回兩個引數 其一返回output中的最大值 或最小值 其二返回該值的索引。2.topk 函式原型 具體的用法參考p...
PyTorch中的topk函式詳解
聽名字就知道這個函式是用來求tensor中某個dim的前k大或者前k小的值以及對應的index。用法torch.topk input,k,dim none,largest true,sorted true,out none tensor,longtensor topk最常用的場合就是求乙個樣本被網路...
TOP K問題(c 實現)
top k問題 c 實現 給定乙個陣列,找出陣列中最大的k個數或者最小的k個數,稱為top k問題。這是面試的常考題,解法可以是基於最大堆 最大堆排序 基於快速排序實現等等,文字基於快速排序的思想實現。我們不會採用快速排序的演算法來實現top k問題,但我們可以利用快速排序的思想,在陣列中隨機找乙個...