pytorch 中的topk函式

2021-10-01 17:41:51 字數 1036 閱讀 6942

1. 函式介紹

最近在**中看到這兩個語句

maxk = max(topk)

_, pred = output.topk(maxk, 1, true, true)

這個函式是用來求output中的最大值或最小值,返回兩個引數:其一返回output中的最大值(或最小值),其二返回該值的索引。

2. topk()函式原型:

(具體的用法參考pytorch官方中文文件:

torch.topk(input, k, dim=none, largest=true, sorted=true, out=none) -> (tensor, longtensor)
引數:

例如:k=1時,求的就是top-1的值

dim=1時,按行求最大或最小值

largest=true時,為求最大值,largest=false時,求最小值

3. 例項:

>>> import torch

>>> output=torch.randn((4, 2))

>>> output

tensor([[-0.3249, 1.0216],

[-0.6855, -0.0272],

[ 0.4624, 0.1392],

[ 1.8406, -0.3436]])

>>> _,pred=output.topk(1, 1, true, true) # 求top1

>>> _

tensor([[ 1.0216],

[-0.0272],

[ 0.4624],

[ 1.8406]])

>>> pred

tensor([[1],

[1],

[0],

[0]])

PyTorch中的topk函式詳解

聽名字就知道這個函式是用來求tensor中某個dim的前k大或者前k小的值以及對應的index。用法torch.topk input,k,dim none,largest true,sorted true,out none tensor,longtensor topk最常用的場合就是求乙個樣本被網路...

pytorch實現topk剪枝

這篇部落格,以mnist資料集為例,對lstm的權重矩陣實現top k剪枝 7,2 介紹了如何在pytorch框架下實現top k剪枝。可以使用如下 檢視模型都含有哪些權重矩陣 for name,in model.named parameters print name 矩陣每行含有28個引數,將其分...

pytorch中的gather函式

from 今天剛開始接觸,讀了一下documentation,寫乙個一開始每太搞懂的函式gather b torch.tensor 1,2,3 4,5,6 print bindex 1 torch.longtensor 0,1 2,0 index 2 torch.longtensor 0,1,1 0...