diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 6d0cdb9fcb5b..12386875ca5d 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -164,7 +164,8 @@ theorem add_eq_or_of_and_eq_zero {w : Nat} (x y : BitVec w) (h : x &&& y = 0#w) : x + y = x ||| y := by rw [add_eq_adc, adc, iunfoldr_replace (fun _ => false) (x ||| y)] · rfl - · simp [adcb, atLeastTwo, h] + · simp only [adcb, atLeastTwo, Bool.and_false, Bool.or_false, bne_false, getLsb_or, + Prod.mk.injEq, and_eq_false_imp] intros i replace h : (x &&& y).getLsb i = (0#w).getLsb i := by rw [h] simp only [getLsb_and, getLsb_zero, and_eq_false_imp] at h @@ -251,6 +252,10 @@ theorem sle_eq_carry (x y : BitVec w) : /-! ### mul recurrence for bitblasting -/ +/-- +A recurrence that describes multiplication as repeated addition. +Is useful for bitblasting multiplication. +-/ def mulRec (l r : BitVec w) (s : Nat) : BitVec w := let cur := if r.getLsb s then (l <<< s) else 0 match s with @@ -264,6 +269,10 @@ theorem mulRec_zero_eq (l r : BitVec w) : theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : mulRec l r (s + 1) = mulRec l r s + if r.getLsb (s + 1) then (l <<< (s + 1)) else 0 := rfl +/-- +When the `(i+1)`th bit of `x` is false, +keeping the lower `(i + 1)` bits of `x` equals keeping the lower `i` bits. +-/ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false {x : BitVec w} {i : Nat} (hx : x.getLsb i = false) : zeroExtend w (x.truncate (i + 1)) = @@ -275,6 +284,11 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false simp [hx] · by_cases hik' : k < i + 1 <;> simp [hik'] <;> omega +/-- +When the `(i+1)`th bit of `x` is true, +keeping the lower `(i + 1)` bits of `x` equalsk eeping the lower `i` bits +and then performing bitwise-or with `twoPow i = (1 << i)`, +-/ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true {x : BitVec w} {i : Nat} (hx : x.getLsb i = true) : zeroExtend w (x.truncate (i + 1)) = @@ -286,19 +300,20 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true simp [hx] · by_cases hik' : k < i + 1 <;> simp [hik, hik'] <;> omega -/-- Recurrence lemma: truncating to `i+1` bits and then zero extending to `w` -equals truncating upto `i` bits `[0..i-1]`, and then adding the `i`th bit of `x`. -/ +/-- +Recurrence lemma: truncating to `i+1` bits and then zero extending to `w` +equals truncating upto `i` bits `[0..i-1]`, and then adding the `i`th bit of `x`. +-/ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w) (i : Nat) : zeroExtend w (x.truncate (i + 1)) = zeroExtend w (x.truncate i) + (x &&& twoPow w i) := by rw [add_eq_or_of_and_eq_zero] · ext k simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, getLsb_or, getLsb_and] - by_cases hik:i = k + by_cases hik : i = k · subst hik simp - · simp [hik] - /- Really, 'omega' should be able to do this-/ + · simp only [getLsb_twoPow, hik, decide_False, Bool.and_false, Bool.or_false] by_cases hik' : k < (i + 1) · have hik'' : k < i := by omega simp [hik', hik''] @@ -308,16 +323,21 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow (x : BitVec w simp omega +/-- +Recurrence lemma: multiplying `l` with the first `s` bits of `r` is the +same as truncating `r` to `s` bits, then zero extending to the original length, +and performing the multplication. -/ theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : mulRec l r s = l * ((r.truncate (s + 1)).zeroExtend w) := by induction s case zero => - simp [mulRec_zero_eq] + simp only [mulRec_zero_eq, ofNat_eq_ofNat, Nat.reduceAdd] by_cases r.getLsb 0 case pos hr => simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero, hr, ofBool_true, ofNat_eq_ofNat] - rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]; simp + rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)] + simp case neg hr => simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero] case succ s' hs => @@ -330,11 +350,12 @@ theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : rw [heq, ← BitVec.mul_add, ← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_twoPow] /-- Zero extending by number of bits larger than the bitwidth has no effect. -/ -theorem zeroExtend_of_ge {x : BitVec w} {i j : Nat} (hi : i ≥ w) : +theorem zeroExtend_zeroExtend_of_ge {x : BitVec w} {i j : Nat} (hi : i ≥ w) : (x.zeroExtend i).zeroExtend j = x.zeroExtend j := by ext k - simp - intros hx; + simp only [getLsb_zeroExtend, Fin.is_lt, decide_True, Bool.true_and, Bool.and_iff_right_iff_imp, + decide_eq_true_eq] + intros hx have hi' : k < w := BitVec.lt_of_getLsb _ _ hx omega @@ -345,7 +366,7 @@ theorem zeroExtend_eq_self {x : BitVec w} : x.zeroExtend w = x := by theorem getLsb_mul (x y : BitVec w) (i : Nat) : (x * y).getLsb i = (mulRec x y w).getLsb i := by - simp [mulRec_eq_mul_signExtend_truncate] - rw [truncate, zeroExtend_of_ge (by omega), zeroExtend_eq_self] + simp only [mulRec_eq_mul_signExtend_truncate] + rw [truncate, zeroExtend_zeroExtend_of_ge (by omega), zeroExtend_eq_self] end BitVec diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 8c1fd59ffd42..2e6bdfadb9d3 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -167,7 +167,7 @@ private theorem lt_two_pow_of_le {x m n : Nat} (lt : x < 2 ^ m) (le : m ≤ n) : theorem getLsb_ofBool (b : Bool) (i : Nat) : (BitVec.ofBool b).getLsb i = ((i = 0) && b) := by rcases b with rfl | rfl · simp [ofBool] - · simp [ofBool, getLsb_ofNat] + · simp only [ofBool, ofNat_eq_ofNat, cond_true, getLsb_ofNat, Bool.and_true] by_cases hi : (i = 0) · simp [hi] · simp [hi] @@ -423,8 +423,8 @@ theorem zeroExtend_one_eq_ofBool_getLsb_zero (x : BitVec w) : /-- Zero extending `1#v` to `1#w` equals `1#w` when `v > 0`. -/ theorem zeroExtend_ofNat_one_eq_ofNat_one_of_lt {v w : Nat} (hv : 0 < v) : (BitVec.ofNat v 1).zeroExtend w = BitVec.ofNat w 1 := by - ext i - obtain ⟨i, hilt⟩ := i + ext ⟨i, hilt⟩ + obtain := i simp only [getLsb_zeroExtend, hilt, decide_True, getLsb_ofNat, Bool.true_and, Bool.and_iff_right_iff_imp, decide_eq_true_eq] intros hi1 @@ -1176,7 +1176,7 @@ theorem BitVec.mul_zero {x : BitVec w} : x * 0#w = 0#w := by theorem BitVec.mul_add {x y z : BitVec w} : x * (y + z) = x * y + x * z := by apply eq_of_toNat_eq - simp + simp only [toNat_mul, toNat_add, Nat.add_mod_mod, Nat.mod_add_mod] rw [Nat.mul_mod, Nat.mod_mod (y.toNat + z.toNat), ← Nat.mul_mod, Nat.mul_add]