Scatter matrix multiplication using torch and torch geometric

26 views Asked by At

I would like to perform the following operation without use a for loop, I would like to parallelize or vectorize it because I'm running in a GPU.

import torch

node_predictions = torch.randn(size = (15, 8))
node2graph = torch.tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2], dtype=torch.int64)

split_size = torch.bincount(input = node2graph)
list_of_node_predictions = torch.split(node_predictions, split_size.tolist())
pred_m_T_times_pred_m = torch.stack([torch.matmul(pred_m.T, pred_m) for pred_m in list_of_node_predictions])
0

There are 0 answers