How can I implement a Fisher-Yates shuffle in Scala without side effects?

430 views Asked by At

I want to implement the Fisher-Yates algorithm (an in-place array shuffle) without side effects by using an STArray for the local mutation effects, and a functional random number generator

type RNG[A] = State[Seed,A]

to produce the random integers needed by the algorithm.

I have a method def intInRange(max: Int): RNG[Int] which I can use to produce a random Int in [0,max).

From Wikipedia:

To shuffle an array a of n elements (indices 0..n-1):
    for i from n − 1 downto 1 do
        j ← random integer such that 0 ≤ j ≤ i
        exchange a[j] and a[i]

I suppose I need to stack State with ST somehow, but this is confusing to me. Do I need a [S]StateT[ST[S,?],Seed,A]? Do I have to rewrite RNG to use StateT as well?

(Edit) I don't want to involve IO, and I don't want to substitute Vector for STArray because the shuffle wouldn't be performed in-place.

I know there is a Haskell implementation here, but I'm not currently capable of understanding and porting this to Scalaz. But maybe you can? :)

Thanks in advance.

3

There are 3 answers

1
Apocalisp On BEST ANSWER

Here is a more or less direct translation from the Haskell version you linked that uses a mutable STArray. The Scalaz STArray doesn't have an exact equivalent of the listArray function, so I've made one up. Otherwise, it's a straightforward transliteration:

import scalaz._
import scalaz.effect.{ST, STArray}
import ST._
import State._
import syntax.traverse._
import std.list._

def shuffle[A:Manifest](xs: List[A]): RNG[List[A]] = {
  def newArray[S](n: Int, as: List[A]): ST[S, STArray[S, A]] =
    if (n <= 0) newArr(0, null.asInstanceOf[A])
    else for {
      r <- newArr[S,A](n, as.head)
      _ <- r.fill((_, a: A) => a, as.zipWithIndex.map(_.swap))
    } yield r
  for {
    seed <- get[Seed]
    n = xs.length
    r <- runST(new Forall[({type λ[σ] = ST[σ, RNG[List[A]]]})#λ] {
      def apply[S] = for {
        g <- newVar[S](seed)
        randomRST = (lo: Int, hi: Int) => for {
          p <- g.read.map(intInRange(hi - lo).apply)
          (a, sp) = p
          _ <- g.write(sp)
        } yield a + lo
        ar  <- newArray[S](n, xs)
        xsp <- Range(0, n).toList.traverseU { i => for {
          j  <- randomRST(i, n)
          vi <- ar read i
          vj <- ar read j
          _  <- ar.write(j, vi)
        } yield vj }
        genp <- g.read
      } yield put(genp).map(_ => xsp)
    })
  } yield r
}

Although the asymptotics of using a mutable array might be good, do note that the constant factors of the ST monad in Scala are quite large. You may be better off just doing this in a monolithic block using regular mutable arrays. The overall shuffle function remains pure because all of your mutable state is local.

1
Travis Brown On

You have lots of options. One simple (but not very principled) one would be just to lift both the Rng and ST operations into IO and then work with them together there. Another would be to use both an STRef[Long] and an STArray in the same ST. Another would be to use a State[(Long, Vector[A]), ?].

You could also use a StateT[State[Long, ?], Vector[A], ?] but that would be kind of pointless. You could probably use a StateT (for the RNG state) over an ST (for the array), but again, I don't really see the point.

It's possible to do this pretty cleanly without side effects with just Rng, though. For example, using NICTA's RNG library:

import com.nicta.rng._, scalaz._, Scalaz._

def shuffle[A](xs: Vector[A]): Rng[Vector[A]] =
  (xs.size - 1 to 1 by -1).toVector.traverseU(
    i => Rng.chooseint(0, i).map((i, _))
  ).map {
    _.foldLeft(xs) {
      case ((i, j), v) =>
        val tmp = v(i)
        v.updated(i, v(j)).updated(j, tmp)
    }
  }

Here you just pick all your swap operations in the Rng monad, and then fold over them with your collection as the accumulator, swapping as you go.

0
miguel On

This is amost the same as Travis solution only difference is that it uses the State monad. I wanted to find a minimal set of imports but I finally gave up:

import com.nicta.rng.Rng
import scalaz._
import Scalaz._

object FisherYatesShuffle {

  def randomJ(i: Int): Rng[Int] = Rng.chooseint(0,i)

  type Exchange = (Int,Int)

  def applyExchange[A](exchange: Exchange)(l: Vector[A]): Vector[A] = {
    val (i,j) = exchange
    val vi = l(i)
    l.updated(i,l(j)).updated(j,vi)
  }

  def stApplyExchange[A](exchange: Exchange): State[Vector[A], Unit] = State.modify(applyExchange(exchange))

  def shuffle[A](l: Vector[A]): Rng[Vector[A]] = {
    val rngExchanges: Rng[Vector[Exchange]] = (l.length - 1 to 1 by -1).toVector.traverseU { i =>
      for {
        j <- randomJ(i)
      } yield (i, j)
    }

    for {
      exchanges <- rngExchanges
    } yield exchanges.traverseU(stApplyExchange[A]).exec(l)
  }

}