1. 函式介紹
最近在**中看到這兩個語句
maxk = max(topk)
_, pred = output.topk(maxk, 1, true, true)
這個函式是用來求output中的最大值或最小值,返回兩個引數:其一返回output中的最大值(或最小值),其二返回該值的索引。
2. topk()函式原型:
(具體的用法參考pytorch官方中文文件:
torch.topk(input, k, dim=none, largest=true, sorted=true, out=none) -> (tensor, longtensor)
引數:
例如:k=1時,求的就是top-1的值
dim=1時,按行求最大或最小值
largest=true時,為求最大值,largest=false時,求最小值
3. 例項:
>>> import torch
>>> output=torch.randn((4, 2))
>>> output
tensor([[-0.3249, 1.0216],
[-0.6855, -0.0272],
[ 0.4624, 0.1392],
[ 1.8406, -0.3436]])
>>> _,pred=output.topk(1, 1, true, true) # 求top1
>>> _
tensor([[ 1.0216],
[-0.0272],
[ 0.4624],
[ 1.8406]])
>>> pred
tensor([[1],
[1],
[0],
[0]])
PyTorch中的topk函式詳解
聽名字就知道這個函式是用來求tensor中某個dim的前k大或者前k小的值以及對應的index。用法torch.topk input,k,dim none,largest true,sorted true,out none tensor,longtensor topk最常用的場合就是求乙個樣本被網路...
pytorch實現topk剪枝
這篇部落格,以mnist資料集為例,對lstm的權重矩陣實現top k剪枝 7,2 介紹了如何在pytorch框架下實現top k剪枝。可以使用如下 檢視模型都含有哪些權重矩陣 for name,in model.named parameters print name 矩陣每行含有28個引數,將其分...
pytorch中的gather函式
from 今天剛開始接觸,讀了一下documentation,寫乙個一開始每太搞懂的函式gather b torch.tensor 1,2,3 4,5,6 print bindex 1 torch.longtensor 0,1 2,0 index 2 torch.longtensor 0,1,1 0...