jax lax way to roll arrays given an variable axis index

61 views Asked by At

Is there any way to roll arrays within a jax compiled method such that the axis along which the array is rolled is a variable? For example, the following code:

import jax.numpy as jnp

A = jnp.ones((4, 4, 4, 4))

indList = jnp.asarray([0, 1, 2])

def roll(ind):
    return jnp.roll(A, -1, axis=ind)

result = jax.lax.map(roll, indList)

Yields the following error:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. expected a static index or sequence of indices.

As far as I can tell, the @partial(jax.jit, static_argnums = 0) decorator to the roll method would not solve the problem

1

There are 1 answers

0
jakevdp On

The issue is that lax.map requires its arguments to be dynamic, and you cannot use jnp.roll with a dynamic axis because it is built on primitives that require the axis to be static.

One way to work around this is to dynamically construct indices that will extract the desired result; for your case the approach might look like this:

import jax
import jax.numpy as jnp

def roll_dynamic(ind):
  assert len(set(A.shape)) == 1, "roll_dynamic requires all dimensions to be equal."
  indices = jax.lax.broadcasted_iota('int32', (A.ndim, A.shape[0]), dimension=1)
  indices = jnp.where(jnp.arange(A.ndim)[:, None] == ind,
                      jnp.roll(indices, -1, axis=-1), indices)
  return A[tuple(jnp.meshgrid(*indices, indexing='ij', sparse=True))]

def roll(ind):
  return jnp.roll(A,-1,axis=ind)

A = jnp.arange(256).reshape(4, 4, 4, 4)
indList = jnp.asarray([0, 1, 2])

for ind in range(4):
  assert jnp.all(roll(ind) == roll_dynamic(ind))

result = jax.lax.map(roll_dynamic, indList)