vectorized batched cross correlation in pytorch

31 views Asked by At

so I have a tensor (a signal) f with shape (B,T,1) and another signal g with the same shape. I want to perform “pairwise” cross-correlation between samples with the same batch index. Namely, if I were to iterate the samples in the batch, I’d perform something like this:

all_xcorrs = []
for b in range(B): # where B is the batch size, or f.shape[0] or g.shape[0]
    f_b = f[[b]].permute(0,2,1) # shape (1,1,T)
    g_b = g[[b]].permute(0,2,1) # shape (1,1,T)
    xcorr_b = F.conv1d(f_b,g_b, padding='same') # shape (1,1,T)
    all_xcorrs.append(xcorr_b)
all_xcorrs=torch.cat(all_xcorrs,axis=0) # should have shape (B,1,T)
all_xcorrs = all_xcorrs.permute(0,2,1) # (B,T,1)

how can I vectorize this process?

0

There are 0 answers