Compute flat index using Numba

61 views Asked by At

I need to have a Numba-compatible generic function that takes multi-dimensional indices and the shape of the multi-dimensional array as input and return the flat index.

Sample Requirement:

from numba import jit

@jit(nopython=True)
def flat_index(indices, shape):
    """
    Computes the flat index given multidimensional indices and the shape of the array.
    
    Parameters:
        indices (array-like): The multidimensional indices.
        shape (array-like): The shape of the array.
        
    Returns:
        int: The flat index.
    """
    flat_idx = 0
    
    #--->do the computations here
    # flat_idx should be computed to 37


    return flat_idx

# Example usage with 3 dimensions only:
indices = (2, 2, 1)  # Example multidimensional indices
shape = (5, 5, 2)    # Example shape of the array

flat_idx = flat_index(indices, shape)
print("Flat index:", flat_idx)

NOTE: In the code above, for simplicity, I have used a 3D array index (2, 2, 1) but the function should be able to deal with any dimensions.

0

There are 0 answers