Pythonic way of slicing array by list indices

103 views Asked by At

I was wondering what is the best way to perform a batched slice (torch/numpy)? I know for constant slicing indices it is possible to perform this operation

batch_size = 2
data = torch.zeros((batch_size,1, 256, 256))
x_start = 10
x_stop = 20

y_start = 10
y_stop = 20
data[torch.arange(batch_size), :, y_start:y_stop, x_start:x_stop] = 1

But the question is how do I handle the case if the start and stop values are different? eg.

batch_size = 2
data = torch.zeros((batch_size,1, 256, 256))
x_start = [10, 5]
x_stop  = [20, 30]

y_start = [10, 5]
y_stop = [20, 30]
data[torch.arange(batch_size), :, y_start:y_stop, x_start:x_stop] = 1 # crashes

I guess I could do it in a for loop, but I was wondering if there is a more pythonic way to do it.

0

There are 0 answers