引言
本篇介紹pytorch 的索引與切片
1234567
in[3]: a = torch.rand(4,3,28,28)in[4]: a[0].shape # 理解上相當於取第一張
out[4]: torch.size([3, 28, 28])
in[5]: a[0,0].shape # 第0張的第0個通道
out[5]: torch.size([28, 28])
in[6]: a[0,0,2,4] # 第0張第0個通道的第2行第4列的畫素點 標量
out[6]: tensor(0.4133) # 沒有用 包起來就是乙個標量 dim為0
1234567
8910
in[7]: a.shapeout[7]: torch.size([4, 3, 28, 28])
in[8]: a[:2].shape # 前面兩張的所有資料
out[8]: torch.size([2, 3, 28, 28])
in[9]: a[:2,:1,:,:].shape # 前面兩張的第0通道的資料
out[9]: torch.size([2, 1, 28, 28])
in[11]: a[:2,1:,:,:].shape # 前面兩張,第1,2通道的資料
out[11]: torch.size([2, 2, 28, 28])
in[10]: a[:2,-1:,:,:].shape # 前面兩張,最後乙個通道的資料 從-1到最末尾,就是它本身。
out[10]: torch.size([2, 1, 28, 28])
1234
a[:,:,0:28,0:28:2].shape # 隔點取樣out[12]: torch.size([4, 3, 28, 14])
a[:,:,::2,::2].shape
out[14]: torch.size([4, 3, 14, 14])
1234567
8910
in[17]: a.shapeout[17]: torch.size([4, 3, 28, 28])
in[19]: a.index_select(0, torch.tensor([0,2])).shape # 當前維度為0,取第0,2張
out[19]: torch.size([2, 3, 28, 28])
in[20]: a.index_select(1, torch.tensor([1,2])).shape # 當前維度為1,取第1,2個通道
out[20]: torch.size([4, 2, 28, 28])
in[21]: a.index_select(2,torch.arange(28)).shape # 第二個引數,只是告訴你取28行
out[21]: torch.size([4, 3, 28, 28])
in[22]: a.index_select(2, torch.arange(8)).shape # 取8行 [0,8)
out[22]: torch.size([4, 3, 8, 28])
1234567
8910
in[23]: a.shapeout[23]: torch.size([4, 3, 28, 28])
in[24]: a[...].shape # 所有維度
out[24]: torch.size([4, 3, 28, 28])
in[25]: a[0,...].shape # 後面都有,取第0個 = a[0]
out[25]: torch.size([3, 28, 28])
in[26]: a[:,1,...].shape
out[26]: torch.size([4, 28, 28])
in[27]: a[...,:2].shape # 當有...出現時,右邊的索引理解為最右邊,只取兩列
out[27]: torch.size([4, 3, 28, 2])
1234567
891011
1213
1415
16
in[31]: x = torch.randn(3,4)in[32]: x
out[32]:
tensor([[ 2.0373, 0.1586, 0.1093, -0.6493],
[ 0.0466, 0.0562, -0.7088, -0.9499],
[-1.2606, 0.6300, -1.6374, -1.6495]])
in[33]: mask = x.ge(0.5) # >= 0.5 的元素的位置上為1,其餘地方為0
in[34]: mask
out[34]:
tensor([[1, 0, 0, 0],
[0, 0, 0, 0],
[0, 1, 0, 0]], dtype=torch.uint8)
in[35]: torch.masked_select(x,mask)
out[35]: tensor([2.0373, 0.6300]) # 之所以打平是因為大於0.5的元素個數是根據內容才能確定的
in[36]: torch.masked_select(x,mask).shape
out[36]: torch.size([2])
1234567
in[39]: src = torch.tensor([[4,3,5],[6,7,8]]) # 先打平成1維的,共6列in[40]: src
out[40]:
tensor([[4, 3, 5],
[6, 7, 8]])
in[41]: torch.take(src, torch.tensor([0, 2, 5])) # 取打平後編碼,位置為0 2 5
out[41]: tensor([4, 5, 8])
pytorch索引與切片
torch會自動從左向右索引 例子 a torch.randn 4,3,28,28 表示類似乙個cnn 的的輸入資料,4表示這個batch一共有4張 而3表示的通道數為3 rgb 28,28 表示的大小 基本索引print a 0 shape torch.size 3,28,28 print a 0...
pytorch索引與切片
目錄 torch會自動從左向右索引 例子 a torch.randn 4,3,28,28 表示類似乙個cnn 的的輸入資料,4表示這個batch一共有4張 而3表示的通道數為3 rgb 28,28 表示的大小 基本索引print a 0 shape torch.size 3,28,28 print ...
Pytorch學習 張量的索引切片
張量的索引切片方式和numpy幾乎是一樣的。切片時支援預設引數和省略號。可以通過索引和切片對部分元素進行修改。此外,對於不規則的切片提取,可以使用torch.index select,torch.masked select,torch.take 如果要通過修改張量的某些元素得到新的張量,可以使用to...