I want to convert this code into JAX style.
mask
is a 2D array of which values are either 1 or 0.
indices
is coordinates of mask where the value is 1. For example, [[5,4], [10,11],...]
num_pixels
is len(indices)
def get_neighbors(index):
"""
Returns the indices of the neighbors of the pixel at the given index.
"""
i, j = index
# return jnp.array([(i-1, j), (i+1, j), (i, j-1), (i, j+1)]) # top, bottom, left, right
return jnp.array(np.array([(i-1, j), (i+1, j), (i, j-1), (i, j+1)])) # this is faster than the above one I think
def is_inside_boundary(img, neighbor):
"""
Returns True if the pixel at the given index is inside the mask and inside the boundary of the image mask.
"""
h = img.shape[0]
w = img.shape[1]
i, j = neighbor
# return i >= 0 and i < h and j >= 0 and j < w
# def check_bool(i, j, h, w):
# return i >= 0 and i < h and j >= 0 and j < w
return jax.lax.cond(jax.lax.bitwise_and( jax.lax.bitwise_and(i >= 0, i < h), jax.lax.bitwise_and(j>=0, j <w) ),
lambda : True,
lambda : False,
)
mask = np.zeros(shape=(10,10))
mask[3:6, 3:6] = 1
indices = np.argwhere(mask).tolist()
num_pixels = len(indices)
A = sparse.lil_matrix((num_pixels, num_pixels))
# iterate over pixels inside mask
for i in range(num_pixels):
# get neighbors of pixel
neighbors = get_neighbors(indices[i])
# iterate over neighbors
for neighbor in neighbors:
# if neighbor is inside the boundary of mask image and
# is part of mask (i.e., pixel value is 1), add to matrix
# if mask[neighbor]:
if is_inside_boundary(mask, neighbor) and mask[neighbor[0], neighbor[1]]:
j = indices.index(list(neighbor))
A[i, j] = -1
# add 4 to diagonal
A[i, i] = 4
This is what I've tried.
def f
refers to the first loop (for i in range(num_pixels)
) and def f2
the second loop (for neighbor in neighbors:
) of the original code:
from functools import partial
@partial(vmap, in_axes=(0,0))
def f(indice, i):
neighbors = get_neighbors(indice)
f2(neighbors, i)
A.at[i,i].set(4) # should be something like A = A.at[i,j].set(4) but gets an error
@partial(vmap, in_axes=(0,None))
def f2(neighbor, i):
# if is_inside_boundary(mask, neighbor) and mask[neighbor]:
# if jax.lax.bitwise_and(jax.lax.cond(is_inside_boundary(mask, neighbor), lambda: True, lambda: False),
# jax.lax.cond(mask[neighbor]==1, lambda: True, lambda: False)):
if partial(jax.lax.cond,
(jax.lax.bitwise_and(is_inside_boundary(mask, neighbor),
mask.at[neighbor[0], neighbor[1]].get()==1),
lambda: True,
lambda: False),
static_argnums=(0,)):
j = jnp.where((indices == neighbor).all(-1), size=1)
A.at[i, j].set(-1) # should be something like A = A.at[i,j].set(-1) but gets an error
mask = np.zeros(shape=(10,10))
mask[3:6, 3:6] = 1
indices = np.argwhere(mask).tolist()
mask = jnp.array(np.array((mask)))
num_pixels = len(indices)
A = sparse.lil_matrix((num_pixels, num_pixels))
A = jnp.array(A.todense())
i_s = jnp.arange(num_pixels)
indices = jnp.array(indices)
# print(indices.shape)
f(indices, i_s)
It works (meaning it 'runs'. not sure if it's correct). The problem for now is that it's too slow; all the partial
, lax.cond
are, I think, slowing down the process. Also, I'm aware that I should avoid using 'jax' with static_argnums
if the parameter of static_argnums
changes too often but couldn't think of any other workarounds.