scatter_(input, dim, index, src)將src中資料根據index中的索引按照dim的方向填進input中.
1) dim = 0,分別對每列填充:1 >>> x = torch.rand(2, 5)
2 >>> x
3 4 0.4319 0.6500 0.4080 0.8760 0.2355
5 0.2609 0.4711 0.8486 0.8573 0.1029
6 [torch.floattensor of size 2x5]
實現原理:>>> torch.zeros(3, 5).scatter_(0, torch.longtensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
0.4319 0.4711 0.8486 0.8760 0.2355
0.0000 0.6500 0.0000 0.8573 0.0000
0.2609 0.0000 0.4080 0.0000 0.1029
[torch.floattensor of size 3x5]
對於lonetensor內的矩陣,暫且稱為 tmp = [[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]];將最終的 3*5的矩陣,暫且稱為result。result初始為全0,需要經過scatter_處理。
舉例:對於tmp[0][0] = 0 -> 取x中x[0][0] = 0.4319,將其插入到result第0列的第0個位置,result[0][0] = 0.4319;
對於tmp[0][1] = 1 -> 取x中x[0][1] = 0.6500,將其插入到result第1列的第1個位置,result[1][1] = 0.6500;
對於tmp[0][2] = 2 -> 取x中x[0][1] = 0.4080,將其插入到result第2列的第2個位置,result[2][2] = 0.4080;
對於tmp[1][0] = 2 -> 取x中x[1][0] = 0.2609,將其插入到result第0列的第2個位置,result[2][0] = 0.2609;
對於tmp[1][1] = 0 -> 取x中x[1][1] = 0.4711,將其插入到result第1列的第0個位置,result[0][1] = 0.4711。
2) dim = 1,分別對每行填充
tmp = [[2], [3]]1 >>> z = torch.zeros(2, 4).scatter_(1, torch.longtensor([[2], [3]]), 1.23)
2 >>> z
3 4 0.0000 0.0000 1.2300 0.0000
5 0.0000 0.0000 0.0000 1.2300
6 [torch.floattensor of size 2x4]
tmp[0][0] = 2 -> 取x中x[0][0] = 0.4319,將其插入到result第0行的第2個位置,result[0][2] = 0.4319;
pytorch 中,一般函式加下劃線代表直接在原來的 tensor 上修改scatter(dim, index, src) 的引數有 3 個
這個 scatter 可以理解成放置元素或者修改元素簡單說就是通過乙個張量 src 來修改另乙個張量,哪個元素需要修改、用 src 中的哪個元素來修改由 dim 和 index 決定
官方文件給出了 3維張量 的具體操作說明,如下所示
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
exmaple:
x = torch.rand(2, 5)
#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
# [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
# [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
# [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])
具體地說,我們的 index 是 torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]),乙個二維張量,下面用圖簡單說明
我們是 2維 張量,一開始進行 self[index[0][0]][0]self[index[0][0]][0],其中 index[0][0]index[0][0] 的值是0,所以執行 self[0][0]=x[0][0]=0.1940self[0][0]=x[0][0]=0.1940
src 除了可以是張量外,也可以是乙個標量
example:
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)
#tensor([[7., 7., 7., 7., 7.],
# [0., 7., 0., 7., 0.],
# [7., 0., 7., 0., 7.]]
scatter()一般可以用來對標籤進行 one-hot 編碼,這就是乙個典型的用標量來修改張量的乙個例子
example:
class_num = 10
batch_size = 4
label = torch.longtensor(batch_size, 1).random_() % class_num
#tensor([[6],
# [0],
# [3],
# [2]])
torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
# [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
# [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
Pytorch 學習筆記
本渣的pytorch 逐步學習鞏固經歷,希望各位大佬指正,寫這個部落格也為了鞏固下記憶。a a b c 表示從 a 取到 b 步長為 c c 2 則表示每2個數取1個 若為a 1 1 1 即表示從倒數最後乙個到正數第 1 個 最後乙個 1 表示倒著取 如下 陣列的第乙個為 0 啊,第 0 個!彆扭 ...
Pytorch學習筆記
資料集 penn fudan資料集 在學習pytorch官網教程時,作者對penn fudan資料集進行了定義,並且在自定義的資料集上實現了對r cnn模型的微調。此篇筆記簡單總結一下pytorch如何實現定義自己的資料集 資料集必須繼承torch.utils.data.dataset類,並且實現 ...
PyTorch入門筆記
原教程 資料集csv 此處使用numpy來匯入,除此之外還可以使用csv和pandas匯入 資料集鏈結 import csv import numpy as np wine path data chapter3 winequality white.csv 路徑 wineq numpy np.load...