Tail-recursive bounded stream of pairs of integers (Scala)?

2.4k views Asked by At

I'm very new to Scala, so forgive my ignorance! I'm trying to iterate of pairs of integers that are bounded by a maximum. For example, if the maximum is 5, then the iteration should return:

(0, 0), (0, 1), ..., (0, 5), (1, 0), ..., (5, 5)

I've chosen to try and tail-recursively return this as a Stream:

    @tailrec
    def _pairs(i: Int, j: Int, maximum: Int): Stream[(Int, Int)] = {
        if (i == maximum && j == maximum) Stream.empty
        else if (j == maximum) (i, j) #:: _pairs(i + 1, 0, maximum)
        else (i, j) #:: _pairs(i, j + 1, maximum)
    }

Without the tailrec annotation the code works:

scala> _pairs(0, 0, 5).take(11)
res16: scala.collection.immutable.Stream[(Int, Int)] = Stream((0,0), ?)

scala> _pairs(0, 0, 5).take(11).toList
res17: List[(Int, Int)] = List((0,0), (0,1), (0,2), (0,3), (0,4), (0,5), (1,0), (1,1), (1,2), (1,3), (1,4))

But this isn't good enough for me. The compiler is correctly pointing out that the last line of _pairs is not returning _pairs:

could not optimize @tailrec annotated method _pairs: it contains a recursive call not in tail position
    else (i, j) #:: _pairs(i, j + 1, maximum)
                ^

So, I have several questions:

  • directly addressing the implementation above, how does one tail-recursively return Stream[(Int, Int)]?
  • taking a step back, what is the most memory-efficient way to iterate over bounded sequences of integers? I don't want to iterate over Range because Range extends IndexedSeq, and I don't want the sequence to exist entirely in memory. Or am I wrong? If I iterate over Range.view do I avoid it coming into memory?

In Python (!), all I want is:

In [6]: def _pairs(maximum):
   ...:     for i in xrange(maximum+1):
   ...:         for j in xrange(maximum+1):
   ...:             yield (i, j)
   ...:             

In [7]: p = _pairs(5)

In [8]: [p.next() for i in xrange(11)]
Out[8]: 
[(0, 0),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 4),
 (0, 5),
 (1, 0),
 (1, 1),
 (1, 2),
 (1, 3),
 (1, 4)]

Thanks for your help! If you think I need to read references / API docs / anything else please tell me, because I'm keen to learn.

2

There are 2 answers

8
Ken Bloom On BEST ANSWER

This is not tail-recursion

Let's suppose you were making a list instead of a stream: (let me use a simpler function to make my point)

def foo(n: Int): List[Int] =
  if (n == 0)
    0 :: Nil
  else
    n :: foo(n - 1)

In the general case in this recursion, after foo(n - 1) returns the function has to do something with the list that it returns -- it has to concatenate another item onto the beginning of the list. So the function can't be tail recursive, becuase something has to be done to the list after the recursion.

Without tail recursion, for some large value of n, you run out of stack space.

The usual list solution

The usual solution would be to pass a ListBuffer as a second parameter, and fill that.

def foo(n: Int) = {
  def fooInternal(n: Int, list: ListBuffer[Int]) = {
    if (n == 0) 
      list.toList
    else {
      list += n
      fooInternal(n - 1, list)
    }
  }
  fooInternal(n, new ListBuffer[Int]())
}

What you're doing is known as "tail recursion modulo cons", and this is an optimization performed automatically by LISP Prolog compilers when they see the tail recursion modulo cons pattern, since it's so common. Scala's compiler does not optimize this automatically.

Streams don't need tail recursion

Streams don't need tail recursion to avoid running out of stack space -- this is becuase they use a clever trick to keep from executing the recursive call to foo at the point where it appears in the code. The function call gets wrapped in a thunk, and only called at the point that you actually try to get the value from the stream. Only one call to foo is active at a time -- it's never recursive.

I've written a previous answer explaining how the #:: operator works here on Stackoverflow. Here's what happens when you call the following recursive stream function. (It is recursive in the mathematical sense, but it doesn't make a function call from within a function call the way you usually expect.)

def foo(n: Int): Stream[Int] =
  if (n == 0)
    0 #:: Nil
  else
    n #:: foo(n - 1)

You call foo(10), it returns a stream with one element computed already, and the tail is a thunk that will call foo(9) the next time you need an element from the stream. foo(9) is not called right now -- rather the call is bound to a lazy val inside the stream, and foo(10) returns immediately. When you finally do need the second value in the stream, foo(9) is called, and it computes one element and sets the tail of hte stream to be a thunk that will call foo(8). foo(9) returns immediately without calling foo(8). And so on...

This allows you to create infinite streams without running out of memory, for example:

def countUp(start: Int): Stream[Int] = start #::countUp(start + 1)

(Be careful what operations you call on this stream. If you try to do a forEach or a map, you'll fill up your whole heap, but using take is a good way to work with an arbitrary prefix of the stream.)

A simpler solution altogether

Instead of dealing with recursion and streams, why not just use Scala's for loop?

def pairs(maximum:Int) =
  for (i <- 0 to maximum;
       j <- 0 to maximum)
    yield (i, j)

This materializes the entire collection in memory, and returns an IndexedSeq[(Int, Int)].

If you need a Stream specifically, you can convert the first range into a Stream.

def pairs(maximum:Int) =
  for (i <- 0 to maximum toStream;
       j <- 0 to maximum)
    yield (i, j)

This will return a Stream[(Int, Int)]. When you access a certain point in the sequence, it will be materialized into memory, and it will stick around as long as you still have a reference to any point in the stream before that element.

You can get even better memory usage by converting both ranges into views.

def pairs(maximum:Int) =
  for (i <- 0 to maximum view;
       j <- 0 to maximum view)
    yield (i, j)

That returns a SeqView[(Int, Int),Seq[_]] that computes each element each time you need it, and doesn't store precomputed results.

You can also get an iterator (which you can only traverse once) the same way

def pairs(maximum:Int) =
  for (i <- 0 to maximum iterator;
       j <- 0 to maximum iterator)
    yield (i, j)

That returns Iterator[(Int, Int)].

1
user unknown On

Maybe an Iterator is better suited for you?

class PairIterator (max: Int) extends Iterator [(Int, Int)] {
  var count = -1
  def hasNext = count <= max * max 
  def next () = { count += 1; (count / max, count % max) }
}

val pi = new PairIterator (5)
pi.take (7).toList