Numpy matmul and einsum 6 to 7 times slower than MATLAB

339 views Asked by At

I am trying to port some code from MATLAB to Python and I am getting much slower performance from Python. I am not very good at Python coding, so any advise to speed these up will be much appreciated.

I tried an einsum one-liner (takes 7.5 seconds on my machine):

import numpy as np

n = 4
N = 200
M = 100
X = 0.1*np.random.rand(M, n, N)
w = 0.1*np.random.rand(M, N, 1)

G = np.einsum('ijk,iljm,lmn->il', w, np.exp(np.einsum('ijk,ljn->ilkn',X,X)), w)

I also tried a matmult implementation (takes 6 seconds on my machine)

G = np.zeros((M, M))
for i in range(M):
    G[:, i] = np.squeeze(w[i,...].T @ (np.exp(X[i, :, :].T @ X) @ w))

But my original MATLAB code is way faster (takes 1 second on my machine)

n = 4;
N = 200;
M = 100;
X = 0.1*rand(n, N, M);
w = 0.1*rand(N, 1, M);

G=zeros(M);
for i=1:M
    G(:,i) = squeeze(pagemtimes(pagemtimes(w(:,1,i).', exp(pagemtimes(X(:,:,i),'transpose',X,'none'))) ,w));
end

I was expecting both Python implementations to be comparable in speed, but they are not. Any ideas why the Python implementations are this slow, or any suggestions to speed those up?

2

There are 2 answers

4
Jérôme Richard On BEST ANSWER

First of all np.einsum has a parameter optimize which is set to False by default (mainly because the optimization can be more expensive than the computation in some cases and it is better in general to pre-compute the optimal path in a separate call first). You can use optimal=True to significantly speed-up np.einsum (it provides the optimal path in this case though the internal implementation is not be optimal). Note that pagemtimes in Matlab is more specific than np.einsum so there is not need for such a parameter (i.e. it is fast by default in this case).

Moreover, Numpy function like np.exp create a new array by default. The thing is computing arrays in-place is generally faster (and it also consumes less memory). This can be done thanks to the out parameter.

The np.exp is pretty expensive on most machines because it runs serially (like most Numpy functions) and it is often not very optimized internally either. Using a fast math library like the one of Intel helps. I suspect Matlab uses such kind of fast math library internally. Alternatively, one can use multiple threads to compute this faster. This is easy to do with the numexpr package.

Here is the resulting more optimized Numpy code:

import numpy as np
import numexpr as ne

# [...] Same initialization as in the question

tmp = np.einsum('ijk,ljn->ilkn',X,X, optimize=True)
ne.evaluate('exp(tmp)', out=tmp)
G = np.einsum('ijk,iljm,lmn->il', w, tmp, w, optimize=True)

Performance results

Here are results on my machine (with a i5-9600KF CPU, 32 GiB of RAM, on Windows):

Naive einsums:        6.62 s
CPython loops:        3.37 s
This answer:          1.27 s   <----

max9111 solution:     0.47 s   (using an unmodified Numba v0.57)
max9111 solution:     0.54 s   (using a modified Numba v0.57)

The optimized code is about 5.2 times faster than the initial code and 2.7 times faster than the initial fastest one!


Note about performances and possible optimizations

The first einsum takes a significant fraction of the runtime in the faster implementation on my machine. This is mainly because einsum perform many small matrix multiplications internally in a way that is not very efficient. Indeed, each matrix multiplication is done in parallel by a BLAS library (like OpenBLAS library which is the default one on most machines like mine). The thing is OpenBLAS is not efficient to compute small matrices in parallel. In fact, computing each small matrix in parallel is not efficient. A more efficient solution is to compute all the matrix multiplication in parallel (each thread should perform several serial matrix multiplication). This is certainly what Matlab does and why it can be a bit faster. This can be done using a parallel Numba code (or with Cython) and by disabling the parallel execution of BLAS routines (note this can have performance side effects on a larger script if it is done globally).

Another possible optimization is to do all the operation at once in Numba using multiple threads. This solution can certainly reduce even more the memory footprint and further improve performance. However, this is far from being easy to write an optimized implementation and the resulting code will be significantly harder to maintain. This is what the max9111's code does.

6
max9111 On

A Numba Implementation

As @Jérôme Richard already mentioned you can also write a pure Numba implementation. I partially used this code generation function on both einsums with some manual of code editing.

Please be aware that from Numba version 0.53 to 0.56, there is a bug/feature, which usually has a high performance impact. I would recommend to change that in version 0.53 until 0.57, if the little benefit on compilation times doesn't matter. Beginning with 0.57 this option seems to be slower than the default.

Pros/Cons

  • Much faster than the accepeted solution (and likely the Matlab solution)
  • Very small temporary arrays,if memory usgae is a problem
  • Scales well with the number of cores you use (there may be problems with newer big/little Intel CPUS, but still around 600ms on a new notebook)
  • The code is hard to quickly understand, comments are necessary to understand what's happening

Implementation

#set chache false to test the behaviour of
#https://github.com/numba/numba/issues/8172#issuecomment-1160474583
#and of course restart the interpreter
@nb.njit(fastmath=True,parallel=True,cache=False)
def einsum(X,w):
    #For loop unrolling
    assert X.shape[1] ==4
    assert w.shape[2] ==1

    #For safety
    assert X.shape[0] == w.shape[0]
    assert X.shape[2] == w.shape[1]

    i_s = X.shape[0]
    x_s = X.shape[1]
    j_s = X.shape[2]
    l_s = X.shape[0]
    m_s = X.shape[2]
    k_s = w.shape[2]
    n_s = w.shape[2]

    res = np.empty((i_s,l_s))

    for i in nb.prange(i_s):
        for l in range(l_s):
            #TMP_0 is thread local, it will be omptimized out of the loop by Numba in parallel mode
            #np.einsum('xm,xj->jm', X,X) -> TMP_0
            TMP_0 = np.zeros((j_s,m_s))
            for x in range(x_s):
                for j in range(j_s):
                    for m in range(m_s):
                        TMP_0[j,m]+=X[l,x,m] *X[i,x,j]

            #EXP in-place
            for j in range(j_s):
                for m in range(m_s):
                    TMP_0[j,m] = np.exp(TMP_0[j,m])

            #TMP_1 is thread local, it will be omptimized out of the loop by Numba in parallel mode
            #np.einsum('jm,jk->m', TMP_0,w[i]) -> TMP_1
            TMP_1 = np.zeros((m_s))
            for j in range(j_s):
                for m in range(m_s):
                    for k in range(k_s):
                        TMP_1[m]+=TMP_0[j,m] *w[i,j,k]

            #np.einsum('m,mn->', TMP_1,w[l]) -> res
            acc=0
            for m in range(m_s):
                for n in range(n_s):
                    acc+=TMP_1[m] *w[l,m,n]
            res[i,l]=acc

    return res

Timings on Ryzen 5 5600G (6C/12T)

Orignal implementation (unique characters):

%timeit G3 = np.einsum('ijk,iljm,lmn->il', w, np.exp(np.einsum('ixj,lxm->iljm',X,X)), w)
4.45 s ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Jérôme Richard's implementation:

1.43 s ± 102 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

My implementation on unmodified Numba abobe v0.53, have to be modified if performance is the main goal, which is usually is the case if you use Numba :-(

665 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

My implementation below v0.53, or modified newer Numba:

142 ms ± 3.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Updated timings

The previous timings where with Numba 0.55, starting with 0.57 Numba seems to show another behaviour. The runtime is now faster with the default, but still a bit slower than vers. 0.56 with opt=2:

%timeit G2 = einsum(X,w)

#0.56, windows installed via pip (opt=2))
#706 ms ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#0.56, windows installed via pip (opt=2))
#153 ms ± 2.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

#0.57, windows installed via pip (default)
#173 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#0.57, windows installed via pip (opt=2)
#247 ms ± 1.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

For comparable timings check if SVML has been used

This should be the default on Anaconda Python, but may not be the case on Standard Python.

def check_SVML(func):
    if 'intel_svmlcc' in func.inspect_llvm(func.signatures[0]):
        print("found")
    else:
        print("not found")

check_SVML(einsum)
#found