先給出官方文件的解釋,我覺得官方的文件寫的已經很清楚了,四個引數分別是input,dim,index,out,輸出的tensor是以index為大小的tensor。
其中,這就是最關鍵的定義
out[i]
[j][k]
= tensor[index[i]
[j][k]
][j]
[k]# dim=0
out[i]
[j][k]
= tensor[i]
[index[i]
[j][k]
][k]
# dim=1
out[i]
[j][k]
= tensor[i]
[j][index[i]
[j][k]
]# dim=3
主要解釋一下dim,dim=0的時候,把index的元素放入行進行索引,有一點需要注意的是,引數index的tensor格式是除了第1維也就是行那一維之外,其他維的格式需與input保持一致!下面給個例子
import torch
a = torch.arange(0,
16).view(4,
4)index = torch.longtensor([[
0,1,
2,3]
])b = a.gather(
0, index)
print
(a)print
(index)
print
(b)#形象的理解就是在每一列的第index上進行索引
for j in
range(4
):print
(a[index[0]
[j]]
[j].item())
----
----
----
----
----
----
----
----
----
----
----
----
----
----
----
----
----
tensor([[
0,1,
2,3]
,[4,
5,6,
7],[
8,9,
10,11]
,[12,
13,14,
15]])
tensor([[
0,1,
2,3]
])tensor([[
0,5,
10,15]
])05
1015
dim = 1的時候,把index的元素放入列進行索引,有一點需要注意的是,引數index的tensor格式是除了第2維也就是列那一維之外,其他維的格式需與input保持一致!下面給個例子
import torch
a = torch.arange(0,
16).view(4,
4)index = torch.longtensor([[
0],[
1],[
2],[
3]])
b = a.gather(
1, index)
print
(a)print
(index)
print
(b)#形象的理解就是在每一行的第index列上進行索引
for j in
range(4
):print
(a[j]
[index[j][0
]].item())
----
----
----
----
----
----
----
----
----
----
----
----
----
----
----
----
----
tensor([[
0,1,
2,3]
,[4,
5,6,
7],[
8,9,
10,11]
,[12,
13,14,
15]])
tensor([[
0],[
1],[
2],[
3]])
tensor([[
0],[
5],[
10],[
15]])
051015
本人對矩陣的一些概念還有一些模糊不清,以上就是我的一些理解,希望有大佬可以一起交流一下,pytorch 的張量一開始很難處理清楚,還需慢慢來。 pytorch中的gather函式
from 今天剛開始接觸,讀了一下documentation,寫乙個一開始每太搞懂的函式gather b torch.tensor 1,2,3 4,5,6 print bindex 1 torch.longtensor 0,1 2,0 index 2 torch.longtensor 0,1,1 0...
pytorch的gather 方法詳解
首先,先將結果展示出來,後續根據結果來進行分析 t torch.tensor 1,2,3 4,5,6 index a torch.longtensor 0,0 0,1 index b torch.longtensor 0,1,1 1,0,0 print t print torch.gather t,...
我對pytorch中gather函式的一點理解
torch.gather input dim,index,out none tensor torch.gather input dim,index,out none tensor gathers values along an axis specified by dim.for a 3 d tens...