I have a library that currently demands of users that they provide a helper function with type:

tEnum :: (KnownNat n) => MyType -> Finite n

so that the library implementation can use a very efficient sized vector representation of a function with type:

foo :: MyType -> a

(MyType is discrete and finite.)

Assuming that deriving a Generic instance for MyType is possible, is there a way to generate tEnum automatically, thus lifting that burden from my library's users?

I would also like to go the other way; that is, automatically derive:

tGen :: (KnownNat n) => Finite n -> MyType
1

There are 1 answers

2
Lucy Maya Menon On BEST ANSWER

I have something working for at least the tEnum side of things. Since you did not specify your representation of Finite I used my own Finite and Nat.

I have included a full code snippet with an example at the bottom of the post, but will only discuss the generic programming parts, leaving out the reasonably standard construction of Peano arithmetic and various useful theorems about it.

A typeclass is used to keep track of things that can be converted into/out of these finite enums. The important bit here is the default type signatures and the default definitions: these mean that if someone derives EnumFin for a class deriving Generic, they don't have to actually write any code, as these defaults will be used. The defaults use methods from another class, which is implemented for the various kinds of things that GHC.Generics can produce. Notice that both the normal and the default signatures use (n ~ ...) => ... n instead of writing the size of the Finite directly in the type signature; this is because GHC will otherwise detect that the default signatures don't have to match the regular signatures (in the case of a class implementation that defines Size but not fromFin or toFin):

class EnumFin a where
  type Size a :: Nat
  type Size a = GSize (Rep a)

  toFin :: (n ~ Size a) => a -> Finite n
  default toFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                => a -> Finite n
  toFin = gToFin . from

  fromFin :: (n ~ Size a) => Finite n -> a
  default fromFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                  => Finite n -> a
  fromFin = to . gFromFin

There are actually also a couple of other utility methods in the class. These are used by the actual generic implementation to get the minimum/maximum Finite n produced by an implementation (0 and n) without having to use more typeclasses & propagate KnownNat-style constraints:

  zero :: (n ~ Size a) => Finite n
  default zero :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
               => Finite n
  zero = gzero @(Rep a)
  gt :: (n ~ Size a) => Finite n
  default gt :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
               => Finite n
  gt = ggt @(Rep a)

The class declaration for the generic class is fairly simple; note however that its parameter is kind * -> *, not *:

class GEnumFin f where
  type GSize f :: Nat
  gToFin :: f a -> Finite (GSize f)
  gFromFin :: Finite (GSize f) -> f a
  gzero :: Finite (GSize f)
  ggt :: Finite (GSize f)

This generics class now must be implemented for each of the relevant generic constructors. For example, U1 is a very simple one, referring to a constructor without fields, which is just encoded as the Finite number 0:

instance GEnumFin U1 where
  type GSize U1 = 'Z
  gToFin U1 = ZF ZS
  gFromFin (ZF ZS) = U1
  gzero = ZF ZS
  ggt = ZF ZS

:*: is used to combine individual fields, so both parts need to be encoded (it encodes lhs*(m+1)+rhs where m is the max value of the rhs):

instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :*: b) where
  type GSize (a :*: b) = Plus (Times (GSize a) ('S (GSize b))) (GSize b)
  gToFin (a :*: b) = addFin (mulFin (gToFin a) (SF (ggt @b))) (gToFin b)
  gFromFin x = (gFromFin a :*: gFromFin b)
    where (a, b) = quotRemFin (toSN (ggt @a)) (toSN (ggt @b)) x
  gzero = addFin (mulFin (gzero @a) (SF (ggt @b))) (gzero @b)
  ggt = addFin (mulFin (ggt @a) (SF (ggt @b))) (ggt @b)

:+: on the other hand is used when representing sums, and so must be able to encode either of its constituents (it encodes the left hand side as 0..n and the right as n+1...n+1+m):

instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :+: b) where
  type GSize (a :+: b) = 'S (Plus (GSize a) (GSize b))
  gToFin (L1 a) = case proofPlusComm (toSN (gzero @a)) (toSN (gzero @b)) of
                    Refl -> addFin (injFin (gzero @b)) (gToFin a)
  gToFin (R1 b) = addFin (SF (ggt @a)) (gToFin b)
  gFromFin x = case proofPlusComm (toSN (ggt @a)) (toSN (ggt @b)) of
                 Refl -> splitFin (toSN (ggt @b)) (toSN (ggt @a)) x
                                  (R1 . gFromFin @b) (L1 . gFromFin @a)
  gzero = addFin (injFin (gzero @a)) (gzero @b)
  ggt = addFin (SF (ggt @a)) (ggt @b)

There is also an important instance for a single constructor field, which requires that the contained type also implement EnumFin:

instance (EnumFin a) => GEnumFin (K1 i a) where
  type GSize (K1 i a) = Size a
  gToFin (K1 a) = toFin a
  gFromFin = K1 . fromFin
  gzero = zero @a
  ggt = gt @a

Finally, it is necessary to implement the M1 constructor, which is used to attach metadata to the generic tree, and which we don't care about at all here:

instance forall i c a. (GEnumFin a) => GEnumFin (M1 i c a) where
  type GSize (M1 i c a) = GSize a
  gToFin (M1 a) = gToFin a
  gFromFin = M1 . gFromFin
  gzero = gzero @a
  ggt = ggt @a

For completeness, here is a complete file that defines all of the Nat/Finite infrastructure used above and exhibits using the Generic implementation:

{-# LANGUAGE TypeInType #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveGeneric #-}
import GHC.Generics
import Data.Type.Equality

-- Fairly standard Peano naturals & various useful theorems about them:
data Nat = Z | S Nat
data SNat (n :: Nat) where
  ZS :: SNat 'Z
  SS :: SNat n -> SNat ('S n)
deriving instance Show (SNat n)

type family Plus (n :: Nat) (m :: Nat) where
  Plus 'Z m = m
  Plus ('S n) m = 'S (Plus n m)

plus :: SNat n -> SNat m -> SNat (Plus n m)
plus ZS m = m
plus (SS n) m = SS (plus n m)

proofPlusNZ :: SNat n -> Plus n 'Z :~: n
proofPlusNZ ZS = Refl
proofPlusNZ (SS n) = case proofPlusNZ n of Refl -> Refl

proofPlusNS :: SNat n -> SNat m -> Plus n ('S m) :~: 'S (Plus n m)
proofPlusNS ZS _ = Refl
proofPlusNS (SS n) m = case proofPlusNS n m of Refl -> Refl

proofPlusAssoc :: SNat n -> SNat m -> SNat o
               -> Plus n (Plus m o) :~: Plus (Plus n m) o
proofPlusAssoc ZS _ _ = Refl
proofPlusAssoc (SS n) ZS _ = case proofPlusNZ n of Refl -> Refl
proofPlusAssoc (SS n) (SS m) ZS =
  case proofPlusNZ m of
    Refl -> case proofPlusNZ (plus n (SS m)) of
      Refl -> Refl
proofPlusAssoc (SS n) (SS m) (SS o) =
  case proofPlusAssoc n (SS m) (SS o) of Refl -> Refl

proofPlusComm :: SNat n -> SNat m -> Plus n m :~: Plus m n
proofPlusComm ZS ZS = Refl
proofPlusComm ZS (SS m) = case proofPlusNZ m of Refl -> Refl
proofPlusComm (SS n) ZS = case proofPlusNZ n of Refl -> Refl
proofPlusComm (SS n) (SS m) =
  case proofPlusComm (SS n) m of
    Refl -> case proofPlusComm n (SS m) of
      Refl -> case proofPlusComm n m of
        Refl -> Refl

type family Times (n :: Nat) (m :: Nat) where
  Times 'Z m = 'Z
  Times ('S n) m = Plus m (Times n m)

times :: SNat n -> SNat m -> SNat (Times n m)
times ZS _ = ZS
times (SS n) m = plus m (times n m)

proofMultNZ :: SNat n -> Times n 'Z :~: 'Z
proofMultNZ ZS = Refl
proofMultNZ (SS n) = case proofMultNZ n of Refl -> Refl

proofMultNS :: SNat n -> SNat m -> Times n ('S m) :~: Plus n (Times n m)
proofMultNS ZS ZS = Refl
proofMultNS ZS (SS m) =
  case proofMultNZ (SS m) of
    Refl -> case proofMultNZ m of
      Refl -> Refl
proofMultNS (SS n) ZS =
  case proofMultNS n ZS of Refl -> Refl
proofMultNS (SS n) (SS m) =
  case proofMultNS (SS n) m of
    Refl -> case proofMultNS n (SS m) of
      Refl -> case proofMultNS n m of
        Refl -> case lemma1 n m (times n (SS m)) of
          Refl -> Refl
  where lemma1 :: SNat n -> SNat m -> SNat o -> Plus n ('S (Plus m o))
                                                :~:
                                                'S (Plus m (Plus n o))
        lemma1 n' m' o' =
          case proofPlusComm n' (SS (plus m' o')) of
            Refl -> case proofPlusComm m' (plus n' o') of
              Refl -> case proofPlusAssoc m' o' n' of
                Refl -> case proofPlusComm n' o' of
                  Refl -> Refl

proofMultSN :: SNat n -> SNat m -> Times ('S n) m :~: Plus (Times n m) m
proofMultSN ZS m = case proofPlusNZ m of Refl -> Refl
proofMultSN (SS n) m =
  case proofPlusNZ (times n m) of
    Refl -> case proofPlusComm m (plus m (plus (times n m) ZS)) of
      Refl -> Refl

proofMultComm :: SNat n -> SNat m -> Times n m :~: Times m n
proofMultComm ZS ZS = Refl
proofMultComm ZS (SS m) = case proofMultNZ (SS m) of
                            Refl -> case proofMultComm ZS m of
                              Refl -> Refl
proofMultComm (SS n) ZS = case proofMultComm n ZS of Refl -> Refl
proofMultComm (SS n) (SS m) =
  case proofMultNS n m of
    Refl -> case proofMultNS m n of
      Refl -> case proofPlusAssoc m n (times n m) of
        Refl -> case proofPlusAssoc n m (times m n) of
          Refl -> case proofPlusComm n m of
            Refl -> case proofMultComm n m of
              Refl -> Refl

-- `Finite n` represents a number in 0..n (inclusive).
--
-- Notice that the "zero" branch includes an `SNat`; this is useful to be
-- able to conveniently write `toSN` below (generally, to be able to
-- reflect the `n` component to the value level) without needing to use a
-- singleton typeclass & pass constraitns around everywhere.
--
-- It should be possible to switch this out for other implementations of
-- `Finite` with different choices, but may require rewriting many of
-- the following functions.
data Finite (n :: Nat) where
  ZF :: SNat n -> Finite n
  SF :: Finite n -> Finite ('S n)
deriving instance Show (Finite n)

toSN :: Finite n -> SNat n
toSN (ZF sn) = sn
toSN (SF f) = SS (toSN f)

addFin :: forall n m. Finite n -> Finite m -> Finite (Plus n m)
addFin (ZF n) (ZF m) = ZF (plus n m)
addFin (ZF n) (SF b) =
  case proofPlusNS n (toSN b) of
    Refl -> SF (addFin (ZF n) b)
addFin (SF a) b = SF (addFin a b)

mulFin :: forall n m. Finite n -> Finite m -> Finite (Times n m)
mulFin (ZF n) (ZF m) = ZF (times n m)
mulFin (ZF n) (SF b) = case proofMultNS n (toSN b) of
                         Refl -> addFin (ZF n) (mulFin (ZF n) b)
mulFin (SF a) b = addFin b (mulFin a b)

quotRemFin :: SNat n -> SNat m -> Finite (Plus (Times n ('S m)) m)
        -> (Finite n, Finite m)
quotRemFin nn mm xx = go mm xx nn mm (ZF ZS) (ZF ZS)
  where go :: forall n m s p q r.
            (  Plus q s ~ n, Plus r p ~ m)
            => SNat m
            -> Finite (Plus (Times s ('S m)) p)
            -> SNat s
            -> SNat p
            -> Finite q
            -> Finite r
            -> (Finite n, Finite m)
        go _ (ZF _) s p q r = (addFin q (ZF s), addFin r (ZF p))
        go m (SF x) s (SS p) q r =
          case proofPlusComm (SS p) (times s m) of
            Refl -> case proofPlusNS (times s (SS m)) p of
              Refl -> case proofPlusNS (toSN r) p of
                Refl -> go m x s p q (SF r)
        go m (SF x) (SS s) ZS q _ =
          case proofPlusNS (toSN q) s of
            Refl -> case proofMultSN s (SS m) of
              Refl -> case proofPlusNS (times s (SS m)) m of
                Refl -> case proofPlusComm (times s (SS m)) (SS m) of
                  Refl -> case proofPlusNZ (times (SS s) (SS m)) of
                    Refl -> go m x s m (SF q) (ZF ZS)

splitFin :: forall n m a. SNat n -> SNat m -> Finite ('S (Plus n m))
         -> (Finite n -> a) -> (Finite m -> a) -> a
splitFin nn mm xx f g = go nn mm xx mm (ZF ZS)
  where go :: forall r s. (Plus r s ~ m)
           => SNat n -> SNat m -> Finite ('S (Plus n s))
           -> SNat s -> Finite r -> a
        go _ _ (ZF _) s r = g (addFin r (ZF s))
        go n m (SF x) (SS s) r =
          case proofPlusNS (toSN r) s of
            Refl -> case proofPlusNS n s of
              Refl -> go n m x s (SF r)
        go n _ (SF x) ZS _ = case proofPlusNZ n of Refl -> f x

injFin :: Finite n -> Finite ('S n)
injFin (ZF n) = ZF (SS n)
injFin (SF a) = SF (injFin a)

toNum :: (Num a) => Finite n -> a
toNum (ZF _) = 0
toNum (SF n) = 1 + toNum n

-- The actual classes & Generic stuff:
class EnumFin a where
  type Size a :: Nat
  type Size a = GSize (Rep a)

  toFin :: (n ~ Size a) => a -> Finite n
  default toFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                => a -> Finite n
  toFin = gToFin . from

  fromFin :: (n ~ Size a) => Finite n -> a
  default fromFin :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
                  => Finite n -> a
  fromFin = to . gFromFin

  zero :: (n ~ Size a) => Finite n
  default zero :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
               => Finite n
  zero = gzero @(Rep a)
  gt :: (n ~ Size a) => Finite n
  default gt :: (Generic a, GEnumFin (Rep a), n ~ GSize (Rep a))
               => Finite n
  gt = ggt @(Rep a)
class GEnumFin f where
  type GSize f :: Nat
  gToFin :: f a -> Finite (GSize f)
  gFromFin :: Finite (GSize f) -> f a
  gzero :: Finite (GSize f)
  ggt :: Finite (GSize f)

instance GEnumFin U1 where
  type GSize U1 = 'Z
  gToFin U1 = ZF ZS
  gFromFin (ZF ZS) = U1
  gzero = ZF ZS
  ggt = ZF ZS

instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :*: b) where
  type GSize (a :*: b) = Plus (Times (GSize a) ('S (GSize b))) (GSize b)
  gToFin (a :*: b) = addFin (mulFin (gToFin a) (SF (ggt @b))) (gToFin b)
  gFromFin x = (gFromFin a :*: gFromFin b)
    where (a, b) = quotRemFin (toSN (ggt @a)) (toSN (ggt @b)) x
  gzero = addFin (mulFin (gzero @a) (SF (ggt @b))) (gzero @b)
  ggt = addFin (mulFin (ggt @a) (SF (ggt @b))) (ggt @b)

instance forall a b. (GEnumFin a, GEnumFin b) => GEnumFin (a :+: b) where
  type GSize (a :+: b) = 'S (Plus (GSize a) (GSize b))
  gToFin (L1 a) = case proofPlusComm (toSN (gzero @a)) (toSN (gzero @b)) of
                    Refl -> addFin (injFin (gzero @b)) (gToFin a)
  gToFin (R1 b) = addFin (SF (ggt @a)) (gToFin b)
  gFromFin x = case proofPlusComm (toSN (ggt @a)) (toSN (ggt @b)) of
                 Refl -> splitFin (toSN (ggt @b)) (toSN (ggt @a)) x
                                  (R1 . gFromFin @b) (L1 . gFromFin @a)
  gzero = addFin (injFin (gzero @a)) (gzero @b)
  ggt = addFin (SF (ggt @a)) (ggt @b)

instance forall i c a. (GEnumFin a) => GEnumFin (M1 i c a) where
  type GSize (M1 i c a) = GSize a
  gToFin (M1 a) = gToFin a
  gFromFin = M1 . gFromFin
  gzero = gzero @a
  ggt = ggt @a

instance (EnumFin a) => GEnumFin (K1 i a) where
  type GSize (K1 i a) = Size a
  gToFin (K1 a) = toFin a
  gFromFin = K1 . fromFin
  gzero = zero @a
  ggt = gt @a

-- Demo:
data Foo = A | B deriving (Show, Generic)
data Bar = C | D deriving (Show, Generic)
data Baz = E Foo | F Bar | G Foo Bar deriving (Show, Generic)

instance EnumFin Foo
instance EnumFin Bar
instance EnumFin Baz

main :: IO ()
main = do
  putStrLn $ show $ toNum @Integer $ gt @Baz
  putStrLn $ show $ toNum @Integer $ toFin $ E A
  putStrLn $ show $ toNum @Integer $ toFin $ E B
  putStrLn $ show $ toNum @Integer $ toFin $ F C
  putStrLn $ show $ toNum @Integer $ toFin $ F D
  putStrLn $ show $ toNum @Integer $ toFin $ G A C
  putStrLn $ show $ toNum @Integer $ toFin $ G A D
  putStrLn $ show $ toNum @Integer $ toFin $ G B C
  putStrLn $ show $ toNum @Integer $ toFin $ G B D
  putStrLn $ show $ fromFin @Baz $ toFin $ E A
  putStrLn $ show $ fromFin @Baz $ toFin $ E B
  putStrLn $ show $ fromFin @Baz $ toFin $ F C
  putStrLn $ show $ fromFin @Baz $ toFin $ F D
  putStrLn $ show $ fromFin @Baz $ toFin $ G A C
  putStrLn $ show $ fromFin @Baz $ toFin $ G A D
  putStrLn $ show $ fromFin @Baz $ toFin $ G B C
  putStrLn $ show $ fromFin @Baz $ toFin $ G B D