Tail recursive fold on a binary tree in Scala

1.9k views Asked by At

I am trying to find a tail recursive fold function for a binary tree. Given the following definitions:

// From the book "Functional Programming in Scala", page 45
sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

Implementing a non tail recursive function is quite straightforward:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B =
  t match {
    case Leaf(v)      => map(v)
    case Branch(l, r) => 
      red(fold(l)(map)(red), fold(r)(map)(red))
  }

But now I am struggling to find a tail recursive fold function so that the annotation @annotation.tailrec can be used.

During my research I have found several examples where tail recursive functions on a tree can e.g. compute the sum of all leafs using an own stack which is then basically a List[Tree[Int]]. But as far as I understand in this case it only works for the additions because it is not important whether you first evaluate the left or the right hand side of the operator. But for a generalised fold it is quite relevant. To show my intension here are some example trees:

val leafs = Branch(Leaf(1), Leaf(2))
val left = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
val right = Branch(Leaf(1), Branch(Leaf(2), Leaf(3)))
val bal = Branch(Branch(Leaf(1), Leaf(2)), Branch(Leaf(3), Leaf(4)))
val cmb = Branch(right, Branch(bal, Branch(leafs, left)))
val trees = List(leafs, left, right, bal, cmb)

Based on those trees I want to create a deep copy with the given fold method like:

val oldNewPairs = 
  trees.map(t => (t, fold(t)(Leaf(_): Tree[Int])(Branch(_, _))))

And then proof that the condition of equality holds for all created copies:

val conditionHolds = oldNewPairs.forall(p => {
  if (p._1 == p._2) true
  else {
    println(s"Original:\n${p._1}\nNew:\n${p._2}")
    false
  }
})
println("Condition holds: " + conditionHolds)

Could someone give me some pointers, please?

You can find the code used in this question at ScalaFiddle: https://scalafiddle.io/sf/eSKJyp2/15

1

There are 1 answers

6
Pablo Francisco Pérez Hidalgo On BEST ANSWER

You could reach a tail recursive solution if you stop using the function call stack and start using a stack managed by your code and an accumulator:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {

  case object BranchStub extends Tree[Nothing]

  @tailrec
  def foldImp(toVisit: List[Tree[A]], acc: Vector[B]): Vector[B] =
    if(toVisit.isEmpty) acc
    else {
      toVisit.head match {
        case Leaf(v) =>
          val leafRes = map(v)
          foldImp(
            toVisit.tail,
            acc :+ leafRes
          )
        case Branch(l, r) =>
          foldImp(l :: r :: BranchStub :: toVisit.tail, acc)
        case BranchStub =>
          foldImp(toVisit.tail, acc.dropRight(2) ++   Vector(acc.takeRight(2).reduce(red)))
      }
    }

  foldImp(t::Nil, Vector.empty).head

}

The idea is to accumulate values from left to right, keep track of the parenthood relation by the introduction of a stub node and reduce the result using your red function using the last two elements of the accumulator whenever a stub node is found in the exploration.

This solution could be optimized but it is already a tail recursive function implementation.

EDIT:

It can be slightly simplified by changing the accumulator data structure to a list seen as a stack:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {

  case object BranchStub extends Tree[Nothing]

  @tailrec
  def foldImp(toVisit: List[Tree[A]], acc: List[B]): List[B] =
    if(toVisit.isEmpty) acc
    else {
      toVisit.head match {
        case Leaf(v) =>
          foldImp(
            toVisit.tail,
            map(v)::acc 
          )
        case Branch(l, r) =>
          foldImp(r :: l :: BranchStub :: toVisit.tail, acc)
        case BranchStub =>
          foldImp(toVisit.tail, acc.take(2).reduce(red) :: acc.drop(2))
      }
    }

  foldImp(t::Nil, Nil).head

}