Affine transformation using cv2 and torch

265 views Asked by At
def opencv_downsampling(img, out_size, translation=(0.0, 0.0), rotation=0.0, 
    shear=(0.0, 0.0), scale=(1.0, 1.0), interpolation=cv2.INTER_LINEAR):
    h, w = img.shape[:2]
    rotx = w / 2 - 0.5
    roty = h / 2 - 0.5
    tmat = np.zeros((2,3), dtype=np.float32)
    a = scale[1] * np.cos(rotation)
    b = scale[0] * np.sin(rotation)
    tmat[0, 2] = translation[1] + (1-a) * rotx - b * roty
    tmat[1, 2] = translation[0] + b * rotx + (1-a) * roty
    tmat[0, 0] = a
    tmat[0, 1] = b
    tmat[1, 0] = b * (-1)
    tmat[1, 1] = a
    if interpolation == cv2.INTER_AREA:
        flags = cv2.INTER_CUBIC
    else:
        flags = interpolation
    img = cv2.warpAffine(img, tmat, (w, h), flags=flags, borderMode=cv2.BORDER_CONSTANT)
    img = cv2.resize(img, dsize=out_size[::-1], interpolation=interpolation)
    return np.clip(img, 0, 1)

This is the affine transformation using CV2

class Proxy_SI(nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nc=64, nb=5, nparams=1):
        super(Proxy_SI, self).__init__()
        """
        Proxy network to approximate the Integrator. It consists in two branches: (1) naive double bilinear
        interpolation showned to be an upper bound in the Fourier domain of the real behaviour, and (2)
        the residual branch that removes to the bilinear branch the exceeding quantity to
        get the good result.

        in_nc: number of input channels in the image (eg 3 for RGB or 1 for grayscale)
        out_nc: number of output channels in the image (eg 3 for RGB or 1 for grayscale)
        nc: number of channels in residuals blocks.
        nparams: number of parameters for the proxy integral (eg. fill factor).
        """
        self.nparams = nparams

        self.head = conv(in_nc + nparams, nc, bias=False, mode='C')
        self.features_high = sequential(*[ResBlock(nc, nc, bias=False, mode='CRC') for _ in range(nb)])
        self.features_low = sequential(*[ResBlock(nc, nc, bias=False, mode='CRC') for _ in range(nb)])
        self.tail = conv(nc, out_nc, bias=False, mode='C')

    def forward(self, x, params, output_size, translation, rotation):
        """
        Inputs:
            x, torch.tensor (n,in_nc,h,w): the batch of input images to be warped and downsampled.
            params, torch.tensor (nparams): the additional parameters to be used.
            output_size, tuple (int, int): size of the output images (corresponds to (h/s, w/s)).
            translation, torch.tensor (n, 2): the translations on the x and y axes.
            rotation, torch.tensor (n, 1): the rotations (in radians).
        Output:
            y, torch.tensor (n,out_nc,*output_size): the batch of output images that are warped and downsampled.
        """
        n, _, h, w = x.shape  # 10, 3, 128, 128

        ## Pre-requisite: create warping operator.
        grid_T, grid_R = self.create_warp(x, translation, rotation)

        ## Branch 1: bilinear warp and downsampling in pixel space.
        x1 = F.grid_sample(x, grid_R)
        x1 = F.grid_sample(x1, grid_T)
        x1 = F.interpolate(x1, size=output_size, recompute_scale_factor=False)

        ## Branch 2: bilinear warp and downsampling in the feature domain.
        p = torch.ones(n, self.nparams, h, w, device=x.device) * params[:, None, None]
        x2 = torch.cat([x, p], dim=1)  # concatenate the parameters and the images.
        x2 = self.head(x2)
        x2 = self.features_high(x2)
        x2 = F.grid_sample(x2, grid_R)
        x2 = F.grid_sample(x2, grid_T)
        x2 = F.interpolate(x2, size=output_size, recompute_scale_factor=False)
        x2 = self.features_low(x2)
        x2 = self.tail(x2)

        ## Branches fusion.
        return x1 - x2

    def create_warp(self, x, translation, rotation):
        """
        Inputs:
            x, torch.tensor (n,c,h,w): the batch of input images to be warped and downsampled (just for getting the shape and device).
            translation, torch.tensor (n, 2): the translations on the x and y axes.
            rotation, torch.tensor (n, 1): the rotations (in radians).
        Output:
            grid, torch.tensor (n,h,w,2): the batch of warping grids.
        """
        output_size = x.shape
        h, w = output_size[2], output_size[3]
        A_t = torch.zeros(output_size[0], 2, 3, device=x.device)
        A_r = torch.zeros(output_size[0], 2, 3, device=x.device)
        a = torch.cos(rotation)
        b = torch.sin(rotation)
        rotx = w / 2 - 0.5
        roty = h / 2 - 0.5

        # Put A_t translation.
        A_t[:, 0, 2] = -(translation[:, 1]) / (w / 2)
        A_t[:, 1, 2] = -(translation[:, 0]) / (h / 2)

        # Put A_t rotation.
        A_t[:, 0, 0] = 1
        A_t[:, 0, 1] = 0
        A_t[:, 1, 0] = 0
        A_t[:, 1, 1] = 1

        # Put A_t translation.
        A_r[:, 0, 2] = 0
        A_r[:, 1, 2] = 0

        # Put A_t rotation.
        A_r[:, 0, 0] = torch.cos(rotation)
        A_r[:, 0, 1] = -torch.sin(rotation)
        A_r[:, 1, 0] = torch.sin(rotation)
        A_r[:, 1, 1] = torch.cos(rotation)

        grid_t = F.affine_grid(A_t, output_size)
        grid_r = F.affine_grid(A_r, output_size)

        return grid_t, grid_r

This code snippet above is using Torch.

This is what the code generates.

I want the bottom two images to look the same.

It works well when there is only translation, but when the rotation is added, the ProxySI changes (I know this because of its four angles).

How should I modify the Torch code to preserve the original shape? (I guess the problem occurs on the create_warp part.

0

There are 0 answers