Why this triton kernel crashes?

34 views Asked by At

I can't find what is the error in this code, it suppose to merge p tensors of size (n, k) such that the merged tensor of size (n, k) has alternating element of these p tensors without doublon per row. (If alternatively you can suggest an other code (in triton or full python) that does the same and that is computationally and memory efficient it would be fine). Thank you for help.

def _merge_edge_index_kernel(
    edge_index_stacked_ptr,
    num_edges,
    num_edge_types,
    edge_index_merged_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    i = tl.program_id(axis=0)

    offsets = tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_edges

    edge_index_merged = tl.full((BLOCK_SIZE,), value=-1, dtype=tl.int32)
    stride = num_edges * num_edge_types

    count = 0
    num_duplicates = 0
    while count < num_edges:
        j = tl.load(edge_index_stacked_ptr + i * stride + count + num_duplicates)
        not_already_in = tl.sum(edge_index_merged == j, axis=0) == 0
        if not_already_in:
            edge_index_merged = tl.where(offsets == count, j, edge_index_merged)
            count += 1
        else:
            num_duplicates += 1

    tl.store(edge_index_merged_ptr + i * num_edges + offsets, edge_index_merged, mask=mask)


def merge_edge_index(edge_index_stacked: torch.Tensor) -> torch.Tensor:
    """
    Inputs
    ------
        * edge_index_stacked: (sum_i seq_lens[i], num_edges, num_edge_types)

    Output
    ------
        * edge_index_merged: (sum_i seq_lens[i], num_edges)
    """

    assert edge_index_stacked.is_cuda, "edge_index_stacked is not on cuda"
    assert edge_index_stacked.is_contiguous(), "edge_index_stacked is not contiguous"

    total_size, num_edges, num_edge_types = edge_index_stacked.shape

    edge_index_merged = torch.empty_like(edge_index_stacked[..., 0])
    BLOCK_SIZE = triton.next_power_of_2(num_edges)

    grid = (total_size,)
    _merge_edge_index_kernel[grid](
        edge_index_stacked,
        num_edges,
        num_edge_types,
        edge_index_merged,
        BLOCK_SIZE=BLOCK_SIZE,  # type: ignore
    )

    return edge_index_merged
0

There are 0 answers