In autograd/numpy I could do:
q[q<0] = 0.0
How can I do the same thing in JAX?
I tried import numpy as onp
and using that to create arrays, but that doesn't seem to work.
To address @truth's comment, jnp.where
would generalise to many dimensions. Here's an example:
from jax import numpy as jnp
x = jnp.arange(4).reshape((2, 2)) - 2
print(x)
# Array([[-2, -1], [0, 1]], dtype=int32)
print(jnp.where(q < 0, 0, q))
# Array([[0, 0], [0, 1]], dtype=int32)
This is jit compatible too (when jnp.where
is used with 3 arguments).
JAX arrays are immutable, so in-place index assignment statements cannot work. Instead, jax provides the
jax.ops
submodule, which provides functionality to create updated versions of arrays.Here is an example of a numpy index assignment and the equivalent JAX index update:
Note that in op-by-op mode, the JAX version does create multiple copies of the array. However when used within a JIT compilation, XLA can often fuse such operations and avoid copying of data.