Where I can find an intuitive explanation of PyTorch's Tensor.unfold() being used to get image patches?

781 views Asked by At

Recently I came across some code that extracted (sliding-window style) a number of square patches from an RGB image (or set of them) of shape N x B x H x W. They did this as follows:

patch_width = 3
patches = image.permute(0,2,3,1).unfold(dim = 1, size = patch_width, stride = patch_width) \
        .unfold(dim = 2, size = patch_width, stride = patch_width)

I understand that the unfold() method "returns all all slices of size size from self tensor in the dimension dim," from reading the documentation, but try as I might, I just can't get a good intuition for why stacking two .unfold() calls produces square patches. I get what happens when you use unfold() once on a tensor. I don't get what happens when you call it twice successively along two different dimensions.

I've seen this approach used multiple times, always without a good explanation as to why it works (1, 2), and it's driving me bonkers. Why are the spatial dimensions H and W permuted to be dims 1 and 2, while the channel dim is set to 3? Why does unfolding the same way on dim 1, then on dim 2 result in square patch_width by patch_width patches?

Any insight would be hugely appreciated, even if it's just a link to an article I missed. I've been Googling for well over an hour now and have met very little success. Thank you!

[1]PyTorch forum post

[2]Another forum post doing the same thing

2

There are 2 answers

1
Alexey Birukov On BEST ANSWER

I suppose, there are two distinct parts in your question, first one is why you need to permute, and second how two unfolds combined produce square image slices.

First moment is rather technical - unfold places produced slices in the new dimension of tensor, being 'inserted at the end of the shape'. permute here is needed to place it near channel or depth dimension, for merging them natural way using view later.

Now second part. Consider a deck of imaginary cards, each card is a picture channel. Take a card and cut it on vertical slices, then place slices on top of each other. Take a second card and do the same, placing result on the first one, do it with all cards. Now repeat the procedure, with cutting slices horisontaly. At the end you have much thinner but taller deck, where former cards become subdecks of patches.

0
iacob On

Let's look at a simple 2d example to see why composing the operations produce 'patches'.

enter image description here


Code:

x = torch.tensor([[1, 2, 3, 4, 5],
                  [6, 7, 8, 9, 10],
                  [11,12,13,14,15]])
>>> x.unfold(1,2,1)
tensor([[[ 1,  2], [ 2,  3], [ 3,  4], [ 4,  5]],
        [[ 6,  7], [ 7,  8], [ 8,  9], [ 9, 10]],
        [[11, 12], [12, 13], [13, 14], [14, 15]]])
>>> x.unfold(1,2,1).unfold(0,2,1)
tensor([[[[ 1,  6],
          [ 2,  7]],

         [[ 2,  7],
          [ 3,  8]],

         [[ 3,  8],
          [ 4,  9]],

         [[ 4,  9],
          [ 5, 10]]],


        [[[ 6, 11],
          [ 7, 12]],

         [[ 7, 12],
          [ 8, 13]],

         [[ 8, 13],
          [ 9, 14]],

         [[ 9, 14],
          [10, 15]]]])