Translate Keras model to PyTorch (shapes cannot be multiplied)

21 views Asked by At

I have made a model using Keras, but now need to translate it into a PyTorch compatible version.

input_shape = 200, 100, 1

def make_model():
    return keras.models.Sequential([
        keras.layers.Conv2D(32, kernel_size=(2, 2), activation='relu', input_shape=input_shape),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Dropout(0.2),
        keras.layers.Flatten(),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dense(3, activation='sigmoid'),
    ])
    
model = make_model()

model.summary()

Model: "sequential_164"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_16 (Conv2D)          (None, 199, 99, 32)       160       
                                                                 
 max_pooling2d_16 (MaxPooli  (None, 99, 49, 32)        0         
 ng2D)                                                           
                                                                 
 dropout_66 (Dropout)        (None, 99, 49, 32)        0         
                                                                 
 flatten_23 (Flatten)        (None, 155232)            0         
                                                                 
 dense_322 (Dense)           (None, 128)               19869824  
                                                                 
 dense_323 (Dense)           (None, 51)                6579      
                                                                 
=================================================================
Total params: 19876563 (75.82 MB)
Trainable params: 19876563 (75.82 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

This my PyTorch approach...

class MyCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(2, 2))
        self.pool = nn.MaxPool2d(2)
        self.dropout = nn.Dropout(0.2)
        self.fc1 = nn.Linear(155232, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 3)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = self.dropout(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x

while the Keras model trains fine, with PyTorch I get this error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x4851 and 155232x128)

I thought this error is weird, because I use the same properties as with Keras.

1

There are 1 answers

0
Fauna Muirgen On

instead of x = torch.flatten(x, start_dim=1), x = torch.flatten(x) is the correct version