對於矩陣乘法
c = a×b
,通常的做法是將矩陣進行分塊相乘,如下圖所示:
從上圖可以看出這種分塊相乘總共用了
8次乘法,當然對於子矩陣相乘(如a0×
b0),還可以繼續遞迴使用分塊相乘。對於中小矩陣來說,很適合使用這種分塊乘法,但是對於大矩陣來說,遞迴的次數較多,如果能減少每次分塊乘法的次數,那麼效能將可以得到很好的提高。
strassen
矩陣乘法就是採用了乙個簡單的運算技巧,將上面的
8次矩陣相乘變成了
7次乘法,看別小看這減少的
1次乘法,因為每遞迴
1次,效能就提高了
1/8,比如對於
1024*1024
的矩陣,第
1次先分解成7次
512*512
的矩陣相乘,對於
512*512
的矩陣,又可以繼續遞迴分解成
256*256
的矩陣相乘,
…,一直遞迴下去,假設分解到
64*64
的矩陣大小後就不再遞迴,那麼所花的時間將是分塊矩陣乘法的
(7/8) * (7/8) * (7/8) * (7/8) = 0.586
倍,提高了快接近一倍。當然這是理論上的值,因為實際上
strassen
乘法增加了其他運算開銷,實際效能會略低一點。
由上可見,strassen矩陣乘法是通過遞迴實現的,它將一般情況下二階矩陣乘法(可擴充套件到n階,但strassen矩陣乘法要求n是2的冪)所需的8次乘法降低為7次,其c++實現**如下:
下面就是
strassen
矩陣乘法的實現方法,
m1 = (a0 + a3) × (b0 + b3)
m2 = (a2 + a3) × b0
m3 = a0 × (b1 - b3)
m4 = a3 × (b2 - b0)
m5 = (a0 + a1) × b3
m6 = (a2 - a0) × (b0 + b1)
m7 = (a1 - a3) × (b2 + b3)
c0 = m1 + m4 - m5 + m7
c1 = m3 + m5
c2 = m2 + m4
c3 = m1 - m2 + m3 + m6
在求解m1,m2,m3,m4,m5,m6,m7
時需要使用
7次矩陣乘法,其他都是矩陣加法和減法。
下面看看
strassen
矩陣乘法的序列實現偽**:
serial_strassenmultiply(a, b, c)
#includeusing
namespace std;
const
int n = 6; //
define the size of the matrix
template
void strassen(int n, t a[n], t b[n], t c[n]);
template
void input(int n, t p[n]);
template
void output(int n, t c[n]);
int main()
template
void input(int n, t p[n])
}
}
template
void output(int n, t c[n])
}
}
}
template
void matrix_add(int n, t x[n], t y[n], t z[n]) else
}
//calculate m1 = (a0 + a3) × (b0 + b3)
matrix_add(n/2, a11, a22, aa);
matrix_add(n/2, b11, b22, bb);
strassen(n/2, aa, bb, m1);
//calculate m2 = (a2 + a3) × b0
matrix_add(n/2, a21, a22, aa);
strassen(n/2, aa, b11, m2);
//calculate m3 = a0 × (b1 - b3)
matrix_sub(n/2, b12, b22, bb);
strassen(n/2, a11, bb, m3);
//calculate m4 = a3 × (b2 - b0)
matrix_sub(n/2, b21, b11, bb);
strassen(n/2, a22, bb, m4);
//calculate m5 = (a0 + a1) × b3
matrix_add(n/2, a11, a12, aa);
strassen(n/2, aa, b22, m5);
//calculate m6 = (a2 - a0) × (b0 + b1)
matrix_sub(n/2, a21, a11, aa);
matrix_add(n/2, b11, b12, bb);
strassen(n/2, aa, bb, m6);
//calculate m7 = (a1 - a3) × (b2 + b3)
matrix_sub(n/2, a12, a22, aa);
matrix_add(n/2, b21, b22, bb);
strassen(n/2, aa, bb, m7);
//calculate c0 = m1 + m4 - m5 + m7
matrix_add(n/2, m1, m4, aa);
matrix_sub(n/2, m7, m5, bb);
matrix_add(n/2, aa, bb, c11);
//calculate c1 = m3 + m5
matrix_add(n/2, m3, m5, c12);
//calculate c2 = m2 + m4
matrix_add(n/2, m2, m4, c21);
//calculate c3 = m1 - m2 + m3 + m6
matrix_sub(n/2, m1, m2, aa);
matrix_add(n/2, m3, m6, bb);
matrix_add(n/2, aa, bb, c22);
//set the result to c[n]
for(int i=0; i2; i++)
}
}
}
演算法導論之演算法基礎(三)
插入排序 模擬 如果你會玩鬥地主,那麼摸牌後按從小到大插入,你這樣插入的過程就是插入排序 程式 在程式中的玩法就像有乙個人發牌,發齊了再拿牌,也就是一開始你就有17張牌,這17張牌對應17個元素的陣列。你從第二種牌開始進行調動,如果第二張牌比第一張牌小,那麼就把第二張牌抽出來,然後把第一張牌放入到第...
重溫演算法導論(三) 氣泡排序
氣泡排序原理簡單,從最後的元素與前面的元素比較,小於則交換,最後最小的在最左邊 偽 實現如下 for i 1 to length a 實際陣列的下標從0開始 do for j length a downto i 1 do if a j a j 1 then exchange a j a j 1 實際...
演算法導論 隨機演算法
一.概率分布 對於有些問題本身是屬於概率問題,如僱傭問題 對於此類問題,我們需要利用概率分析來得到演算法的執行時間,有時也用來分析其他的量。例如,僱傭問題中的費用問題也需要結合概率分析來計算得到。為了使用概率分析,我們必須使用或者假設已知關於輸入的概率分布,然後通過分析該演算法計算出平均情況下的執行...