How to implement Batchnorm2d in Pytorch myself?

7.6k views Asked by At

I'm trying to implement Batchnorm2d() layer with:

class BatchNorm2d(nn.Module):

    def __init__(self, num_features):
        super(BatchNorm2d, self).__init__()
        self.num_features = num_features
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.eps = 1e-5
        self.momentum = 0.1
        self.first_run = True

    def forward(self, input):
        # input: [batch_size, num_feature_map, height, width]
        device = input.device
        if self.training:
            mean = torch.mean(input, dim=0, keepdim=True).to(device)  # [1, num_feature, height, width]
            var = torch.var(input, dim=0, unbiased=False, keepdim=True).to(device)  # [1, num_feature, height, width]
            if self.first_run:
                self.weight = Parameter(torch.randn(input.shape, dtype=torch.float32, device=device), requires_grad=True)
                self.bias = Parameter(torch.randn(input.shape, dtype=torch.float32, device=device), requires_grad=True)
                self.register_buffer('running_mean', torch.zeros(input.shape).to(input.device))
                self.register_buffer('running_var', torch.ones(input.shape).to(input.device))
                self.first_run = False
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
            bn_init = (input - mean) / torch.sqrt(var + self.eps)
        else:
            bn_init = (input - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        return self.weight * bn_init + self.bias

But after training & testing I found that the results using my layer is incomparable with the results using nn.Batchnorm2d(). There must be something wrong with it, and I guess the problem relates to initializing parameters in forward()? I did that because I don't know how to know the shape of input in __init__(), maybe there is a better way. I don't know how to fix it, please help. Thanks!!

3

There are 3 answers

0
lemonramen3 On BEST ANSWER

Got answers from HERE!\
So the shape of weight(bias) is (1, num_features, 1, 1), not (1, num_features, width, height).

0
Mircea On

The video from Andrej Karpathy has a very intuitive explanation.

Here is a code snippet with the 1D implementation, from the notebook associated with the video:

class BatchNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # parameters (trained with backprop)
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
    # buffers (trained with a running 'momentum update')
    self.running_mean = torch.zeros(dim)
    self.running_var = torch.ones(dim)
  
  def __call__(self, x):
    # calculate the forward pass
    if self.training:
      xmean = x.mean(0, keepdim=True) # batch mean
      xvar = x.var(0, keepdim=True) # batch variance
    else:
      xmean = self.running_mean
      xvar = self.running_var
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    # update the buffers
    if self.training:
      with torch.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

The pytorch implementation is in c++.

However, this implementation + explanation, from Dive into deep learning website, as mentioned in the approved answer, might help you understand the implementation difference between 1D and 2D case.

0
Harman Singh On

Just if someone stumbles on this, you don't actually have to set the 'device' inside the model as done above. Outside the model, you can just do

device = torch.device('cuda:0')
model = model.to(device)

not sure if this is better than manually setting devices for weights and biases inside the module, but definitely more standard I think