Scala: Why this function is not tail recursive?

217 views Asked by At

I have such implementation of Merge Sort:

import scala.annotation.tailrec

object MergeSort {
  def sortBy[T]: ((T, T) => Int) => Seq[T] => Seq[T] = comparator => seqToSort => {
    @tailrec
    def merge(xs : Seq[T], ys : Seq[T], accum : Seq[T] = Seq()) : Seq[T] = (xs, ys) match {
      case (Seq(), _) => ys ++ accum
      case (_, Seq()) => xs ++ accum
      case (x::rx, y::ry) =>
        if(comparator(x, y) < 0)
          merge(xs, ry, y +: accum)
        else
          merge(rx, ys, x +: accum)
    }

    @tailrec
    // Problem with this function
    def step : Seq[Seq[T]] => Seq[T] = {
      case Seq(xs) => xs
      case xss =>
        val afterStep = xss.grouped(2).map({
          case Seq(xs) => xs
          case Seq(xs, ys) => merge(xs, ys)
        }).toSeq
        // Error here
        step(afterStep)
    }

    step(seqToSort.map(Seq(_)))
  }
}

It does not compile. It says that recursive call in step function is not in tail position. But it IS in tail position. Is there any way to fix it without trampoline?

1

There are 1 answers

0
adamwy On BEST ANSWER

The reason for that, is that step is a function that returns a function of signature: Seq[Seq[T]] => Seq[T]. So the recursive call doesn't call the same method directly, but obtains this function first and then calls it for given argument, which is not tail recursive.

To solve this error you must declare step this way:

@tailrec
def step(seq: Seq[Seq[T]]): Seq[T] = seq match {
  ...
}