Is there a dilated k-nearest neighbour solution available fast execution?

88 views Asked by At

I am implementing the dilated k-nearest neighbors algorithm. The algorithm unfortunately has nested loops. The presence of loops severely hampers the execution speed.

import torch
dilation=3
nbd_size=5
knn_key=torch.randint([0,30,(64,12,198,100)])


dilated_keys=torch.zeros([knn_key.shape[0],knn_key.shape[1],knn_key.shape[2],nbd_size])

for i in range(knn_key.shape[0]):
    for j in range(knn_key.shape[1]):
        for k in range(knn_key.shape[2]):
            list_indices=[]
            while (len(list_indices))<nbd_size:
                for l in range(knn_key.shape[3]):
                    if knn_key[i][j][k][l]%dilation==k%dilation:
                        list_indices.append(knn_key[i][j][k][l])
                        if (len(list_indices))>=nbd_size:
                            break
            list_indices_tensor=torch.tensor(list_indices)
            dilated_keys[i][j][k]=list_indices_tensor

The variable knn_key stores the 100 nearest neighbours among the 1000 data points originally available. The dilated_keys stores the nbd_size=5 selected indices of the neighbours that are used after applying dialation filter. Any help to use broadcasting solution to remove the three nested loops will be highly helpful.

1

There are 1 answers

0
asymptote On

You can reduce the inner loops and conditions as follows:

import torch

dilation = 3
nbd_size = 5
knn_key = torch.randint(0, 30, (64, 12, 198, 100))

dilated_keys = torch.zeros((knn_key.shape[0], knn_key.shape[1], knn_key.shape[2], nbd_size), dtype=torch.int64)

for i in range(knn_key.shape[0]):
    for j in range(knn_key.shape[1]):
        for k in range(knn_key.shape[2]):
            key = knn_key[i, j, k]
            indices = torch.nonzero(key % dilation == k % dilation).squeeze()
            selected_indices = indices[:nbd_size]
            dilated_keys[i, j, k] = key[selected_indices]