how to make hashable datatype to use with jit?

93 views Asked by At

What would be a workaround of this code in jitted function?

j = indices.index(list(neighbor)) where neighbor is, for example, (2,3), indices = [[1,2], [4,5], ...]

I've tried several other alternatives like partial but didn't work. One issue when using partial is that indices is not hashable so can't use partial function.

1

There are 1 answers

2
jakevdp On BEST ANSWER

list.index is a Python function that will not work within JIT if the contents of the list are traced values. I would recommend converting your lists to arrays, and do something like this:

import jax
import jax.numpy as jnp

indices = jnp.array([[1, 2], [4, 5], [3, 6], [2, 3], [5, 7]])
neighbor = jnp.array([2, 3])

@jax.jit
def get_index(indices, neighbor):
  return jnp.where((indices == neighbor).all(-1), size=1)[0]

idx = get_index(indices, neighbor)
print(idx)
# [3]