Is there any way to roll arrays within a jax compiled method such that the axis along which the array is rolled is a variable? For example, the following code:
import jax.numpy as jnp
A = jnp.ones((4, 4, 4, 4))
indList = jnp.asarray([0, 1, 2])
def roll(ind):
return jnp.roll(A, -1, axis=ind)
result = jax.lax.map(roll, indList)
Yields the following error:
ConcretizationTypeError
: Abstract tracer value encountered where concrete value is expected: traced array with shapeint32[]
. expected a static index or sequence of indices.
As far as I can tell, the @partial(jax.jit, static_argnums = 0)
decorator to the roll
method would not solve the problem
The issue is that
lax.map
requires its arguments to be dynamic, and you cannot usejnp.roll
with a dynamic axis because it is built on primitives that require the axis to be static.One way to work around this is to dynamically construct indices that will extract the desired result; for your case the approach might look like this: