Alternative to scipy stats zmap function

201 views Asked by At

Is there any alternative to scipy stats module of the zmap function? I'm currently using it to obtain the zmap scores for two really large arrays and it's taking quite some time.

Are there any libraries or alternatives that could boost its performance? Or even another of obtaining what the zmap function does?

Your ideas and comments would be appreciated!

Here's my minimal reproducible code below:

from scipy import stats
import numpy as np

FeatureData = np.random.rand(483, 1)
goodData = np.random.rand(4640, 483)
FeatureNorm= stats.zmap(FeatureData, goodData)

And here's what the scipy stats.zmap does under the hood:

def zmap(scores, compare, axis=0, ddof=0):
    scores, compare = map(np.asanyarray, [scores, compare])
    mns = compare.mean(axis=axis, keepdims=True)
    sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
    return (scores - mns) / sstd

Any ideas on how to optimize it for my use case? Could I use libraries like numba or JAX to boost this further?

1

There are 1 answers

1
jakevdp On BEST ANSWER

Fortunately, the zmap code is pretty straightforward. The overhead in numpy, however, will come from the fact that it must instantiate intermediate arrays. If you use a numerical compiler such as that available in numba or jax, it can fuse these operations and do the computation with less overhead.

Unfortunately, numba doesn't support optional arguments to mean and std, so let's take a look at JAX. For reference, here are benchmarks for scipy and for the raw numpy version of the function, computed on a Google Colab CPU runtime:

import numpy as np
from scipy import stats

FeatureData = np.random.rand(483, 1)
goodData = np.random.rand(4640, 483)

%timeit stats.zmap(FeatureData, goodData)
# 100 loops, best of 3: 13.9 ms per loop

def np_zmap(scores, compare, axis=0, ddof=0):
    scores, compare = map(np.asanyarray, [scores, compare])
    mns = compare.mean(axis=axis, keepdims=True)
    sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
    return (scores - mns) / sstd

%timeit np_zmap(FeatureData, goodData)
# 100 loops, best of 3: 13.8 ms per loop

Here is the equivalent code executed in JAX, both eager mode and JIT compiled:

import jax.numpy as jnp
from jax import jit

def jnp_zmap(scores, compare, axis=0, ddof=0):
    scores, compare = map(jnp.asarray, [scores, compare])
    mns = compare.mean(axis=axis, keepdims=True)
    sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
    return (scores - mns) / sstd

jit_jnp_zmap = jit(jnp_zmap)

FeatureData = jnp.array(FeatureData)
goodData = jnp.array(goodData)
%timeit jnp_zmap(FeatureData, goodData).block_until_ready()
# 100 loops, best of 3: 8.59 ms per loop

jit_jnp_zmap(FeatureData, goodData)  # trigger compilation
%timeit jit_jnp_zmap(FeatureData, goodData).block_until_ready()
# 100 loops, best of 3: 2.78 ms per loop

The JIT-compiled version is about a factor of 5 faster than the scipy or numpy code. On a Colab T4 GPU runtime, the compiled version gains another factor of 10:

%timeit jit_jnp_zmap(FeatureData, goodData).block_until_ready()
1000 loops, best of 3: 286 µs per loop

If this kind of operation is a bottleneck in your analysis, a compiler like JAX might be a good option.