From the JAX docs:
import jax
import jax.numpy as jnp
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
selu(x)
"The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions."
The docs then proceed to wrap selu
in jit
:
selu_jit = jax.jit(selu)
selu_jit(x)
And for some reason this improves performance significantly.
Why is jit
even needed here? More specifically, why is the original code "sending one operation at a time to the accelerator"?
I was under the impression that jax.numpy
is meant for this exact purpose, oherwise we might as well be using plain old numpy
? What was wrong with the original selu
?
Thanks!
Edit: after a short discussion below I realized a more concise answer to the original question: JAX uses eager computations by default; if you want lazy evaluation—what's sometimes called graph mode in other packages—you can specify this by wrapping your function in
jax.jit
.Python is an interpreted language, which means that statements are executed one at a time. This is the sense in which the un-jitted code is sending one operation at a time to the compiler: each statement must execute and return a value before the interpreter runs the next.
Within a jit-compiled function, JAX replaces arrays with abstract tracers in order to determine the full sequence of operations in the function, and to send them all to XLA for compilation, where the operations may be rearranged or transformed by the compiler to make the overall execution more efficient.
The reason we use
jax.numpy
rather than normalnumpy
is becausejax.numpy
operations work with the JIT tracer machinery, whereas normalnumpy
operations do not.For a high-level intro to how JAX and its transforms work, a good place to start is How To Think In JAX.