對於矩陣乘法c = a × b,通常的做法是將矩陣進行分塊相乘,如下圖所示:
從上圖可以看出這種分塊相乘總共用了8次乘法,當然對於子矩陣相乘(如a0×b0),還可以繼續遞迴使用分塊相乘。對於中小矩陣來說,很適合使用這種分塊乘法,但是對於大矩陣來說,遞迴的次數較多,如果能減少每次分塊乘法的次數,那麼效能將可以得到很好的提高。
strassen矩陣乘法就是採用了乙個簡單的運算技巧,將上面的8次矩陣相乘變成了7次乘法,看別小看這減少的1次乘法,因為每遞迴1次,效能就提高了1/8,比如對於1024*1024的矩陣,第1次先分解成7次512*512的矩陣相乘,對於512*512的矩陣,又可以繼續遞迴分解成256*256的矩陣相乘,…,一直遞迴下去,假設分解到64*64的矩陣大小後就不再遞迴,那麼所花的時間將是分塊矩陣乘法的(7/8) * (7/8) * (7/8) * (7/8) = 0.586倍,提高了快接近一倍。當然這是理論上的值,因為實際上strassen乘法增加了其他運算開銷,實際效能會略低一點。
下面就是strassen矩陣乘法的實現方法,
m1 = (a0 + a3) × (b0 + b3)
m2 = (a2 + a3) × b0
m3 = a0 × (b1 - b3)
m4 = a3 × (b2 - b0)
m5 = (a0 + a1) × b3
m6 = (a2 - a0) × (b0 + b1)
m7 = (a1 - a3) × (b2 + b3)
c0 = m1 + m4 - m5 + m7
c1 = m3 + m5
c2 = m2 + m4
c3 = m1 - m2 + m3 + m6
在求解m1,m2,m3,m4,m5,m6,m7時需要使用7次矩陣乘法,其他都是矩陣加法和減法。
下面看看strassen矩陣乘法的序列實現偽**:
serial_strassenmultiply(a, b, c)
由上可見,strassen矩陣乘法是通過遞迴實現的,它將一般情況下二階矩陣乘法(可擴充套件到n階,但strassen矩陣乘法要求n是2的冪)所需的8次乘法降低為7次,其c++實現**如下:
#include using namespace std;
const int n = 6; //define the size of the matrix
templatevoid strassen(int n, t a[n], t b[n], t c[n]);
templatevoid input(int n, t p[n]);
templatevoid output(int n, t c[n]);
int main()
}}/**the output fanction of matrix*/
templatevoid output(int n, t c[n])
}
}}/**matrix addition*/
template void matrix_add(int n, t x[n], t y[n], t z[n]) else
} //calculate m1 = (a0 + a3) × (b0 + b3)
matrix_add(n/2, a11, a22, aa);
matrix_add(n/2, b11, b22, bb);
strassen(n/2, aa, bb, m1);
//calculate m2 = (a2 + a3) × b0
matrix_add(n/2, a21, a22, aa);
strassen(n/2, aa, b11, m2);
//calculate m3 = a0 × (b1 - b3)
matrix_sub(n/2, b12, b22, bb);
strassen(n/2, a11, bb, m3);
//calculate m4 = a3 × (b2 - b0)
matrix_sub(n/2, b21, b11, bb);
strassen(n/2, a22, bb, m4);
//calculate m5 = (a0 + a1) × b3
matrix_add(n/2, a11, a12, aa);
strassen(n/2, aa, b22, m5);
//calculate m6 = (a2 - a0) × (b0 + b1)
matrix_sub(n/2, a21, a11, aa);
matrix_add(n/2, b11, b12, bb);
strassen(n/2, aa, bb, m6);
//calculate m7 = (a1 - a3) × (b2 + b3)
matrix_sub(n/2, a12, a22, aa);
matrix_add(n/2, b21, b22, bb);
strassen(n/2, aa, bb, m7);
//calculate c0 = m1 + m4 - m5 + m7
matrix_add(n/2, m1, m4, aa);
matrix_sub(n/2, m7, m5, bb);
matrix_add(n/2, aa, bb, c11);
//calculate c1 = m3 + m5
matrix_add(n/2, m3, m5, c12);
//calculate c2 = m2 + m4
matrix_add(n/2, m2, m4, c21);
//calculate c3 = m1 - m2 + m3 + m6
matrix_sub(n/2, m1, m2, aa);
matrix_add(n/2, m3, m6, bb);
matrix_add(n/2, aa, bb, c22);
//set the result to c[n]
for(int i=0; i}}
}
Strassen矩陣相乘演算法
strassen的矩陣相乘方法是一種典型的分治演算法。目前為止,我們已經見過一些分治策略的演算法了,例如歸併排序和karatsuba大數快速乘法。現在,讓我再來看看分治策略的背後是什麼。同動態規劃不同,在動態規劃中,為了得到最終的解決方案,我們經常需要把乙個大的問題 展開 為幾個子問題,但是這裡,我...
strassen矩陣乘法 Strassen矩陣乘法
求矩陣a,b相乘的結果c 直接根據矩陣乘法的定義來遍歷計算。c 語言 void matrixmul int a,int b,int c,int m,int b,int n void test3 int b 3 2 int c 2 2 matrixmul int a,int b,int c,2,3,2...
矩陣乘法 之 strassen 演算法
一般情況下矩陣乘法需要三個for迴圈,時間複雜度為o n 3 現在我們將矩陣分塊如圖 來自mit演算法導論 一般演算法需要八次乘法 r a e b g s a f b h t c e d g u c f d h strassen將其變成7次乘法,因為大家都知道乘法比加減法消耗更多,所有時間複雜更高!...