theory BitSet
  imports Main "../lib/Base_MEC"
begin

(*General lemmas*)
lemma eq_nat_mod_div: "i mod n = x mod n \<Longrightarrow> i div n = x div n \<Longrightarrow> (x::nat) = i" 
  by (smt (verit, ccfv_SIG) divide_less_cancel mod_eq_self_iff_div_eq_0 of_nat_0_eq_iff of_nat_eq_iff real_of_nat_div_aux)

lemma ndvd_pred_lmod: "0 < (m::nat) \<Longrightarrow> \<not>m dvd n \<Longrightarrow> (n - 1) mod m = (n mod m) - 1"
  by (smt (verit) Suc_diff_1 Suc_to_right dvd_0_right mod_Suc mod_eq_0_iff_dvd not_gr_zero)

lemma dvd_pred_lmod: "(n::nat) > 0 \<Longrightarrow> m > 0 \<Longrightarrow> m dvd n \<Longrightarrow> (n - 1) mod m = m - 1"
  by (metis Suc_diff_1 lessI less_nat_zero_code mod_Suc mod_eq_0_iff_dvd not_less_eq not_less_less_Suc_eq)

find_theorems "\<not> (?a dvd ?b)" "(div)"
lemma ndvd_pred_ldiv: "0 < (m::nat) \<Longrightarrow> \<not>m dvd n \<Longrightarrow> (n - 1) div m = n div m" 
  proof -
    assume A1: "0 < m"
    and  A2: "\<not>m dvd n"
    then have B1:"n > 0" and B2:"m > 1" 
      apply(fold mod_greater_zero_iff_not_dvd)
      by (metis gr0I mod_less less_one mod_by_1 nat_neq_iff)+
    hence "m * ((n - 1) div m) + (n - 1) mod m + 1 = m * (n div m) + n mod m"
      by force
    moreover have "((n - 1) mod m) + 1 = n mod m"
      using ndvd_pred_lmod[OF _ A2] A2 B2
      apply(fold mod_greater_zero_iff_not_dvd)
      by auto
    ultimately have "m * ((n - 1) div m) = m * (n div m)" 
      by linarith
    then show ?thesis 
      by fastforce
  qed

lemma dvd_pred_ldiv: 
  assumes MN0: "0 < (m::nat)"
  and NN0: "0 < n" 
  and NDM: "m dvd n" 
  shows "(n - 1) div m = n div m - 1"
  proof -
    from MN0 NN0 NDM have "m * (n div m) = m * ((n - 1) div m) + (n - 1) mod m + 1"
      by simp
    also have "... = m * ((n - 1) div m) + (m - 1) + 1"
      using dvd_pred_lmod[OF NN0 MN0 NDM] 
      by argo
    also have "... = m * ((n - 1) div m) + m"
      using MN0
      by auto
    also have "... = m * ((n - 1) div m + 1)"
      by algebra
    finally have "n div m = (n - 1) div m + 1"
      using MN0 mult_left_cancel by blast
    thus ?thesis by auto
  qed
    

lemma distrib_div_add: "((a::nat) mod c + b mod c < c) \<Longrightarrow> a div c + b div c = (a + b) div c"
  proof -
    assume ASM: "a mod c + b mod c < c"
    hence DIS: "(a + b) mod c = a mod c + b mod c" 
      by (metis mod_add_eq nat_mod_eq)

    have CN0:"c > 0" using ASM by linarith

    have A1: "a + b = c * ((a + b) div c) + (a + b) mod c" by presburger
    have "a + b = c * (a div c) + c * (b div c) + a mod c + b mod c" by simp
    also have "... = c * ((a div c) + (b div c)) + a mod c + b mod c" by algebra
    finally have "c * ((a + b) div c) + (a + b) mod c = c * ((a div c) + (b div c)) + a mod c + b mod c" using A1 by argo
    thus ?thesis using DIS CN0 by force
  qed


section \<open>div_round_up\<close>

definition div_round_up :: "nat \<Rightarrow> nat \<Rightarrow> nat" where 
  "div_round_up n m = (if m dvd n then n div m else n div m + 1)"

definition div_round_up2 :: "nat \<Rightarrow> nat \<Rightarrow> nat" where 
  "div_round_up2 n m = (n + m - 1) div m"

definition div_up :: "nat \<Rightarrow> nat \<Rightarrow> nat" where 
  "div_up n m = (if n = 0 then 0 else ((n - 1) div m + 1))"

find_theorems "(dvd)" "(+)"

lemma div_round_up_eq_div_round_up2: "m > 0 \<Longrightarrow> div_round_up n m = div_round_up2 n m"
  proof -
    assume MN0: "m > 0"
    {
      fix n m::nat
      assume MN0: "m > 0"
      assume DVD: "m dvd n"
      hence "n > 0 \<Longrightarrow> ((n + m) - 1) div m = (n + m) div m - 1" 
        using dvd_pred_ldiv[OF _ _ DVD[folded dvd_add_triv_right_iff[of m n]]] MN0
        by blast
      also have "... = n div m + m div m - 1"
        by (simp add: DVD)
      finally have "n > 0 \<Longrightarrow> n div m = ((n + m) - 1) div m" 
        using MN0
        by force
      hence "n div m = (n + m - 1) div m" 
        apply(cases "n = 0") 
        apply presburger
        apply auto
        done
    } note AUX1 = this

    {      
      fix n m::nat
      assume MN0: "m > 0"
      assume NDVD: "\<not>m dvd n"
      have "(n + m - 1) div m = n div m + 1" 
        using ndvd_pred_ldiv[OF MN0 NDVD[folded dvd_add_triv_right_iff[of m n]]]
        by (simp add: MN0)
    } note AUX2 = this

    from AUX1[OF MN0, of n] AUX2[OF MN0, of n] show ?thesis
      by(auto simp: div_round_up_def div_round_up2_def)
  qed


lemma div_round_up2_eq_div_up: "m > 0 \<Longrightarrow> div_round_up2 n m = div_up n m"
  proof(cases "n = 0")
    case True
    moreover assume "m > 0"
    ultimately show ?thesis 
      unfolding div_round_up2_def div_up_def 
      by simp
  next
    case False
    assume A2: "m > 0"

    have "(n - 1) div m + 1 = (n - 1) div m + m div m"
      by (simp add: A2)
    also have "... = ((n - 1) + m) div m" 
      apply(rule distrib_div_add)
      by (auto simp: A2)
    also have "... = (n + m - 1) div m"
      using False by fastforce
    finally show ?thesis
      unfolding div_round_up2_def div_up_def
      by simp
  qed


lemma "m > 0 \<Longrightarrow> div_round_up n m = div_up n m"
  using div_round_up_eq_div_round_up2 div_round_up2_eq_div_up
  by simp
            
     
lemma round_up_ge: "m > 0 \<Longrightarrow> n \<le> m * div_up n m"
  unfolding div_up_def
  apply(simp split: if_split)
  apply(fold minus_mod_eq_mult_div[of "n - Suc 0" m])
  apply(cases "m dvd n")  
  using dvd_pred_lmod apply auto[]
  using ndvd_pred_lmod apply simp
  apply(fold mod_greater_zero_iff_not_dvd)
  apply simp 
  using pos_mod_bound[of m] 
  by (meson le_add_diff mod_le_divisor)



sepref_decl_op div_round_up: "div_up" :: "[\<lambda> (n, m). m > 0]\<^sub>f nat_rel \<times>\<^sub>r nat_rel \<rightarrow> nat_rel" .

definition "div_round_up_impl n m = 
    do{
      ASSERT(m > 0);
      let x = (n div m);
      if n = 0 then
        RETURN 0
      else
        do {
          ASSERT((n - 1) div m + 1 \<le> n);
          RETURN((n - 1) div m + 1)
        }
    }"



lemma div_round_up2_refine: "(div_round_up_impl, mop_div_round_up) \<in> nat_rel \<rightarrow> nat_rel \<rightarrow> \<langle>nat_rel\<rangle>nres_rel"
  unfolding div_round_up_impl_def mop_div_round_up_def op_div_round_up_def div_up_def
  apply (clarsimp simp: )
  subgoal for n m
    apply(cases "0 < m")
    subgoal
      apply(auto simp: refine_pw_simps pw_nres_rel_iff Let_def) 
      apply (metis Suc_le_mono Suc_pred div_le_dividend)
      by (metis Suc_le_mono Suc_pred div_le_dividend)
    subgoal
      by(auto simp: refine_pw_simps pw_nres_rel_iff Let_def)
    done
  done


  
sepref_def div_round_up_ll is "uncurry div_round_up_impl" :: "(snat_assn' TYPE('a::len2))\<^sup>k *\<^sub>a (snat_assn' TYPE('a))\<^sup>k \<rightarrow>\<^sub>a (snat_assn' TYPE('a))"
  unfolding div_round_up_impl_def
  apply (annot_snat_const "TYPE('a)")
  apply sepref
  done

sepref_decl_impl(ismop) div_round_up_ll.refine[FCOMP div_round_up2_refine] .


section \<open>Bitwise Operations Sepref\<close>


definition get_bit_impl :: "size_t \<Rightarrow> nat \<Rightarrow> bool nres" where "get_bit_impl w i = 
  do { 
    let v = 1 << i; 
    let v = w AND v; 
    RETURN (v \<noteq> 0) 
  }"

lemma get_bit_impl_refine: "get_bit_impl w i \<le> RETURN (w !! i)"
  unfolding get_bit_impl_def
  apply refine_vcg 
  using and_eq_0_is_nth by auto


sepref_def get_bit_ll is "uncurry get_bit_impl" :: "[\<lambda> (_, i). i < len_size_T]\<^sub>a id_assn\<^sup>k *\<^sub>a size_assn\<^sup>k \<rightarrow> bool1_assn"
  unfolding get_bit_impl_def 
  apply sepref
  done


lemma get_bit_impl_rel_refine: "(uncurry get_bit_impl, uncurry (RETURN oo (!!))) \<in> word_rel \<times>\<^sub>r nat_rel \<rightarrow> \<langle>bool_rel\<rangle>nres_rel"
  using get_bit_impl_refine
  apply (auto intro: nres_relI)
  done

lemmas [sepref_fr_rules] = get_bit_ll.refine[FCOMP get_bit_impl_rel_refine]


find_in_thms "(NOT)" in sepref_fr_rules

definition set_bit_impl :: "size_t \<Rightarrow> nat \<Rightarrow> bool \<Rightarrow> size_t nres" where "set_bit_impl w i b = 
  do{ 
    if b then do{ 
      let v = 1 << i; 
      let v = w OR v; 
      RETURN v} 
    else do {
      let v = 1 << i; 
      let v = NOT v; 
      let v = w AND v; 
      RETURN v
    }
  }"

find_theorems set_bit
thm set_bit_eq_or
thm unset_bit_eq_and_not

lemma set_bit_impl_refine: "set_bit_impl w i b \<le> RETURN (set_bit w i b)"
  unfolding set_bit_impl_def
  apply(refine_vcg)
  by (auto simp add: set_bit_eq set_bit_eq_or unset_bit_eq_and_not)


sepref_def set_bit_ll is "uncurry2 set_bit_impl" :: "[\<lambda> ((_, i), _). i < len_size_T]\<^sub>a id_assn\<^sup>d *\<^sub>a size_assn\<^sup>k *\<^sub>a bool1_assn\<^sup>k \<rightarrow> id_assn"
  unfolding set_bit_impl_def 
  apply sepref
  done

lemma set_bit_impl_rel_refine: "(uncurry2 set_bit_impl, uncurry2 (RETURN ooo (set_bit))) \<in> (word_rel \<times>\<^sub>r nat_rel) \<times>\<^sub>r bool_rel \<rightarrow> \<langle>word_rel\<rangle>nres_rel"
  using set_bit_impl_refine
  apply (auto intro: nres_relI)
  done

lemmas [sepref_fr_rules] = set_bit_ll.refine[FCOMP set_bit_impl_rel_refine]


section \<open>Bitset Operations\<close>
definition "ll_get_bit w i = doM {
    v \<leftarrow> ll_shl 1 i;
    v' \<leftarrow> ll_and v w;
    (ll_icmp_ne v' 0)
  }"
 
(*
lemma "llvm_htriple
  (size_assn w wi \<and>* size_assn i ii \<and>* \<up>(i < len_size_T))
  (ll_shl wi ii)
  (\<lambda>ri. size_assn (w * 2 ^ i) ri)"
  unfolding ll_shl_def op_lift_arith2_def bitSHL'_def bitSHL_def
  apply(vcg_monadify)
  apply vcg'
  subgoal sorry
  apply vcg'
  apply auto
  sorry*)

term ll_shl

find_theorems wpa
find_in_thms ll_shl in vcg_rules
find_in_thms ll_shl in vcg_normalize_simps
find_in_thms ll_shl in vcg_framed_erules
find_theorems "(**)" "(\<and>)"


context begin

interpretation llvm_prim_setup .


section \<open>Bitset\<close>



type_synonym bitset = "size_t list"


definition bitset_get :: "nat \<Rightarrow> bitset \<Rightarrow> bool" where "bitset_get i bi \<equiv> (bi ! (i div len_size_T) !! (i mod len_size_T))"

definition bitset_set :: "nat \<Rightarrow> bitset \<Rightarrow> bool \<Rightarrow> bitset" where "bitset_set i bi b = (let j = i div len_size_T; k = i mod len_size_T in bi[j:=set_bit (bi ! j) k b])"

definition bitset_empty :: "nat \<Rightarrow> bitset" where "bitset_empty n = replicate (op_div_round_up n len_size_T) 0"

definition "bitset_insert bs i = bitset_set bs i True"

definition "bitset_delete bs i = bitset_set bs i False"

definition bitset_\<alpha> :: "bitset \<Rightarrow> nat set" where "bitset_\<alpha> bs = Collect (\<lambda> i. bitset_get i bs \<and> i < len_size_T * length bs)"


(*TODO naming*)
lemma bitset_empty_\<alpha>[simp]: "bitset_\<alpha> (bitset_empty n) = {}"
  unfolding bitset_\<alpha>_def bitset_empty_def bitset_get_def
  apply auto
  done

lemma bitset_get_\<alpha>[simp]: "i < LENGTH(size_T) * length bs \<Longrightarrow> bitset_get i bs \<longleftrightarrow> i \<in> bitset_\<alpha> bs"
  unfolding bitset_\<alpha>_def 
  apply simp
  done

lemma bitset_insert_\<alpha>[simp]: "i < LENGTH(size_T) * length bs \<Longrightarrow> bitset_\<alpha> (bitset_insert i bs) = insert i (bitset_\<alpha> bs)"
  unfolding bitset_insert_def bitset_\<alpha>_def bitset_set_def bitset_get_def Let_def 
  apply (auto simp: nth_list_update' bit_set_bit_word_iff mult.commute td_gal_lt intro: eq_nat_mod_div)
  done
  
lemma bitset_delete_\<alpha>[simp]: "bitset_\<alpha> (bitset_delete i bs) = bitset_\<alpha> bs - {i}"
  unfolding bitset_delete_def bitset_\<alpha>_def bitset_set_def bitset_get_def Let_def 
  apply (auto simp: nth_list_update' bit_set_bit_word_iff mult.commute td_gal_lt intro: eq_nat_mod_div)
  done

definition bitset_invar :: "nat \<Rightarrow> bitset \<Rightarrow> bool" where "bitset_invar n bs = (n \<le> LENGTH(size_T) * length bs)"


lemma bitset_empty_invar[simp]: "bitset_invar n (bitset_empty n)"
  unfolding bitset_invar_def bitset_empty_def
  apply (auto simp: round_up_ge)
  done

lemma bitset_set_invar [simp]: "bitset_invar n bs \<Longrightarrow> bitset_invar n (bitset_set i bs b)"
  unfolding bitset_invar_def bitset_set_def Let_def 
  apply auto
  done

lemma bitset_insert_invar [simp]: "bitset_invar n bs \<Longrightarrow> bitset_invar n (bitset_insert i bs)"
  unfolding bitset_insert_def 
  by simp

lemma bitset_delete_invar [simp]: "bitset_invar n bs \<Longrightarrow> bitset_invar n (bitset_delete i bs)"
  unfolding bitset_delete_def 
  by simp 

type_synonym bitseti = "size_t ptr"

definition "bitset_rel n = br bitset_\<alpha> (bitset_invar n)"
definition 
  bitset_assn :: "size_t list \<Rightarrow> bitseti \<Rightarrow> llvm_amemory \<Rightarrow> bool" where
  "bitset_assn = IICF_Array.array_assn (id_assn::size_t \<Rightarrow> _)"

lemma bitset_empty_refine: "(bitset_empty n, {}) \<in> bitset_rel n"
  unfolding bitset_rel_def 
  apply(auto simp: in_br_conv)
  done

lemma bitset_invar_i_bound: "i < n \<Longrightarrow> bitset_invar n bs \<Longrightarrow> i < len_size_T * length bs "
  unfolding bitset_invar_def
  by simp

lemma bitset_get_refine: "(bitset_get, op_set_member) \<in> (nbn_rel n) \<rightarrow> bitset_rel n \<rightarrow> bool_rel"
  apply (auto simp: bitset_rel_def in_br_conv bitset_invar_i_bound)
  done

lemma bitset_insert_refine: "(bitset_insert, op_set_insert) \<in> (nbn_rel n) \<rightarrow> bitset_rel n \<rightarrow> bitset_rel n"
  apply(auto simp: bitset_rel_def in_br_conv bitset_invar_i_bound) 
  done

lemma bitset_delete_refine: "(bitset_delete, op_set_delete) \<in> (nbn_rel n) \<rightarrow> bitset_rel n \<rightarrow> bitset_rel n"
  apply(auto simp: bitset_rel_def in_br_conv) 
  done

term array_assn
term narray_assn

context
  fixes n ni
  assumes n_impl[sepref_import_param]: "(ni, n) \<in> size_rel"
  notes [[sepref_register_adhoc n]]
begin


sepref_definition bitset_empty_ll is "uncurry0 (RETURN (bitset_empty n))" :: "unit_assn\<^sup>k \<rightarrow>\<^sub>a IICF_Array.array_assn id_assn"
  unfolding bitset_empty_def
  apply (fold array_replicate_init_def )
  apply (annot_snat_const "TYPE(size_T)")
  apply sepref
  done

(* bitset operation outside context*)
  

end


sepref_def bitset_get_ll is "uncurry (RETURN oo bitset_get)" :: "[\<lambda>(i,l). i<length l * len_size_T]\<^sub>a size_assn\<^sup>k *\<^sub>a bitset_assn\<^sup>k \<rightarrow> bool1_assn"
  unfolding bitset_get_def bitset_assn_def
  apply (annot_snat_const "TYPE(size_T)")
  apply sepref
  done


sepref_def bitset_insert_ll is "uncurry (RETURN oo bitset_insert)" :: "[\<lambda>(i,l). i<length l * len_size_T]\<^sub>a size_assn\<^sup>k *\<^sub>a bitset_assn\<^sup>d \<rightarrow> bitset_assn"
  unfolding bitset_insert_def bitset_set_def bitset_assn_def
  apply (annot_snat_const "TYPE(size_T)")
  apply sepref
  done

sepref_def bitset_delete_ll is "uncurry (RETURN oo bitset_delete)" :: "[\<lambda>(i,l). i<length l * len_size_T]\<^sub>a size_assn\<^sup>k *\<^sub>a bitset_assn\<^sup>d \<rightarrow> bitset_assn"
  unfolding bitset_delete_def bitset_set_def bitset_assn_def
  apply (annot_snat_const "TYPE(size_T)")
  apply sepref
  done


end
end