PyTorch 矩陣乘法總結

2022-02-09 12:15:50 字數 1247 閱讀 9738

torch.mm(mat1, mat2, out=none),其中mat1(\(n\times m\)),mat2(\(m\times d\)),輸出out的維度是(\(n\times d\))。

該函式一般只用來計算兩個二維矩陣的矩陣乘法,並且不支援broadcast操作。

由於神經網路訓練一般採用mini-batch,經常輸入的時三維帶batch的矩陣,所以提供torch.bmm(bmat1, bmat2, out=none),其中bmat1(\(b\times n \times m\)),bmat2(\(b\times m \times d\)),輸出out的維度是(\(b \times n \times d\))。

該函式的兩個輸入必須是三維矩陣且第一維相同(表示batch維度),不支援broadcast操作。

torch.matmul(input, other, out=none)支援broadcast操作,使用起來比較複雜。

針對多維資料matmul()乘法,我們可以認為該matmul()乘法使用使用兩個引數的後兩個維度來計算,其他的維度都可以認為是batch維度。假設兩個輸入的維度分別是input(\(1000 \times 500 \times 99 \times 11\)),other(\(500 \times 11 \times 99\))那麼我們可以認為torch.matmul(input, other, out=none)乘法首先是進行後兩位矩陣乘法得到\((99 \times 11) \times (11 \times 99)\rightarrow(99 \times 99)\) ,然後分析兩個引數的batch size分別是 \(( 1000 \times 500)\) 和 \(500\) , 可以廣播成為 \((1000 \times 500)\), 因此最終輸出的維度是(\(1000 \times 500 \times 99 \times 99\))。

torch.mul(mat1, other, out=none),其中other乘數可以是標量,也可以是任意維度的矩陣,只要滿足最終相乘是可以broadcast的即可

矩陣乘法總結

矩陣的資料 http huixisheng.download.csdn.net 到今天為止終於差不多把矩陣乘法的題目寫的差不多了。剩下的幾個留著慢慢思考,做矩陣的題目,最主要的關鍵是遞推公式。1 fic的通項可以這樣表示 f k 1 sqrt 5 hdu有兩個題目喲公道了這個犀利的推導公式 2 求s...

pytorch中的乘法

總結 按元素相乘用torch.mul,二維矩陣乘法用torch.mm,batch二維矩陣用torch.bmm,batch 廣播用torch.matmul if name main a torch.tensor 1 2,3 b torch.arange 0,12 reshape 4 3 c torch...

pytorch的各種乘法操作,點乘和矩陣乘

點乘 相應點相乘,x.mul y 即點乘操作,點乘不求和操作,又可以叫作hadamard product 哈達瑪積 相同位置的相乘,形狀保持不變 import torch x torch.tensor 3,3 3,3 y x x x.dot x z torch.mul x,x x.mul x pri...