I'm training a model which includes batch normalization layer, but i noticed that the accuracy can fluctuate widely (from 55% to 31% in just one epoch), both train accuracy and test accuracy, so i think it's not caused by overfitting.
This is my accuracy over epoch
This is my model architecture
return nn.Sequential(
nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3),
nn.BatchNorm2d(64,momentum=momentum),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
Residual(64, 64),
Residual(64, 64),
Residual(64, 128, use_1x1=True, stride=2),
Residual(128, 128),
Residual(128, 256, use_1x1=True, stride=2),
Residual(256, 256),
Residual(256, 512, use_1x1=True, stride=2),
Residual(512, 512),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512, 176)
).to(device)
class Residual(nn.Module):
def __init__(self,input_channel,output_channel,use_1x1=False,stride=1):
super().__init__()
self.conv1=nn.Conv2d(input_channel,output_channel,kernel_size=3,padding=1,stride=stride)
self.conv2=nn.Conv2d(output_channel,output_channel,kernel_size=3,padding=1)
self.bn1=nn.BatchNorm2d(output_channel,momentum=momentum)
self.bn2=nn.BatchNorm2d(output_channel,momentum=momentum)
if use_1x1:
self.conv3=nn.Conv2d(input_channel,output_channel,kernel_size=1,stride=stride)
else:
self.conv3=None
def forward(self,X):
Y=F.relu(self.bn1(self.conv1(X)))
Y=self.bn2(self.conv2(Y))
if self.conv3 is not None:
X=self.conv3(X)
Y += X
return F.relu(Y)
But magically, if i don't call model.eval()
in the accuracy evaluation function, which keeps the running_mean
and running_var
updating, the accuracy won't fluctuate
Furthermore, if i go through the training set after each epoch, as is shown in the code below
for epoch in range(epochs):
net.train()
for X, y in train_iter:
optimizer.zero_grad()
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
optimizer.step()
for X,y in train_iter:
net(X)
eval_accuracy()
the accuracy doesn't fluctuate, too
I've tried to change the momentum, but it doesn't work Now i'm totally confused, i don't have any idea why the accuracy fluctuates and why the method above works