better way to use lax.cond, partial, vmap

115 views Asked by At

I want to convert this code into JAX style.

mask is a 2D array of which values are either 1 or 0.
indices is coordinates of mask where the value is 1. For example, [[5,4], [10,11],...]
num_pixels is len(indices)

def get_neighbors(index):
    """
    Returns the indices of the neighbors of the pixel at the given index.
    """
    i, j = index
    # return jnp.array([(i-1, j), (i+1, j), (i, j-1), (i, j+1)]) # top, bottom, left, right
    return jnp.array(np.array([(i-1, j), (i+1, j), (i, j-1), (i, j+1)])) # this is faster than the above one I think
def is_inside_boundary(img, neighbor):
    """
    Returns True if the pixel at the given index is inside the mask and inside the boundary of the image mask.
    """
    h = img.shape[0]
    w = img.shape[1]
    i, j = neighbor
    # return i >= 0 and i < h and j >= 0 and j < w
    
    # def check_bool(i, j, h, w):
    #     return i >= 0 and i < h and j >= 0 and j < w
    return jax.lax.cond(jax.lax.bitwise_and( jax.lax.bitwise_and(i >= 0, i < h), jax.lax.bitwise_and(j>=0, j <w) ), 
                        lambda : True, 
                        lambda : False, 
                        )

mask = np.zeros(shape=(10,10))
mask[3:6, 3:6] = 1
indices = np.argwhere(mask).tolist()
num_pixels = len(indices)
A = sparse.lil_matrix((num_pixels, num_pixels))
# iterate over pixels inside mask
for i in range(num_pixels):
    # get neighbors of pixel
    neighbors = get_neighbors(indices[i])
    # iterate over neighbors
    for neighbor in neighbors:
        # if neighbor is inside the boundary of mask image and
        # is part of mask (i.e., pixel value is 1), add to matrix
        # if mask[neighbor]:
        if is_inside_boundary(mask, neighbor) and mask[neighbor[0], neighbor[1]]:
            j = indices.index(list(neighbor))
            A[i, j] = -1
    # add 4 to diagonal
    A[i, i] = 4  

This is what I've tried.
def f refers to the first loop (for i in range(num_pixels)) and def f2 the second loop (for neighbor in neighbors:) of the original code:

from functools import partial
@partial(vmap, in_axes=(0,0))
def f(indice, i):
    neighbors = get_neighbors(indice)
    f2(neighbors, i)
    A.at[i,i].set(4) # should be something like A = A.at[i,j].set(4) but gets an error


@partial(vmap, in_axes=(0,None))        
def f2(neighbor, i):
    # if is_inside_boundary(mask, neighbor) and mask[neighbor]:
    # if jax.lax.bitwise_and(jax.lax.cond(is_inside_boundary(mask, neighbor), lambda: True, lambda: False),
    #                        jax.lax.cond(mask[neighbor]==1, lambda: True, lambda: False)):
        
    if partial(jax.lax.cond, 
                (jax.lax.bitwise_and(is_inside_boundary(mask, neighbor), 
                                    mask.at[neighbor[0], neighbor[1]].get()==1),
                  lambda: True,
                  lambda: False), 
                  static_argnums=(0,)):
        j = jnp.where((indices == neighbor).all(-1), size=1)               
        A.at[i, j].set(-1) # should be something like A = A.at[i,j].set(-1) but gets an error


mask = np.zeros(shape=(10,10))
mask[3:6, 3:6] = 1
indices = np.argwhere(mask).tolist()

mask = jnp.array(np.array((mask)))
num_pixels = len(indices)
A = sparse.lil_matrix((num_pixels, num_pixels))        
A = jnp.array(A.todense())
i_s = jnp.arange(num_pixels)
indices = jnp.array(indices)
# print(indices.shape)
f(indices, i_s)

It works (meaning it 'runs'. not sure if it's correct). The problem for now is that it's too slow; all the partial, lax.cond are, I think, slowing down the process. Also, I'm aware that I should avoid using 'jax' with static_argnums if the parameter of static_argnums changes too often but couldn't think of any other workarounds.

0

There are 0 answers