I'm trying to develop a GNN that would be able to learn the shortest path problem in a graph. My idea would be to learn the update of the Bellman-Ford algorithm via MLPs:
BF update: $$\forall u, d_u^{k+1} = \min_v d_v^k +c(u,v)$$
GNN : $$\forall u, h_u^{k+1} = MLP_2(\sum_v (MLP_1(h_u^k, h_v^k))$$
To detail a bit more, I would build a sequence of MP-GNN layers which would all have in common the 2 MLPs - since they're supposed to learn the same operation at each iteration. My node features would be setup to 0 for the node of reference and 1 for the others, and would be updated sequentially by each layer. My edges features would be the matrix weights.
Here is the code I've come up with sofar - which runs but produces disappointing results, 30% accuracy when compared to Dijkstra while trained and tested on random graphs with 10 nodes and integer edge weights.
My original input is a batch of adjacency matrices and weights edges matrices, so my function matrix_to_graph converts it to the type of input suited for torch_geometric.
Would you see something wrong in my approach or in the architecture of the GNN? Also, would you have any advice on how to improve accuracy?
class MLP(nn.Module):
def __init__(self, input_dim, output_dim, depth, breadth):
super(MLP, self).__init__()
layers = [nn.Linear(input_dim, breadth), nn.ReLU()]
for _ in range(depth - 1):
layers.append(nn.Linear(breadth, breadth))
layers.append(nn.ReLU())
layers.append(nn.Linear(breadth, output_dim))
self.mlp = nn.Sequential(*layers)
def forward(self, x):
return self.mlp(x.float())
class GNNLayer(MessagePassing):
def __init__(
self,
use_x_i,
aggr,
message_mlp,
update_mlp,
):
super(GNNLayer, self).__init__(aggr=aggr)
self.use_x_i = use_x_i
self.message_mlp = message_mlp
self.update_mlp = update_mlp
def forward(self, batch):
# Pass message to propagate
propagation = self.propagate(
edge_index=batch.edge_index, x=batch.x, edge_weight=batch.edge_attr
)
return propagation
def message(self, x_i, x_j, edge_weight):
# Use x_i if use_x_i is True, otherwise just use x_j and edge_weight
if self.use_x_i:
message_input = torch.cat([x_i, x_j, edge_weight.unsqueeze(-1)], dim=-1)
else:
message_input = torch.cat([x_j, edge_weight.unsqueeze(-1)], dim=-1)
message_output = self.message_mlp(message_input)
return message_output
def update(self, aggr_out):
# Update node features to the aggregated messages
update_output = self.update_mlp(aggr_out)
return update_output
class GNN(nn.Module):
def __init__(
self,
num_nodes,
num_iter,
message_depth,
message_breadth,
update_depth,
update_breadth,
use_x_i=False,
aggr="min",
out_dim=None,
):
super(GNN, self).__init__()
self.num_iter = num_iter
# Create common MLPs for message and update functions
input_dim = 2 + use_x_i
self.message_mlp = MLP(input_dim, 1, message_depth, message_breadth)
self.update_mlp = MLP(1, 1, update_depth, update_breadth)
# Create layers with references to the common MLPs
self.layers = nn.ModuleList(
[
GNNLayer(use_x_i, aggr, self.message_mlp, self.update_mlp)
for _ in range(num_iter)
]
)
if out_dim:
self.output_mlp = MLP(num_nodes, out_dim, 2, 2 * num_nodes)
self.out_dim = out_dim
def forward(self, batch):
# Ensure that batch is a Batch object
if not isinstance(batch, Batch):
raise TypeError("Input must be a PyTorch Geometric Batch object")
for layer in self.layers:
batch.x = layer(batch)
# Each graph in the batch has the same number of nodes
num_nodes_per_graph = batch.num_nodes // batch.num_graphs
final_output = batch.x.view(batch.num_graphs, num_nodes_per_graph, -1).squeeze()
if self.out_dim:
final_output = self.output_mlp(final_output)
return final_output
def matrix_to_graph(matrices):
"""
Convert a batch of matrices into a Batch containing graphs.
Each pair of matrix in the batch represents a graph, where the pair of matrix contains
both adjacency information and edge weights.
Parameters:
matrices (torch.Tensor): A batch of matrices with shape (B, c, n, n),
where B is the batch size, c is the number of channels
(c=2 for adjacency matrix and edge weights matrix),
and n is the number of nodes in each graph.
Returns:
Batch: A batch of graphs.
"""
graph_list = []
for matrix in matrices[:, 1]: # take only the edges weights
num_nodes = matrix.size(0)
# Use torch.nonzero to find indices of non-zero elements
edge_index = torch.nonzero(matrix, as_tuple=False).t().contiguous()
# Gather corresponding edge weights
edge_weight = matrix[edge_index[0], edge_index[1]]
# Node features: Initialized as 1 for all nodes except the source node
x = torch.full((num_nodes, 1), 1)
x[0][0] = 0 # First node is the source
graph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_weight)
graph_list.append(graph_data)
return Batch.from_data_list(graph_list)