Haskell avoiding stack overflow in folds without sacrificing performance

164 views Asked by At

The following piece of code experiences a stack overflow for large inputs:

{-# LANGUAGE DeriveDataTypeable, OverloadedStrings #-}
import qualified Data.ByteString.Lazy.Char8 as L


genTweets :: L.ByteString -> L.ByteString
genTweets text | L.null text = ""
               | otherwise = L.intercalate "\n\n" $ genTweets' $ L.words text
  where genTweets' txt = foldr p [] txt
          where p word [] = [word]
                p word words@(w:ws) | L.length word + L.length w <= 139 =
                                        (word `L.append` " " `L.append` w):ws
                                    | otherwise = word:words

I assume my predicate is building a list of thunks, but I'm not sure why, or how to fix it.

The equivalent code using foldl' runs fine, but takes forever, since it appends constantly, and uses a ton of memory.

import Data.List (foldl')

genTweetsStrict :: L.ByteString -> L.ByteString
genTweetsStrict text | L.null text = "" 
                     | otherwise = L.intercalate "\n\n" $ genTweetsStrict' $ L.words text
  where genTweetsStrict' txt = foldl' p [] txt
          where p [] word = [word]
                p words word | L.length word + L.length (last words) <= 139 =
                                init words ++ [last words `L.append` " " `L.append` word]
                             | otherwise = words ++ [word]

What is causing the first snippet to build up thunks, and can it be avoided? Is it possible to write the second snippet so that it doesn't rely on (++)?

2

There are 2 answers

1
Mikhail Glushenkov On BEST ANSWER
L.length word + L.length (last words) <= 139

This is the problem. On each iteration, you're traversing the accumulator list, and then

init words ++ [last words `L.append` " " `L.append` word]

appending at the end. Obviously this going to take a long time (proportional to the length of the accumulator list). A better solution is to generate the output list lazily, interleaving processing with reading the input stream (you don't need to read the whole input to output the first 140-character tweet).

The following version of your program processes a relatively large file (/usr/share/dict/words) in under a 1 second time, while using O(1) space:

{-# LANGUAGE OverloadedStrings, BangPatterns #-}

module Main where

import qualified Data.ByteString.Lazy.Char8 as L
import Data.Int (Int64)

genTweets :: L.ByteString -> L.ByteString
genTweets text | L.null text = ""
               | otherwise   = L.intercalate "\n\n" $ toTweets $ L.words text
  where

    -- Concatenate words into 139-character tweets.
    toTweets :: [L.ByteString] -> [L.ByteString]
    toTweets []     = []
    toTweets [w]    = [w]
    toTweets (w:ws) = go (L.length w, w) ws

    -- Main loop. Notice how the output tweet (cur_str) is generated as soon as
    -- possible, thus enabling L.writeFile to consume it before the whole
    -- input is processed.
    go :: (Int64, L.ByteString) -> [L.ByteString] -> [L.ByteString]
    go (_cur_len, !cur_str) []     = [cur_str]
    go (!cur_len, !cur_str) (w:ws)
      | lw + cur_len <= 139        = go (cur_len + lw + 1,
                                         cur_str `L.append` " " `L.append` w) ws
      | otherwise                  = cur_str : go (lw, w) ws
      where
        lw = L.length w

-- Notice the use of lazy I/O.
main :: IO ()
main = do dict <- L.readFile "/usr/share/dict/words"
          L.writeFile "tweets" (genTweets dict)
2
Sassa NF On

p word words@(w:ws)

This pattern matching causes evaluation of the "tail", which is, of course, the result of foldr p [] (w:ws), which is the result of p w ws, which causes ws to pattern-match the head again, etc.

Note that foldr and foldl' will split the text differently. foldr will have the shortest tweet appear first, foldl' will make the shortest tweet appear last.


I would go about it like so:

genTweets' = unfoldr f where
  f [] = Nothing
  f (w:ws) = Just $ g w ws $ L.length w
  g w [] _ = (w, [])
  g w ws@(w':_) len | len+1+(L.length w') > 139 = (w,ws)
  g w (w':ws') len = g (w `L.append` " " `L.append` w') ws' $ len+1+(L.length w')