Implementing factorial and fibonacci using State monad (as a learning exercise)

1.3k views Asked by At

I worked my way through Mike Vanier's monad tutorial (which is excellent) and I'm working on a few of the exercises in his post on how to use a "State" monad.

In particular, he suggests an exercise which consists of writing functions for factorial and fibonacci using a State monad. I gave it a shot and came up with the answers below. (I find do notation pretty confusing, hence my choice of syntax).

Neither of my implementations look particularly "Haskell-y" and, in the interest of not internalizing bad practices, I thought I'd ask folks for input on how they would've gone about implementing these functions (using the state monad). Is it possibly to write this code far more simply (aside from switching to do notation)? I strongly suspect this is the case.


I'm aware that it's a bit impractical to use a state monad for this purpose but this is purely a learning exercise - pun most certainly intended.

That said, the performance is not that much worse: in order to calc the factorial of 100000 (the answer is ~21k digits long), the unfoldr version took ~1.2 sec (in GHCi) vs. ~1.5 sec for the state monad version.

import Control.Monad.State (State, get, put, evalState)
import Data.List (unfoldr)

fibonacci :: Integer -> Integer
fibonacci 0 = 0
fibonacci n = evalState fib_state (1,0,1,n)

fib_state :: State (Integer,Integer,Integer,Integer) Integer
fib_state = get >>=
            \s ->
              let (p1,p2,ctr,n) = s
              in case compare ctr n of
                   LT -> put (p1+p2, p1, ctr+1, n) >> fib_state
                   _  -> return p1

factorial :: Integer -> Integer
factorial n = evalState fact_state (n,1)

fact_state :: State (Integer,Integer) Integer
fact_state = get >>=
             \s -> 
               let (n,f) = s 
               in case n of
                      0 -> return f
                      _ -> put (n-1,f*n) >> fact_state

-------------------------------------------------------------------
--Functions below are used only to test output of functions above

factorial' :: Integer -> Integer
factorial' n = product [1..n]

fibonacci' :: Int -> Integer
fibonacci' 0 = 1
fibonacci' 1 = 1
fibonacci' n =  
  let getFst (a,b,c) = a
  in  getFst
    $ last 
    $ unfoldr (\(p1,p2,cnt) -> 
               if cnt == n
                  then Nothing
                  else Just ((p1,p2,cnt)
                            ,(p1+p2,p1,cnt+1))
              ) (1,1,1) 
2

There are 2 answers

0
Mokosha On BEST ANSWER

Your functions seem to be a bit more complicated than they need to be, but you have the right idea. For the factorial, all you need to keep track of is the current number you're multiplying by and the number that you've accumulated so far. So, we'll say that State Int Int is a computation that operates on the current number on the state and returns the number that you've multiplied up until now:

fact_state :: State Int Int
fact_state = get >>= \x -> if x <= 1
                           then return 1
                           else (put (x - 1) >> fmap (*x) fact_state)

factorial :: Int -> Int
factorial = evalState fact_state

Prelude Control.Monad.State.Strict Control.Applicative> factorial <$> [1..10]
[1,2,6,24,120,720,5040,40320,362880,3628800]

The fibonacci sequence is similar. You need to keep the last two numbers in order to know what you're going to be adding together, and how far you've gone so far:

fibs_state :: State (Int, Int, Int) Int
fibs_state = get >>= \(x1, x2, n) -> if n == 0
                                     then return x1
                                     else (put (x2, x1+x2, n-1) >> fibs_state)

fibonacci n = evalState fibs_state (0, 1, n)

Prelude Control.Monad.State.Strict Control.Applicative> fibonacci <$> [1..10]
[1, 1, 2, 3, 5, 8, 13, 21, 34, 55]
0
ErikR On

Two stylistic suggestions:

        \s ->
          let (p1,p2,ctr,n) = s
          in ...

is equivalent to:

        \(p1,p2,ctr,n) -> ...

and your case statement for fib_state may be written with an if statement:

        if ctr < n
          then put (p1+p2, p1, ctr+1, n) >> fib_state
          else return p1