Skip to content

Commit

Permalink
chore: Fin.ofNat' uses NeZero (#5356)
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-em authored Sep 16, 2024
1 parent 078e9b6 commit c25d206
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/Init/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ protected def ofNatLt {n : Nat} (i : Nat) (p : i < 2^n) : BitVec n where
/-- The `BitVec` with value `i mod 2^n`. -/
@[match_pattern]
protected def ofNat (n : Nat) (i : Nat) : BitVec n where
toFin := Fin.ofNat' i (Nat.two_pow_pos n)
toFin := Fin.ofNat' (2^n) i

instance instOfNat : OfNat (BitVec n) i where ofNat := .ofNat n i
instance natCastInst : NatCast (BitVec w) := ⟨BitVec.ofNat w⟩
Expand Down
6 changes: 3 additions & 3 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ theorem ofBool_eq_iff_eq : ∀ {b b' : Bool}, BitVec.ofBool b = BitVec.ofBool b'
@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (BitVec.ofNat w x).toNat = x % 2^w := by
simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat']

@[simp] theorem toFin_ofNat (x : Nat) : toFin (BitVec.ofNat w x) = Fin.ofNat' x (Nat.two_pow_pos w) := rfl
@[simp] theorem toFin_ofNat (x : Nat) : toFin (BitVec.ofNat w x) = Fin.ofNat' (2^w) x := rfl

-- Remark: we don't use `[simp]` here because simproc` subsumes it for literals.
-- If `x` and `n` are not literals, applying this theorem eagerly may not be a good idea.
Expand Down Expand Up @@ -832,7 +832,7 @@ theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl
BitVec.toNat_ofNat _ _

@[simp] theorem toFin_shiftLeft {n : Nat} (x : BitVec w) :
BitVec.toFin (x <<< n) = Fin.ofNat' (x.toNat <<< n) (Nat.two_pow_pos w) := rfl
BitVec.toFin (x <<< n) = Fin.ofNat' (2^w) (x.toNat <<< n) := rfl

@[simp]
theorem shiftLeft_zero_eq (x : BitVec w) : x <<< 0 = x := by
Expand Down Expand Up @@ -1526,7 +1526,7 @@ theorem ofNat_sub_ofNat {n} (x y : Nat) : BitVec.ofNat n x - BitVec.ofNat n y =
simp [Neg.neg, BitVec.neg]

@[simp] theorem toFin_neg (x : BitVec n) :
(-x).toFin = Fin.ofNat' (2^n - x.toNat) (Nat.two_pow_pos _) :=
(-x).toFin = Fin.ofNat' (2^n) (2^n - x.toNat) :=
rfl

theorem sub_toAdd {n} (x y : BitVec n) : x - y = x + - y := by
Expand Down
20 changes: 12 additions & 8 deletions src/Init/Data/Fin/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,28 @@ This differs from addition, which wraps around:
(2 : Fin 3) + 1 = (0 : Fin 3)
```
-/
def succ : Fin n → Fin n.succ
def succ : Fin n → Fin (n + 1)
| ⟨i, h⟩ => ⟨i+1, Nat.succ_lt_succ h⟩

variable {n : Nat}

/--
Returns `a` modulo `n + 1` as a `Fin n.succ`.
-/
protected def ofNat {n : Nat} (a : Nat) : Fin n.succ :=
protected def ofNat {n : Nat} (a : Nat) : Fin (n + 1) :=
⟨a % (n+1), Nat.mod_lt _ (Nat.zero_lt_succ _)⟩

/--
Returns `a` modulo `n` as a `Fin n`.
The assumption `n > 0` ensures that `Fin n` is nonempty.
The assumption `NeZero n` ensures that `Fin n` is nonempty.
-/
protected def ofNat' {n : Nat} (a : Nat) (h : n > 0) : Fin n :=
⟨a % n, Nat.mod_lt _ h⟩
protected def ofNat' (n : Nat) [NeZero n] (a : Nat) : Fin n :=
⟨a % n, Nat.mod_lt _ (pos_of_neZero n)⟩

-- We intend to deprecate `Fin.ofNat` in favor of `Fin.ofNat'` (and later rename).
-- This is waiting on https://github.com/leanprover/lean4/pull/5323
-- attribute [deprecated Fin.ofNat' (since := "2024-09-16")] Fin.ofNat

private theorem mlt {b : Nat} : {a : Nat} → a < n → b % n < n
| 0, h => Nat.mod_lt _ h
Expand Down Expand Up @@ -141,10 +145,10 @@ instance : ShiftLeft (Fin n) where
instance : ShiftRight (Fin n) where
shiftRight := Fin.shiftRight

instance instOfNat {n : Nat} [NeZero n] {i : Nat} : OfNat (Fin (no_index n)) i where
ofNat := Fin.ofNat' i (pos_of_neZero _)
instance instOfNat {n : Nat} [NeZero n] {i : Nat} : OfNat (Fin n) i where
ofNat := Fin.ofNat' n i

instance : Inhabited (Fin (no_index (n+1))) where
instance instInhabited {n : Nat} [NeZero n] : Inhabited (Fin n) where
default := 0

@[simp] theorem zero_eta : (⟨0, Nat.zero_lt_succ _⟩ : Fin (n + 1)) = 0 := rfl
Expand Down
22 changes: 11 additions & 11 deletions src/Init/Data/Fin/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ theorem eq_mk_iff_val_eq {a : Fin n} {k : Nat} {hk : k < n} :

theorem mk_val (i : Fin n) : (⟨i, i.isLt⟩ : Fin n) = i := Fin.eta ..

@[simp] theorem val_ofNat' (a : Nat) (is_pos : n > 0) :
(Fin.ofNat' a is_pos).val = a % n := rfl
@[simp] theorem val_ofNat' (n : Nat) [NeZero n] (a : Nat) :
(Fin.ofNat' n a).val = a % n := rfl

@[simp] theorem ofNat'_val_eq_self (x : Fin n) (h) : (Fin.ofNat' x h) = x := by
@[simp] theorem ofNat'_val_eq_self [NeZero n](x : Fin n) : (Fin.ofNat' n x) = x := by
ext
rw [val_ofNat', Nat.mod_eq_of_lt]
exact x.2
Expand Down Expand Up @@ -750,13 +750,13 @@ theorem addCases_right {m n : Nat} {motive : Fin (m + n) → Sort _} {left right

/-! ### add -/

@[simp] theorem ofNat'_add (x : Nat) (lt : 0 < n) (y : Fin n) :
Fin.ofNat' x lt + y = Fin.ofNat' (x + y.val) lt := by
@[simp] theorem ofNat'_add [NeZero n] (x : Nat) (y : Fin n) :
Fin.ofNat' n x + y = Fin.ofNat' n (x + y.val) := by
apply Fin.eq_of_val_eq
simp [Fin.ofNat', Fin.add_def]

@[simp] theorem add_ofNat' (x : Fin n) (y : Nat) (lt : 0 < n) :
x + Fin.ofNat' y lt = Fin.ofNat' (x.val + y) lt := by
@[simp] theorem add_ofNat' [NeZero n] (x : Fin n) (y : Nat) :
x + Fin.ofNat' n y = Fin.ofNat' n (x.val + y) := by
apply Fin.eq_of_val_eq
simp [Fin.ofNat', Fin.add_def]

Expand All @@ -765,13 +765,13 @@ theorem addCases_right {m n : Nat} {motive : Fin (m + n) → Sort _} {left right
protected theorem coe_sub (a b : Fin n) : ((a - b : Fin n) : Nat) = ((n - b) + a) % n := by
cases a; cases b; rfl

@[simp] theorem ofNat'_sub (x : Nat) (lt : 0 < n) (y : Fin n) :
Fin.ofNat' x lt - y = Fin.ofNat' ((n - y.val) + x) lt := by
@[simp] theorem ofNat'_sub [NeZero n] (x : Nat) (y : Fin n) :
Fin.ofNat' n x - y = Fin.ofNat' n ((n - y.val) + x) := by
apply Fin.eq_of_val_eq
simp [Fin.ofNat', Fin.sub_def]

@[simp] theorem sub_ofNat' (x : Fin n) (y : Nat) (lt : 0 < n) :
x - Fin.ofNat' y lt = Fin.ofNat' ((n - y % n) + x.val) lt := by
@[simp] theorem sub_ofNat' [NeZero n] (x : Fin n) (y : Nat) :
x - Fin.ofNat' n y = Fin.ofNat' n ((n - y % n) + x.val) := by
apply Fin.eq_of_val_eq
simp [Fin.ofNat', Fin.sub_def]

Expand Down
2 changes: 1 addition & 1 deletion src/Init/Data/Nat/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ protected theorem zero_ne_one : 0 ≠ (1 : Nat) :=

theorem succ_ne_zero (n : Nat) : succ n ≠ 0 := by simp

instance {n : Nat} : NeZero (succ n) := ⟨succ_ne_zero n⟩
instance instNeZeroSucc {n : Nat} : NeZero (n + 1) := ⟨succ_ne_zero n⟩

/-! # mul + order -/

Expand Down
8 changes: 7 additions & 1 deletion src/Init/Data/UInt/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,17 @@ instance (a b : UInt64) : Decidable (a ≤ b) := UInt64.decLe a b
instance : Max UInt64 := maxOfLe
instance : Min UInt64 := minOfLe

-- This instance would interfere with the global instance `NeZero (n + 1)`,
-- so we only enable it locally.
@[local instance]
private def instNeZeroUSizeSize : NeZero USize.size := ⟨add_one_ne_zero _⟩

@[deprecated (since := "2024-09-16")]
theorem usize_size_gt_zero : USize.size > 0 :=
Nat.zero_lt_succ ..

@[extern "lean_usize_of_nat"]
def USize.ofNat (n : @& Nat) : USize := ⟨Fin.ofNat' n usize_size_gt_zero
def USize.ofNat (n : @& Nat) : USize := ⟨Fin.ofNat' _ n
abbrev Nat.toUSize := USize.ofNat
@[extern "lean_usize_to_nat"]
def USize.toNat (n : USize) : Nat := n.val.val
Expand Down
5 changes: 3 additions & 2 deletions src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Fin.lean
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ builtin_dsimproc [simp, seval] reduceFinMk (Fin.mk _ _) := fun e => do
let_expr Fin.mk n v _ ← e | return .continue
let some n ← evalNat n |>.run | return .continue
let some v ← getNatValue? v | return .continue
if h : n > 0 then
return .done <| toExpr (Fin.ofNat' v h)
if h : n ≠ 0 then
have : NeZero n := ⟨h⟩
return .done <| toExpr (Fin.ofNat' n v)
else
return .continue

Expand Down
6 changes: 3 additions & 3 deletions tests/lean/run/generalizeMany.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ n : Nat
v : Fin n
n' : Nat
v' : Fin n'
h₁ : n.succ = n'
h₁ : n + 1 = n'
h₂ : HEq v.succ v'
⊢ p n' v'
-/
#guard_msgs in
example (p : (n : Nat) → Fin n → Prop)
(n : Nat)
(v : Fin n)
: p n.succ v.succ := by
generalize h₁ : n.succ = n', h₂ : v.succ = v'
: p (n + 1) v.succ := by
generalize h₁ : (n + 1) = n', h₂ : v.succ = v'
trace_state
admit

0 comments on commit c25d206

Please sign in to comment.