I am trying to port some code from MATLAB to Python and I am getting much slower performance from Python. I am not very good at Python coding, so any advise to speed these up will be much appreciated.
I tried an einsum one-liner (takes 7.5 seconds on my machine):
import numpy as np
n = 4
N = 200
M = 100
X = 0.1*np.random.rand(M, n, N)
w = 0.1*np.random.rand(M, N, 1)
G = np.einsum('ijk,iljm,lmn->il', w, np.exp(np.einsum('ijk,ljn->ilkn',X,X)), w)
I also tried a matmult implementation (takes 6 seconds on my machine)
G = np.zeros((M, M))
for i in range(M):
G[:, i] = np.squeeze(w[i,...].T @ (np.exp(X[i, :, :].T @ X) @ w))
But my original MATLAB code is way faster (takes 1 second on my machine)
n = 4;
N = 200;
M = 100;
X = 0.1*rand(n, N, M);
w = 0.1*rand(N, 1, M);
G=zeros(M);
for i=1:M
G(:,i) = squeeze(pagemtimes(pagemtimes(w(:,1,i).', exp(pagemtimes(X(:,:,i),'transpose',X,'none'))) ,w));
end
I was expecting both Python implementations to be comparable in speed, but they are not. Any ideas why the Python implementations are this slow, or any suggestions to speed those up?
First of all
np.einsumhas a parameteroptimizewhich is set toFalseby default (mainly because the optimization can be more expensive than the computation in some cases and it is better in general to pre-compute the optimal path in a separate call first). You can useoptimal=Trueto significantly speed-upnp.einsum(it provides the optimal path in this case though the internal implementation is not be optimal). Note thatpagemtimesin Matlab is more specific thannp.einsumso there is not need for such a parameter (i.e. it is fast by default in this case).Moreover, Numpy function like
np.expcreate a new array by default. The thing is computing arrays in-place is generally faster (and it also consumes less memory). This can be done thanks to theoutparameter.The
np.expis pretty expensive on most machines because it runs serially (like most Numpy functions) and it is often not very optimized internally either. Using a fast math library like the one of Intel helps. I suspect Matlab uses such kind of fast math library internally. Alternatively, one can use multiple threads to compute this faster. This is easy to do with thenumexprpackage.Here is the resulting more optimized Numpy code:
Performance results
Here are results on my machine (with a i5-9600KF CPU, 32 GiB of RAM, on Windows):
The optimized code is about 5.2 times faster than the initial code and 2.7 times faster than the initial fastest one!
Note about performances and possible optimizations
The first
einsumtakes a significant fraction of the runtime in the faster implementation on my machine. This is mainly becauseeinsumperform many small matrix multiplications internally in a way that is not very efficient. Indeed, each matrix multiplication is done in parallel by a BLAS library (like OpenBLAS library which is the default one on most machines like mine). The thing is OpenBLAS is not efficient to compute small matrices in parallel. In fact, computing each small matrix in parallel is not efficient. A more efficient solution is to compute all the matrix multiplication in parallel (each thread should perform several serial matrix multiplication). This is certainly what Matlab does and why it can be a bit faster. This can be done using a parallel Numba code (or with Cython) and by disabling the parallel execution of BLAS routines (note this can have performance side effects on a larger script if it is done globally).Another possible optimization is to do all the operation at once in Numba using multiple threads. This solution can certainly reduce even more the memory footprint and further improve performance. However, this is far from being easy to write an optimized implementation and the resulting code will be significantly harder to maintain. This is what the max9111's code does.