Tarjan algorithm - Python to scala

708 views Asked by At

I'm trying to translate the recursive python code for tarjan algorithm to scala and especially this part :

def tarjan_recursive(g):
        S = []
        S_set = set()
        index = {}
        lowlink = {}
        ret = []
        def visit(v):
                index[v] = len(index)
                lowlink[v] = index[v]
                S.append(v)
                S_set.add(v)

                for w in g.get(v,()):
                        print(w)
                        if w not in index:
                                visit(w)
                                lowlink[v] = min(lowlink[w], lowlink[v])
                        elif w in S_set:
                                lowlink[v] = min(lowlink[v], index[w])
                if lowlink[v] == index[v]:
                        scc = []
                        w = None
                        while v != w:
                                w = S.pop()
                                scc.append(w)
                                S_set.remove(w)
                        ret.append(scc)

        for v in g:
                print(index)
                if not v in index:
                        visit(v)
        return ret

I know that there's tarjan algorithm in scala here or here but it doesn't return good result and translate it from python help me understand it.

Here's what I have :

def tj_recursive(g: Map[Int,List[Int]])= {
        var s : mutable.ListBuffer[Int] = new mutable.ListBuffer()
        var s_set : mutable.Set[Int] = mutable.Set()
        var index : mutable.Map[Int,Int] =  mutable.Map()
        var lowlink : mutable.Map[Int,Int]=  mutable.Map()
        var ret : mutable.Map[Int,mutable.ListBuffer[Int]]= mutable.Map()

        def visit(v: Int):Int = {
                 index(v) = index.size
               lowlink(v) = index(v)
               var zz :List[Int]= gg.get(v).toList(0)
                            for( w <- zz) {
                  if( !(index.contains(w)) ){
                     visit(w)
                     lowlink(v) = List(lowlink(w),lowlink(v)).min
                   }else if(s_set.contains(w)){
                     lowlink(v)=List(lowlink(v),index(w)).min
                   }
               }

               if(lowlink(v)==index(v)){
                  var scc:mutable.ListBuffer[Int] = new mutable.ListBuffer()
                  var w:Int=null.asInstanceOf[Int]
                  while(v!=w){
                    w= s.last
                    scc+=w
                    s_set-=w
                  }
           ret+=scc
        }
        }

   for( v <- g) {if( !(index.contains(v)) ){visit(v)}}
   ret
}

I know this isn't the scala way at all (and not clean ...) but I'm planning to slowly change it to a more functional style when I get the first version working.

For now, I got this error :

type mismatch;  found   : Unit  required: Int

at this line

if(lowlink(v)==index(v)){ 

I think it's coming from this line but I'm not sure :

if( !(index.contains(w)) 

But it's really hard to debug it since I can't just println my mistakes ...

Thanks !

3

There are 3 answers

1
Travis Brown On BEST ANSWER

Here's a fairly literal translation of the Python:

def tj_recursive(g: Map[Int, List[Int]])= {
  val s = mutable.Buffer.empty[Int]
  val s_set = mutable.Set.empty[Int]
  val index = mutable.Map.empty[Int, Int]
  val lowlink = mutable.Map.empty[Int, Int]
  val ret = mutable.Buffer.empty[mutable.Buffer[Int]]

  def visit(v: Int): Unit = {
    index(v) = index.size
    lowlink(v) = index(v)
    s += v
    s_set += v

    for (w <- g(v)) {
      if (!index.contains(w)) {
        visit(w)
        lowlink(v) = math.min(lowlink(w), lowlink(v))
      } else if (s_set(w)) {
        lowlink(v) = math.min(lowlink(v), index(w))
      }
    }

    if (lowlink(v) == index(v)) {
      val scc = mutable.Buffer.empty[Int]
      var w = -1

      while(v != w) {
        w = s.remove(s.size - 1)
        scc += w
        s_set -= w
      }

      ret += scc
    }
  }

  for (v <- g.keys) if (!index.contains(v)) visit(v)
  ret
}

It produces the same output on e.g.:

tj_recursive(Map(
  1 -> List(2),    2 -> List(1, 5), 3 -> List(4),
  4 -> List(3, 5), 5 -> List(6),    6 -> List(7),
  7 -> List(8),    8 -> List(6, 9), 9 -> Nil
))

The biggest problem with your implementation was the return type of visit (which should have been Unit, not Int) and the fact that you were iterating over the graph's items instead of the graph's keys in the final for-comprehension, but I've made a number of other edits for style and clarity (while still keeping the basic shape).

0
SimonThr On

I know this post is old, but I have lately been working with the implementation of Tarjans algorithm in Scala. In the implementation of the code, I was looking at this post and it occurred to me, that it could be done in a simpler way:

case class Edge[A](from: A, to: Set[A])

class TarjanGraph[A](src: Iterable[Edge[A]]) {
  lazy val trajan: mutable.Buffer[mutable.Buffer[A]] = {
    var s = mutable.Buffer.empty[A] //Stack to keep track of nodes reachable from current node
    val index = mutable.Map.empty[A, Int] //index of each node
    val lowLink = mutable.Map.empty[A, Int] //The smallest index reachable from the node
    val ret = mutable.Buffer.empty[mutable.Buffer[A]] //Keep track of SCC in graph
def visit(v: A): Unit = {
  //Set index and lowlink of node on first visit
  index(v) = index.size
  lowLink(v) = index(v)
  //Add to stack
  s += v
  if (src.exists(_.from == v)) {
    for (w <- src.find(e => e.from == v).head.to) {
      if (!index.contains(w)) { //Node is not explored yet
        //Perform DFS from node W
        visit(w)
        //Update the lowlink value of v so it has the value of the lowest node reachable from itself and from node w
        lowLink(v) = math.min(lowLink(w), lowLink(v))
      } else if (s.contains(w)) {
        // Node w is on the stack meaning - it means there is a path from w to v
        // and since node w is a neighbor to node v there is also a path from v to w
        lowLink(v) = math.min(lowLink(v), index(w))
      }
    }
  }
  //The lowlink value haven't been updated meaning it is the root of a cycle/SCC
  if (lowLink(v) == index(v)) {
    //Add the elements to the cycle that has been added to the stack and whose lowlink has been updated by node v's lowlink
    //This is the elements on the stack that is placed behind v
    val n = s.length - s.indexOf(v)
    ret += s.takeRight(n)
    //Remove these elements from the stack
    s.dropRightInPlace(n)
  }
}

//Perform a DFS from  all no nodes that hasn't been explored
src.foreach(v => if (!index.contains(v.from)) visit(v.from))
ret
  }

  // A cycle exist if there is a SCC with at least two components
  lazy val hasCycle: Boolean = trajan.exists(_.size >= 2)
  lazy val trajanCycle: Iterable[Seq[A]] = trajan.filter(_.size >= 2).distinct.map(_.toSeq).toSeq
  lazy val topologicalSortedEdges: Seq[Edge[A]] =
    if (hasCycle) Seq[Edge[A]]()
    else trajan.flatten.reverse.flatMap(x => src.find(_.from == x)).toSeq
}
0
Eduardo On

Here is an iterative version. It is a translation from the recursive version of the algorithm in Wikipedia.

case class Arc[A](from:A, to:A)

class SparseDG[A](src: Iterable[Arc[A]]) {

  val verts = (src.map(_.from) ++ src.map(_.to)).toSet.toIndexedSeq
  val qVert = verts.size
  val vertMap = verts.zipWithIndex.toMap
  val indexedSrc = src.map{ arc => Arc(vertMap(arc.from), vertMap(arc.to)) }

  val exit  = (0 until qVert)
              .map(v => indexedSrc.filter(_.from == v).map(_.to).toIndexedSeq)


  lazy val tarjan_iterative: Seq[Seq[A]] = {
    trait Step
    case object SetDepth           extends Step
    case object ConsiderSuccessors extends Step
    case object CalcLowlink        extends Step
    case object PopIfRoot          extends Step
    case class StackFrame(v:Int, next:Step)

    val result = Buffer[Seq[A]]()
    val index = new Array[Int](qVert).map(_ => -1)   // -1 = undefined
    val lowlink = new Array[Int](qVert).map(_ => -1) // -1 = undefined
    val wIndex = new Array[Int](qVert)               // used to iterate w nodes
    var _index = 0
    val s = Stack[Int]()
    val isRemoved = BitSet()
    val strongconnect = Stack[StackFrame]()

    (0 until qVert).foreach { v_idx =>
      if(index(v_idx) == -1) {
        strongconnect.push(StackFrame(v_idx, SetDepth))
        while(!strongconnect.isEmpty) {
          val StackFrame(v, step) = strongconnect.pop()
          step match {
            case SetDepth =>
              index(v) = _index
              lowlink(v) = _index
              _index += 1
              s.push(v)
              isRemoved.remove(v)
              strongconnect.push(StackFrame(v, ConsiderSuccessors))

            case ConsiderSuccessors =>
              if(wIndex(v) < exit(v).size){
                val w = exit(v)(wIndex(v))
                if(index(w) == -1){
                  strongconnect.push(StackFrame(v, CalcLowlink))
                  strongconnect.push(StackFrame(w, SetDepth))
                }
                else{
                  if(!isRemoved.contains(w)){
                    if(lowlink(v) > lowlink(w)) lowlink(v) = index(w)
                  }
                  wIndex(v) += 1
                  strongconnect.push(StackFrame(v, ConsiderSuccessors))
                }
              }
              else{
                strongconnect.push(StackFrame(v, PopIfRoot))
              }

            case CalcLowlink =>
              val w = exit(v)(wIndex(v))
              if(lowlink(v) > lowlink(w)) lowlink(v) = lowlink(w)
              wIndex(v) += 1
              strongconnect.push(StackFrame(v, ConsiderSuccessors))

            case PopIfRoot =>
              if(index(v) == lowlink(v)){
                val buf = Buffer[A]()
                var w = 0
                do{
                  w = s.pop()
                  isRemoved += w
                  buf += verts(w)
                }
                while(w != v)
                result += buf.toSeq
              }
          }
        }
      }
    }
    result.toSeq
  }

  lazy val hasCycle = tarjan_iterative.find(_.size >= 2).isDefined

  lazy val topologicalSort =
    if(hasCycle) None
    else         Some(tarjan_iterative.flatten.reverse)

}

Running the example graph in the Wikipedia article:

val g = new SparseDG(Seq(
        Arc("1","2"),
        Arc("2","3"),
        Arc("3","1"),
        Arc("4","2"),
        Arc("4","3"),
        Arc("6","3"),
        Arc("6","7"),
        Arc("7","6"),
        Arc("4","5"),
        Arc("5","4"),
        Arc("5","6"),
        Arc("8","5"),
        Arc("8","8"),
        Arc("8","7")
      ))

g.tarjan_iterative

returns:

ArrayBuffer(ArrayBuffer(1, 3, 2), ArrayBuffer(7, 6), ArrayBuffer(4, 5), ArrayBuffer(8))