from:
今天剛開始接觸,讀了一下documentation,寫乙個一開始每太搞懂的函式gather
b = torch.tensor([[1,2,3],[4,5,6]]
)print
bindex_1 = torch.longtensor([[0,1],[2,0]]
)index_2 = torch.longtensor([[0,1,1],[0,0,0]]
)print
torch.gather(b, dim=1
, index=index_1)
print
torch.gather(b, dim=0
, index=index_2)
觀察它的輸出結果:
1
2 3
4 5
6[torch.floattensor of size 2
x3] 1
2 64[torch.floattensor of size 2
x2] 1
5 6
1 2
3[torch.floattensor of size 2
x3]
這裡是官方文件的解釋
torch.gather(input, dim, index, out=none) → tensor
gathers values along an axis specified by dim.
for a 3
-d tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
parameters:
input (tensor) – the source tensor
dim (int) – the axis along which to index
index (longtensor) – the indices of elements to gather
out (tensor, optional) – destination tensor
example:
>>> t = torch.tensor([[1,2],[3,4]]
) >>> torch.gather(t, 1
, torch.longtensor([[0,0],[1,0]]
)) 1
1 4
3 [torch.floattensor of size 2
x2]
可以看出,gather的作用是這樣的,index實際上是索引,具體是行還是列的索引要看前面dim 的指定,比如對於我們的栗子,【1,2,3;4,5,6,】,指定dim=1,也就是橫向,那麼索引就是列號。index的大小就是輸出的大小,所以比如index是【1,0;0,0】,那麼看index第一行,1列指的是2, 0列指的是1,同理,第二行為4,4 。這樣就輸入為【2,1;4,4】,參考這樣的解釋看上面的輸出結果,即可理解gather的含義。
gather在one-hot為輸出的多分類問題中,可以把最大值座標作為index傳進去,然後提取到每一行的正確**結果,這也是gather可能的乙個作用。
2023年05月30日20:05:01
春去夏來,溫情演為慾望。 —— 作家, 安德烈莫羅阿
pytorch的gather 方法詳解
首先,先將結果展示出來,後續根據結果來進行分析 t torch.tensor 1,2,3 4,5,6 index a torch.longtensor 0,0 0,1 index b torch.longtensor 0,1,1 1,0,0 print t print torch.gather t,...
我對pytorch中gather函式的一點理解
torch.gather input dim,index,out none tensor torch.gather input dim,index,out none tensor gathers values along an axis specified by dim.for a 3 d tens...
pytorch的gather函式的一些粗略的理解
先給出官方文件的解釋,我覺得官方的文件寫的已經很清楚了,四個引數分別是input,dim,index,out,輸出的tensor是以index為大小的tensor。其中,這就是最關鍵的定義 out i j k tensor index i j k j k dim 0 out i j k tensor...