I have a tensor like this:
tensor([[[ 7.3478, -1.8058, -2.6140, ..., -0.2719, -0.3171, -0.4737]],
[[ 7.3606, -1.8269, -1.9825, ..., -0.8680, 0.4894, 0.2708]]],
grad_fn=<CatBackward>)
I want to get the topk values across both the rows. Currently what I am able to do is the following:
ipdb> stacked.topk(2)
torch.return_types.topk(
values=tensor([[[14.3902, 14.3039]],
[[14.8927, 12.1973]]], grad_fn=<TopkBackward>),
indices=tensor([[[60, 12]],
[[12, 23]]]))
From the output, you can see that the top 2 values were retrieved from both rows. I want to get an output as follows:
14.8927 that maps to index 12
14.3902 that maps to index 60
Note that if the top 2 values were in the first row, it will only return the values from there and completely ignore the second row and vice versa.
Need help in this regard.
A very hacky way of doing what I am trying to say would be the following but it is very hacky and shown for a BEAM_WIDTH of 2:
BEAM_WIDTH = 2
top_k = stacked.data.topk(BEAM_WIDTH, dim=2)
v1, i1 = top_k[0][0][0], top_k[1][0][0]
v2, i2 = top_k[0][1][0], top_k[1][1][0]
i = j = 0
final = []
for _ in range(BEAM_WIDTH):
if v1[i] >= v2[j]:
final.append((v1[i], i1[i]))
i += 1
else:
final.append((v2[j], i2[j]))
j += 1
Repeated Indices
I believe this is what you want. First you would find the topk elements in the flattened list, then convert those indices back to the row-relative format.
Unique Indices
The previous approach doesn't enforce unique indices. If unique indices are needed then you could find the max between rows, then find the topk among that.
Example
To demonstrate the difference between these two approaches, suppose you have
In the repeated indices case you would end up with
In the unique indices case you would get