總結:按元素相乘用torch.mul,二維矩陣乘法用torch.mm,batch二維矩陣用torch.bmm,batch、廣播用torch.matmul
if __name__==
"__main__"
: a = torch.tensor([1
,2,3
])b = torch.arange(0,
12).reshape((4
,3))
c = torch.tensor([4
,5,6
,7])
d = torch.arange(0,
12).reshape((3
,4))
aa = torch.unsqueeze(a, dim=1)
cc = torch.unsqueeze(c, dim=0)
print
('a:,a.shape:'
.format
(a, a.shape)
)print
('b:,b.shape:'
.format
(b, b.shape)
)print
('c:,c.shape:'
.format
(c, c.shape)
)print
('d:,d.shape:'
.format
(d, d.shape)
)print
('aa:,aa.shape:'
.format
(aa, aa.shape)
)print
('cc:,cc.shape:'
.format
(cc, cc.shape)
)#output:
a:tensor([1
,2,3
]),a.shape:torch.size([3
])b:tensor([[
0,1,
2],[
3,4,
5],[
6,7,
8],[
9,10,
11]])
,b.shape:torch.size([4
,3])
c:tensor([4
,5,6
,7])
,c.shape:torch.size([4
])d:tensor([[
0,1,
2,3]
,[4,
5,6,
7],[
8,9,
10,11]
]),d.shape:torch.size([3
,4])
aa:tensor([[
1],[
2],[
3]])
,aa.shape:torch.size([3
,1])
cc:tensor([[
4,5,
6,7]
]),cc.shape:torch.size([1
,4])
** torch.mul()元素乘:能自動增加維度,並且沿著新維度進行廣播,或者之前有新維度且為1。**
print
(a.mul(b)
)#a.shape:torch.size([3])
print
(aa.mul(d)
)#aa.shape:torch.size([3, 1])
#output:
tensor([[
0,2,
6],[
3,8,
15],[
6,14,
24],[
9,20,
33]])
tensor([[
0,1,
2,3]
,[8,
10,12,
14],[
24,27,
30,33]
])
torch.mm():二維矩陣相乘,並且滿足對應的乘法規則,不能廣播
print
(b.mm(d)
)#output:
tensor([[
20,23,
26,29]
,[56,
68,80,
92],[
92,113,
134,
155],[
128,
158,
188,
218]
])
** torch.matmul():可進行廣播,以及batch乘**
情況1:向量✖向量:點乘
>>
>
# vector x vector
>>
> tensor1 = torch.randn(3)
>>
> tensor2 = torch.randn(3)
>>
> torch.matmul(tensor1, tensor2)
.size(
)torch.size(
)
情況2:矩陣與向量相乘:向量增加乙個新維度1,矩陣相乘後,再將此維度移除。
如 tensor1 @ tensor2 --> (3,4)(4,) --> (3,4)(4,1) --> (3,1) --> (3,)
>>
>
# matrix x vector
>>
> tensor1 = torch.randn(3,
4)>>
> tensor2 = torch.randn(4)
>>
> torch.matmul(tensor1, tensor2)
.size(
)torch.size([3
])
情況3:批量矩陣與向量乘,向量broadcast到dim=1,以及dim=0,之後後兩個維度作矩陣乘法,對於增加的維度1最後移除。
如: tensor1 @ tensor2 --> (10,3,4)(4,) --> (10,3,4)(10,4,1) --> (10,3,1) --> (10,3)
>>
>
# batched matrix x broadcasted vector
>>
> tensor1 = torch.randn(10,
3,4)
>>
> tensor2 = torch.randn(4)
>>
> torch.matmul(tensor1, tensor2)
.size(
)torch.size([10
,3])
情況4:批量矩陣與批量矩陣相乘,後兩維度矩陣乘。
>>
>
# batched matrix x batched matrix
>>
> tensor1 = torch.randn(10,
3,4)
>>
> tensor2 = torch.randn(10,
4,5)
>>
> torch.matmul(tensor1, tensor2)
.size(
)torch.size([10
,3,5
])
情況5:批量矩陣與矩陣乘,矩陣先broadcast到批量數,之後後兩個維度乘。
>>
>
# batched matrix x broadcasted matrix
>>
> tensor1 = torch.randn(10,
3,4)
>>
> tensor2 = torch.randn(4,
5)>>
> torch.matmul(tensor1, tensor2)
.size(
)torch.size([10
,3,5
])
PyTorch 矩陣乘法總結
torch.mm mat1,mat2,out none 其中mat1 n times m mat2 m times d 輸出out的維度是 n times d 該函式一般只用來計算兩個二維矩陣的矩陣乘法,並且不支援broadcast操作。由於神經網路訓練一般採用mini batch,經常輸入的時三維...
Pytorch 中 torchvision的錯誤
在學習pytorch的時候,使用 torchvision的時候發生了乙個小小的問題 安裝都成功了,並且import torch也沒問題,但是在import torchvision的時候,出現了如下所示的錯誤資訊 dll load failed 找不到指定模組。首先,我們得知道torchvision在...
pytorch中index select 的用法
a torch.linspace 1,12,steps 12 view 3,4 print a b torch.index select a,0,torch.tensor 0,2 print b print a.index select 0,torch.tensor 0,2 c torch.inde...