Simplifying Subformulas in Coq

480 views Asked by At

I'm trying to solve an equation of the form

A * B * C * D * E = F

where * some complicated left associative operation.

At the moment, everything is opaque (including * and A through F), and can be made transparent via autounfold with M_db.

The problem is that if I globally unfold the definition in the formula, simplification will take forever. Instead, I want to first unfold A * B, apply some tactics to reduce it to a normal form X, and then do the same with X * C and so forth.

Any idea how I would accomplish this? Here's my current approach but the in A or at A doesn't work. Also, it's not clear to me whether this is the right structure, or reduce_m ought to return something.

Ltac reduce_m M :=
  match M with
  | ?A × ?B => reduce_m A;
              reduce_m B;
              simpl;
              autorewrite with C_db
  | ?A      => autounfold with M_db (* in A *);
              simpl; 
              autorewrite with C_db
  end.


Ltac simpl_m := 
  match goal with
  | [|- ?M = _ ] => reduce_m M
  end.

A minimalish example:

Require Import Arith.

Definition add_f (f g : nat -> nat) :=  fun x => f x + g x.

Infix "+" := add_f.

Definition f := fun x => if x =? 4 then 1 else 0.
Definition g := fun x => if x <=? 4 then 3 else 0.
Definition h := fun x => if x =? 2 then 2 else 0.

Lemma ex : f + g + h = fun x => match x with
                             | 0 => 3
                             | 1 => 3
                             | 2 => 5
                             | 3 => 3
                             | 4 => 4
                             | _ => 0 
                             end.
1

There are 1 answers

8
Jason Gross On BEST ANSWER

You can put your term in a hypothesis and autounfold in that. That is, you can replace

autounfold with M_db (* in A *)

with

let Aterm := fresh in
set (Aterm := A);
autounfold with M_db in Aterm;
subst Aterm

If your A is too big, this will be slow, because set is slightly complicated and does some sort of reduction. If this is the case, you can set up your goal so that you have:

HA     : A' = A
HB     : B' = B
HC     : C' = C
HD     : D' = D
HE     : E' = E
HAB    : AB = A' * B'
HABC   : ABC = AB * C'
HABCD  : ABCD = ABC * D'
HABCDE : ABCDE = ABCD * E'
------------------------
ABCDE = F

and then you can do something like

Ltac reduce H :=
  autounfold with M_db in H; simpl in H; autorewrite with C_db in H.

reduce HA; reduce HB; reduce HC; reduce HD; reduce HE;
subst A' B'; reduce HAB;
subst AB C'; reduce HABC;
subst ABC D'; reduce HABCD;
subst ABCD E'; reduce HABCDE;
subst ABCDE.

Update to account for the example:

To do the reduction on your function, you do indeed need either function extensionality, or to use a relation other than =. However, you don't need function extensionality to do the modularization bits:

Require Import Arith.

Definition add_f (f g : nat -> nat) :=  fun x => f x + g x.

Infix "+" := add_f.

Definition f := fun x => if x =? 4 then 1 else 0.
Definition g := fun x => if x <=? 4 then 3 else 0.
Definition h := fun x => if x =? 2 then 2 else 0.

Ltac save x x' H :=
  remember x as x' eqn:H in *.

Lemma ex : f + g + h = fun x => match x with
                                | 0 => 3
                                | 1 => 3
                                | 2 => 5
                                | 3 => 3
                                | 4 => 4
                                | _ => 0 
                                end.
Proof.
  save f f' Hf; save g g' Hg; save h h' Hh;
  save (f' + g') fg Hfg; save (fg + h') fgh Hfgh.
  cbv [f g] in *.
  subst f' g'.
  cbv [add_f] in Hfg.
  (* note: if you want to simplify [(if x =? 4 then 1 else 0) +
      (if x <=? 4 then 3 else 0)], then you need function
      extensionality.  However, you don't need it simply to
      modularize the simplification. *)

Alternatively, if you set up your goal a bit differently, you can avoid function extensionality:

Require Import Arith Coq.Classes.RelationClasses Coq.Setoids.Setoid Coq.Classes.Morphisms.

Definition add_f (f g : nat -> nat) :=  fun x => f x + g x.

Infix "+" := add_f.

Definition f := fun x => if x =? 4 then 1 else 0.
Definition g := fun x => if x <=? 4 then 3 else 0.
Definition h := fun x => if x =? 2 then 2 else 0.

Ltac save x x' H :=
  remember x as x' eqn:H in *.
Definition nat_case (P : nat -> Type) (o : P 0) (s : forall n, P (S n)) (x : nat) : P x
  := match x with
     | 0 => o
     | S n' => s n'
     end.
Lemma nat_case_plus (a a' : nat) (b b' : nat -> nat) (x : nat)
  : (nat_case _ a b x + nat_case _ a' b' x)%nat = nat_case _ (a + a')%nat (fun x => b x + b' x)%nat x.
Proof. destruct x; reflexivity. Qed.
Lemma nat_case_plus_const (a : nat) (b : nat -> nat) (x : nat) (y : nat)
  : (nat_case _ a b x + y)%nat = nat_case _ (a + y)%nat (fun x => b x + y)%nat x.
Proof. destruct x; reflexivity. Qed.
Global Instance nat_case_Proper {P} : Proper (eq ==> forall_relation (fun _ => eq) ==> forall_relation (fun _ => eq)) (nat_case P).
Proof.
  unfold forall_relation; intros x x' ? f f' Hf [|a]; unfold nat_case; auto.
Qed.
Global Instance nat_case_Proper' {P} : Proper (eq ==> pointwise_relation _ eq ==> forall_relation (fun _ => eq)) (nat_case (fun _ => P)).
Proof.
  unfold forall_relation, pointwise_relation; intros x x' ? f f' Hf [|a]; unfold nat_case; auto.
Qed.
Global Instance nat_case_Proper'' {P} {x} : Proper (pointwise_relation _ eq ==> eq ==> eq) (nat_case (fun _ => P) x).
Proof.
  intros ??? a b ?; subst b; destruct a; simpl; auto.
Qed.
Global Instance nat_case_Proper''' {P} {x} : Proper (forall_relation (fun _ => eq) ==> eq ==> eq) (nat_case (fun _ => P) x).
Proof.
  intros ??? a b ?; subst b; destruct a; simpl; auto.
Qed.
Ltac reduce :=
  let solve_tac := unfold nat_case; repeat match goal with |- context[match ?x with O => _ | _ => _ end] => destruct x end; reflexivity in
  repeat match goal with
         | [ H : context[if ?x =? 4 then ?a else ?b] |- _ ]
           => replace (if x =? 4 then a else b) with (match x with 4 => a | _ => b end) in H by solve_tac
         | [ H : context[if ?x =? 2 then ?a else ?b] |- _ ]
           => replace (if x =? 2 then a else b) with (match x with 2 => a | _ => b end) in H by solve_tac
         | [ H : context[if ?x <=? 4 then ?a else ?b] |- _ ]
           => replace (if x <=? 4 then a else b) with (match x with 0 | 1 | 2 | 3 | 4 => a | _ => b end) in H by solve_tac
         | [ H : context G[match ?x as x' in nat return @?T x' with O => ?a | S n => @?s n end] |- _ ]
           => let G' := context G[@nat_case T a s x] in
              change G' in H
         | [ H : context G[fun v => match @?x v as x' in nat return @?T x' with O => ?a | S n => @?s n end] |- _ ]
           => let G' := context G[fun v => @nat_case T a s (x v)] in
              change G' in H; cbv beta in *
         | [ H : context[(nat_case _ _ _ _ + nat_case _ _ _ _)%nat] |- _ ]
           => progress repeat setoid_rewrite nat_case_plus in H; simpl in H
         | [ H : context[(nat_case _ _ _ _ + _)%nat] |- _ ]
           => progress repeat setoid_rewrite nat_case_plus_const in H; simpl in H
         end.
Lemma ex : forall x, (f + g + h) x = match x with
                                     | 0 => 3
                                     | 1 => 3
                                     | 2 => 5
                                     | 3 => 3
                                     | 4 => 4
                                     | _ => 0 
                                     end.
Proof.
  intro x; cbv [add_f].
  save (f x) f' Hf; save (g x) g' Hg; save (h x) h' Hh; save (f' + g')%nat fg Hfg; save (fg + h')%nat fgh Hfgh.
  cbv [f g] in *.
  subst f' g'; reduce.
  cbv [h] in *; reduce.
  subst fg h'; reduce.
  subst fgh.
  unfold nat_case.
  reflexivity.
Qed.