issue with arcface ( 0 accuracy)

856 views Asked by At

Hello guys I've joined a university-level image recognition competition.

  • In the test, they will give two images (people face) and my model need to detect pair of the image is the same person or not

  • My model is resnet18 with IR block and SE block. and it will use Arcface loss.

  • I can use only the MS1M dataset with a total of 86876 classes

The problem is that loss is getting better, but accuracy is 0 and not changing.

Here's part of code I'm working on.

Train

def train_model(model, net, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs): 
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train']:
            if phase == 'train':
                model.train()  # Set model to training mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in notebook.tqdm(dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device).long()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    features = model(inputs) 
                    outputs = net(features, labels)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloader)
            epoch_acc = running_corrects.double() / len(dataloader)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'train' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        torch.save({'epoch': epoch,
                    'mode_state_dict': model.state_dict(),
                    'fc_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),    # HERE IS THE CHANGE
                    }, f'/content/drive/MyDrive/inha_data/training_saver/training_stat{epoch}.pth')
        print(f'finished {epoch} and saved model_save_{epoch}.pt')
        print()
        

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best train Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), 'model_save.pt')
    return model

Parameters

train_dataset = MS1MDataset('train')
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True,num_workers=4)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 디바이스 설정

num_classes = 86876 
# normal classifier
# net = nn.Sequential(nn.Linear(512, num_classes))
# Feature extractor backbone, input is 112x112 image output is 512 feature vector
model_ft = resnet18(True)
#set metric
metric_fc = metrics.ArcMarginProduct(512, num_classes, s = 30.0, m = 0.50, easy_margin = False)
metric_fc.to(device)
# net = net.to(device)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer_ft = torch.optim.Adam([{'params': model_ft.parameters()}, {'params': metric_fc.parameters()}],
                                     lr=0.1)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=4, gamma=0.1)

Arcface

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math


class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s
        # print(output)

        return output

dataset

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(), 
        transforms.ColorJitter(brightness=0.125, contrast=0.125, saturation=0.125), 
        transforms.ToTensor(), 
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
#train_ms1_data = torchvision.datasets.ImageFolder('/content/drive/MyDrive/inha_data/train', transform = data_transforms)

class MS1MDataset(Dataset): 
    
    def __init__(self,split): 

        self.file_list = '/content/drive/MyDrive/inha_data/ID_List.txt'
        self.images = []
        self.labels = []
        self.transformer = data_transforms['train']
        with open(self.file_list) as f:
            files = f.read().splitlines()
        for i, fi in enumerate(files):
            fi = fi.split()
            image = "/content/" + fi[1] 
            label = int(fi[0])
            self.images.append(image)
            self.labels.append(label)
            
    def __getitem__(self, index): 

        img = Image.open(self.images[index]) 
        img = self.transformer(img) 

        label = self.labels[index]
        return img, label

    def __len__(self): 
        return len(self.images)
2

There are 2 answers

0
user251386 On BEST ANSWER

You can try to use a smaller m in ArcFace, even a minus value.

0
H.asadi On

When I encountered the same problem, I resolved it by starting training with a small number of classes and batches and gradually increasing the dataset.