Why is my Strassen Matrix multiplier so fast?

2.3k views Asked by At

As an experiment I implemented the Strassen Matrix Multiplication Algorithm to see if truly lead to faster code for large n.

https://github.com/wcochran/strassen_multiplier/blob/master/mm.c

To my surprise it was way faster for large n. For example, the n=1024 case took 17.20 seconds using the conventional method whereas it only took 1.13 seconds using the Strassen method (2x2.66 GHz Xeon). What -- a 15x speedup!? It should only be marginally faster. In fact, it seemed to be as good for even small 32x32 matrices!?

The only way I can explain this much of a speed-up is that my algorithm is more cache-friendly -- i.e., it focuses on small pieces of the matrices and thus the data is more localized. Maybe I should be doing all my matrix arithmetic piecemeal when possible.

Any other theories on why this is so fast?

3

There are 3 answers

0
primfaktor On

What is the loop order in your conventional multiplication? If you have

for (int i = 0; i < new_height; ++i)
{
    for (int j = 0; j < new_width; ++j)
    {
        double sum = 0.0;
        for (int k = 0; k < common; ++k)
        {
            sum += lhs[i * common + k] * rhs[k * new_width + j];
        }
        product[i * new_width + j] = sum;
    }
}

then you're not being very nice to the cache because you're accessing the right hand side matrix in a non-continuous manner. After reordering to

for (int i = 0; i < new_height; ++i)
{
    for (int k = 0; k < common; ++k)
    {
        double const fixed = lhs[i * common + k];
        for (int j = 0; j < new_width; ++j)
        {
            product[i * new_width + j] += fixed * rhs[k * new_width + j];
        }
    }
}

access to two matrices in the inner-most loop are continuous and one is even fixed. A good compiler would probably do this automatically, but I chose to explicitly pull it out for demonstration.

You didn't specify the language, but as for C++, advanced compilers even recognize the unfriendly loop order in some configurations and reorder them.

1
phkahler On

First question is "are the results correct?". If so, it's likely that your "conventional" method is not a good implementation.

The conventional method is not to use 3 nested FOR loops to scan the inputs in the order you learned in math class. One simple improvement is to transpose the matrix on the right so that it sits in memory with columns being coherent rather than rows. Modify the multiply loop to use this alternate layout and it will run much faster on a large matrix.

The standard matrix libraries implement much more cache friendly methods that consider the size of the data cache.

You might also implement a recursive version of the standard matrix product (subdivide into 2x2 matrix of matricies that are half the size). This will give something closer to optimal cache performance, which strassen gets from being recursive.

So either you're doing it wrong, or your conventional code is not optimized.

0
user1188672 On

The recursive nature of Strassen has better memory locality, so that may be a part of the picture. A recursive regular matrix multiplication is perhaps a reasonable thing to compare to.