(**************************************************************************
* Proving functional programs through a deep embedding                    *
* Primitive functions of the embedded language                            *
***************************************************************************)

Set Implicit Arguments.
Require Export DeepTactics.

(** This file contains specification for basic operations.
    It includes operations defined in the semantics of the
    embedded language (add, sub, mul, div, equ, leq), as
    well as common functions which are defined in this file
    (not, ||, &&, neq, geq, lt, gt). *)


(************************************************************)
(* ** Definition of types of comparable elements *)

(** A type is said to be [comparable] if the builtin 
    equality [val_cmp] implements the logical equality
    on values of this type. This definition is used in
    the specification of the polymorphic equality.
    
    More precisely, [comparable _A] holds if for any two 
    values [V1] and [V2] of type [A], the application of 
    [val_cmp] to the encoding of these two values through 
    encoder [_A] returns a boolean equal to the logical 
      comparison of [V1] with [V2]. *)

Definition comparable (A:obj) (_A:code A) :=
  forall V1 V2, (val_cmp (_A V1) (_A V2)) = (V1 == V2).

Implicit Arguments comparable [A].

(** In the rest of this section we show that the predicate
    [comparable] holds for basic types: int, unit, bool,
    option, list, and tuples. *)

Section Comparable.

Variables (A1 A2 A3 A4 : obj).
Variables (_A1 : code A1) (_A2 : code A2)
          (_A3 : code A3) (_A4 : code A4).
Variables (H1 : comparable _A1) (H2 : comparable _A2)
          (H3 : comparable _A3) (H4 : comparable _A4).

Lemma comparable_int : comparable _Int.
Proof. intros_all~. Qed.

Lemma comparable_unit : comparable _Unit.
Proof. intros_all. rewrite~ (@unit_unique V1 V2). Qed.

Lemma comparable_bool : comparable _Bool.
Proof. intros_all. destruct V1; destruct~ V2. Qed.

Lemma comparable_option : comparable (_Option _A1).
Proof.
  intros_all. destruct V1; destruct V2; simpls~.
  rewrite H1. calc_bool*.
Qed.

Lemma comparable_list : comparable (_List _A1).
Proof.
  intros_all. gen V2. induction V1; destruct V2; simpls~.
  rewrite H1. calc_bool. rewrite* IHV1.
Qed.

Lemma comparable_tup2 : comparable (#2 _A1 _A2).
Proof.
  intros_all. destruct V1; destruct V2. simpl.
  rewrite H1. rewrite H2. calc_bool; rewrite~ beq_prod2.
Qed.

Lemma comparable_tup3 : comparable (#3 _A1 _A2 _A3).
Proof.
  intros_all. destruct V1 as [[X1 X2] X3].
  destruct V2 as [[Y1 Y2] Y3]. simpl.
  rewrite H1. rewrite H2. rewrite H3. 
  calc_bool. rewrite beq_prod3. rewrite* assoc_and.
Qed.

Lemma comparable_tup4 : comparable (#4 _A1 _A2 _A3 _A4).
Proof.
  intros_all. destruct V1 as [[[X1 X2] X3] X4].
  destruct V2 as [[[Y1 Y2] Y3] Y4]. simpl.
  rewrite H1. rewrite H2. rewrite H3. rewrite H4. 
  calc_bool. rewrite beq_prod4. do 2 rewrite assoc_and. auto.
Qed.

End Comparable.


(************************************************************)
(* ** Notation for variables used in the definitions *)

Notation "'x" := (`1`0) (at level 0) : builtin_vars_scope.
Notation "'y" := (`2`0) (at level 0) : builtin_vars_scope.


(************************************************************)
(* ** Basic operations, part 1 *)

(** We define currified functions wrapping around primitive
    operations that take pairs of values as arguments. *)

Section DefBuiltin.
Open Scope builtin_vars_scope.

Definition prim_add := ('fun 'x 'y '->
  (\ val_builtin builtin_add) ' #('x,'y))%val.
Definition prim_sub := ('fun 'x 'y '-> 
  (\ val_builtin builtin_sub) ' #('x,'y))%val. 
Definition prim_mul := ('fun 'x 'y '-> 
  (\ val_builtin builtin_mul) ' #('x,'y))%val. 
Definition prim_div := ('fun 'x 'y '-> 
  (\ val_builtin builtin_div) ' #('x,'y))%val. 
Definition prim_equ := ('fun 'x 'y '-> 
  (\ val_builtin builtin_equ) ' #('x,'y))%val. 
Definition prim_leq := ('fun 'x 'y '-> 
  (\ val_builtin builtin_leq) ' #('x,'y))%val. 


Lemma add_spec : spec prim_add [n1:_Int] [n2:_Int] is
  >> _Int st = n1 + n2.
Proof. xintros n1 n2. xreds. xreturns. Qed.
Lemma sub_spec : spec prim_sub [n1:_Int] [n2:_Int] is
  >> _Int st = n1 - n2.
Proof. xintros n1 n2. xreds. xreturns. Qed.

Lemma mul_spec : spec prim_mul [n1:_Int] [n2:_Int] is
  >> _Int st = n1 * n2.
Proof. xintros n1 n2. xreds. xreturns. Qed.

Lemma div_spec_general : spec prim_div [n1:_Int] [n2:_Int] is
  if n2 == 0 then >! 'Div_by_zero
            else >> _Int st = exact_div n1 n2.
Proof. xintros n1 n2. xreds. testsb: (n2==0); xreturns. Qed.
 
Lemma div_spec : spec prim_div [n1:_Int] [n2:_Int] = t is
  (n2 != 0 -> t >> _Int st = exact_div n1 n2).
Proof.   
  xweakens div_spec_general as n1 n2.
  intros Ne. testsb: (n2 == 0).
    down Ne. auto. 
    auto.
Qed.

Lemma leq_spec : spec prim_leq [n1:_Int] [n2:_Int] is
  >> _Bool st = (n1 <= n2).
Proof. xintros n1 n2. xreds. xreturns. Qed.

Lemma equ_spec_base : 
  spec prim_equ [x1:_Val] [x2:_Val] is
    >> _Bool st = (val_cmp x1 x2).
Proof. xintros x1 x2. xreds. xreturns. Qed.

Lemma equ_spec : forall (A:obj) (_A:code A),
  comparable _A ->
  spec prim_equ [x1:_A] [x2:_A] is
    >> _Bool st = (x1 == x2).
Proof.  
  intros. intros x1. lets_simpl K: (equ_spec_base (_A x1)).
  destruct (behave_returns_inv K) as [X [PX S2]].
  apply* behave_returns.
  intros x2. lets_simpl K': (S2 (_A x2)). 
  destruct (behave_returns_inv K') as [Y [PY S3]].
  apply* behave_returns.
Qed.

Lemma equ_spec_int : spec prim_equ [n1:_Int] [n2:_Int] is
    >> _Bool st = (n1 == n2).
Proof. apply (equ_spec comparable_int). Qed.

Lemma equ_spec_bool : spec prim_equ [b1:_Bool] [b2:_Bool] is
    >> _Bool st = (b1 == b2).
Proof. apply (equ_spec comparable_bool). Qed.

End DefBuiltin.

Notation "t1 ''+' t2" := (\prim_add ' t1 ' t2) 
  (at level 50) : trm_scope.
Notation "t1 ''-' t2" := (\prim_sub ' t1 ' t2) 
  (at level 50) : trm_scope.
Notation "t1 ''*' t2" := (\prim_mul ' t1 ' t2) 
  (at level 40) : trm_scope.
Notation "t1 ''/' t2" := (\prim_div ' t1 ' t2) 
  (at level 40) : trm_scope.
Notation "t1 ''=' t2" := (\prim_equ ' t1 ' t2) 
  (at level 50) : trm_scope. 
Notation "t1 ''<=' t2" := (\prim_leq ' t1 ' t2) 
  (at level 50) : trm_scope.

Hint Resolve add_spec sub_spec mul_spec div_spec leq_spec
  equ_spec equ_spec_int equ_spec_bool : specs.


(************************************************************)
(* ** Basic operations, part 2 *)

(** In this section, we define boolean operators using the
    if-then-else construct, which is actually a syntactic
    sugar for pattern matching on boolean values. *)

Section DefBooleans.
Open Scope builtin_vars_scope.

Definition prim_not := 
  ('fun 'x '-> 'if 'x 'then false 'else true)%val.
Definition prim_and := 
  ('fun 'x 'y '-> 'if 'x 'then 'y 'else false)%val.
Definition prim_or := 
  ('fun 'x 'y '-> 'if 'x 'then true 'else 'y)%val.
Definition compare :=
  ('fun 'x 'y '-> 'if 'x '= 'y 'then _Int 0 'else _Int 1)%val.

Lemma not_spec : spec prim_not [b:_Bool] is
  >> _Bool st = neg b.
Proof. xintros b. xred. xcase; #xreturns. Qed.

Lemma or_spec : spec prim_or [b1:_Bool] [b2:_Bool] is
  >> _Bool st = b1 || b2.
Proof. xintros b1 b2. xred. xcase; #xreturns. calc_bool*. Qed.

Lemma and_spec : spec prim_and [b1:_Bool] [b2:_Bool] is
  >> _Bool st = b1 && b2.
Proof. xintros b1 b2. xred. xcase; #xreturns. calc_bool*. Qed.

Lemma compare_spec : forall (A:obj) (_A:code A),
  comparable _A ->
  spec compare [x1:_A] [x2:_A] is
    >> [n:_Int] st (n = 0 <-> x1 == x2).
Proof.  
  introv Cmp. xintros x1 x2. xred.
  xapplys. destruct (x1 == x2); xpat; xreturns;
  split; intros; try solve [ auto | false ].
Qed.

Lemma compare_zero : forall (A:obj) (_A:code A),
  comparable _A -> forall x1 x2,
  (compare ' _A x1 ' _A x2 '= 0) >> _Bool st = (x1 == x2).
Proof.
  intros. xapply compare_spec as r [E1 E2]. auto. 
  #xapplys. xreturns. testsb Eq: (r == 0); fold_bool.
    auto.
    up as EQ. substb x1. down Eq. rewrite~ E2.
Qed.

End DefBooleans.

Notation "''not' t" := (\prim_not ' t)
  (at level 31, left associativity) : trm_scope.
Notation "t1 ''||' t2" := (\prim_or ' t1 ' t2)
  (at level 50, left associativity) : trm_scope.
Notation "t1 ''&&' t2" := (\prim_and ' t1 ' t2)
  (at level 40, left associativity) : trm_scope.

Hint Resolve not_spec or_spec and_spec : specs.
  
(** A tactic to reason on terms of the form [compare v1 v2 = 0]. *)

Ltac xcompare :=
  match goal with |- 
   context [ (\compare ' \?v1 ' \?v2 '= \_Int (int_pos 0)) ] =>
   change (\compare ' \v1 ' \v2 '= \_Int (int_pos 0))
     with (ctx_marker (\compare ' \v1 ' \v2 '= \_Int (int_pos 0)));
   xin; [ fapplys compare_zero; eauto | xouts ]
   end. 


(************************************************************)
(* ** Basic operations, part 3 *)

(** Finally, we define other arithmetic comparison 
    functions: [neq] as the negation of [equ], and
    [geq], [lt], [gt] in terms of [leq]. *)

Section DefMoreArith.
Open Scope builtin_vars_scope.

Definition prim_neq := 
  ('fun 'x 'y '-> 'not ('x '= 'y))%val.
Definition prim_geq :=
  ('fun 'x 'y '-> ('y '<= 'x))%val.
Definition prim_gt := 
  ('fun 'x 'y '-> 'not ('x '<= 'y))%val.
Definition prim_lt := 
  ('fun 'x 'y '-> 'not ('y '<= 'x))%val.

Lemma neq_spec : forall (A:obj) (_A:code A),
  comparable _A ->
  spec prim_neq [x1:_A] [x2:_A] is
    >> _Bool st = (x1 != x2).
Proof. intros. xintros x1 x2. xred. xapplys. xapplys. xreturns. Qed.

Lemma geq_spec : spec prim_geq [n1:_Int] [n2:_Int] is
  >> _Bool st = n1 >= n2.
Proof. xintros n1 n2. xred. xapplys. xreturns. Qed.

Lemma lt_spec : spec prim_lt [n1:_Int] [n2:_Int] is
  >> _Bool st = n1 < n2.
Proof. xintros n1 n2. xred. xapplys. xapplys. xreturns. Qed.

Lemma gt_spec : spec prim_gt [n1:_Int] [n2:_Int] is
  >> _Bool st = n1 > n2.
Proof. xintros n1 n2. xred. xapplys. xapplys. xreturns. Qed.


End DefMoreArith.

Notation "t1 ''<>' t2" := (\prim_neq ' t1 ' t2) 
  (at level 50) : trm_scope.
Notation "t1 ''>=' t2" := (\prim_geq ' t1 ' t2) 
  (at level 50) : trm_scope.
Notation "t1 ''>' t2" := (\prim_gt ' t1 ' t2) 
  (at level 50) : trm_scope.
Notation "t1 ''<' t2" := (\prim_lt ' t1 ' t2)
  (at level 50) : trm_scope.

Hint Resolve lt_spec gt_spec geq_spec neq_spec : specs.


(************************************************************)
(* ** Basic operations, part 4 *)

(* Arithmetic bit shift is specified only for shifting by
   one bit. Its concrete definition is left abstract. *)

Variable prim_asr : val.
Variable prim_asr_spec : 
  spec prim_asr [x:_Int] [y:_Int] = r is
    y = 1 -> r >> _Int st = (exact_div x 2).

Hint Resolve prim_asr_spec : specs.