I can't find what is the error in this code, it suppose to merge p tensors of size (n, k) such that the merged tensor of size (n, k) has alternating element of these p tensors without doublon per row. (If alternatively you can suggest an other code (in triton or full python) that does the same and that is computationally and memory efficient it would be fine). Thank you for help.
def _merge_edge_index_kernel(
edge_index_stacked_ptr,
num_edges,
num_edge_types,
edge_index_merged_ptr,
BLOCK_SIZE: tl.constexpr,
):
i = tl.program_id(axis=0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < num_edges
edge_index_merged = tl.full((BLOCK_SIZE,), value=-1, dtype=tl.int32)
stride = num_edges * num_edge_types
count = 0
num_duplicates = 0
while count < num_edges:
j = tl.load(edge_index_stacked_ptr + i * stride + count + num_duplicates)
not_already_in = tl.sum(edge_index_merged == j, axis=0) == 0
if not_already_in:
edge_index_merged = tl.where(offsets == count, j, edge_index_merged)
count += 1
else:
num_duplicates += 1
tl.store(edge_index_merged_ptr + i * num_edges + offsets, edge_index_merged, mask=mask)
def merge_edge_index(edge_index_stacked: torch.Tensor) -> torch.Tensor:
"""
Inputs
------
* edge_index_stacked: (sum_i seq_lens[i], num_edges, num_edge_types)
Output
------
* edge_index_merged: (sum_i seq_lens[i], num_edges)
"""
assert edge_index_stacked.is_cuda, "edge_index_stacked is not on cuda"
assert edge_index_stacked.is_contiguous(), "edge_index_stacked is not contiguous"
total_size, num_edges, num_edge_types = edge_index_stacked.shape
edge_index_merged = torch.empty_like(edge_index_stacked[..., 0])
BLOCK_SIZE = triton.next_power_of_2(num_edges)
grid = (total_size,)
_merge_edge_index_kernel[grid](
edge_index_stacked,
num_edges,
num_edge_types,
edge_index_merged,
BLOCK_SIZE=BLOCK_SIZE, # type: ignore
)
return edge_index_merged