pytorch中的廣播機制和numpy中的廣播機制一樣, 因為都是陣列的廣播機制
兩個維度不同的tensor可以相乘, 示例
a = torch.arange(0,
6).reshape((6
,))'''
tensor([0, 1, 2, 3, 4, 5])
shape: torch.size([6])
ndim: 1
'''b = torch.arange(0,
12).reshape((2
,6))
'''tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]])
shape: torch.size([2, 6])
ndim: 2
'''# a和b的ndim不同, 但是可以element-wise相乘, 因為用到了廣播機制
res = torch.mul(a,b)
'''tensor([[ 0, 1, 4, 9, 16, 25],
[ 0, 7, 16, 27, 40, 55]])
shape: torch.size([2, 6])
ndim: 2
'''
如何理解陣列的廣播機制
以陣列a和陣列b的相加為例, 其餘數**算同理
核心:如果相加的兩個陣列的shape不同, 就會觸發廣播機制, 1)程式會自動執行操作使得a.shape==b.shape, 2)對應位置進行相加
運算結果的shape是:a.shape和b.shape對應位置的最大值,比如:a.shape=(1,9,4),b.shape=(15,1,4),那麼a+b的shape是(15,9,4)
有兩種情況能夠進行廣播
a.ndim > b.ndim, 並且a.shape最後幾個元素包含b.shape, 比如下面三種情況, 注意不要混淆ndim和shape這兩個基本概念
a.ndim == b.ndim, 並且a.shape和b.shape對應位置的元素要麼相同要麼其中乙個是1, 比如
下面分別進行舉例
a.ndim 大於 b.ndim
# a.shape=(2,2,3,4)
a = np.arange(1,
49).reshape((2
,2,3
,4))
# b.shape=(3,4)
b = np.arange(1,
13).reshape((3
,4))
# numpy會將b.shape調整至(2,2,3,4), 這一步相當於numpy自動實現np.tile(b,[2,2,1,1])
res = a + b
print
('***********************************'
)print
(a)print
(a.shape)
print
('***********************************'
)print
(b)print
(b.shape)
print
('***********************************'
)print
(res)
print
(res.shape)
print
('***********************************'
)print
(a+b == a + np.tile(b,[2
,2,1
,1])
)
a.ndim 等於 b.ndim#示例1
# a.shape=(4,3)
a = np.arange(12)
.reshape(4,
3)# b.shape=(4,1)
b = np.arange(4)
.reshape(4,
1)# numpy會將b.shape調整至(4,3), 這一步相當於numpy自動實現np.tile(b,[1,3])
res = a + b
print
('***********************************'
)print
(a)print
(a.shape)
print
('***********************************'
)print
(b)print
(b.shape)
print
('***********************************'
)print
(res)
print
(res.shape)
print
('***********************************'
)print
((a+b == a + np.tile(b,[1
,3])
))# 列印結果都是true
#示例2
# a.shape=(1,9,4)
a = np.arange(1,
37).reshape((1
,9,4
))# b.shape=(15,1,4)
b = np.arange(1,
61).reshape((15
,1,4
))res = a + b
print
('***********************************'
)# print(a)
print
(a.shape)
print
('***********************************'
)# print(b)
print
(b.shape)
print
('***********************************'
)# print(res)
print
(res.shape)
print
('***********************************'
)q = np.tile(a,[15
,1,1
])+ np.tile(b,[1
,9,1
])print
(q == res)
# 列印結果都是true
pytorch的廣播機制
廣播機制,就是將不同維度 不同長度的tensor,在滿足一定規則的前提下能夠自動進行長度和維度的擴充,從而使不同維度 不同長度的tensor之間正確的進行運算。自動廣播規則 兩個tensor能夠進行自動廣播需要滿足以下幾個規則 對應相等 其中乙個tensor的大小等於1 其中乙個tensor的某個維...
pytorch的廣播機制
廣播機制,就是將不同維度 不同長度的tensor,在滿足一定規則的前提下能夠自動進行長度和維度的擴充,從而使不同維度 不同長度的tensor之間正確的進行運算。自動廣播規則 兩個tensor能夠進行自動廣播需要滿足以下幾個規則 對應相等 其中乙個tensor的大小等於1 其中乙個tensor的某個維...
numpy中的廣播機制
numpy兩個陣列的相加 相減以及相乘都是對應元素之間的操作。import numpy as np x np.array 2,2,3 1,2,3 y np.array 1,1,3 2,2,4 print x y numpy當中的陣列相乘是對應元素的乘積,與線性代數當中的矩陣相乘不一樣 輸入結果如下 ...