Conditional update in JAX?

4.3k views Asked by At

In autograd/numpy I could do:

q[q<0] = 0.0

How can I do the same thing in JAX?

I tried import numpy as onp and using that to create arrays, but that doesn't seem to work.

2

There are 2 answers

1
jakevdp On BEST ANSWER

JAX arrays are immutable, so in-place index assignment statements cannot work. Instead, jax provides the jax.ops submodule, which provides functionality to create updated versions of arrays.

Here is an example of a numpy index assignment and the equivalent JAX index update:

import numpy as np
q = np.arange(-5, 5)
q[q < 0] = 0
print(q)
# [0 0 0 0 0 0 1 2 3 4]

import jax.numpy as jnp
q = jnp.arange(-5, 5)
q = q.at[q < 0].set(0)  # NB: this does not modify the original array,
                        # but rather returns a modified copy.
print(q)
# [0 0 0 0 0 0 1 2 3 4]

Note that in op-by-op mode, the JAX version does create multiple copies of the array. However when used within a JIT compilation, XLA can often fuse such operations and avoid copying of data.

0
Papples On

To address @truth's comment, jnp.where would generalise to many dimensions. Here's an example:

from jax import numpy as jnp
x = jnp.arange(4).reshape((2, 2)) - 2
print(x)
# Array([[-2, -1], [0, 1]], dtype=int32)
print(jnp.where(q < 0, 0, q))
# Array([[0, 0], [0, 1]], dtype=int32)

This is jit compatible too (when jnp.where is used with 3 arguments).