Scala: Best way to filter & map in one iteration

4.7k views Asked by At

I'm new to Scala and trying to figure out the best way to filter & map a collection. Here's a toy example to explain my problem.

Approach 1: This is pretty bad since I'm iterating through the list twice and calculating the same value in each iteration.

val N = 5
val nums = 0 until 10
val sqNumsLargerThanN = nums filter { x: Int => (x * x) > N } map { x: Int => (x * x).toString }

Approach 2: This is slightly better but I still need to calculate (x * x) twice.

val N = 5
val nums = 0 until 10
val sqNumsLargerThanN = nums collect { case x: Int if (x * x) > N => (x * x).toString }

So, is it possible to calculate this without iterating through the collection twice and avoid repeating the same calculations?

9

There are 9 answers

3
Gavin Schulz On

You can use collect which applies a partial function to every value of the collection that it's defined at. Your example could be rewritten as follows:

val sqNumsLargerThanN = nums collect {
    case (x: Int) if (x * x) > N => (x * x).toString
}
2
Rex Kerr On

The typical approach is to use an iterator (if possible) or view (if iterator won't work). This doesn't exactly avoid two traversals, but it does avoid creation of a full-sized intermediate collection. You then map first and filter afterwards and then map again if needed:

xs.iterator.map(x => x*x).filter(_ > N).map(_.toString)

The advantage of this approach is that it's really easy to read and, since there are no intermediate collections, it's reasonably efficient.

If you are asking because this is a performance bottleneck, then the answer is usually to write a tail-recursive function or use the old-style while loop method. For instance, in your case

def sumSqBigN(xs: Array[Int], N: Int): Array[String] = {
  val ysb = Array.newBuilder[String]
  def inner(start: Int): Array[String] = {
    if (start >= xs.length) ysb.result
    else {
      val sq = xs(start) * xs(start)
      if (sq > N) ysb += sq.toString
      inner(start + 1)
    }
  }
  inner(0)
}

You can also pass a parameter forward in inner instead of using an external builder (especially useful for sums).

1
adelbertc On

Could use a foldRight

nums.foldRight(List.empty[Int]) {
  case (i, is) =>
    val s = i * i
    if (s > N) s :: is else is
  }

A foldLeft would also achieve a similar goal, but the resulting list would be in reverse order (due to the associativity of foldLeft.

Alternatively if you'd like to play with Scalaz

import scalaz.std.list._
import scalaz.syntax.foldable._

nums.foldMap { i =>
  val s = i * i
  if (s > N) List(s) else List()
}
2
marios On

A very simple approach that only does the multiplication operation once. It's also lazy, so it will be executing code only when needed.

nums.view.map(x=>x*x).withFilter(x => x> N).map(_.toString)

Take a look here for differences between filter and withFilter.

5
triggerNZ On

I have yet to confirm that this is truly a single pass, but:

  val sqNumsLargerThanN = nums flatMap { x =>
    val square = x * x
    if (square > N) Some(x) else None
  }
0
Ramón J Romero y Vigil On

Using a for comprehension would work:

val sqNumsLargerThanN = for {x <- nums if x*x > N } yield (x*x).toString

Also, I'm not sure but I think the scala compiler is smart about a filter before a map and will only do 1 pass if possible.

0
Paul Draper On

You can use flatMap.

val sqNumsLargerThanN = nums flatMap { x =>
  val square = x * x
  if (square > N) Some(square.toString) else None
}

Or with Scalaz,

import scalaz.Scalaz._

val sqNumsLargerThanN = nums flatMap { x =>
  val square = x * x
  (square > N).option(square.toString)
}

The solves the asked question of how to do this with one iteration. This can be useful when streaming data, like with an Iterator.

However...if you are instead wanting the absolute fastest implementation, this is not it. In fact, I suspect you would use a mutable ArrayList and a while loop. But only after profiling would you know for sure. In any case, that's for another question.

0
elm On

Consider this for comprehension,

  for (x <- 0 until 10; v = x*x if v > N) yield v.toString

which unfolds to a flatMap over the range and a (lazy) withFilter onto the once only calculated square, and yields a collection with filtered results. To note one iteration and one calculation of square is required (in addition to creating the range).

0
gauri On

I am also beginner did it as follows

 for(y<-(num.map(x=>x*x)) if y>5 ) { println(y)}