Skip to content

Commit

Permalink
feat: sshiftRight bitblasting (#4889)
Browse files Browse the repository at this point in the history
We follow the same strategy as
#4872,
#4571, and implement bitblasting
theorems for `sshiftRight`.

---------

Co-authored-by: Tobias Grosser <[email protected]>
  • Loading branch information
bollu and tobiasgrosser authored Aug 7, 2024
1 parent 1efd665 commit e106be1
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/Init/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,15 @@ def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s)
instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩
instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩

/--
Arithmetic right shift for bit vectors. The high bits are filled with the
most-significant bit.
As a numeric operation, this is equivalent to `a.toInt >>> s.toNat`.
SMT-Lib name: `bvashr`.
-/
def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat

/-- Auxiliary function for `rotateLeft`, which does not take into account the case where
the rotation amount is greater than the bitvector width. -/
def rotateLeftAux (x : BitVec w) (n : Nat) : BitVec w :=
Expand Down
61 changes: 61 additions & 0 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,67 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) :
· simp [of_length_zero]
· simp [shiftLeftRec_eq]

/- ### Arithmetic shift right (sshiftRight) recurrence -/

/--
`sshiftRightRec x y n` shifts `x` arithmetically/signed to the right by the first `n` bits of `y`.
The theorem `sshiftRight_eq_sshiftRightRec` proves the equivalence of `(x.sshiftRight y)` and `sshiftRightRec`.
Together with equations `sshiftRightRec_zero`, `sshiftRightRec_succ`,
this allows us to unfold `sshiftRight` into a circuit for bitblasting.
-/
def sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ :=
let shiftAmt := (y &&& (twoPow w₂ n))
match n with
| 0 => x.sshiftRight' shiftAmt
| n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt

@[simp]
theorem sshiftRightRec_zero_eq (x : BitVec w₁) (y : BitVec w₂) :
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by
simp only [sshiftRightRec, twoPow_zero]

@[simp]
theorem sshiftRightRec_succ_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by
simp [sshiftRightRec]

/--
If `y &&& z = 0`, `x.sshiftRight (y ||| z) = (x.sshiftRight y).sshiftRight z`.
This follows as `y &&& z = 0` implies `y ||| z = y + z`,
and thus `x.sshiftRight (y ||| z) = x.sshiftRight (y + z) = (x.sshiftRight y).sshiftRight z`.
-/
theorem sshiftRight'_or_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
(h : y &&& z = 0#w₂) :
x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by
simp [sshiftRight', ← add_eq_or_of_and_eq_zero _ _ h,
toNat_add_of_and_eq_zero h, sshiftRight_add]

theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
induction n generalizing x y
case zero =>
ext i
simp [twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsb, truncate_one]
case succ n ih =>
simp only [sshiftRightRec_succ_eq, and_twoPow, ih]
by_cases h : y.getLsb (n + 1)
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h,
sshiftRight'_or_of_and_eq_zero (by simp), h]
simp
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1)
(by simp [h])]
simp [h]

/--
Show that `x.sshiftRight y` can be written in terms of `sshiftRightRec`.
This can be unfolded in terms of `sshiftRightRec_zero_eq`, `sshiftRightRec_succ_eq` for bitblasting.
-/
theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) :
(x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by
rcases w₂ with rfl | w₂
· simp [of_length_zero]
· simp [sshiftRightRec_eq]

/- ### Logical shift right (ushiftRight) recurrence for bitblasting -/

/--
Expand Down
37 changes: 36 additions & 1 deletion src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) :
· rw [Nat.shiftRight_eq_div_pow]
apply Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega)

theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
@[simp] theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
getLsb (x.sshiftRight s) i =
(!decide (w ≤ i) && if s + i < w then x.getLsb (s + i) else x.msb) := by
rcases hmsb : x.msb with rfl | rfl
Expand All @@ -807,6 +807,41 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
Nat.not_lt, decide_eq_true_eq]
omega

/-- The msb after arithmetic shifting right equals the original msb. -/
theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} :
(x.sshiftRight n).msb = x.msb := by
rw [msb_eq_getLsb_last, getLsb_sshiftRight, msb_eq_getLsb_last]
by_cases hw₀ : w = 0
· simp [hw₀]
· simp only [show ¬(w ≤ w - 1) by omega, decide_False, Bool.not_false, Bool.true_and,
ite_eq_right_iff]
intros h
simp [show n = 0 by omega]

@[simp] theorem sshiftRight_zero {x : BitVec w} : x.sshiftRight 0 = x := by
ext i
simp

theorem sshiftRight_add {x : BitVec w} {m n : Nat} :
x.sshiftRight (m + n) = (x.sshiftRight m).sshiftRight n := by
ext i
simp only [getLsb_sshiftRight, Nat.add_assoc]
by_cases h₁ : w ≤ (i : Nat)
· simp [h₁]
· simp only [h₁, decide_False, Bool.not_false, Bool.true_and]
by_cases h₂ : n + ↑i < w
· simp [h₂]
· simp only [h₂, ↓reduceIte]
by_cases h₃ : m + (n + ↑i) < w
· simp [h₃]
omega
· simp [h₃, sshiftRight_msb_eq_msb]

/-! ### sshiftRight reductions from BitVec to Nat -/

@[simp]
theorem sshiftRight_eq' (x : BitVec w) : x.sshiftRight' y = x.sshiftRight y.toNat := rfl

/-! ### signExtend -/

/-- Equation theorem for `Int.sub` when both arguments are `Int.ofNat` -/
Expand Down

0 comments on commit e106be1

Please sign in to comment.