Solving Shortest Path Problem with Message Passing GNN

65 views Asked by At

I'm trying to develop a GNN that would be able to learn the shortest path problem in a graph. My idea would be to learn the update of the Bellman-Ford algorithm via MLPs:

BF update: $$\forall u, d_u^{k+1} = \min_v d_v^k +c(u,v)$$

GNN : $$\forall u, h_u^{k+1} = MLP_2(\sum_v (MLP_1(h_u^k, h_v^k))$$

To detail a bit more, I would build a sequence of MP-GNN layers which would all have in common the 2 MLPs - since they're supposed to learn the same operation at each iteration. My node features would be setup to 0 for the node of reference and 1 for the others, and would be updated sequentially by each layer. My edges features would be the matrix weights.

Here is the code I've come up with sofar - which runs but produces disappointing results, 30% accuracy when compared to Dijkstra while trained and tested on random graphs with 10 nodes and integer edge weights.

My original input is a batch of adjacency matrices and weights edges matrices, so my function matrix_to_graph converts it to the type of input suited for torch_geometric.

Would you see something wrong in my approach or in the architecture of the GNN? Also, would you have any advice on how to improve accuracy?

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, depth, breadth):
        super(MLP, self).__init__()
        layers = [nn.Linear(input_dim, breadth), nn.ReLU()]
        for _ in range(depth - 1):
            layers.append(nn.Linear(breadth, breadth))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(breadth, output_dim))
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x.float())


class GNNLayer(MessagePassing):
    def __init__(
        self,
        use_x_i,
        aggr,
        message_mlp,
        update_mlp,
    ):
        super(GNNLayer, self).__init__(aggr=aggr)
        self.use_x_i = use_x_i
        self.message_mlp = message_mlp
        self.update_mlp = update_mlp

    def forward(self, batch):
        # Pass message to propagate
        propagation = self.propagate(
            edge_index=batch.edge_index, x=batch.x, edge_weight=batch.edge_attr
        )
        return propagation

    def message(self, x_i, x_j, edge_weight):
        # Use x_i if use_x_i is True, otherwise just use x_j and edge_weight
        if self.use_x_i:
            message_input = torch.cat([x_i, x_j, edge_weight.unsqueeze(-1)], dim=-1)
        else:
            message_input = torch.cat([x_j, edge_weight.unsqueeze(-1)], dim=-1)
        message_output = self.message_mlp(message_input)
        return message_output

    def update(self, aggr_out):
        # Update node features to the aggregated messages
        update_output = self.update_mlp(aggr_out)
        return update_output


class GNN(nn.Module):
    def __init__(
        self,
        num_nodes,
        num_iter,
        message_depth,
        message_breadth,
        update_depth,
        update_breadth,
        use_x_i=False,
        aggr="min",
        out_dim=None,
    ):
        super(GNN, self).__init__()
        self.num_iter = num_iter

        # Create common MLPs for message and update functions
        input_dim = 2 + use_x_i
        self.message_mlp = MLP(input_dim, 1, message_depth, message_breadth)
        self.update_mlp = MLP(1, 1, update_depth, update_breadth)

        # Create layers with references to the common MLPs
        self.layers = nn.ModuleList(
            [
                GNNLayer(use_x_i, aggr, self.message_mlp, self.update_mlp)
                for _ in range(num_iter)
            ]
        )
        if out_dim:
            self.output_mlp = MLP(num_nodes, out_dim, 2, 2 * num_nodes)
        self.out_dim = out_dim

    def forward(self, batch):
        # Ensure that batch is a Batch object
        if not isinstance(batch, Batch):
            raise TypeError("Input must be a PyTorch Geometric Batch object")

        for layer in self.layers:
            batch.x = layer(batch)

        # Each graph in the batch has the same number of nodes
        num_nodes_per_graph = batch.num_nodes // batch.num_graphs
        final_output = batch.x.view(batch.num_graphs, num_nodes_per_graph, -1).squeeze()

        if self.out_dim:
            final_output = self.output_mlp(final_output)
        return final_output


def matrix_to_graph(matrices):
    """
    Convert a batch of matrices into a Batch containing graphs.

    Each pair of matrix in the batch represents a graph, where the pair of matrix contains
    both adjacency information and edge weights.

    Parameters:
    matrices (torch.Tensor): A batch of matrices with shape (B, c, n, n),
                             where B is the batch size, c is the number of channels
                             (c=2 for adjacency matrix and edge weights matrix),
                             and n is the number of nodes in each graph.

    Returns:
    Batch: A batch of graphs.
    """
    graph_list = []
    for matrix in matrices[:, 1]:  # take only the edges weights
        num_nodes = matrix.size(0)

        # Use torch.nonzero to find indices of non-zero elements
        edge_index = torch.nonzero(matrix, as_tuple=False).t().contiguous()

        # Gather corresponding edge weights
        edge_weight = matrix[edge_index[0], edge_index[1]]

        # Node features: Initialized as 1 for all nodes except the source node
        x = torch.full((num_nodes, 1), 1)
        x[0][0] = 0  # First node is the source

        graph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_weight)
        graph_list.append(graph_data)

    return Batch.from_data_list(graph_list)
0

There are 0 answers