Strassen algorithm for n (n is uneven) dimension of a matrix

499 views Asked by At

I need this for a project of mine. I have a basic understanding of the algorithm, I copied the code from the internet but it seems to not work if the dimensions of a matrix are uneven (3 5 7, etc). The code is a part of the Matrix Class.

private static double[][] multiply(double[][] matrixA, double[][] matrixB) //recursive multiplication function { int n = matrixA.length;

    double[][] MatrixRes = new double[n][n];
    if (n == 1) {
        MatrixRes[0][0] = matrixA[0][0] * matrixB[0][0];
        return MatrixRes;
    }

        double[][] A11 = new double[n / 2][n / 2];
        double[][] A12 = new double[n / 2][n / 2];
        double[][] A21 = new double[n / 2][n / 2];
        double[][] A22 = new double[n / 2][n / 2];

        double[][] B11 = new double[n / 2][n / 2];
        double[][] B12 = new double[n / 2][n / 2];
        double[][] B21 = new double[n / 2][n / 2];
        double[][] B22 = new double[n / 2][n / 2];

        split(matrixA, A11, 0, 0);
        split(matrixA, A12, 0, n / 2);
        split(matrixA, A21, n / 2, 0);
        split(matrixA, A22, n / 2, n / 2);

        split(matrixB, B11, 0, 0);
        split(matrixB, B12, 0, n / 2);
        split(matrixB, B21, n / 2, 0);
        split(matrixB, B22, n / 2, n / 2);

        double[][] M1 = multiply(add(A11, A22), add(B11, B22));

        double[][] M2 = multiply(add(A21, A22), B11);

        double[][] M3 = multiply(A11, sub(B12, B22));

        double[][] M4 = multiply(A22, sub(B21, B11));

        double[][] M5 = multiply(add(A11, A12), B22);

        double[][] M6 = multiply(sub(A21, A11), add(B11, B12));

        double[][] M7 = multiply(sub(A12, A22), add(B21, B22));

        double[][] C11 = add(sub(add(M1, M4), M5), M7);

        double[][] C12 = add(M3, M5);

        double[][] C21 = add(M2, M4);

        double[][] C22 = add(sub(add(M1, M3), M2), M6);

        join(C11, MatrixRes, 0, 0);
        join(C12, MatrixRes, 0, n / 2);
        join(C21, MatrixRes, n / 2, 0);
        join(C22, MatrixRes, n / 2, n / 2);

    return MatrixRes;

}

private static double[][] add(double[][] matrix1, double[][] matrix2) // add 2 array matrixes
{
    int n = matrix2.length;

    double[][] sum = new double[n][n];

    for (int i = 0; i < n; i++){
        for (int j = 0; j < n; j++)
            sum[i][j] = matrix1[i][j] + matrix2[i][j];
    }

    return sum;
}

private static double[][] sub(double[][] matrix1, double[][] matrix2) // sub 2 array matrixes
{
    int n = matrix2.length;

    double[][] sub = new double[n][n];

    for (int i = 0; i < n; i++){
        for (int j = 0; j < n; j++)
            sub[i][j] = matrix1[i][j] - matrix2[i][j];
    }
    return sub;
}

private static void split(double[][] P, double[][] C, int iB, int jB) // split (used for multiplication)
{

    for (int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)

        for (int j1 = 0, j2 = jB; j1 < C.length;
             j1++, j2++)

            C[i1][j1] = P[i2][j2];
}

private static void join(double[][] C, double[][] P, int iB, int jB) // join (used for multiplication)

{

    for (int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)

        for (int j1 = 0, j2 = jB; j1 < C.length;
             j1++, j2++)

            P[i2][j2] = C[i1][j1];
}
1

There are 1 answers

1
Muhteva On

Check the matrix multiplication algorithm below. I added time complexity analysis, naive approach and a main function to test the code as well. It covers the cases if n is not a power of two (or uneven of course).

package Miscellaneous;

/*
 * Strassen's Matrix Multiplication
 * 
 * Belongs to a guy named Strassen (1969)
 * ~ O(n^2.81)
 * 
 * Algorithm Analysis:
 * T(n) = 7 * T(n/2)  - basic operation: multiplication (called 7 times recursively)
 * 
 * T(n) = 7 * 7 * T(n/4)
 *      = 7 * 7 * 7 * T(n/8)
 *        ...
 *      = 7^k * T(n/n)   n = 2^k
 *      = 7^(log2(n)) * 1
 *      = n^(log7 / log2)
 *      = n^2.81
 *  
 * This is the simplified analysis in which we don't take additions/subtraction into consideration.
 * Exact number of operations vary, however it's still O(n^2.81).
 */
public class MatrixMultiplication {
    
    
    // Applied on any matrix ~ O(n^3)
    public static double[][] naiveMultiplication(double[][] A, double[][] B) {
        if(A[0].length != B.length) return null;  // invalid multiplication
        
        double [][] C = new double[A.length][B[0].length];
        for(int i=0; i<C.length; i++) {
            for(int j=0; j<C[0].length; j++) {
                C[i][j] = 0;
                
                for(int k=0; k<A[i].length; k++) {
                    C[i][j] += A[i][k] * B[k][j];
                }       
            }
        }
        return C;
    }
    
    // applied for square matrices where n is a power of two (if matrix is square, we can make it power of two easily)
    public static double[][] StrassenMatrixMultiplication(double[][] A, double[][] B, int n) {
        double log = Math.log(n) / Math.log(2);
        if(log != (int)Math.ceil(log)) {  // not power of 2
            int closestPower = (int)Math.pow(2, Math.ceil(log));
            
            double[][] modifiedA = new double[closestPower][closestPower];
            double[][] modifiedB = new double[closestPower][closestPower];
            
            for(int i=0; i<modifiedA.length; i++) {
                for(int j=0; j<modifiedB.length; j++) {
                    if(i < n && j < n) {
                        modifiedA[i][j] = A[i][j];
                        modifiedB[i][j] = B[i][j];
                    } else {
                        modifiedA[i][j] = 0;
                        modifiedB[i][j] = 0;
                    }
                } 
            }
            A = modifiedA;
            B = modifiedB;
            n = closestPower;
        }

        if (n==1) {
            return new double[][] {{A[0][0] * B[0][0]}};
        } else {
            double[][] A11 = partition(A, 0, n/2, 0, n/2, n/2);
            double[][] A12 = partition(A, 0, n/2, n/2, n, n/2);
            double[][] A21 = partition(A, n/2, n, 0, n/2, n/2);
            double[][] A22 = partition(A, n/2, n, n/2, n, n/2);
            double[][] B11 = partition(B, 0, n/2, 0, n/2, n/2);
            double[][] B12 = partition(B, 0, n/2, n/2, n, n/2);
            double[][] B21 = partition(B, n/2, n, 0, n/2, n/2);
            double[][] B22 = partition(B, n/2, n, n/2, n, n/2);

            double[][] I = StrassenMatrixMultiplication(matrixOpt(A12, A22, '-'), matrixOpt(B21, B22, '+'), n/2);    // I = (A12 - A22)(B21 + B22)
            double[][] II =  StrassenMatrixMultiplication(matrixOpt(A11, A22, '+'), matrixOpt(B11, B22, '+'), n/2);  // II = (A11 + A22)(B11 + B22)
            double[][] III = StrassenMatrixMultiplication(matrixOpt(A11, A21, '-'), matrixOpt(B11, B12, '+'), n/2);  // III = (A11 - A21)(B11 + B12)
            double[][] IV = StrassenMatrixMultiplication(matrixOpt(A11, A12, '+'), B22, n/2);                        // IV = (A11 + A12)B22
            double[][] V =  StrassenMatrixMultiplication(A11, matrixOpt(B12, B22, '-'), n/2);                        // V = A11(B12 - B22)
            double[][] VI = StrassenMatrixMultiplication(A22, matrixOpt(B21, B11, '-'), n/2);                        // VI = A22(B21 - B11)
            double[][] VII = StrassenMatrixMultiplication(matrixOpt(A21, A22, '+'), B11, n/2);                        // VII = (A21 + A22)B11
            
            double[][] C11 = matrixOpt(matrixOpt(I, II, '+'), matrixOpt(VI, IV, '-'), '+');         // C11 = I + II - IV + VI
            double[][] C12 = matrixOpt(IV, V, '+');                                                 // C12 = IV + V
            double[][] C21 = matrixOpt(VI, VII, '+');                                               // C21 = VI + VII
            double[][] C22 = matrixOpt(matrixOpt(II, V, '+'), matrixOpt(III, VII, '+'), '-');       // C22 = II - III + V - VII
            
            double[][] C = collapse(C11, C12, C21, C22, n);
            return C;
        }
        
        

    }
    
    public static double[][] collapse(double[][] C11, double[][] C12, double[][] C21, double[][] C22, int n) {
        double[][] C = new double[n][n];
        for(int i=0; i<n; i++) {
            for(int j=0; j<n; j++) {
                if(i < n/2 && j < n/2) C[i][j] = C11[i][j];
                if(i < n/2 && j >= n/2) C[i][j] = C12[i][j - n/2];
                if(i >= n/2 && j < n/2) C[i][j] = C21[i - n/2][j];
                if(i >= n/2 && j >= n/2) C[i][j] = C22[i - n/2][j - n/2];
            }
        }
        return C;
    }
    public static double[][] partition(double[][] src, int row1, int row2, int col1, int col2, int n) {
        double[][] part = new double[n][n];
        for(int i=row1; i<row2; i++) {
            for(int j=col1; j<col2; j++) {
                part[i - row1][j - col1] = src[i][j];
            }
        }
        return part;
    }
    public static double[][] matrixOpt(double[][] A, double[][] B, char opr) {
        double [][] C = new double[A.length][A.length];
        for(int i=0; i<A.length; i++) {
            for(int j=0; j<A.length; j++) {
                C[i][j] = opr == '+' ?  A[i][j] + B[i][j] : A[i][j] - B[i][j];
            }
        }
        return C;
    }
    
    public static void main(String[] args) {
        double[][] A = {{1, 2, 3},
                 {4, 5, 6},
                 {7, 8, 9}};
        double [][] B = {{9, 8, 7},
                {6, 5, 4},
                {3, 2, 1}};

        
        double[][] res = StrassenMatrixMultiplication(A, B, A.length);
        
        for(int i=0; i<res.length; i++) {
            for(int j=0; j<res.length; j++) {
                System.out.print(res[i][j] + " ");
            }
            System.out.println();
        }
        
        
    }

}