Duplicate EDSL code generated with RecursiveDo in Haskell

107 views Asked by At

Using ghc v8.0.1, compiled with -O2 flag.

I'm having a problem with RecursiveDo (mdo). There are two slightly different functions that should produce the same output, but they don't.

The following function produces the correct output:

proc2 :: Assembler ()
proc2 = mdo
    set (R 0) (I 0x5a5a)
    let r = (R 0)
    let bits = (I 2)
    let count = (R 70)
    set count bits
    _loop <- label
    cmp count (I 0)
    je _end
    add r r
    sub count (I 1)
    jmp _loop
    _end <- label
    end

The correct output is

0000:> SET (R 0) (I 23130)
0001:  SET (R 70) (I 2)
0002:  CMP (R 70) (I 0)
0003:  JE (A 7)
0004:  ADD (R 0) (R 0)
0005:  SUB (R 70) (I 1)
0006:  JMP (A 2)
0007:  END

The following function produces the incorrect output:

proc1 :: Assembler ()
proc1 = mdo
    set (R 0) (I 0x5a5a)
    shl (R 0) (I 1)
    end

shl :: (MonadFix m, Instructions m) => Operand -> Operand -> m ()
shl r@(R _) bits = mdo
    let count = (R 70)
    set count bits
    repeatN count $ mdo
        add r r     -- shift left by one
shl _ _ = undefined

repeatN :: (MonadFix m, Instructions m) => Operand -> m a -> m a
repeatN n@(R _) body = mdo
    _loop <- label
    cmp n (I 0)
    je _end
    retval <- body
    sub n (I 1)
    jmp _loop
    _end <- label
    return retval
repeatN _ _ = undefined

The incorrect output is

0000:> SET (R 0) (I 23130)
0001:  SET (R 70) (I 1)
0002:  CMP (R 70) (I 0)

0003:  JE (A 7)
0004:  ADD (R 0) (R 0)
0005:  SUB (R 70) (I 1)
0006:  JMP (A 2)

0007:  JE (A 7)
0008:  ADD (R 0) (R 0)
0009:  SUB (R 70) (I 1)
000A:  JMP (A 2)

000B:  END

The lines from 0007 to 000A are the duplicates of the lines from 0003 to 0006 and (in this particular case) the end result is an infinite loop at 0007.

The code in question implements an EDSL (assembler for the Ting Pen) in Haskell. The output of the program is a machine code for Ting Pen.

I'm using MonadFix to be able to capture forward labels in the assembly language and it is when I use some code combinators that I'm getting incorrect output (some generated code gets duplicated). I have included some tracing code and am able to trace the code generation. There is a point when RecursiveDo mechanism does something that generates duplicate code (see also the output of the program provided below).

{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE MagicHash #-}

module TingBugChase1 where

import Data.Word (Word16)

import Control.Applicative (Applicative(..))
import Control.Monad (liftM, ap, return)
import Control.Monad.Fix (MonadFix(..))
import Text.Printf (printf)


i# :: (Integral a, Num b) => a -> b
i# = fromIntegral

-- ================================================================= Assembler

data Instruction = END
    | CLEARVER
    | SET Operand Operand
    | CMP Operand Operand
    | AND Operand Operand
    | OR Operand Operand
    | NOT Operand
    | JMP Operand
    | JE Operand
    | JNE Operand
    | JG Operand
    | JGE Operand
    | JB Operand
    | JBE Operand
    | ADD Operand Operand
    | SUB Operand Operand
    | RETURN
    | CALLID Operand
    | PLAYOID Operand
    | PAUSE Operand
    {- … -}
    deriving Show


data AsmState = AsmState
    { _code :: [Instruction]
    , _location :: Location
    , _codeHistory :: [([Instruction],[Instruction])]
    }

disasmCode :: [Instruction] -> Int -> [String]
disasmCode [] _ = ["[]"]
disasmCode code pc = map disasm1 $ zip [0..] code
    where
        disasm1 :: (Int, Instruction) -> String
        disasm1 (addr, instr) = printf "%04X:%s %s" addr (pointer addr) (show instr)
        pointer :: Int -> String
        pointer addr = if addr == i# pc then ">" else " "

instance Show AsmState where
    show (AsmState {..}) = "AsmState {" ++
        unlines
        [ "Code:\n" ++ unlines (disasmCode _code 0)
        , "Location: " ++ (show _location)
        , "History:\n" ++ unlines (map disasmHistory _codeHistory)
        ] ++ "}"
        where
            disasmHistory (a,b) = 
                unlines $
                    disasmCode a 0
                    ++ ["++"] ++
                    disasmCode b 0


data Assembler a = Assembler { runAsm :: AsmState -> (a, AsmState) }

-- https://wiki.haskell.org/Functor-Applicative-Monad_Proposal
-- Monad (Assembler w)
instance Functor Assembler where
    fmap = liftM

instance Applicative Assembler where
    {- move the definition of `return` from the `Monad` instance here -}
    pure a = Assembler $ \s -> (a,s)
    (<*>) = ap

instance Monad Assembler where
    return = pure -- redundant since GHC 7.10 due to default impl
    x >>= fy = Assembler $ \s -> 
            let 
                (a, sA) = runAsm x s
                (b, sB) = runAsm (fy a) sA
            in (b, 
                sB 
                { _code = _code sA ++ _code sB
                , _location = _location sB
                , _codeHistory = _codeHistory sB ++ [(_code sA, _code sB)]
                })

instance MonadFix Assembler where
    mfix f = Assembler $ \s -> 
        let (a, sA) = runAsm (f a) s 
        in (a, sA)

{- Append the list of instructions to the code stream. -}
append :: [Instruction] -> Assembler ()
append xs = Assembler $ \s -> 
    ((), s { _code = xs, _location = newLoc $ _location s })
    where
        newLoc (A loc) = A $ loc + (i# . length $ xs)
        newLoc _ = undefined

-- ========================================================= Instructions

data Operand = 
    R Word16    -- registers
    | I Word16  -- immediate value (integer)
    | A Word16  -- address (location)
    deriving (Eq, Show)

type Location = Operand

-- Instructions
class Instructions m where
    end :: m ()
    clearver :: m ()
    set :: Operand -> Operand -> m ()
    cmp :: Operand -> Operand -> m ()
    and :: Operand -> Operand -> m ()
    or :: Operand -> Operand -> m ()
    not :: Operand -> m ()
    jmp :: Location -> m ()
    je :: Location -> m ()
    jne :: Location -> m ()
    jg :: Location -> m ()
    jge :: Location -> m ()
    jb :: Location -> m ()
    jbe :: Location -> m ()
    add :: Operand -> Operand -> m ()
    sub :: Operand -> Operand -> m ()
    ret :: m ()
    callid :: Operand -> m ()
    playoid :: Operand -> m ()
    pause :: Operand -> m ()

    label :: m Location


{- Code combinators -}
repeatN :: (MonadFix m, Instructions m) => Operand -> m a -> m a
repeatN n@(R _) body = mdo
    _loop <- label
    cmp n (I 0)
    je _end
    retval <- body
    sub n (I 1)
    jmp _loop
    _end <- label
    return retval
repeatN _ _ = undefined

{- 
    Derived (non-native) instructions, aka macros 
    Scratch registers r70..r79
-}
shl :: (MonadFix m, Instructions m) => Operand -> Operand -> m ()
shl r@(R _) bits = mdo
    -- allocate registers
    let count = (R 70)

    set count bits
    repeatN count $ mdo
        add r r     -- shift left by one
shl _ _ = undefined


instance Instructions Assembler where 
    end = append [END]
    clearver = append [CLEARVER]
    set op1 op2 = append [SET op1 op2]
    cmp op1 op2 = append [CMP op1 op2]
    and op1 op2 = append [AND op1 op2]
    or op1 op2 = append [OR op1 op2]
    not op1 = append [NOT op1]

    jmp op1 = append [JMP op1]
    je op1 = append [JE op1]
    jne op1 = append [JNE op1]
    jg op1 = append [JG op1]
    jge op1 = append [JGE op1]
    jb op1 = append [JB op1]
    jbe op1 = append [JBE op1]

    add op1 op2 = append [ADD op1 op2]
    sub op1 op2 = append [SUB op1 op2]

    ret = append [RETURN]
    callid op1 = append [CALLID op1]
    playoid op1 = append [PLAYOID op1]
    pause op1 = append [PAUSE op1]

    {- The label function returns the current index of the output stream. -}
    label = Assembler $ \s -> (_location s, s { _code = [] })

-- ========================================================= Tests

asm :: Assembler () -> AsmState
asm proc = snd . runAsm proc $ AsmState 
            { _code = []
            , _location = A 0
            , _codeHistory = [] 
            }

doTest :: Assembler () -> String -> IO ()
doTest proc testName = do
    let ass = asm proc
    putStrLn testName
    putStrLn $ show ass

proc1 :: Assembler ()
proc1 = mdo
    set (R 0) (I 0x5a5a)
    shl (R 0) (I 1)
    end

proc2 :: Assembler ()
proc2 = mdo
    set (R 0) (I 0x5a5a)
    -- allocate registers
    let r = (R 0)
    let bits = (I 2)
    let count = (R 70)

    set count bits
    _loop <- label
    cmp count (I 0)
    je _end
    add r r
    sub count (I 1)
    jmp _loop
    _end <- label
    end

-- ========================================================= Main

main :: IO ()
main = do
    doTest proc1 "Incorrect Output"
    doTest proc2 "Correct Output"

The output of the program follows.

The Incorrect Output from proc1:

AsmState {Code:
0000:> SET (R 0) (I 23130)
0001:  SET (R 70) (I 1)
0002:  CMP (R 70) (I 0)
0003:  JE (A 7)

0004:  ADD (R 0) (R 0)
0005:  SUB (R 70) (I 1)
0006:  JMP (A 2)
0007:  JE (A 7)
0008:  ADD (R 0) (R 0)
0009:  SUB (R 70) (I 1)
000A:  JMP (A 2)
000B:  END

Location: A 8
History:
[]
++
[]

0000:> JMP (A 2)
++
[]

0000:> SUB (R 70) (I 1)
++
0000:> JMP (A 2)

0000:> ADD (R 0) (R 0)
++
0000:> SUB (R 70) (I 1)
0001:  JMP (A 2)

0000:> JE (A 7)
++
0000:> ADD (R 0) (R 0)
0001:  SUB (R 70) (I 1)
0002:  JMP (A 2)

This is where the code duplication happens:

0000:> JE (A 7)
0001:  ADD (R 0) (R 0)
0002:  SUB (R 70) (I 1)
0003:  JMP (A 2)
++
0000:> JE (A 7)
0001:  ADD (R 0) (R 0)
0002:  SUB (R 70) (I 1)
0003:  JMP (A 2)

0000:> CMP (R 70) (I 0)
++
0000:> JE (A 7)
0001:  ADD (R 0) (R 0)
0002:  SUB (R 70) (I 1)
0003:  JMP (A 2)
0004:  JE (A 7)
0005:  ADD (R 0) (R 0)
0006:  SUB (R 70) (I 1)
0007:  JMP (A 2)

[]
++
0000:> CMP (R 70) (I 0)
0001:  JE (A 7)
0002:  ADD (R 0) (R 0)
0003:  SUB (R 70) (I 1)
0004:  JMP (A 2)
0005:  JE (A 7)
0006:  ADD (R 0) (R 0)
0007:  SUB (R 70) (I 1)
0008:  JMP (A 2)

0000:> SET (R 70) (I 1)
++
0000:> CMP (R 70) (I 0)
0001:  JE (A 7)
0002:  ADD (R 0) (R 0)
0003:  SUB (R 70) (I 1)
0004:  JMP (A 2)
0005:  JE (A 7)
0006:  ADD (R 0) (R 0)
0007:  SUB (R 70) (I 1)
0008:  JMP (A 2)

0000:> SET (R 70) (I 1)
0001:  CMP (R 70) (I 0)
0002:  JE (A 7)
0003:  ADD (R 0) (R 0)
0004:  SUB (R 70) (I 1)
0005:  JMP (A 2)
0006:  JE (A 7)
0007:  ADD (R 0) (R 0)
0008:  SUB (R 70) (I 1)
0009:  JMP (A 2)
++
0000:> END

0000:> SET (R 0) (I 23130)
++
0000:> SET (R 70) (I 1)
0001:  CMP (R 70) (I 0)
0002:  JE (A 7)
0003:  ADD (R 0) (R 0)
0004:  SUB (R 70) (I 1)
0005:  JMP (A 2)
0006:  JE (A 7)
0007:  ADD (R 0) (R 0)
0008:  SUB (R 70) (I 1)
0009:  JMP (A 2)
000A:  END
}

The Correct Output from proc2:

AsmState {Code:
0000:> SET (R 0) (I 23130)
0001:  SET (R 70) (I 2)
0002:  CMP (R 70) (I 0)
0003:  JE (A 7)
0004:  ADD (R 0) (R 0)
0005:  SUB (R 70) (I 1)
0006:  JMP (A 2)
0007:  END

Location: A 8
History:
[]
++
[]

0000:> JMP (A 2)
++
[]

0000:> SUB (R 70) (I 1)
++
0000:> JMP (A 2)

0000:> ADD (R 0) (R 0)
++
0000:> SUB (R 70) (I 1)
0001:  JMP (A 2)

0000:> JE (A 7)
++
0000:> ADD (R 0) (R 0)
0001:  SUB (R 70) (I 1)
0002:  JMP (A 2)

0000:> JE (A 7)
0001:  ADD (R 0) (R 0)
0002:  SUB (R 70) (I 1)
0003:  JMP (A 2)
++
0000:> END

0000:> CMP (R 70) (I 0)
++
0000:> JE (A 7)
0001:  ADD (R 0) (R 0)
0002:  SUB (R 70) (I 1)
0003:  JMP (A 2)
0004:  END

[]
++
0000:> CMP (R 70) (I 0)
0001:  JE (A 7)
0002:  ADD (R 0) (R 0)
0003:  SUB (R 70) (I 1)
0004:  JMP (A 2)
0005:  END

0000:> SET (R 70) (I 2)
++
0000:> CMP (R 70) (I 0)
0001:  JE (A 7)
0002:  ADD (R 0) (R 0)
0003:  SUB (R 70) (I 1)
0004:  JMP (A 2)
0005:  END

0000:> SET (R 0) (I 23130)
++
0000:> SET (R 70) (I 2)
0001:  CMP (R 70) (I 0)
0002:  JE (A 7)
0003:  ADD (R 0) (R 0)
0004:  SUB (R 70) (I 1)
0005:  JMP (A 2)
0006:  END
}
1

There are 1 answers

6
Petr On BEST ANSWER

I believe the problem is that your monad instance is flawed. It looks as if it's supposed to be a State monad, but then the definition of >>= does some manipulations that look more like the Writer monad (using the [] and Last monoids). I'm quite sure that at least mfix is incompatible with >>=, but my guess is that even >>= itself might fail monad laws.

I haven't looked for a good counter-example, but instead I can offer you a version that I believe works and that's based on the standard tools available. I'll post only the changed parts, my whole testing code is available here:

import Control.Monad.RWS
import Control.Monad.Fix (MonadFix(..))
import qualified Data.Foldable as F
import qualified Data.Sequence as S
import Text.Printf (printf)

-- ...

The state is the current location and the writer monoid is a sequence of instructions (I used Seq rather than [], because of the possibly bad performance of ++ for long lists; you could also use DList if desired):

newtype Assembler a = Assembler (RWS () (S.Seq Instruction) Word16 a)
    deriving (Functor, Applicative, Monad, MonadFix)

{- Append the list of instructions to the code stream. -}
append :: [Instruction] -> Assembler ()
append xs = Assembler . rws $ \_ loc -> ((), loc + (i# . length $ xs), S.fromList xs)

asm :: Assembler () -> AsmState
asm (Assembler proc) =
    let (location, code) = execRWS proc () 0
    in AsmState { _code = F.toList code
                , _location = A location
                , _codeHistory = [] -- I just ignored this field ...
                }

instance Instructions Assembler where 
    -- all same as before, except
    label = A <$> Assembler get

Now it'd probably make more sense to rename AsmState to something like AsmResult.

Also instead of using Location with partial functions, I'd suggest to use just Word16 as I did, or define a newtype that captures just location and then use it in Operand. This makes the code safer and cleaner.

(In any case it'd be good to have a test suite that'd test for issues like this.)