torch.gather(
input
,dim,index,out=
none
) → tensor
torch.gather(
input
, dim, index, out=
none
) → tensor
gathers values along an axis specified by dim.
for a 3
-d tensor the output is specified by:
out[i]
[j][k]
=input
[index[i]
[j][k]
][j]
[k]# dim=0
out[i]
[j][k]
=input
[i][index[i]
[j][k]
][k]
# dim=1
out[i]
[j][k]
=input
[i][j]
[index[i]
[j][k]
]# dim=2
parameters:
input
(tensor) – the source tensor
dim (
int) – the axis along which to index
index (longtensor) – the indices of elements to gather
out (tensor, optional) – destination tensor
example:
>>
> t = torch.tensor([[
1,2]
,[3,
4]])
>>
> torch.gather(t,
1, torch.longtensor([[
0,0]
,[1,
0]])
)114
3[torch.floattensor of size 2x2]
import torch
a = torch.tensor([[
1,2]
,[3,
4]])
b = torch.gather(a,
1,torch.longtensor([[
0,0]
,[1,
0]])
)#1. 取各個元素行號:[(0,y)(0,y)][(1,y)(1,y)]
#2. 取各個元素值做行號:[(0,0)(0,0)][(1,1)(1,0)]
#3. 根據得到的索引在輸入中取值
#[1,1],[4,3]
c = torch.gather(a,
0,torch.longtensor([[
0,0]
,[1,
0]])
)#1. 取各個元素列號:[(x,0)(x,1)][(x,0)(x,1)]
#2. 取各個元素值做行號:[(0,0)(0,1)][(1,0)(0,1)]
#3. 根據得到的索引在輸入中取值
#[1,2],[3,2]
假設輸入與上同;index=b;輸出為c
b中每個元素分別為b(0,0)=0,b(0,1)=0
b(1,0)=1,b(1,1)=0
如果dim=0(列)
則取b中元素的列號,如:b(0,1)的1
b(0,1)=0,所以c中的c(0,1)=輸入的(0,1)處元素2
如果dim=1(行)
則取b中元素的列號,如:b(0,1)的0
b(0,1)=0,所以c中的c(0,1)=輸入的(0,0)處元素1
總結如下:
輸出 元素 在 輸入張量 中的位置為:
輸出元素位置取決與同位置的index元素
dim=1時,取同位置的index元素的行號做行號,該位置處index元素做列號
dim=0時,取同位置的index元素的列號做列號,該位置處index元素做行號。
最後根據得到的索引在輸入中取值
index型別必須為longtensor
gather最終的輸出變數與index同形。
對Pytorch 中的contiguous理解說明
最近遇到這個函式,但查的中文部落格裡的解釋貌似不是很到位,這裡翻譯一下stackoverflow上的回答並加上自己的理解。在pytorch中,只有很少幾個操作是不改變tensor的內容本身,而只是重新定義下標與元素的對應關係的。換句話說,這種操作不進行資料拷貝和資料的改變,變的是元資料。這些操作是 ...
對Pytorch中backward()函式的理解
寫在第一句 這個部落格解釋的也很好,參考了很多 pytorch中的自動求導函式backward 所需引數含義 所以切入正題 backward 函式中的引數應該怎麼理解?官方 如果需要計算導數,可以在tensor上呼叫.backward 1.如果tensor是乙個標量 即它包含乙個元素的資料 則不需要...
對PyTorch中inplace欄位的全面理解
torch.nn.relu inplace true inplace true 表示進行原地操作,對上一層傳遞下來的tensor直接進行修改,如x x 3 inplace false 表示新建乙個變數儲存操作結果,如y x 3,x y inplace true 可以節省運算記憶體,不用多儲存變數。補...