矩陣乘法 之 strassen 演算法

2021-12-29 20:48:20 字數 2798 閱讀 5459

一般情況下矩陣乘法需要三個for迴圈,時間複雜度為o(n^3),現在我們將矩陣分塊如圖:( 來自mit演算法導論 )

一般演算法需要八次乘法

r = a * e + b * g ;

s = a * f  + b * h ;

t = c * e + d  * g; 

u = c * f + d * h;

strassen將其變成7次乘法,因為大家都知道乘法比加減法消耗更多,所有時間複雜更高!

strassen的處理是:

令:p1 = a * ( f - h )

p2 = ( a + b ) *  h

p3 = ( c +d ) * e

p4 = d *  ( g - e )

p5 = ( a + d ) * ( e + h )

p6 =  ( b - d ) * ( g + h ) 

p7 = ( a - c ) * ( e + f )

那麼我們可以知道:

r  = p5 + p4 + p6 - p2

s = p1 + p2

t = p3 + p4

u = p5 + p1 - p3 - p7

我們可以看到上面只有7次乘法和多次加減法,最終達到降低複雜度為o( n^lg7 ) ~= o( n^2.81 );

**實現如下:

[cpp]  

// strassen 演算法:將矩陣相乘的複雜度降到o(n^lg7) ~= o(n^2.81)  

// 原理是將8次乘法減少到7次的處理  

// 現在理論上的最好的演算法是o(n^2,367),僅僅是理論上的而已  

//  

//  

// 下面的**僅僅是簡單的例項而已,不必較真哦,呵呵~  

// 下面的空間可以優化的,此處就不麻煩了~  

#include  

#define  n  10  

//matrix + matrix  

void plus( int t[n/2][n/2], int r[n/2][n/2], int s[n/2][n/2] )  

}  }    

//matrix - matrix  

void minus( int t[n/2][n/2], int r[n/2][n/2], int s[n/2][n/2] )  

}  }    

//matrix * matrix  

void mul( int t[n/2][n/2], int r[n/2][n/2], int s[n/2][n/2]  )  

}  }  

}    

int main()  

}  printf("\ninput the second matrix...:\n");  

for( i = 0; i < n; i++ )  

}  // a b c d e f g h  

for( i = 0; i < n / 2; i++ )  

}  //p1  

minus( r, f, h );  

mul( p1, a, r );   

//p2  

plus( r, a, b );  

mul( p2, r, h );  

//p3  

plus( r, c, d );  

mul( p3, r, e );  

//p4  

minus( r, g, e );  

mul( p4, d, r );  

//p5  

plus( r, a, d );  

plus( s, e, f );  

mul( p5, r, s );  

//p6  

minus( r, b, d );  

plus( s, g, h );  

mul( p6, r, s );  

//p7  

minus( r, a, c );  

plus( s, e, f );  

mul( p7, r, s );  

//r = p5 + p4 - p2 + p6  

plus( t1, p5, p4 );  

minus( t2, t1, p2 );  

plus( r, t2, p6 );  

//s = p1 + p2  

plus( s, p1, p2 );  

//t = p3 + p4  

plus( t, p3, p4 );  

//u = p5 + p1 - p3 - p7 = p5 + p1 - ( p3 + p7 )  

plus( t1, p5, p1 );  

plus( t2, p3, p7 );  

minus( u, t1, t2 );  

for( i = 0; i < n / 2; i++ )  

}  printf("\n下面是strassen演算法處理結果:\n");  

for( i = 0; i < n; i++ )  

printf("\n");  

}  //下面是樸素演算法處理  

printf("\n下面是樸素演算法處理結果:\n");  

for( i = 0; i < n; i++ )  

}  }  

for( i = 0; i < n; i++ )  

printf("\n");  

}  return 0;  

}   

現在最好的計算矩陣乘法的複雜度是o( n^2.376 ),不過只是理論上的結果。此處僅僅做參考~

strassen矩陣乘法 Strassen矩陣乘法

求矩陣a,b相乘的結果c 直接根據矩陣乘法的定義來遍歷計算。c 語言 void matrixmul int a,int b,int c,int m,int b,int n void test3 int b 3 2 int c 2 2 matrixmul int a,int b,int c,2,3,2...

Strassen矩陣乘法

strassen矩陣乘法 strassen矩陣乘法是通過遞迴實現的,它將一般情況下二階矩陣乘法 可擴充套件到 階,但strassen矩陣乘法要求 是 的冪 所需的8次乘法降低為7次,將計算時間從o ne3 降低為o ne2.81 矩陣c ab,可寫為 c11 a11b11 a12b21 c12 a1...

Strassen矩陣乘法

矩陣乘法是線性代數中最常見的運算之一,它在數值計算中有廣泛的應用。若a和b是2個n n的矩陣,則它們的乘積c ab同樣是乙個n n的矩陣。a和b的乘積矩陣c中的元素c i,j 定義為 若依此定義來計算a和b的乘積矩陣c,則每計算c的乙個元素c i,j 需要做n個乘法和n 1次加法。因此,求出矩陣c的...