Pytorch Geometric; RuntimeError: mat1 dim 1 must match mat2 dim 0

125 views Asked by At

I'm new in pytorch geometric, and when running my model I obtain this error:

RuntimeError: mat1 dim 1 must match mat2 dim 0

The error occurs while running this code, and it happens at the z = model.encode(x, train_pos_edge_index) line

def train():
model.train()
optimizer.zero_grad()
z = model.encode(x, train_pos_edge_index)
loss = model.recon_loss(z, train_pos_edge_index)

loss = loss + (1 / data.num_nodes) * model.kl_loss()  # new line
loss.backward()
optimizer.step()
return float(loss)

The class through which I generate my model is the following:

class VariationalGCNEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
    super(VariationalGCNEncoder, self).__init__()
    self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True) # cached only for transductive learning
    self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
    self.conv_logstd = GCNConv(2 * out_channels, out_channels, cached=True)

def forward(self, x, edge_index):
    x = self.conv1(x, edge_index).relu()
    return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

Using VGAE from torch_geometric.nn, I create the model as follows:

out_channels = 2 #dimension of the embedding
num_features = 128 


model = VGAE(VariationalGCNEncoder(num_features, out_channels))

The dimensionality of the variables is the following:

test_neg_edge_index=[2, 68],
test_pos_edge_index=[2, 68],
train_neg_adj_mask=[132, 132],
train_pos_edge_index=[2, 1166],
val_neg_edge_index=[2, 34],
val_pos_edge_index=[2, 34],
x=[132, 132]

And the dimensionality of the layers in the model is the following:

VGAE(
  (encoder): VariationalGCNEncoder(
    (conv1): GCNConv(128, 4)
    (conv_mu): GCNConv(4, 2)
    (conv_logstd): GCNConv(4, 2)
  )
  (decoder): InnerProductDecoder()
)

The full error backtrace is the following:

Traceback (most recent call last)
<ipython-input-20-af87b3233297> in <module>
      2 
      3 for epoch in range(1, epochs + 1):
----> 4     loss = train()
      5     auc, ap = test(data.test_pos_edge_index, data.test_neg_edge_index)
      6     print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

<ipython-input-19-e42a4a22847f> in train()
      2     model.train()
      3     optimizer.zero_grad()
----> 4     z = model.encode(x, train_pos_edge_index)
      5     loss = model.recon_loss(z, train_pos_edge_index)
      6 

~\anaconda3\lib\site-packages\torch_geometric\nn\models\autoencoder.py in encode(self, *args, **kwargs)
    153     def encode(self, *args, **kwargs):
    154         """"""
--> 155         self.__mu__, self.__logstd__ = self.encoder(*args, **kwargs)
    156         self.__logstd__ = self.__logstd__.clamp(max=MAX_LOGSTD)
    157         z = self.reparametrize(self.__mu__, self.__logstd__)

~\anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

<ipython-input-16-d5c82de35dd4> in forward(self, x, edge_index)
      7 
      8     def forward(self, x, edge_index):
----> 9         x = self.conv1(x, edge_index).relu()
     10         return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

~\anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~\anaconda3\lib\site-packages\torch_geometric\nn\conv\gcn_conv.py in forward(self, x, edge_index, edge_weight)
    177                     edge_index = cache
    178 
--> 179         x = x @ self.weight
    180 
    181         # propagate_type: (x: Tensor, edge_weight: OptTensor)

RuntimeError: mat1 dim 1 must match mat2 dim 0
0

There are 0 answers