Apparent space leak in variant upon Brent's "teleporting turtle" algorithm

191 views Asked by At

I have been implementing a variant of Brent's "teleporting turtle" algorithm mapped over all depthward paths through an N-tree for the purposes of value comparison of two different data structures, with my own backtracking algorithm for rolling back cycles without excluding non-cyclical paths that overlap partially with cyclical paths. From all appearances my algorithm is correct (even though I get the feeling that I should actually prove this, even though I have no background in proving anything about code), but I noticed today when trying to run 1000000 cycles of equal tests and not-equal tests (controlled by testCount) at 1-1024 nodes (controlled by maxNodeCount) and 2-5 branches per node (controlled by nodeSizeRange) that it very quickly ate all 8 GB of RAM on my system and rapidly started using large amounts of swap, forcing me to kill it. When I reduced the number of nodes to 1-512 it still rapidly, but not quite as rapidly, started using RAM on my system until it seemingly maxed out at 6 GB of RAM (I am not sure how much RAM it will really use, since I left it running at home). At 1-256 nodes it seemingly used a few GB of space, but not enough that I actually took much note.

The question is, why is it using such obscenely large amounts of RAM, when its space requirements should scale by O(n), where n is a function of the depth of the deepest path through the tree before any cycles are caught, the size of the largest cycle in the tree, and the number of cycle starting points in the tree. I could not find any obvious places where space leak behavior would be occurring in the code. The only thing I could think of is the nature of Brent's algorithm itself combined with that I am keeping a stack for a given depthward path; the combination of that the increments between turtles increase by 2^n, with very deep paths with cycles and very large cycles they could actually cycle for very long periods of time, causing large quantities of stack to be accumulated, before cycles are caught. But as Haskell is notorious for space leaks, this might be just a normal space leak rather than something algorithmic in nature which I might be missing the cause of.

(Edit; I realized this cannot be algorithmic, as the relationship between turtle depths and turtle scales are such that for a given turtle depth d the next turtle depth is ((d + 1) * 2) - 1; for instance, at depth 1023 the next turtle depth is 2047.)

Here is my code for the algorithm:

{-# LANGUAGE RecordWildCards, BangPatterns #-}

module EqualTree (Tree(..),
                  equal)
       where

import Data.Array.IO (IOArray)
import Data.Array.MArray (readArray,
                          getBounds)

data Tree a = Value a | Node (Node a)

type Node a = IOArray Int (Tree a)

data Frame a = Frame { frameNodes :: !(Node a, Node a),
                       frameSiblings :: !(Maybe (Siblings a)),
                       frameTurtle :: !(Turtle a) }

data Siblings a = Siblings { siblingNodes :: !(Node a, Node a),
                             siblingIndex :: !Int }

data Turtle a = Turtle { turtleDepth :: !Int,
                         turtleScale :: !Int,
                         turtleNodes :: !(Node a, Node a) }

data EqState a = EqState { stateFrames :: [Frame a],
                           stateCycles :: [(Node a, Node a)],
                           stateDepth :: !Int }

data Unrolled a = Unrolled { unrolledNodes :: !(Node a, Node a),
                             unrolledState :: !(EqState a),
                             unrolledSiblings :: !(Maybe (Siblings a)) }

data NodeComparison = EqualNodes | NotEqualNodes | HalfEqualNodes

equal :: Eq a => Tree a -> Tree a -> IO Bool
equal tree0 tree1 =
  let state = EqState { stateFrames = [], stateCycles = [], stateDepth = 0 }
  in ascend state tree0 tree1 Nothing

ascend :: Eq a => EqState a -> Tree a -> Tree a -> Maybe (Siblings a) -> IO Bool
ascend state (Value value0) (Value value1) siblings =
  if value0 == value1
  then descend state siblings
  else return False
ascend state (Node node0) (Node node1) siblings =
  case memberNodes (node0, node1) (stateCycles state) of
    EqualNodes -> descend state siblings
    HalfEqualNodes -> return False
    NotEqualNodes -> do
      (_, bound0) <- getBounds node0
      (_, bound1) <- getBounds node1
      if bound0 == bound1
        then
          let turtleNodes = currentTurtleNodes state
              state' = state { stateFrames =
                                  newFrame state node0 node1 siblings :
                                  stateFrames state,
                               stateDepth = (stateDepth state) + 1 }
              checkDepth = nextTurtleDepth state'
          in case turtleNodes of
               Just turtleNodes' -> 
                 case equalNodes (node0, node1) turtleNodes' of
                   EqualNodes -> beginRecovery state node0 node1 siblings
                   HalfEqualNodes -> return False
                   NotEqualNodes -> ascendFirst state' node0 node1
               Nothing -> ascendFirst state' node0 node1
        else return False
ascend _ _ _ _ = return False

ascendFirst :: Eq a => EqState a -> Node a -> Node a -> IO Bool
ascendFirst state node0 node1 = do
  (_, bound) <- getBounds node0
  tree0 <- readArray node0 0
  tree1 <- readArray node1 0
  if bound > 0
    then let siblings = Siblings { siblingNodes = (node0, node1),
                                   siblingIndex = 1 }
         in ascend state tree0 tree1 (Just siblings)
    else ascend state tree0 tree1 Nothing

descend :: Eq a => EqState a -> Maybe (Siblings a) -> IO Bool
descend state Nothing =
  case stateFrames state of
    [] -> return True
    frame : rest ->
      let state' = state { stateFrames = rest,
                           stateDepth = stateDepth state - 1 }
      in descend state' (frameSiblings frame)
descend state (Just Siblings{..}) = do
  let (node0, node1) = siblingNodes
  (_, bound) <- getBounds node0
  tree0 <- readArray node0 siblingIndex
  tree1 <- readArray node1 siblingIndex
  if siblingIndex < bound
    then let siblings' = Siblings { siblingNodes = (node0, node1),
                                    siblingIndex = siblingIndex + 1 }
         in ascend state tree0 tree1 (Just siblings')
    else ascend state tree0 tree1 Nothing

beginRecovery :: Eq a => EqState a -> Node a -> Node a -> Maybe (Siblings a)
                 -> IO Bool
beginRecovery state node0 node1 siblings =
  let turtle = case stateFrames state of
                 [] -> error "must have first frame in stack"
                 frame : _ -> frameTurtle frame
      distance = (stateDepth state + 1) - turtleDepth turtle
      unrolledFrame = Unrolled { unrolledNodes = (node0, node1),
                                 unrolledState = state,
                                 unrolledSiblings = siblings }
  in unrolledFrame `seq` unrollCycle state [unrolledFrame] (distance - 1)

unrollCycle :: Eq a => EqState a -> [Unrolled a] -> Int -> IO Bool
unrollCycle state unrolled !count
  | count <= 0 = findCycleStart state unrolled
  | otherwise =
      case stateFrames state of
        [] -> error "frame must be found"
        frame : rest ->
          let state' = state { stateFrames = rest,
                               stateDepth = stateDepth state - 1 }
              unrolledFrame =
                Unrolled { unrolledNodes = frameNodes frame,
                           unrolledState = state',
                           unrolledSiblings = frameSiblings frame }
          in unrolledFrame `seq`
             unrollCycle state' (unrolledFrame : unrolled) (count - 1)

findCycleStart :: Eq a => EqState a -> [Unrolled a] -> IO Bool
findCycleStart state unrolled =
  case stateFrames state of
    [] ->
      return True
    frame : [] ->
      case memberUnrolled (frameNodes frame) unrolled of
        (NotEqualNodes, _) -> error "node not in nodes unrolled"
        (HalfEqualNodes, _) -> return False
        (EqualNodes, Just (state, siblings)) ->
          let state' =
                state { stateCycles = frameNodes frame : stateCycles state }
          in state' `seq` descend state' siblings
    frame : rest@(prevFrame : _) ->
      case memberUnrolled (frameNodes prevFrame) unrolled of
        (EqualNodes, _) ->
          let state' = state { stateFrames = rest,
                               stateDepth = stateDepth state - 1 }
              unrolledFrame =
                Unrolled { unrolledNodes = frameNodes frame,
                           unrolledState = state',
                           unrolledSiblings = frameSiblings frame }
              unrolled' = updateUnrolled unrolledFrame unrolled
          in unrolledFrame `seq` findCycleStart state' unrolled'
        (HalfEqualNodes, _) -> return False
        (NotEqualNodes, _) ->
          case memberUnrolled (frameNodes frame) unrolled of
            (NotEqualNodes, _) -> error "node not in nodes unrolled"
            (HalfEqualNodes, _) -> return False
            (EqualNodes, Just (state, siblings)) ->
              let state' =
                    state { stateCycles = frameNodes frame : stateCycles state }
              in state' `seq` descend state' siblings

updateUnrolled :: Unrolled a -> [Unrolled a] -> [Unrolled a]
updateUnrolled _ [] = []
updateUnrolled unrolled0 (unrolled1 : rest) =
  case equalNodes (unrolledNodes unrolled0) (unrolledNodes unrolled1) of
    EqualNodes -> unrolled0 : rest
    NotEqualNodes -> unrolled1 : updateUnrolled unrolled0 rest
    HalfEqualNodes -> error "this should not be possible"

memberUnrolled :: (Node a, Node a) -> [Unrolled a] ->
                  (NodeComparison, Maybe (EqState a, Maybe (Siblings a)))
memberUnrolled _ [] = (NotEqualNodes, Nothing)
memberUnrolled nodes (Unrolled{..} : rest) =
  case equalNodes nodes unrolledNodes of
    EqualNodes -> (EqualNodes, Just (unrolledState, unrolledSiblings))
    HalfEqualNodes -> (HalfEqualNodes, Nothing)
    NotEqualNodes -> memberUnrolled nodes rest

newFrame :: EqState a -> Node a -> Node a -> Maybe (Siblings a) -> Frame a
newFrame state node0 node1 siblings =
  let turtle =
        if (stateDepth state + 1) == nextTurtleDepth state
        then Turtle { turtleDepth = stateDepth state + 1,
                      turtleScale = currentTurtleScale state * 2, 
                      turtleNodes = (node0, node1) }
        else case stateFrames state of
               [] -> Turtle { turtleDepth = 1, turtleScale = 2,
                              turtleNodes = (node0, node1) }
               frame : _ -> frameTurtle frame
  in Frame { frameNodes = (node0, node1),
             frameSiblings = siblings,
             frameTurtle = turtle }

memberNodes :: (Node a, Node a) -> [(Node a, Node a)] -> NodeComparison
memberNodes _ [] = NotEqualNodes
memberNodes nodes0 (nodes1 : rest) =
  case equalNodes nodes0 nodes1 of
    NotEqualNodes -> memberNodes nodes0 rest
    HalfEqualNodes -> HalfEqualNodes
    EqualNodes -> EqualNodes

equalNodes :: (Node a, Node a) -> (Node a, Node a) -> NodeComparison
equalNodes (node0, node1) (node2, node3) =
  if node0 == node2
  then if node1 == node3
       then EqualNodes
       else HalfEqualNodes
  else if node1 == node3
       then HalfEqualNodes
       else NotEqualNodes

currentTurtleNodes :: EqState a -> Maybe (Node a, Node a)
currentTurtleNodes state =
  case stateFrames state of
    [] -> Nothing
    frame : _ -> Just . turtleNodes . frameTurtle $ frame

currentTurtleScale :: EqState a -> Int
currentTurtleScale state =
  case stateFrames state of
    [] -> 1
    frame : _ -> turtleScale $ frameTurtle frame

nextTurtleDepth :: EqState a -> Int
nextTurtleDepth state =
  case stateFrames state of
    [] -> 1
    frame : _ -> let turtle = frameTurtle frame
                 in turtleDepth turtle + turtleScale turtle

Here is a naive version of the algorithm used by the test program.

{-# LANGUAGE RecordWildCards #-}

module NaiveEqualTree (Tree(..),
                       naiveEqual)
       where

import Data.Array.IO (IOArray)
import Data.Array.MArray (readArray,
                          getBounds)

import EqualTree (Tree(..),
                  Node)

data Frame a = Frame { frameNodes :: !(Node a, Node a),
                       frameSiblings :: !(Maybe (Siblings a)) }

data Siblings a = Siblings { siblingNodes :: !(Node a, Node a),
                             siblingIndex :: !Int }

data NodeComparison = EqualNodes | NotEqualNodes | HalfEqualNodes

naiveEqual :: Eq a => Tree a -> Tree a -> IO Bool
naiveEqual tree0 tree1 = ascend [] tree0 tree1 Nothing

ascend :: Eq a => [Frame a] -> Tree a -> Tree a -> Maybe (Siblings a) -> IO Bool
ascend state (Value value0) (Value value1) siblings =
  if value0 == value1
  then descend state siblings
  else return False
ascend state (Node node0) (Node node1) siblings =
  case testNodes (node0, node1) state of
    EqualNodes -> descend state siblings
    HalfEqualNodes -> return False
    NotEqualNodes -> do
      (_, bound0) <- getBounds node0
      (_, bound1) <- getBounds node1
      if bound0 == bound1
        then do
          let frame = Frame { frameNodes = (node0, node1),
                              frameSiblings = siblings }
              state' = frame : state
          tree0 <- readArray node0 0
          tree1 <- readArray node1 0
          if bound0 > 0
            then let siblings = Siblings { siblingNodes = (node0, node1),
                                           siblingIndex = 1 }
                 in frame `seq` ascend state' tree0 tree1 (Just siblings)
            else frame `seq` ascend state' tree0 tree1 Nothing
        else return False
ascend _ _ _ _ = return False

descend :: Eq a => [Frame a] -> Maybe (Siblings a) -> IO Bool
descend state Nothing =
  case state of
    [] -> return True
    frame : rest -> descend rest (frameSiblings frame)
descend state (Just Siblings{..}) = do
  let (node0, node1) = siblingNodes
  (_, bound) <- getBounds node0
  tree0 <- readArray node0 siblingIndex
  tree1 <- readArray node1 siblingIndex
  if siblingIndex < bound
    then let siblings' = Siblings { siblingNodes = (node0, node1),
                                    siblingIndex = siblingIndex + 1 }
         in ascend state tree0 tree1 (Just siblings')
    else ascend state tree0 tree1 Nothing

testNodes :: (Node a, Node a) -> [Frame a] -> NodeComparison
testNodes _ [] = NotEqualNodes
testNodes nodes (frame : rest) =
  case equalNodes nodes (frameNodes frame) of
    NotEqualNodes -> testNodes nodes rest
    HalfEqualNodes -> HalfEqualNodes
    EqualNodes -> EqualNodes

equalNodes :: (Node a, Node a) -> (Node a, Node a) -> NodeComparison
equalNodes (node0, node1) (node2, node3) =
  if node0 == node2
  then if node1 == node3
       then EqualNodes
       else HalfEqualNodes
  else if node1 == node3
       then HalfEqualNodes
       else NotEqualNodes

Here is the code of the test program. Note that this will occasionally fail on the not-equals test because it is designed to generate sets of nodes with a significant degree of commonality, as controlled by maxCommonPortion.

{-# LANGUAGE TupleSections #-}

module Main where

import Data.Array (Array,
                   listArray,
                   bounds,
                   (!))
import Data.Array.IO (IOArray)
import Data.Array.MArray (writeArray,
                          newArray_)
import Control.Monad (forM_,
                      mapM,
                      mapM_,
                      liftM,
                      foldM)
import Control.Exception (SomeException,
                          catch)
import System.Random (StdGen,
                      newStdGen,
                      random,
                      randomR,
                      split)
import Prelude hiding (catch)

import EqualTree (Tree(..),
                  equal)
import NaiveEqualTree (naiveEqual)

leafChance :: Double
leafChance = 0.5

valueCount :: Int
valueCount = 1

maxNodeCount :: Int
maxNodeCount = 1024

commonPortionRange :: (Double, Double)
commonPortionRange = (0.8, 0.9)

commonRootChance :: Double
commonRootChance = 0.5

nodeSizeRange :: (Int, Int)
nodeSizeRange = (2, 5)

testCount :: Int
testCount = 1000

makeMapping :: Int -> (Int, Int) -> Int -> StdGen ->
               ([Either Int Int], StdGen)
makeMapping values range nodes gen =
  let (count, gen') = randomR range gen
  in makeMapping' 0 [] count gen'
  where makeMapping' index mapping count gen
          | index >= count = (mapping, gen)
          | otherwise =
            let (chance, gen0) = random gen
                (slot, gen2) =
                  if chance <= leafChance
                  then let (value, gen1) = randomR (0, values - 1) gen0
                       in (Left value, gen1)
                  else let (nodeIndex, gen1) = randomR (0, nodes - 1) gen0
                       in (Right nodeIndex, gen1)
            in makeMapping' (index + 1) (slot : mapping) count gen2

makeMappings :: Int -> Int -> (Int, Int) -> StdGen ->
                ([[Either Int Int]], StdGen)
makeMappings size values range gen =
  let (size', gen') = randomR (1, size) gen
  in makeMappings' 0 size' [] gen'
  where makeMappings' index size mappings gen
          | index >= size = (mappings, gen)
          | otherwise =
            let (mapping, gen') = makeMapping values range size gen
            in makeMappings' (index + 1) size (mapping : mappings) gen'

makeMappingsPair :: Int -> (Double, Double) -> Int -> (Int, Int) -> StdGen ->
                    ([[Either Int Int]], [[Either Int Int]], StdGen)
makeMappingsPair size commonPortionRange values range gen =
  let (size', gen0) = randomR (2, size) gen
      (commonPortion, gen1) = randomR commonPortionRange gen0
      size0 = 1 + (floor $ fromIntegral size' * commonPortion)
      size1 = size' - size0
      (mappings, gen2) = makeMappingsPair' 0 size0 size' [] gen1
      (mappings0, gen3) = makeMappingsPair' 0 size1 size' [] gen2
      (mappings1, gen4) = makeMappingsPair' 0 size1 size' [] gen3
      (commonRootValue, gen5) = random gen4
  in if commonRootValue < commonRootChance
     then (mappings ++ mappings0, mappings ++ mappings1, gen5)
     else (mappings0 ++ mappings, mappings1 ++ mappings, gen5)
  where makeMappingsPair' index size size' mappings gen
          | index >= size = (mappings, gen)
          | otherwise =
            let (mapping, gen') = makeMapping values range size' gen
            in makeMappingsPair' (index + 1) size size' (mapping : mappings)
               gen'

populateNode :: IOArray Int (Tree a) -> Array Int (IOArray Int (Tree a)) ->
                [Either a Int] -> IO ()
populateNode node nodes mapping =
  mapM_ (uncurry populateSlot) (zip [0..] mapping)
  where populateSlot index (Left value) =
          writeArray node index $ Value value
        populateSlot index (Right nodeIndex) =
          writeArray node index . Node $ nodes ! nodeIndex

makeTree :: [[Either a Int]] -> IO (Tree a)
makeTree mappings = do
  let size = length mappings
  nodes <- liftM (listArray (0, size - 1)) $ mapM makeNode mappings
  mapM_ (\(index, mapping) -> populateNode (nodes ! index) nodes mapping)
    (zip [0..] mappings)
  return . Node $ nodes ! 0
  where makeNode mapping = newArray_ (0, length mapping - 1)

testEqual :: StdGen -> IO (Bool, StdGen)
testEqual gen = do
  let (mappings, gen0) =
        makeMappings maxNodeCount valueCount nodeSizeRange gen
  tree0 <- makeTree mappings
  tree1 <- makeTree mappings
  catch (liftM (, gen0) $ equal tree0 tree1) $ \e -> do
    putStrLn $ show (e :: SomeException)
    return (False, gen0)

testNotEqual :: StdGen -> IO (Bool, Bool, StdGen)
testNotEqual gen = do
  let (mappings0, mappings1, gen0) =
        makeMappingsPair maxNodeCount commonPortionRange valueCount
        nodeSizeRange gen
  tree0 <- makeTree mappings0
  tree1 <- makeTree mappings1
  test <- naiveEqual tree0 tree1
  if not test
    then
      catch (testNotEqual' tree0 tree1 mappings0 mappings1 gen0) $ \e -> do
        putStrLn $ show (e :: SomeException)
        return (False, False, gen0)
    else return (True, True, gen0)
  where testNotEqual' tree0 tree1 mappings0 mappings1 gen0 = do
          test <- equal tree0 tree1
          if test
            then do
              putStrLn "Match failure: "
              putStrLn "Mappings 0: "
              mapM (putStrLn . show) $ zip [0..] mappings0
              putStrLn "Mappings 1: "
              mapM (putStrLn . show) $ zip [0..] mappings1
              return (False, False, gen0)
            else return (True, False, gen0)

doTestEqual :: (StdGen, Int) -> Int -> IO (StdGen, Int)
doTestEqual (gen, successCount) _ = do
  (success, gen') <- testEqual gen
  return (gen', successCount + (if success then 1 else 0))

doTestNotEqual :: (StdGen, Int, Int) -> Int -> IO (StdGen, Int, Int)
doTestNotEqual (gen, successCount, excludeCount) _ = do
  (success, exclude, gen') <- testNotEqual gen
  return (gen', successCount + (if success then 1 else 0),
          excludeCount + (if exclude then 1 else 0))

main :: IO ()
main = do
  gen <- newStdGen
  (gen0, equalSuccessCount) <- foldM doTestEqual (gen, 0) [1..testCount]
  putStrLn $ show equalSuccessCount ++ " out of " ++ show testCount ++
    " tests for equality passed"
  (_, notEqualSuccessCount, excludeCount) <-
    foldM doTestNotEqual (gen0, 0, 0) [1..testCount]
  putStrLn $ show notEqualSuccessCount ++ " out of " ++ show testCount ++
    " tests for inequality passed (with " ++ show excludeCount ++ " excluded)"
1

There are 1 answers

0
Travis Bemann On

It turns out that the problem was due to a bug which was keeping the "unrolled" list from being properly updated, likely combined with live variables held by chains of thunks that were not necessarily being forced (even though it was when I made the former fix that the problem went away, so a lack of strictness may not have been the cause of the biggest problem).

The code in the original post has been updated to reflect the fixes made to it.