I have the following code running in ST
using random numbers:
module Main
import Data.Vect
import Data.Fin
import Control.ST
import Control.ST.Random
%default total
choose : (Monad m) => (rnd : Var) -> (n : Nat) -> Vect (n + k) a -> ST m (Vect n a, Vect k a) [rnd ::: Random]
choose rnd n xs = pure $ splitAt n xs -- Dummy implementation
partitionLen : (a -> Bool) -> Vect n a -> DPair (Nat, Nat) (\(m, k) => (m + k = n, Vect m a, Vect k a))
partitionLen p [] = ((0, 0) ** (Refl, [], []))
partitionLen p (x :: xs) = case partitionLen p xs of
((m, k) ** (prf, lefts, rights)) =>
if p x then
((S m, k) ** (cong prf, x::lefts, rights))
else
((m, S k) ** (trans (plusS m k) (cong prf), lefts, x::rights))
where
plusS : (x : Nat) -> (y : Nat) -> x + (S y) = S (x + y)
plusS Z y = Refl
plusS (S x) y = cong (plusS x y)
generate : (Monad m) => (rnd : Var) -> ST m (Vect 25 (Fin 25)) [rnd ::: Random]
generate rnd = do
(shared, nonshared) <- choose rnd 4 indices
(agents1, nonagents1) <- choose rnd 6 nonshared
(agents2, nonagents2) <- choose rnd 6 nonagents1
(assassins1, others1) <- choose rnd 2 nonagents1
case partitionLen (`elem` assassins1) nonagents2 of
((n, k) ** (prf, xs, ys)) => do
(assassins2, others2') <- choose rnd 2 (agents1 ++ xs)
let prf' = trans (sym $ plusAssociative 4 n k) $ cong {f = (+) 4} prf
let others2 = the (Vect 13 (Fin 25)) $ replace {P = \n => Vect n (Fin 25)} prf' (others2' ++ ys)
pure $ shared ++ agents2 ++ assassins2 ++ others2
where
indices : Vect 25 (Fin 25)
indices = fromList [0..24]
I'd like to refactor generate
so that instead of the whole tail of the computation is under a case
, I would compute (assassins2, others2)
in a sub-computation, i.e. I would like to rewrite it as such:
(assassins2, others2) <- case partitionLen (`elem` assassins1) nonagents2 of
((n, k) ** (prf, xs, ys)) => do
(assassins2, others2') <- choose rnd 2 (agents1 ++ xs)
let prf' = trans (sym $ plusAssociative 4 n k) $ cong {f = (+) 4} prf
let others2 = replace {P = \n => Vect n (Fin 25)} prf' (others2' ++ ys)
pure (assassins2, others2)
pure $ shared ++ agents2 ++ assassins2 ++ others2
I believe this should be an equivalent transformation. However, this second version fails type-checking with:
When checking right hand side of Main.case block in case block in case block in case block in case block in generate with expected type
STrans m
(Vect 25 (Fin 25))
(st2_fn (assassins2, others2))
(\result => [rnd ::: State Integer])
When checking argument ys to function Data.Vect.++:
Type mismatch between
B (Type of others2)
and
Vect n (Fin 25) (Expected type)