We were assigned to implemented the following function for Strassen's algorithm for matrix multiplication in C++, using recursion on the base cases for n=1 and n=2
#include "strassen_mm.hpp"
#include <iomanip>
#include <iostream>
#include <stdexcept>
template <typename T>
static void printMatrix(const vector<vector<T>>& matrix, string header) {
cout << header << "\n";
for (const auto& row : matrix) {
for (const auto& elem : row) {
cout << setw(10) << left << elem;
}
cout << endl << flush;
}
}
template <typename T>
static std::vector<std::vector<T>> add_zero_row_col(
const std::vector<std::vector<T>>& matrix) {
// Create a new matrix with an extra row and column of zeros
std::vector<std::vector<T>> result(matrix.size() + 1,
std::vector<T>(matrix[0].size() + 1, 0));
// Copy the original matrix into the new matrix
for (size_t i = 0; i < matrix.size(); i++) {
for (size_t j = 0; j < matrix[0].size(); j++) {
result[i][j] = matrix[i][j];
}
}
return result;
}
template <typename T>
static void remove_last_row_and_column(std::vector<std::vector<T>>& matrix) {
if (matrix.empty() || matrix[0].empty()) {
// Matrix is already empty
return;
}
matrix.pop_back(); // Remove last row
for (auto& row : matrix) {
if (!row.empty()) {
row.pop_back(); // Remove last column from each row
}
}
}
template<typename T> static
vector<vector<T>> operator-(const vector<vector<T>>& a, const vector<vector<T>>& b) {
// Check if the matrices have the same dimensions
if (a.size() != b.size() || a[0].size() != b[0].size()) {
throw std::invalid_argument("Matrices must have the same dimensions.");
}
// Create a new matrix to hold the result
vector<vector<T>> result(a.size(), vector<T>(a[0].size()));
// Subtract each element of the matrices and store the result in the new matrix
for (int i = 0; i < a.size(); i++) {
for (int j = 0; j < a[0].size(); j++) {
result[i][j] = a[i][j] - b[i][j];
}
}
return result;
}
template <typename T>
static vector<vector<T>> operator*(const vector<vector<T>>& a,
const vector<vector<T>>& b) {
// Get the dimensions of the matrices
int rows_a = a.size();
int cols_a = a[0].size();
int rows_b = b.size();
int cols_b = b[0].size();
// Make sure the matrices can be multiplied
if (cols_a != rows_b) {
cout << "Error: cannot multiply matrices of these dimensions" << endl;
return {};
}
// Create the result matrix
vector<vector<T>> result(rows_a, vector<T>(cols_b));
// Multiply the matrices
for (int i = 0; i < rows_a; i++) {
for (int j = 0; j < cols_b; j++) {
T sum = 0;
for (int k = 0; k < cols_a; k++) {
sum += a[i][k] * b[k][j];
}
result[i][j] = sum;
}
}
return result;
}
// summing up 2 matrices
template <typename T>
static vector<vector<T>> operator+(const vector<vector<T>>& matrix1,
const vector<vector<T>>& matrix2) {
// check that the matrices have the same size
if (matrix1.size() != matrix2.size() ||
matrix1[0].size() != matrix2[0].size()) {
throw runtime_error("Error: matrices have different sizes");
}
// add the matrices element-wise
vector<vector<T>> result(matrix1.size(), vector<T>(matrix1[0].size(), 0));
for (int i = 0; i < matrix1.size(); i++) {
for (int j = 0; j < matrix1[0].size(); j++) {
result[i][j] = matrix1[i][j] + matrix2[i][j];
}
}
// return the result
return result;
}
template <typename T>
static vector<vector<T>> strassen_mm_internal(const vector<vector<T>>& A,
const vector<vector<T>>& B,
bool removelast = false) {
if (A.size() == 0 || B.size() == 0) return vector<vector<T>>();
if (A.size() != A[0].size() || B.size() != B[0].size() ||
A.size() != B.size()) {
throw runtime_error("strassen_mm size check failed ");
}
auto n = A.size();
// base case for n=1
if (n == 1) {
return vector<vector<T>>(1, vector<T>(1, A[0][0] * B[0][0]));
}
// base cae for n=2
if (n == 2) {
vector<vector<T>> C(2, vector<T>(2, 0));
T p1 = A[0][0] * (B[0][1] - B[1][1]);
T p2 = (A[0][0] + A[0][1]) * B[1][1];
T p3 = (A[1][0] + A[1][1]) * B[0][0];
T p4 = A[1][1] * (B[1][0] - B[0][0]);
T p5 = (A[0][0] + A[1][1]) * (B[0][0] + B[1][1]);
T p6 = (A[0][1] - A[1][1]) * (B[1][0] + B[1][1]);
T p7 = (A[0][0] - A[1][0]) * (B[0][0] + B[0][1]);
C[0][0] = p5 + p4 - p2 + p6;
C[0][1] = p1 + p2;
C[1][0] = p3 + p4;
C[1][1] = p5 + p1 - p3 - p7;
return C;
}
// padd with zero row and zero column if n is odd
if (n % 2 == 1) {
auto Anew = add_zero_row_col(A);
auto Bnew = add_zero_row_col(B);
return strassen_mm_internal(Anew, Bnew, true);
}
// Divide the matrices into four smaller submatrices
int half = n / 2;
vector<vector<T>> A11(half, vector<T>(half));
vector<vector<T>> A12(half, vector<T>(half));
vector<vector<T>> A21(half, vector<T>(half));
vector<vector<T>> A22(half, vector<T>(half));
vector<vector<T>> B11(half, vector<T>(half));
vector<vector<T>> B12(half, vector<T>(half));
vector<vector<T>> B21(half, vector<T>(half));
vector<vector<T>> B22(half, vector<T>(half));
for (int i = 0; i < half; i++) {
for (int j = 0; j < half; j++) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][j + half];
A21[i][j] = A[i + half][j];
A22[i][j] = A[i + half][j + half];
B11[i][j] = B[i][j];
B12[i][j] = B[i][j + half];
B21[i][j] = B[i + half][j];
B22[i][j] = B[i + half][j + half];
}
}
auto P1 = strassen_mm_internal((A11 + A22),(B11 + B22));
auto P2 = strassen_mm_internal((A21 + A22),B11);
auto P3 = strassen_mm_internal(A11,(B12 - B22));
auto P4 = strassen_mm_internal(A22,(B21 - B11));
auto P5 = strassen_mm_internal((A11+A12),B22);
auto P6 = strassen_mm_internal((A21-A11),(B11+B12));
auto P7 = strassen_mm_internal((A12 -A22),(B21+B22));
auto C11= P1+P4-P5+P7;
auto C12 = P3 +P5;
auto C21 = P2+P4;
auto C22 = P1 + P3 -P2 +P6;
auto size = removelast ? n - 1 : n;
vector<vector<T>> C(size, vector<T>(size));
// Combine the four submatrices into the final result matrix
for (int i = 0; i < n / 2; i++) {
for (int j = 0; j < n / 2; j++) {
C[i][j] = C11[i][j];
if (j + n / 2 < size) C[i][j + n / 2] = C12[i][j];
if (i + n / 2 < size) C[i + n / 2][j] = C21[i][j];
if (i + n / 2 < size && j + n / 2 < size)
C[i + n / 2][j + n / 2] = C22[i][j];
}
}
return C;
}
template <typename T>
vector<vector<T>> strassen_mm(const vector<vector<T>>& A,
const vector<vector<T>>& B) {
return strassen_mm_internal(A, B);
}
template vector<vector<double>> strassen_mm(const vector<vector<double>>& A,
const vector<vector<double>>& B);
template vector<vector<float>> strassen_mm(const vector<vector<float>>& A,
const vector<vector<float>>& B);
I get the correct results but the code performance is very slow (even with -O3 optimization ) - maybe 50 times slower than the simplest O(n^3) implementation of matrix multiplication. It is probably because I'm copying data back and forth. Any idea how to improve the performance ?