B = spec_x.size(0)
H = spec_x.size(1)
T = spec_x.size(2)
# Initialize x tensor with zeros
z = torch.zeros(B, 256, H).to(pitch.device)
# Iterate over each batch element
for b in range(B):
# Iterate over each pitch index
for i in range(256):
# Mask spec_x where pitch equals i
masked_spec_x = spec_x[b].masked_select(pitch[b] == i)
# Compute mean along the time dimension
mean_spec_x = torch.mean(masked_spec_x, dim=0)
# Assign the mean to the corresponding position in x
z[b, i] = mean_spec_x
The above code has 2 tensors, spec_x, and pitch. pitch is B T, it's a 2D tensor and it tells us an index from 0 to 255 corresponding to the pitch of the spectrogram at each frame.
The goal is to build tensor z which is B, 256, H, where H is the hidden size of spec_x.
z[b][i] = average of spec_x[b] where pitch == i
The above code works, but it's very slow because of the loops, I'm just not sure if there's a way to remove the loops using pytorch built ins.
Thanks!
One solution I see is to use a reduce function to distribute the values from
spec_xat indices given bypitch. Thetorch.scatterfunction seems complex to set up but all you need to do is make sure thatAll three tensors (
z,src, andindex) have the same number of dimensions;The indexing tensor (
index) has values smaller than the dimension size of the output tensor (z) at the scattering dimension (dim).To accommodate for the dimension different, we can unsqueeze and expand all three tensors. The output tensor
zintermediate shape is(B,C,H,T):The scattering operation will be applied on
dim=1(dimension indexed by integers between[0, 255[). In pseudo-code, that corresponds to:The first step is to scatter the values:
A trick to get the correct average computed is to apply the same operation but on a tensor of ones of the same shape as
src:Then simply sum
oandcountover their last two dimensions and divideoby the counts:You may notice that the output tensor is not of the desired shape, you can fix that by repeating the hidden state dimension since all values are equal row-wise: