How can I refactor this without IORefs?

370 views Asked by At

How could I refactor this so that eventually IORefs would not be necessary?

inc :: IORef Int -> IO ()
inc ref = modifyIORef ref (+1)

main = withSocketsDo $ do
        s <- socket AF_INET Datagram defaultProtocol
        c <- newIORef 0
        f <- newIORef 0
        hostAddr <- inet_addr host
        time $ forM [0 .. 10000] $ \i -> do
              sendAllTo s (B.pack  "ping") (SockAddrInet port hostAddr)
              (r, _) <- recvFrom s 1024 
              if (B.unpack r) == "PING" then (inc c) else (inc f)
        c' <- readIORef c
        print (c')
        sClose s
        return()
2

There are 2 answers

1
Dan Burton On BEST ANSWER

What's wrong with using IORefs here? You're in IO anyways with the networking operations. IORefs aren't always the cleanest solution, but they seem to do the job well in this case.

Regardless, for the sake of answering the question, let's remove the IORefs. These references serve as a way of keeping state, so we'll have to come up with an alternate way to keep the stateful information.

The pseudocode for what we want to do is this:

open the connection
10000 times: 
  send a message
  receive the response
  (keep track of how many responses are the message "PING")
print how many responses were the message "PING"

The chunk that is indented under 1000 times can be abstracted into its own function. If we are to avoid IORefs, then this function will have to take in a previous state and produce a next state.

main = withSocketsDo $ do
  s <- socket AF_INET Datagram defaultProtocol
  hostAddr <- inet_addr host
  let sendMsg = sendAllTo s (B.pack  "ping") (SockAddrInet port hostAddr)
      recvMsg = fst `fmap` recvFrom s 1024
  (c,f) <- ???
  print c
  sClose s

So the question is this: what do we put at the ??? place? We need to define some way to "perform" an IO action, take its result, and modify state with that result somehow. We also need to know how many times to do it.

 performRepeatedlyWithState :: a             -- some state
                            -> IO b          -- some IO action that yields a value
                            -> (a -> b -> a) -- some way to produce a new state
                            -> Int           -- how many times to do it
                            -> IO a          -- the resultant state, as an IO action
 performRepeatedlyWithState s _ _ 0 = return s
 performRepeatedlyWithState someState someAction produceNewState timesToDoIt = do
   actionresult <- someAction
   let newState = produceNewState someState actionResult
   doWithState newState someAction produceNewState (pred timesToDoIt)

All I did here was write down the type signature that matched what I said above, and produced the relatively obvious implementation. I gave everything a very verbose name to hopefully make it apparent exactly what this function means. Equipped with this simple function, we just need to use it.

let origState = (0,0)
    action = ???
    mkNewState = ???
    times = 10000
(c,f) <- performRepeatedlyWithState origState action mkNewState times

I've filled in the easy parameters here. The original state is (c,f) = (0,0), and we want to perform this 10000 times. (Or is it 10001?) But what should action and mkNewState look like? The action should have type IO b; it's some IO action that produces something.

action = sendMsg >> recvMsg

I bound sendMsg and recvMsg to expressions from your code earlier. The action we want to perform is to send a message, and then receive a message. The value this action produces is the message received.

Now, what should mkNewState look like? It should have the type a -> b -> a, where a is the type of the State, and b is the type of the action result.

mkNewState (c,f) val = if (B.unpack val) == "PING"
                         then (succ c, f)
                         else (c, succ f)

This isn't the cleanest solution, but do you get the general idea? You can replace IORefs by writing a function that recursively calls itself, passing extra parameters along in order to keep track of state. The exact same idea is embodied in the foldM solution suggested on the similar question.

Bang patterns, as Nathan Howell suggests, would be wise, to avoid building up a large thunk of succ (succ (succ ...))) in your state:

mkNewState (!c, !f) val = ...
6
Nathan Howell On

Building on the earlier comment regarding a stack overflow.

The accumulators 'f' and 'c' in either the IORef or foldM case need to be evaluated to prevent a long chain of thunks from being allocated while you're iterating. One way of forcing evaluation of the thunks is to use a bang pattern. This tells the compiler to evaluate the value, removing the thunk, even though it's value is not demanded in the function.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

import Control.Concurrent
import Control.Monad
import Data.ByteString.Char8
import Data.Foldable (foldlM)
import Data.IORef
import Network.Socket hiding (recvFrom)
import Network.Socket.ByteString (recvFrom, sendAllTo)

main = withSocketsDo $ do
  let host = "127.0.0.1"
      port= 9898

  s <- socket AF_INET Datagram defaultProtocol
  hostAddr <- inet_addr host
  -- explicitly mark both accumulators as strict using bang patterns
  let step (!c, !f) i = do
        sendAllTo s "PING" (SockAddrInet port hostAddr)
        (r, _) <- recvFrom s 1024 
        return $ case r of
          -- because c and f are never used, the addition operator below
          -- builds a thunk chain. these can lead to a stack overflow
          -- when the chain is being evalulated by the 'print c' call below.
          "PING" -> (c+1, f)
          _      -> (c, f+1)
  (c, f) <- foldlM step (0, 0) [0..10000]
  print c
  sClose s
  return ()