Dependent pattern matching on two values with the same type

960 views Asked by At

We will use the standard definition of the finite sets:

Inductive fin : nat -> Set :=
| F1 : forall {n : nat}, fin (S n)
| FS : forall {n : nat}, fin n -> fin (S n).

Let us assume that we have some P : forall {m : nat} (x y : fin m) : Set (the important thing is that both arguments of P are of the same type). For demonstration purposes, let P be just:

Definition P {m : nat} (x y : fin m) := {x = y} + {x <> y}.

Now, we would like to write a custom function that compares two numbers:

Fixpoint feq_dec {m : nat} (x y : fin m) : P x y.

The idea is straightforward: we match on x and y, for x = F1, y = F1 we trivially return equality, for x = FS x', y = FS y' we recursively call the procedure for x' and y', for other cases we can trivially return inequality.

A direct translation of this idea into Coq obviously fails:

refine (
  match x, y return P x y with
  | F1 _, F1 _ => _
  | FS _ x', F1 _ => _
  | F1 _, FS _ y' => _
  | FS _ x', FS _ y' => _
  end
).

(*
 * The term "y0" has type "fin n0" while it is expected to have type "fin n".
 *)

During the match on x and y we loose the type information so we cannot apply them to P. The standard trick with passing type equality proof does not help here:

refine (
  match x in fin mx, y in fin my return mx = my -> P x y with
  | F1 _, F1 _ => _
  | FS _ x', F1 _ => _
  | F1 _, FS _ y' => _
  | FS _ x', FS _ y' => _
  end eq_refl
).

(*
 * The term "y0" has type "fin my" while it is expected to have type "fin mx".
 *)

So, maybe we can use that proof of equality to cast x have the same type as y?

Definition fcast {m1 m2 : nat} (Heq : m1 = m2) (x : fin m1) : fin m2.
Proof.
  rewrite <- Heq.
  apply x.
Defined.

We also need to be able to get rid of the cast later on. However, I noticed that fcast eq_refl x = x is not sufficient as we need to make it work with arbitrary equivalence proof. I have found something called UIP that does exactly what I need.

Require Import Coq.Program.Program.

Lemma fcast_void {m : nat} : forall (x : fin m) (H : m = m),
  fcast H x = x.
Proof.
  intros.
  rewrite -> (UIP nat m m H eq_refl).
  trivial.
Defined.

Now we are ready to finish the whole definition:

refine (
  match x in fin mx, y in fin my
  return forall (Hmx : m = mx) (Hmy : mx = my), P (fcast Hmy x) y with
  | F1 _, F1 _ => fun Hmx Hmy => _
  | FS _ x', F1 _ => fun Hmx Hmy => _
  | F1 _, FS _ y' => fun Hmx Hmy => _
  | FS _ x', FS _ y' => fun Hmx Hmy => _
  end eq_refl eq_refl
); inversion Hmy; subst; rewrite fcast_void.
- left. reflexivity.
- right. intro Contra. inversion Contra.
- right. intro Contra. inversion Contra.
- destruct (feq_dec _ x' y') as [Heq | Hneq].
  + left. apply f_equal. apply Heq.
  + right. intro Contra. dependent destruction Contra. apply Hneq. reflexivity.
Defined.

It goes through! However, it doesn't evaluate to any useful value. For example the following yields a term with five nested matches instead of a simple value (in_right or in_left). I suspect the problem is with the UIP axiom that I used.

Compute (@feq_dec 5 (FS F1) (FS F1)).

So in the end, the definition that I came up with is pretty much useless. I have also tried doing nested matches using the convoy pattern instead of doing matching two values at the same time but I hit the same obstacles: as soon as I do the matching on the second value, P stops being applicable to it. Can I do it some other way?

2

There are 2 answers

1
gallais On BEST ANSWER

You can write the terms by hand but it's a nightmare. Here I describe the computational part and use tactics to deal with the proving:

Fixpoint feq_dec {m : nat} (x y : fin m) : P x y.
refine (
match m return forall (x y : fin m), P x y with
  | O    => _
  | S m' => fun x y =>
  match (case x, case y) with
    | (inright eqx            , inright eqy)             => left _
    | (inleft (exist _ x' eqx), inright eqy)             => right _
    | (inright eqx            , inleft (exist _ y' eqy)) => right _
    | (inleft (exist _ x' eqx), inleft (exist _ y' eqy)) =>
    match feq_dec _ x' y' with
      | left eqx'y'   => left _
      | right neqx'y' => right _
    end
  end
end x y); simpl in *; subst.
- inversion 0.
- reflexivity.
- intro Heq; apply neqx'y'.
  assert (Heq' : Some x' = Some y') by exact (f_equal finpred Heq).
  inversion Heq'; reflexivity.
- inversion 1.
- inversion 1.
- reflexivity.
Defined.

The function defined this way works as expected:

Compute (@feq_dec 5 (FS F1) (FS F1)).
(* 
 = left eq_refl
 : P (FS F1) (FS F1)
*)

This code relies on 3 tricks:

1. Start by inspecting the bound m.

Indeed if you don't know anything about the bound m, you'll learn two different facts from the match on x and y respectively and you'll need to reconcile these facts (i.e. show that the two predecessor for m you're given are in fact equal). If, on the other hand, you know that m has the shape S m' then you can...

2. Use a case function inverting the term based on the bound's shape

If you know that the bound has a shape S m' then you know for each one of your fins that you are in one of two cases: either the fin is F1 or it is FS x' for some x'. case makes this formal:

Definition C {m : nat} (x : fin (S m)) :=
  { x' | x = FS x' } + { x = F1 }.

Definition case {m : nat} (x : fin (S m)) : C x :=
match x in fin (S n) return { x' | x = FS x' } + { x = F1 } with
  | F1    => inright eq_refl
  | FS x' => inleft (exist _ x' eq_refl)
end.

Coq will be smart enough to detect that the values we are returning from case are direct subterms of the arguments it takes. So performing recursive calls when both x and y have the shape FS _ won't be a problem!

3. Use congruence with an ad-hoc function to peel off constructors

In the branch where we have performed a recursive call but got a negative answer in return, we need to prove FS x' <> FS y' knowing that x' <> y'. Which means that we need to turn Heq : FS x' = FS y' into x' = y'.

Because FS has a complicated return type, simply performing inversion on Heq won't yield a usable result (we get an equality between dependent pairs of a nat p and a fin p). This is were finpred comes into play: it's a total function which, when faced with FS _ simply peels off the FS constructor.

Definition finpred {m : nat} (x : fin m) : option (fin (pred m)) :=
match x with
  | F1    => None
  | FS x' => Some x'
end.

Combined with f_equal and Heq we get a proof that Some x' = Some y' on which we can use inversion and get the equality we wanted.

Edit: I've put all the code in a self-contained gist.

2
ejgallego On

This is a known problem and in most cases you are gonna fare better using the equality on the underlying nat and then taking profit than the to_nat function is injective:

From mathcomp Require Import all_ssreflect.

Set Implicit Arguments.
Unset Strict Implicit.
Unset Printing Implicit Defensive.

Require Import PeanoNat Fin.

Fixpoint to_nat m (x : t m) :=
  match x with
  | F1 _   => 0
  | FS _ x => (to_nat x).+1
  end.

Lemma to_nat_inj m : injective (@to_nat m).
Proof.
elim: m / => /= [|m t iht y].
  exact: (caseS (fun n (y : t n.+1) => _)).
move: m y t iht.
by apply: (caseS (fun n (y : t n.+1) => _)) => //= n p t iht [] /iht ->.
Qed.

Lemma feq_dec {m : nat} (x y : t m) : {x = y} + {x <> y}.
Proof.
have [heq | heqN] := Nat.eq_dec (to_nat x) (to_nat y).
  by left; apply: to_nat_inj.
by right=> H; apply: heqN; rewrite H.
Qed.

But even so, things are still cumbersome to work with. You could try to use the 'I_n type include in ssreflect which separates the computational value from the bound, a bit of search in SO should give you enough pointers.

If you turn the Qeds into Defined the above will compute for your case, and in general it should be enough to give you either left ? or right ? allowing proofs that depend on it to continue.

However it will need some large amount of tweaking if you want it to go throu a normal form in the non equal case [mainly, the O_S lemma being opaque, which also affects Nat.eq_dec]