How to handle odd resolutions in Unet architecture PyTorch

2.8k views Asked by At

I'm implementing a U-Net based architecture in PyTorch. At train time, I've patches of size 256x256 which doesn't cause any problem. However at test time, I've full HD images (1920x1080). This is causing a problem during skip connections.

Downsampling 1920x1080 3 times gives 240x135. If I downsample one more time, the resolution becomes 120x68 which when upsampled gives 240x136. Now, I cannot concatenate these two feature maps. How can I solve this?

PS: I thought this is a fairly common problem, but I didn't get any solution or even mentioning of this problem anywhere on the web. Am I missing something?

1

There are 1 answers

1
hkchengrex On BEST ANSWER

It is a very common problem in segmentation networks where skip-connections are often involved in the decoding process. Networks usually (depending on the actual architecture) require input size that has side lengths as integer multiples of the largest stride (8, 16, 32, etc.).

There are two main ways:

  1. Resize input to the nearest feasible size.
  2. Pad the input to the next larger feasible size.

I prefer (2) because (1) can cause small changes in the pixel level for all the pixels, leading to unnecessary blurriness. Note that we usually need to recover the original shape afterward in both methods.

My favorite code snippet for this task (symmetric padding for height/width):

import torch
import torch.nn.functional as F

def pad_to(x, stride):
    h, w = x.shape[-2:]

    if h % stride > 0:
        new_h = h + stride - h % stride
    else:
        new_h = h
    if w % stride > 0:
        new_w = w + stride - w % stride
    else:
        new_w = w
    lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
    lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
    pads = (lw, uw, lh, uh)

    # zero-padding by default.
    # See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
    out = F.pad(x, pads, "constant", 0)

    return out, pads

def unpad(x, pad):
    if pad[2]+pad[3] > 0:
        x = x[:,:,pad[2]:-pad[3],:]
    if pad[0]+pad[1] > 0:
        x = x[:,:,:,pad[0]:-pad[1]]
    return x

A test snippet:

x = torch.zeros(4, 3, 1080, 1920) # Raw data
x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network 
x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape

print('Original: ', x.shape)
print('Padded: ', x_pad.shape)
print('Recovered: ', x_unpad.shape)

Output:

Original:  torch.Size([4, 3, 1080, 1920])
Padded:  torch.Size([4, 3, 1088, 1920])
Recovered:  torch.Size([4, 3, 1080, 1920])

Reference: https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33