PyTorch's grid_sample vs landmark transformation

49 views Asked by At

I want to do a simple rotation transformation of the image. I go through all the steps with affine_grid and grid_sample. It properly transforms the image while centering the rotation at the center of the image. That's all fine.

Then I want to transform some landmarks the same way. I recenter the affine transformation matrix at the center of the image space, as I suppose pytorch does this internally.

It all looks fine but it's also introducing some error. I tested it and it is cause of neither align_corners=True nor float32/64 precision.

Here is the minimal working sample:

import PIL.Image
import torch.nn.functional as F
import torch
import numpy as np
import matplotlib.pyplot as plt

img = np.asarray(PIL.Image.open('frog.jpg'))
imgt = torch.from_numpy(img.copy()).permute(2,0,1).unsqueeze(0)
# imgt.shape == torch.Size([1, 3, 1200, 1600])

landmarks1 = np.array([[200.  , 1250.],
                       [600.  , 250. ],
                       [1000. , 500. ]], dtype=np.float32)

affine_matrix = torch.tensor([[[ 0.9659, -0.2588,  0.0000],
                               [ 0.2588,  0.9659,  0.0000],
                               [ 0.0000,  0.0000,  1.0000]]])

grid = F.affine_grid(affine_matrix[:, :-1,:], size=imgt.shape, align_corners=True)
imgt2 = F.grid_sample(imgt.float(), grid, align_corners=True)

aff_mat = affine_matrix.squeeze(0).numpy().copy()
recentering = np.eye(3, dtype=np.float32)
recentering[:2, 2] = [imgt.shape[-2] / 2, imgt.shape[-1] / 2]

ones = np.ones((landmarks1.shape[0], 1), dtype=np.float32)
landmarks = np.concatenate((landmarks1, ones), axis=1)

aff_mat = recentering @ aff_mat @ np.linalg.inv(recentering)
landmarks2 = (aff_mat @ landmarks.T).T

plt.subplot(1,2,1)
plt.imshow(imgt.squeeze(0).permute(1,2,0).int().numpy())
plt.scatter(landmarks1[:,1], landmarks1[:,0], c='yellow')
plt.subplot(1,2,2)
plt.imshow(imgt2.squeeze(0).permute(1,2,0).int().numpy())
plt.scatter(landmarks2[:,1], landmarks2[:,0], c='red')

Which gives this output:

enter image description here

As you can see the dots differ slightly with their position. I can't figure out why.

0

There are 0 answers