Efficiently fill an array from a function

1.3k views Asked by At

I want to construct a 2D array from a function in such a way that I can utilize jax.jit.

The way I would normally do this using numpy is to create an empty array, and then fill that array in-place.

xx = jnp.empty((num_a, num_b))
yy = jnp.empty((num_a, num_b))
zz = jnp.empty((num_a, num_b))

for ii_a in range(num_a):
    for ii_b in range(num_b):
        a = aa[ii_a, ii_b]
        b = bb[ii_a, ii_b]

        xyz = self.get_coord(a, b)

        xx[ii_a, ii_b] = xyz[0]
        yy[ii_a, ii_b] = xyz[1]
        zz[ii_a, ii_b] = xyz[2]

To make this work within jax I have attempted to use the jax.opt.index_update.

        xx = xx.at[ii_a, ii_b].set(xyz[0])
        yy = yy.at[ii_a, ii_b].set(xyz[1])
        zz = zz.at[ii_a, ii_b].set(xyz[2])

This runs without errors but is very slow when I try to use a @jax.jit decorator (at least an order of magnitude slower than the pure python/numpy version).

What is the best way to fill a multi-dimensional array from a function using jax?

2

There are 2 answers

1
jakevdp On BEST ANSWER

JAX has a vmap transform that is designed specifically for this kind of application.

As long as your get_coords function is compatible with JAX (i.e. is a pure function with no side-effects), you can accomplish this in one line:

from jax import vmap
xx, yy, zz = vmap(vmap(get_coord))(aa, bb)
0
amicitas On

This can be achieved efficiently by using either the jax.vmap or the jax.numpy.vectorize functions.

An example using vectorize:

import jax.numpy as jnp

def get_coord(a, b):
    return jnp.array([a, b, a+b])

f0 = jnp.vectorize(get_coord, signature='(),()->(i)')
f1 = jnp.vectorize(f0, excluded=(1,), signature='()->(i,j)')

xyz = f1(a,b)

The vectorize function uses vmap under the hood, so this should be exactly equivalent to:

f0 = jax.vmap(get_coord, (None, 0))
f1 = jax.vmap(f0, (0, None)) 

The advantage of using vectorize is that the code can be still be run in standard numpy. The disadvantage is less concise code and possibly a small amount of overhead because of the wrapper.