B = spec_x.size(0)
H = spec_x.size(1)
T = spec_x.size(2)
# Initialize x tensor with zeros
z = torch.zeros(B, 256, H).to(pitch.device)
# Iterate over each batch element
for b in range(B):
# Iterate over each pitch index
for i in range(256):
# Mask spec_x where pitch equals i
masked_spec_x = spec_x[b].masked_select(pitch[b] == i)
# Compute mean along the time dimension
mean_spec_x = torch.mean(masked_spec_x, dim=0)
# Assign the mean to the corresponding position in x
z[b, i] = mean_spec_x
The above code has 2 tensors, spec_x, and pitch. pitch is B T, it's a 2D tensor and it tells us an index from 0 to 255 corresponding to the pitch of the spectrogram at each frame.
The goal is to build tensor z which is B, 256, H, where H is the hidden size of spec_x.
z[b][i] = average of spec_x[b] where pitch == i
The above code works, but it's very slow because of the loops, I'm just not sure if there's a way to remove the loops using pytorch built ins.
Thanks!
One solution I see is to use a reduce function to distribute the values from
spec_x
at indices given bypitch
. Thetorch.scatter
function seems complex to set up but all you need to do is make sure thatAll three tensors (
z
,src
, andindex
) have the same number of dimensions;The indexing tensor (
index
) has values smaller than the dimension size of the output tensor (z
) at the scattering dimension (dim
).To accommodate for the dimension different, we can unsqueeze and expand all three tensors. The output tensor
z
intermediate shape is(B,C,H,T)
:The scattering operation will be applied on
dim=1
(dimension indexed by integers between[0, 255[
). In pseudo-code, that corresponds to:The first step is to scatter the values:
A trick to get the correct average computed is to apply the same operation but on a tensor of ones of the same shape as
src
:Then simply sum
o
andcount
over their last two dimensions and divideo
by the counts:You may notice that the output tensor is not of the desired shape, you can fix that by repeating the hidden state dimension since all values are equal row-wise: