Selecting values from a 4d tensor with a 3d tensor

448 views Asked by At

I've recently run into this problem in pytorch when working with 4D tensors which should be indexed with 3D tensors.

Let's say we have this 4D tensor:

possible_values.size()
torch.Size([2, 5, 5, 4])

where:

dim 1 = batch
dim 2 = x_axis
dim 3 = y_axis
dim 4 = possible values of coordinate (x_i,y_j)

we then have a 3D "indexing" tensor, which should be used to select the values of dim 4, based on an x and y coordinate:

coordinates.size()
torch.Size([2, 5, 2])

where:

dim 1 = batch
dim 2 = sequences of (x,y) 
dim 3 = (x,y) coordinate

for example, coordinates would look like

[ [ [1,5] [3,3] [2,4] [1,3] [2,3] ]
  [ [1,5] [4,3] [2,1] [5,3] [5,3] ] ]

what we want to do is to select from a batch the possible values for the coordinates specified by coordinates. So from the first batch we want to select the 4 values at coordinates [1, 5], [3, 3] and so on.

I have looked some at index_select and gather, but can't get my head around it currently (or make it do roughly what I want).

Thanks.

2

There are 2 answers

0
DerekG On BEST ANSWER

Ok, let's start by removing the batch dimension:

possible_values[i,coordinates[i,:,0],coordinates[i,:,1],:]  # [output is of shape [5,4]

The above gives the correct values for a single batch element. Now we need a way to broadcast this operation for all values of i (i.e. across the batch dimension).

possible_values[:,coordinates[:,:,0],coordinates[:,:,1],:]  # [output is of shape [2,2,5,4]

This is mostly correct but it is "over-broadcasted" (i.e. it returns the desired indices for each batch element, for EVERY batch element".) We now need to index just the main diagonal elements across the first 2 dimensions such that we get the desired indices for each batch element, for EACH batch element:

batch_size = possible_values.shape[0]
batch_idx = torch.arange(batch_size)
possible_values[:,coordinates[:,:,0],coordinates[:,:,1],:][batch_size,batch_size,:,:]   # output is of shape [2,5,4]

This solution leaves something to be desired in that it doesn't extend to arbitrarily many dimensions without modification (i.e. if you added a z-axis, you'd have to add an additional coordinates[:,:,2] index to the block and so on.

0
Shai On

I think you are looking for torch.nn.functional.grid_sample.
You do need to slightly modify your inputs, but I expect it to work:

import torch.nn.functional as nnf

possible_values = possible_values.permute(0, 3, 1, 2)  # make the "channel" dimension the second one
out = nnf.grid_sample(input=possible_values, grid=coordinates[..., None, :], mode='nearest')