Pytorch input mismatch could be processed during the forward pass

48 views Asked by At

I don't understand why the mismatch of the input size and the pytorch linear layer still could be processed during the forward pass

I tried my AE model with pytorch, the following is the model. I don't understand, why the mismatch input size could be processed during the forward pass?

class Encoder(nn.Module):
    def __init__(self,input_dim,latent_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim,64)
        self.linear2 = nn.Linear(64,16)
        self.linear3 = nn.Linear(16,latent_dim)
        self.relu = nn.ReLU()
    def forward(self,x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.relu(out)
        latent = self.linear3(out)
        return latent

    
class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(latent_dim,16)
        self.linear2 = nn.Linear(16,64)
        self.linear3 = nn.Linear(64,output_dim)
        self.relu    = nn.ReLU()
    def forward(self,x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.relu(out)
        result = self.linear3(out)
        return result


class AutoEncoder(nn.Module):
    def __init__(self,input_dim, latent_dim, output_dim):
        super().__init__()
        self.encoder = Encoder(input_dim,latent_dim)
        self.decoder = Decoder(latent_dim,output_dim)
    
    def forward(self,x):
        enc = self.encoder(x)
        dec = self.decoder(enc)
        return dec


model_AE = AutoEncoder(input_dim = 13, latent_dim = 8, output_dim = 13).to(device)

then I tried both of these inputs :

1) Test1 :

  • variable input name : dct,
  • shape : torch.Size([254, 13]),
  • code :
for _, (rxt, ryt, dot, dct, mmt) in tqdm(enumerate(dataloader_test)):
        out = model_AE(dct)

2) Test2 :

  • variable input name : x_cat,
  • shape : torch.Size([254, 26]),
  • code :
for _, (rxt, ryt, dot, dct, mmt) in tqdm(enumerate(dataloader_test)):
        x_cat = torch.cat(dim = 1, tensors = [dct, mmt])
        out = model_AE(x_cat)

It suppose to be worked well in the test1 because there will be no problem with matmul ([254,13],[13,64],....). But, it's also worked well in test2, and I wonder, how it can be possible? Since test2 used 26 features which will not match with first linear layer (13,64).

Thank you for your help! :)


Update

I found source of the problem by running the following scripts in different two pytorch environment :

class Encoder(nn.Module):
    def __init__(self,input_dim,latent_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim,64)
        self.linear2 = nn.Linear(64,16)
        self.linear3 = nn.Linear(16,latent_dim)
        self.relu = nn.ReLU()
    def forward(self,x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.relu(out)
        latent = self.linear3(out)
        return latent

    
class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(latent_dim,16)
        self.linear2 = nn.Linear(16,64)
        self.linear3 = nn.Linear(64,output_dim)
        self.relu    = nn.ReLU()
    def forward(self,x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.relu(out)
        result = self.linear3(out)
        return result


class AutoEncoder(nn.Module):
    def __init__(self,input_dim, latent_dim, output_dim):
        super().__init__()
        self.encoder = Encoder(input_dim,latent_dim)
        self.decoder = Decoder(latent_dim,output_dim)
    
    def forward(self,x):
        enc = self.encoder(x)
        dec = self.decoder(enc)
        return dec


model_AE = AutoEncoder(input_dim = 13, latent_dim = 8, output_dim = 13).to(device)

inputxx = torch.randn(254, 26).to(device)
out = model_AE(inputxx)

I tried this on :

  • torch version : 1.9.0+cu111 (works well --> not correct)
  • torch version : 2.1.2+cu121 (not work well --> correct)

I think that it's related to the bug on the pytorch 1.9.0+cu111 version.

0

There are 0 answers