Lowest common ancestor for multiple vertices in networkx

88 views Asked by At

How can I compute the lowest common ancestor (LCA) for a directed graph in networkx for a subset of vertices?

For example, for the graph

G = nx.DiGraph()
G.add_edges_from([(1, 2), (1, 3), (3, 4), (3, 5)])

vertex 3 is the LCA for the vertex {4, 5} and vertex 1 for the the nodes {3, 4, 5}. In case it matters: All vertices are leaves.

nx.lowest_common_ancestor() is not suitable since it requires a pair of vertices, but does not allow a set of vertices.

Thanks!

1

There are 1 answers

0
HMH1013 On

After checking the source code of the nx.lowest_common_ancestor, I found the function for calculating the LCA with a pair of 2 nodes.

So I did some modifications for this code can work for multiple nodes and return to node 1 not node 3 when the input nodes are {3, 4, 5}.

import networkx as nx
from functools import reduce

def update_ancestor_cache(node, ancestor_cache, G):
    if node not in ancestor_cache:
        ancestor_cache[node] = nx.ancestors(G, node)
        ancestor_cache[node].add(node)
    return ancestor_cache

def generate_lca_from_pairs(G, pairs):
    ancestor_cache = {}
    if isinstance(pairs[0], int):
        raise nx.NetworkXError("LCA needs at least two nodes") 
    for i in range(len(pairs[0])):
        node = pairs[0][i]
        ancestor_cache = update_ancestor_cache(node, ancestor_cache, G)

    common_ancestors = reduce(set.intersection, (set(val) for val in ancestor_cache.values()))

    if common_ancestors:
        common_ancestor = next(iter(common_ancestors))
        while True:
            successor = None
            for lower_ancestor in G.successors(common_ancestor):
                if lower_ancestor in common_ancestors and lower_ancestor not in pairs[0]:
                    successor = lower_ancestor
                    break
            if successor is None:
                break
            common_ancestor = successor
        yield (pairs[0], common_ancestor)

    return generate_lca_from_pairs(G, pairs)

Here shows the result of the codes :

G = nx.DiGraph()
G.add_edges_from([(1, 2), (1, 3), (3, 4), (3, 5)])

pairs = [(4, 5)]
dict(generate_lca_from_pairs(G, pairs)) #{(4, 5): 3}

pairs = [(3, 4, 5)]
dict(generate_lca_from_pairs(G, pairs)) # {(3, 4, 5): 1}

pairs = [(1, 2)]
dict(generate_lca_from_pairs(G, pairs)) # {(1, 2): 1}