Strassen's algorithm for matrix multiplication in C++

349 views Asked by At

We were assigned to implemented the following function for Strassen's algorithm for matrix multiplication in C++, using recursion on the base cases for n=1 and n=2

#include "strassen_mm.hpp"

#include <iomanip>
#include <iostream>
#include <stdexcept>

template <typename T>
static void printMatrix(const vector<vector<T>>& matrix, string header) {
    cout << header << "\n";
    for (const auto& row : matrix) {
        for (const auto& elem : row) {
            cout << setw(10) << left << elem;
        }
        cout << endl << flush;
    }
}

template <typename T>
static std::vector<std::vector<T>> add_zero_row_col(
    const std::vector<std::vector<T>>& matrix) {
    // Create a new matrix with an extra row and column of zeros
    std::vector<std::vector<T>> result(matrix.size() + 1,
                                       std::vector<T>(matrix[0].size() + 1, 0));

    // Copy the original matrix into the new matrix
    for (size_t i = 0; i < matrix.size(); i++) {
        for (size_t j = 0; j < matrix[0].size(); j++) {
            result[i][j] = matrix[i][j];
        }
    }

    return result;
}

template <typename T>
static void remove_last_row_and_column(std::vector<std::vector<T>>& matrix) {
    if (matrix.empty() || matrix[0].empty()) {
        // Matrix is already empty
        return;
    }

    matrix.pop_back();  // Remove last row

    for (auto& row : matrix) {
        if (!row.empty()) {
            row.pop_back();  // Remove last column from each row
        }
    }
}

template<typename T> static 
vector<vector<T>> operator-(const vector<vector<T>>& a, const vector<vector<T>>& b) {
    // Check if the matrices have the same dimensions
    if (a.size() != b.size() || a[0].size() != b[0].size()) {
        throw std::invalid_argument("Matrices must have the same dimensions.");
    }
    
    // Create a new matrix to hold the result
    vector<vector<T>> result(a.size(), vector<T>(a[0].size()));
    
    // Subtract each element of the matrices and store the result in the new matrix
    for (int i = 0; i < a.size(); i++) {
        for (int j = 0; j < a[0].size(); j++) {
            result[i][j] = a[i][j] - b[i][j];
        }
    }
    
    return result;
}

template <typename T>
static vector<vector<T>> operator*(const vector<vector<T>>& a,
                                   const vector<vector<T>>& b) {
    // Get the dimensions of the matrices
    int rows_a = a.size();
    int cols_a = a[0].size();
    int rows_b = b.size();
    int cols_b = b[0].size();

    // Make sure the matrices can be multiplied
    if (cols_a != rows_b) {
        cout << "Error: cannot multiply matrices of these dimensions" << endl;
        return {};
    }

    // Create the result matrix
    vector<vector<T>> result(rows_a, vector<T>(cols_b));

    // Multiply the matrices
    for (int i = 0; i < rows_a; i++) {
        for (int j = 0; j < cols_b; j++) {
            T sum = 0;
            for (int k = 0; k < cols_a; k++) {
                sum += a[i][k] * b[k][j];
            }
            result[i][j] = sum;
        }
    }

    return result;
}

// summing up 2 matrices
template <typename T>
static vector<vector<T>> operator+(const vector<vector<T>>& matrix1,
                                   const vector<vector<T>>& matrix2) {
    // check that the matrices have the same size
    if (matrix1.size() != matrix2.size() ||
        matrix1[0].size() != matrix2[0].size()) {
        throw runtime_error("Error: matrices have different sizes");
    }

    // add the matrices element-wise
    vector<vector<T>> result(matrix1.size(), vector<T>(matrix1[0].size(), 0));
    for (int i = 0; i < matrix1.size(); i++) {
        for (int j = 0; j < matrix1[0].size(); j++) {
            result[i][j] = matrix1[i][j] + matrix2[i][j];
        }
    }

    // return the result
    return result;
}


template <typename T>
static vector<vector<T>> strassen_mm_internal(const vector<vector<T>>& A,
                                              const vector<vector<T>>& B,
                                              bool removelast = false) {
    if (A.size() == 0 || B.size() == 0) return vector<vector<T>>();

    if (A.size() != A[0].size() || B.size() != B[0].size() ||
        A.size() != B.size()) {
        throw runtime_error("strassen_mm size check failed ");
    }
    auto n = A.size();

    // base case for n=1
    if (n == 1) {
        return vector<vector<T>>(1, vector<T>(1, A[0][0] * B[0][0]));
    }

    // base cae for n=2
    if (n == 2) {
        vector<vector<T>> C(2, vector<T>(2, 0));
        T p1 = A[0][0] * (B[0][1] - B[1][1]);
        T p2 = (A[0][0] + A[0][1]) * B[1][1];
        T p3 = (A[1][0] + A[1][1]) * B[0][0];
        T p4 = A[1][1] * (B[1][0] - B[0][0]);
        T p5 = (A[0][0] + A[1][1]) * (B[0][0] + B[1][1]);
        T p6 = (A[0][1] - A[1][1]) * (B[1][0] + B[1][1]);
        T p7 = (A[0][0] - A[1][0]) * (B[0][0] + B[0][1]);
        C[0][0] = p5 + p4 - p2 + p6;
        C[0][1] = p1 + p2;
        C[1][0] = p3 + p4;
        C[1][1] = p5 + p1 - p3 - p7;
        return C;
    }

    // padd with zero row and zero column if n is odd
    if (n % 2 == 1) {
        auto Anew = add_zero_row_col(A);
        auto Bnew = add_zero_row_col(B);
        return strassen_mm_internal(Anew, Bnew, true);
    }

    // Divide the matrices into four smaller submatrices
    int half = n / 2;
    vector<vector<T>> A11(half, vector<T>(half));
    vector<vector<T>> A12(half, vector<T>(half));
    vector<vector<T>> A21(half, vector<T>(half));
    vector<vector<T>> A22(half, vector<T>(half));
    vector<vector<T>> B11(half, vector<T>(half));
    vector<vector<T>> B12(half, vector<T>(half));
    vector<vector<T>> B21(half, vector<T>(half));
    vector<vector<T>> B22(half, vector<T>(half));

    for (int i = 0; i < half; i++) {
        for (int j = 0; j < half; j++) {
            A11[i][j] = A[i][j];
            A12[i][j] = A[i][j + half];
            A21[i][j] = A[i + half][j];
            A22[i][j] = A[i + half][j + half];
            B11[i][j] = B[i][j];
            B12[i][j] = B[i][j + half];
            B21[i][j] = B[i + half][j];
            B22[i][j] = B[i + half][j + half];
        }
    }

    auto P1 = strassen_mm_internal((A11 + A22),(B11 + B22));
    auto P2 = strassen_mm_internal((A21 + A22),B11);
    auto P3 = strassen_mm_internal(A11,(B12 - B22));
    auto P4 = strassen_mm_internal(A22,(B21 - B11));
    auto P5 = strassen_mm_internal((A11+A12),B22);
    auto P6 = strassen_mm_internal((A21-A11),(B11+B12));
    auto P7 = strassen_mm_internal((A12 -A22),(B21+B22));

    auto C11= P1+P4-P5+P7;
    auto C12 = P3 +P5;
    auto C21 = P2+P4;
    auto C22 = P1 + P3 -P2 +P6;

    auto size = removelast ? n - 1 : n;
    vector<vector<T>> C(size, vector<T>(size));

    // Combine the four submatrices into the final result matrix
    for (int i = 0; i < n / 2; i++) {
        for (int j = 0; j < n / 2; j++) {
            C[i][j] = C11[i][j];
            if (j + n / 2 < size) C[i][j + n / 2] = C12[i][j];
            if (i + n / 2 < size) C[i + n / 2][j] = C21[i][j];
            if (i + n / 2 < size && j + n / 2 < size)
                C[i + n / 2][j + n / 2] = C22[i][j];
        }
    }

    return C;
}



template <typename T>
vector<vector<T>> strassen_mm(const vector<vector<T>>& A,
                              const vector<vector<T>>& B) {
    return strassen_mm_internal(A, B);
}

template vector<vector<double>> strassen_mm(const vector<vector<double>>& A,
                                            const vector<vector<double>>& B);

template vector<vector<float>> strassen_mm(const vector<vector<float>>& A,
                                           const vector<vector<float>>& B);

I get the correct results but the code performance is very slow (even with -O3 optimization ) - maybe 50 times slower than the simplest O(n^3) implementation of matrix multiplication. It is probably because I'm copying data back and forth. Any idea how to improve the performance ?

0

There are 0 answers