Tail recursion with the State and CPS monads?

245 views Asked by At

I am in the process of creating a simple parsing library in Haskell, which compiles down the parser specification to optimized code using Template Haskell.

However, I am trying to figure out what kind of code is the most efficient to optimize to, so let us consider a simple parserly-written parser snippet.

A parser can be seen as a function String -> (a, String) where a is whatever we want to recognize during parsing, and the resulting string is whatever part of the input we have not matched. When parsers are run in sequence, the next parser should continue working on this string. At the end we can say a full string was successfully parsed if the end result is of the shape ("", _).

In this example, we'll use a parser that parses any number of 'a'-characters. (Error handling, that is reporting an error if any character other than an 'a' was seen, has been left out of the code snippets to keep everything concise for the sake of question clarity).

Straightforward (naive?) implementation

If we write the code for a recursive parser parserly, a straightforward hand-rolled implementation might be:

parser_raw :: String -> (String, String)
parser_raw str =
  case str of
    [] -> ([], [])
    ('a' : rest) ->
      let
        (res', rest') = parser_raw rest
      in
        (('a' : res'), rest')

This matches any number of 'a''s. The parser recurses until the string is empty.

Looking at the created GHC Core, we see the following output:

Rec {
-- RHS size: {terms: 36, types: 45, coercions: 0, joins: 0/1}
$wparser_raw
  = \ w_s7BO ->
      case w_s7BO of {
        [] -> (# [], [] #);
        : ds_d736 rest_a5S6 ->
          case ds_d736 of { C# ds1_d737 ->
          case ds1_d737 of {
            __DEFAULT -> case lvl6_r7Fw of wild2_00 { };
            'a'# ->
              let {
                ds3_s7vv
                  = case $wparser_raw rest_a5S6 of { (# ww1_s7C3, ww2_s7C4 #) ->
                    (ww1_s7C3, ww2_s7C4)
                    } } in
              (# : lvl2_r7Fs
                   (case ds3_s7vv of { (res'_a65m, rest'_a65n) -> res'_a65m }),
                 case ds3_s7vv of { (res'_a65m, rest'_a65n) -> rest'_a65n } #)
          }
          }
      }
end Rec }

To my untrained eye, it seems like this is not tail recursive, as we first do the recursive call and then still need to combine the output into a tuple afterwards. I also find it odd/interesting that the let in the source code still is there in the compiled code, which makes it seem like the computation really needs to be done in multiple (i.e. non-tail-recursive) steps.

Two other approaches to write this code come to mind. (Maybe there are more?)

continuation-passing style

parser_cps :: String -> (String, String)
parser_cps str = parser_cps' str id
  where
    parser_cps' :: String -> ((String, String) -> (String, String)) -> (String, String)
    parser_cps' str fun =
      case str of
        [] -> fun ([], [])
        ('a' : rest) ->
          parser_cps' rest ((\(tl, final) -> (('a' : tl), final) ) . fun)

This results in:

Rec {
-- RHS size: {terms: 28, types: 29, coercions: 0, joins: 0/0}
parser_cps_parser_cps'
  = \ str_a5S8 fun_a5S9 ->
      case str_a5S8 of {
        [] -> fun_a5S9 lvl9_r7FM;
        : ds_d72V rest_a5Sa ->
          case ds_d72V of { C# ds1_d72W ->
          case ds1_d72W of {
            __DEFAULT -> lvl8_r7FL;
            'a'# ->
              parser_cps_parser_cps'
                rest_a5Sa
                (\ x_i72P ->
                   case fun_a5S9 x_i72P of { (tl_a5Sb, final_a5Sc) ->
                   (: lvl2_r7FF tl_a5Sb, final_a5Sc)
                   })
          }
          }
      }
end Rec }

-- RHS size: {terms: 4, types: 4, coercions: 0, joins: 0/0}
parser_cps = \ str_a5S6 -> parser_cps_parser_cps' str_a5S6 id

This looks tail-recursive, in the sense that we perform a tail-call. It does seem like the result is built up as a large thunk however. Are we using a lot of extra heap space until the base of the recursion is reached?

the State monad

-- Using `mtl`'s
-- import Control.Monad.State.Strict

parser_state :: String -> (String, String)
parser_state = runState parser_state'
  where
    parser_state' :: State String String
    parser_state' = do
      str <- get
      case str of
        [] -> return []
        ('a' : rest) -> do
          put rest
          res <- parser_state'
          return ('a' : res)

This results in:

Rec {
-- RHS size: {terms: 26, types: 36, coercions: 0, joins: 0/0}
$wparser_state'
  = \ w_s7Ca ->
      case w_s7Ca of {
        [] -> (# [], [] #);
        : ds_d72p rest_a5Vb ->
          case ds_d72p of { C# ds1_d72q ->
          case ds1_d72q of {
            __DEFAULT -> case lvl11_r7FO of wild2_00 { };
            'a'# ->
              case $wparser_state' rest_a5Vb of { (# ww1_s7Cl, ww2_s7Cm #) ->
              (# : lvl2_r7FF ww1_s7Cl, ww2_s7Cm #)
              }
          }
          }
      }
end Rec }

-- RHS size: {terms: 8, types: 14, coercions: 6, joins: 0/0}
parser_state1
  = \ w_s7Ca ->
      case $wparser_state' w_s7Ca of { (# ww1_s7Cl, ww2_s7Cm #) ->
      (ww1_s7Cl, ww2_s7Cm) `cast` <Co:6>
      }

-- RHS size: {terms: 1, types: 0, coercions: 7, joins: 0/0}
parser_state = parser_state1 `cast` <Co:7>

Here, I'm not sure how to interpret the core: The let of the parser_raw-code is gone. Instead, the recursive call ends up immediately in the case. Afterwards we still have to put the result in a tuple, but is this "sans cons" enough to please the recursion gods?


So, to summarize: these are three techniques to write a simple parsing function. I would like to know which one of these is the most memory-efficient and how to get some intuition in recognizing tail-recursion from looking at GHC Core output.

0

There are 0 answers