I was wondering what is the best way to perform a batched slice (torch/numpy)? I know for constant slicing indices it is possible to perform this operation
batch_size = 2
data = torch.zeros((batch_size,1, 256, 256))
x_start = 10
x_stop = 20
y_start = 10
y_stop = 20
data[torch.arange(batch_size), :, y_start:y_stop, x_start:x_stop] = 1
But the question is how do I handle the case if the start and stop values are different? eg.
batch_size = 2
data = torch.zeros((batch_size,1, 256, 256))
x_start = [10, 5]
x_stop = [20, 30]
y_start = [10, 5]
y_stop = [20, 30]
data[torch.arange(batch_size), :, y_start:y_stop, x_start:x_stop] = 1 # crashes
I guess I could do it in a for loop, but I was wondering if there is a more pythonic way to do it.