I want to construct a 2D array from a function in such a way that I can utilize jax.jit
.
The way I would normally do this using numpy
is to create an empty array, and then fill that array in-place.
xx = jnp.empty((num_a, num_b))
yy = jnp.empty((num_a, num_b))
zz = jnp.empty((num_a, num_b))
for ii_a in range(num_a):
for ii_b in range(num_b):
a = aa[ii_a, ii_b]
b = bb[ii_a, ii_b]
xyz = self.get_coord(a, b)
xx[ii_a, ii_b] = xyz[0]
yy[ii_a, ii_b] = xyz[1]
zz[ii_a, ii_b] = xyz[2]
To make this work within jax
I have attempted to use the jax.opt.index_update
.
xx = xx.at[ii_a, ii_b].set(xyz[0])
yy = yy.at[ii_a, ii_b].set(xyz[1])
zz = zz.at[ii_a, ii_b].set(xyz[2])
This runs without errors but is very slow when I try to use a @jax.jit
decorator (at least an order of magnitude slower than the pure python/numpy version).
What is the best way to fill a multi-dimensional array from a function using jax
?
JAX has a
vmap
transform that is designed specifically for this kind of application.As long as your
get_coords
function is compatible with JAX (i.e. is a pure function with no side-effects), you can accomplish this in one line: