Moving an ST computation to a sub-computation

22 views Asked by At

I have the following code running in ST using random numbers:

module Main

import Data.Vect
import Data.Fin

import Control.ST
import Control.ST.Random

%default total

choose : (Monad m) => (rnd : Var) -> (n : Nat) -> Vect (n + k) a -> ST m (Vect n a, Vect k a) [rnd ::: Random]
choose rnd n xs = pure $ splitAt n xs -- Dummy implementation

partitionLen : (a -> Bool) -> Vect n a -> DPair (Nat, Nat) (\(m, k) => (m + k = n, Vect m a, Vect k a))
partitionLen p [] = ((0, 0) ** (Refl, [], []))
partitionLen p (x :: xs) = case partitionLen p xs of
    ((m, k) ** (prf, lefts, rights)) =>
      if p x then
        ((S m, k) ** (cong prf, x::lefts, rights))
      else
        ((m, S k) ** (trans (plusS m k) (cong prf), lefts, x::rights))
  where
    plusS : (x : Nat) -> (y : Nat) -> x + (S y) = S (x + y)
    plusS Z y = Refl
    plusS (S x) y = cong (plusS x y)

generate : (Monad m) => (rnd : Var) -> ST m (Vect 25 (Fin 25)) [rnd ::: Random]
generate rnd = do
    (shared, nonshared) <- choose rnd 4 indices

    (agents1, nonagents1) <- choose rnd 6 nonshared
    (agents2, nonagents2) <- choose rnd 6 nonagents1

    (assassins1, others1) <- choose rnd 2 nonagents1

    case partitionLen (`elem` assassins1) nonagents2 of
      ((n, k) ** (prf, xs, ys)) => do
        (assassins2, others2') <- choose rnd 2 (agents1 ++ xs)
        let prf' = trans (sym $ plusAssociative 4 n k) $ cong {f = (+) 4} prf
        let others2 = the (Vect 13 (Fin 25)) $ replace {P = \n => Vect n (Fin 25)} prf' (others2' ++ ys)

        pure $ shared ++ agents2 ++ assassins2 ++ others2
  where
    indices : Vect 25 (Fin 25)
    indices = fromList [0..24]

I'd like to refactor generate so that instead of the whole tail of the computation is under a case, I would compute (assassins2, others2) in a sub-computation, i.e. I would like to rewrite it as such:

    (assassins2, others2) <- case partitionLen (`elem` assassins1) nonagents2 of
      ((n, k) ** (prf, xs, ys)) => do
        (assassins2, others2') <- choose rnd 2 (agents1 ++ xs)
        let prf' = trans (sym $ plusAssociative 4 n k) $ cong {f = (+) 4} prf
        let others2 = replace {P = \n => Vect n (Fin 25)} prf' (others2' ++ ys)
        pure (assassins2, others2)
    pure $ shared ++ agents2 ++ assassins2 ++ others2

I believe this should be an equivalent transformation. However, this second version fails type-checking with:

     When checking right hand side of Main.case block in case block in case block in case block in case block in generate with expected type
             STrans m
                    (Vect 25 (Fin 25))
                    (st2_fn (assassins2, others2))
                    (\result => [rnd ::: State Integer])

     When checking argument ys to function Data.Vect.++:
             Type mismatch between
                     B (Type of others2)
             and
                     Vect n (Fin 25) (Expected type)
0

There are 0 answers