How to fold over a constructor with special cases?

211 views Asked by At

So I have a tree that I want to collapse where the nodes are of type

data Node = Node1 Node | Node2 Node Node | ... deriving Data

except for a few special cases. I want to do something along the lines of

collapse SPECIALCASE1 = ...
collapse SPECIALCASE2 = ...
...
collapse node = foldl (++) $ gmapQ validate node

where all the special cases generate lists of results that the last case just recursively collapses; but this doesn't work as the function that is the first parameter of gmapQ has to be of type forall d. Data d => d -> u and not Node -> u, which as far as I know just limits you to only using functions operating on the Data type.

Is there any way of coercing the values in the problem to be of the correct type, or another more lenient map function perhaps?

Extra info:

The actual code for the function described above as collapse is named validate and is for traversing and finding unbound variables in an abstract syntax tree (for a very simple language) for which the special cases are handled like this

validate _ (Nr _) = []
validate env (Let var val expr) = validate env val ++ validate (var:env) expr
validate env (Var var) = if elem var env then [] else [var]

which is essentially the rules that literal numbers don't have variables in them, let expressions binds a variable and variables need to be checked if bound or not. Every other construct in this toy-language is just a combination of numbers and variables (e.g. summation, multiplication, etc.) and as such when I check for unbound variables I just need to traverse their sub-trees and combine the results; thus the gmapQ.

Extra info 2:

The actual data type used instead of the Node example above is of the form

data Ast = Nr Int
         | Sum Ast Ast
         | Mul Ast Ast
         | Min Ast
         | If Ast Ast Ast
         | Let String Ast Ast
         | Var String
           deriving (Show, Eq, Data)
2

There are 2 answers

1
K. A. Buhr On BEST ANSWER

The direct way to do what you want is to write your special case for validate as:

validate env expr = concat $ gmapQ ([] `mkQ` (validate env)) expr

This uses mkQ from Data.Generics.Aliases. The whole point of mkQ is to create queries of type forall d. Data d => d -> u that can operate differently on different Data instances. By the way, there's no magic here. You could have defined it manually in terms of cast as:

validate env expr = concat $ gmapQ myQuery expr
  where myQuery :: Data d => d -> [String]
        myQuery d = case cast d of Just d -> validate env d
                                   _ -> []

Still, I've generally found it clearer to use uniplate from the lens library. The idea is to create a default Plated instance:

instance Plated Ast where
  plate = uniplate   -- uniplate from Data.Data.Lens 

which magically defines children :: Ast -> [Ast] to return all direct descendants of a node. You can then write your default validate case as:

validate env expr = concatMap (validate env) (children expr)

The full code w/ a test that prints ["z"]:

{-# LANGUAGE DeriveDataTypeable #-}

module SpecialCase where

import Control.Lens.Plated
import Data.Data
import Data.Data.Lens (uniplate)

data Ast = Nr Int
         | Sum Ast Ast
         | Mul Ast Ast
         | Min Ast
         | If Ast Ast Ast
         | Let String Ast Ast
         | Var String
           deriving (Show, Eq, Data)

instance Plated Ast where
  plate = uniplate

validate env (Let var val expr) = validate env val ++ validate (var:env) expr
validate env (Var var) = if elem var env then [] else [var]
-- either use this uniplate version:
validate env expr = concatMap (validate env) (children expr)
-- or use the alternative, lens-free version:
-- validate env expr = concat $ gmapQ ([] `mkQ` (validate env)) expr

main = print $ validate [] (Let "x" (Nr 3) (Let "y" (Var "x") 
             (Sum (Mul (Var "x") (Var "z")) (Var "y"))))
2
dfeuer On

I'm sorry I was too slow to get a Data-based answer written before K. A. Buhr jumped on it. Here's another approach, based on recursion-schemes.

First, the boilerplate:

{-# LANGUAGE TemplateHaskell, TypeFamilies
           , DeriveTraversable #-}

import Data.Functor.Foldable
import Data.Functor.Foldable.TH

data Ast = Nr Int
         | Sum Ast Ast
         | Mul Ast Ast
         | Min Ast
         | If Ast Ast Ast
         | Let String Ast Ast
         | Var String
         deriving (Show, Eq)

makeBaseFunctor ''Ast

This creates a type AstF that takes the recursion out of Ast. It looks like this:

data AstF ast = NrF Int
              | SumF ast ast
              | MulF ast ast
              ....
              deriving (Functor,Foldable,Traversable)

It then also creates several instances. We'll be using two of the auto-generated instances: the Recursive instance of Ast to recursively validate the tree, and the Foldable instance of AstF to concatenate the results from the children in the default case.

I found it helpful to create a separate type for environments; this is quite optional.

newtype Env = Env {getEnv :: [String]}

emptyEnv :: Env
emptyEnv = Env []

extendEnv :: String -> Env -> Env
extendEnv a (Env as) = Env (a : as)

isFree :: String -> Env -> Bool
isFree a (Env as) = not (elem a as)

Now we can get down to business, using the Recursive instance of Ast to get cata for free.

validate :: Env -> Ast -> [String]
validate env0 ast0 = cata go ast0 env0
  where
    go :: AstF (Env -> [String]) -> Env -> [String]
    go (LetF var val expr) env = val env ++ expr (extendEnv var env)
    go (VarF var) env = [var | isFree var env]
    go expr env = foldMap id expr env