Why does the accuracy fluctuate widely after using batch normalization

69 views Asked by At

I'm training a model which includes batch normalization layer, but i noticed that the accuracy can fluctuate widely (from 55% to 31% in just one epoch), both train accuracy and test accuracy, so i think it's not caused by overfitting.

This is my accuracy over epoch

This is joint graph

This is my model architecture

    return nn.Sequential(
        nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3),
        nn.BatchNorm2d(64,momentum=momentum),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
        Residual(64, 64),
        Residual(64, 64),
        Residual(64, 128, use_1x1=True, stride=2),
        Residual(128, 128),
        Residual(128, 256, use_1x1=True, stride=2),
        Residual(256, 256),
        Residual(256, 512, use_1x1=True, stride=2),
        Residual(512, 512),
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(512, 176)
    ).to(device)
class Residual(nn.Module):
    def __init__(self,input_channel,output_channel,use_1x1=False,stride=1):
        super().__init__()
        self.conv1=nn.Conv2d(input_channel,output_channel,kernel_size=3,padding=1,stride=stride)
        self.conv2=nn.Conv2d(output_channel,output_channel,kernel_size=3,padding=1)
        self.bn1=nn.BatchNorm2d(output_channel,momentum=momentum)
        self.bn2=nn.BatchNorm2d(output_channel,momentum=momentum)

        if use_1x1:
            self.conv3=nn.Conv2d(input_channel,output_channel,kernel_size=1,stride=stride)
        else:
            self.conv3=None

    def forward(self,X):
        Y=F.relu(self.bn1(self.conv1(X)))
        Y=self.bn2(self.conv2(Y))
        if self.conv3 is not None:
            X=self.conv3(X)
        Y += X
        return F.relu(Y)

But magically, if i don't call model.eval() in the accuracy evaluation function, which keeps the running_mean and running_var updating, the accuracy won't fluctuate Furthermore, if i go through the training set after each epoch, as is shown in the code below

    for epoch in range(epochs):
        net.train()
        for X, y in train_iter:
            optimizer.zero_grad()
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()

        for X,y in train_iter:
            net(X)

        eval_accuracy()

the accuracy doesn't fluctuate, too

I've tried to change the momentum, but it doesn't work Now i'm totally confused, i don't have any idea why the accuracy fluctuates and why the method above works

0

There are 0 answers