Is there any alternative to scipy stats module of the zmap function? I'm currently using it to obtain the zmap scores for two really large arrays and it's taking quite some time.
Are there any libraries or alternatives that could boost its performance? Or even another of obtaining what the zmap function does?
Your ideas and comments would be appreciated!
Here's my minimal reproducible code below:
from scipy import stats
import numpy as np
FeatureData = np.random.rand(483, 1)
goodData = np.random.rand(4640, 483)
FeatureNorm= stats.zmap(FeatureData, goodData)
And here's what the scipy stats.zmap does under the hood:
def zmap(scores, compare, axis=0, ddof=0):
scores, compare = map(np.asanyarray, [scores, compare])
mns = compare.mean(axis=axis, keepdims=True)
sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
return (scores - mns) / sstd
Any ideas on how to optimize it for my use case? Could I use libraries like numba or JAX to boost this further?
Fortunately, the
zmap
code is pretty straightforward. The overhead in numpy, however, will come from the fact that it must instantiate intermediate arrays. If you use a numerical compiler such as that available innumba
orjax
, it can fuse these operations and do the computation with less overhead.Unfortunately, numba doesn't support optional arguments to
mean
andstd
, so let's take a look at JAX. For reference, here are benchmarks for scipy and for the raw numpy version of the function, computed on a Google Colab CPU runtime:Here is the equivalent code executed in JAX, both eager mode and JIT compiled:
The JIT-compiled version is about a factor of 5 faster than the scipy or numpy code. On a Colab T4 GPU runtime, the compiled version gains another factor of 10:
If this kind of operation is a bottleneck in your analysis, a compiler like JAX might be a good option.