Template matching model for object orientation estimation converges fast with in-plane rotations only, but fails with full 3D orientations

32 views Asked by At

Context

I'm experimenting with a model that should match a query image of a known object with the corresponding template image(s) where the orientation could be the same. (I will be dealing with symmetric objects and heavy occlusions thus this relation is often one-to-many.)

I give the model an image pair as input (query image + candidate template image) and I expect 0.0 if the model thinks that the objects do not have the same orientation or 1.0 if the model thinks they do have the same orientation. (I use L1_loss for training.)

I train this model with synthetic data in batches where I give for each query image:

  • a positive case: the query image with the correct associated template image (expect classification of 1.0), enter image description here
  • and 'negative' case: the same query image but with a random template image (expect classification of 0.0). enter image description here

Problem

The strange thing is that the model trains and performs almost suspiciously good when the negative template is an in-plane rotation of the positive template (avg. classification of pos. case = ~.99, and neg. case = ~.1). But when the negative template is a completely random template, with any 3D object orientation, the model struggles a lot (avg. classification of pos. case = ~.75, and neg. case = ~.5). This seems strange to me since there should be more differences between the pos. and neg. case thus it should be easier to discriminate them.

Code

Model:

class TemplateEvaluator(nn.Module):
    def __init__(self, q_encoder=resnet18(weights=ResNet18_Weights.IMAGENET1K_V1), t_encoder=resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)):
        super(TemplateEvaluator, self).__init__()
        self.q_encoder = q_encoder
        self.t_encoder = t_encoder
        
        self.fc = nn.Sequential(
            nn.Linear(2000, 1),
            nn.Sigmoid()
        )
    
    def forward(self, data):
        q = data[0]
        t = data[1]
        q = self.q_encoder(q)
        t = self.t_encoder(t)
        res = self.fc(torch.cat([q,t],-1))
        return res

Training step:

  • cb_id contains the IDs of the associated correct template (template with smallest angular difference)
  • t_img_rand are the negative templates
def template_eval_train_step(iteration, models, data, codebook, opts=None, show=False, metric_label=''):
    # Get query image, associated codebook template ID, and associated orientation
    q_img, cb_id, rot = data
    n = q_img.shape[0]
    t_eval = models[0]
    
    # Get random template IDs
    cb_id_rand = np.random.choice(codebook["size"],n)

    # Get associated and random template images
    t_img = torch.stack([cb_get_img(i,codebook) for i in cb_id]).to(device)

    # Uncomment to use random template as neg cases
    t_img_rand = torch.stack([cb_get_img(i,codebook) for i in cb_id_rand]).to(device)
    # Uncomment to use in-plane rotations of pos template as neg cases
    # t_img_rand = torch.stack([rotate_image_tensor(y,np.random.random()*360) for y in t_img])
    
    # Cases with similar template image ('Positive')
    p_cases = torch.stack([q_img.permute(0, 3, 1, 2),t_img.permute(0, 3, 1, 2)])

    # Cases with random template image ('Negative')
    n_cases = torch.stack([q_img.permute(0, 3, 1, 2),t_img_rand.permute(0, 3, 1, 2)])

    # Mix together for 50/50 distribution in batch
    mixed_cases = torch.concat([p_cases,n_cases], 1)

    # Run model
    c = t_eval(mixed_cases)

    # Get classification for pos and neg cases
    p_cls = c[:n]
    n_cls = c[n:]

    # Compute loss
    p_loss = F.l1_loss(p_cls, torch.ones_like(p_cls, requires_grad=True))
    n_loss = F.l1_loss(n_cls, torch.zeros_like(n_cls, requires_grad=True))
    loss = (p_loss + n_loss)/2
    
    # Visualise pos and neg case at i=0
    if show:
        i=0
        view([q_img[i].detach().cpu().numpy(), t_img[i].detach().cpu().numpy()])
        print("p_cls:",p_cls[i].detach().cpu().numpy())
        view([q_img[i].detach().cpu().numpy(), t_img_rand[i].detach().cpu().numpy()])
        print("n_cls:",n_cls[i].detach().cpu().numpy())

    # Run optimizer (if given)
    if opts is not None:
        opts[0].zero_grad()
        loss.backward()
        
        # Print gradient info
        if show:
            t_eval.cpu()
            plot_grad_flow(t_eval.named_parameters())
            t_eval.to(device)
        
        opts[0].step()

    # Compute eval metrics
    p_rate = p_cls.sum() / n
    n_rate = n_cls.sum() / n

    # Garbage collection 
    gc.collect()

    return [ {"label": metric_label, "name": "loss", "value":loss.cpu().item()},
             {"label": metric_label, "name": "p_rate", "value":p_rate.cpu().item()},
             {"label": metric_label, "name": "n_rate", "value":n_rate.cpu().item()}]

Train loop:

  • init_train, init_verify just put the models in train or eval mode
  • train_step is the previous function
def fit(epochs, models, init_train, init_verify, train_step, verify_step, opts, train_dl, verify_dl, eval_dl, codebook, vis_epoch_step=10):
    train_data = []
    verify_data = []
    eval_data = []

    for epoch in tqdm(range(epochs)):
        init_train(epoch, models)
        
        i = 0
        for data in train_dl:
            train_metrics = train_step(epoch, models, data, opts=opts, codebook=codebook, show=epoch % vis_epoch_step == 0 and i == 0)
            train_data.append(train_metrics)
            i = i + 1
            
            n = len(train_dl)
            p = round((i/n)*100)
            if p>0:
                sys.stdout.write('\r')
                bar_len = round(p/5)
                empty_len = round((100-p)/5)
                sys.stdout.write("Train batch %d/%d [%s%s] %d%%" % (i, n, '#'*bar_len, '_'*empty_len, p))
                sys.stdout.flush()
            
        # verification step
        init_verify(epoch, models)
        with torch.no_grad():
            
            i = 0
            for data in verify_dl:
                verify_metrics = verify_step(epoch, models, data, codebook=codebook, show=epoch % vis_epoch_step == 0 and i == 0)
                verify_data.append(verify_metrics)
                i = i + 1
            
                n = len(verify_dl)
                p = round((i/n)*100)
                if p>0:
                    sys.stdout.write('\r')
                    bar_len = round(p/5)
                    empty_len = round((100-p)/5)
                    sys.stdout.write("Verification batch %d/%d [%s%s] %d%%" % (i, n, '#'*bar_len, '_'*empty_len, p))
                    sys.stdout.flush()
...

Results with in-plane rotation of postive template as negative template

t_img_rand = torch.stack([rotate_image_tensor(y,np.random.random()*360) for y in t_img])

Training (p/n_rate is avg pos/neg case classification):

enter image description here

Example case:

enter image description here

p_cls: [0.998]

enter image description here

n_cls: [0.000]

with random template as negative template (same model initialisation, optimizer and hyperparams):

cb_id_rand = np.random.choice(codebook["size"],n)
t_img_rand = torch.stack([cb_get_img(i,codebook) for i in cb_id_rand]).to(device)

Training (p/n_rate is avg pos/neg case classification):

enter image description here

Example case:

enter image description here

p_cls: [0.001]

enter image description here

n_cls: [0.998]

1

There are 1 answers

0
Th F On

I think I found the problem. The problem was that generating the in-plane rotations of positive templates creates templates that are not part of the initial discrete set of templates. The model seems to memorize the initial set of templates and just gives 0.0 as output when the given template seems to not be part of it.

So unfortunately the 'good' performance was just a kind of overfitting. The 'bad' performance was the honest performance.

Update The real problem with my code was using a single-layer perceptron over the concatenated output of the ResNet models. A single-layer perceptron can not compute equality of values (XOR/NXOR problem). I now solved this by using element-wise multiplication instead of concatenation of the ResNet outputs.