Skip to content

Commit

Permalink
feat: assume function application arguments occurring in local simp t…
Browse files Browse the repository at this point in the history
…heorems have been annotated with `no_index` (#3406)

closes #2670
  • Loading branch information
leodemoura committed Feb 19, 2024
1 parent ca94124 commit 90b5a00
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 25 deletions.
15 changes: 15 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ v4.7.0 (development in progress)
rw [h]
```

* When adding new local theorems to `simp`, the system assumes that the function application arguments
have been annotated with `no_index`. This modification, which addresses issue [#2670](https://github.com/leanprover/lean4/issues/2670),
restores the Lean 3 behavior that users expect. With this modification, the following examples are now operational:
```
example {α β : Type} {f : α × β → β → β} (h : ∀ p : α × β, f p p.2 = p.2)
(a : α) (b : β) : f (a, b) b = b := by
simp [h]
example {α β : Type} {f : α × β → β → β}
(a : α) (b : β) (h : f (a,b) (a,b).2 = (a,b).2) : f (a, b) b = b := by
simp [h]
```
In both cases, `h` is applicable because `simp` does not index f-arguments anymore when adding `h` to the `simp`-set.
It's important to note, however, that global theorems continue to be indexed in the usual manner.

Breaking changes:
* `Lean.withTraceNode` and variants got a stronger `MonadAlwaysExcept` assumption to
fix trace trees not being built on elaboration runtime exceptions. Instances for most elaboration
Expand Down
47 changes: 35 additions & 12 deletions src/Lean/Meta/DiscrTree.lean
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,37 @@ def reduceDT (e : Expr) (root : Bool) (config : WhnfCoreConfig) : MetaM Expr :=

/- Remark: we use `shouldAddAsStar` only for nested terms, and `root == false` for nested terms -/

private def pushArgs (root : Bool) (todo : Array Expr) (e : Expr) (config : WhnfCoreConfig) : MetaM (Key × Array Expr) := do

/--
Append `n` wildcards to `todo`
-/
private def pushWildcards (n : Nat) (todo : Array Expr) : Array Expr :=
match n with
| 0 => todo
| n+1 => pushWildcards n (todo.push tmpStar)

/--
When `noIndexAtArgs := true`, `pushArgs` assumes function application arguments have a `no_index` annotation.
That is, `f a b` is indexed as it was `f (no_index a) (no_index b)`.
This feature is used when indexing local proofs in the simplifier. This is useful in examples like the one described on issue #2670.
In this issue, we have a local hypotheses `(h : ∀ p : α × β, f p p.2 = p.2)`, and users expect it to be applicable to
`f (a, b) b = b`. This worked in Lean 3 since no indexing was used. We can retrieve Lean 3 behavior by writing
`(h : ∀ p : α × β, f p (no_index p.2) = p.2)`, but this is very inconvenient when the hypotheses was not written by the user in first place.
For example, it was introduced by another tactic. Thus, when populating the discrimination tree explicit arguments provided to `simp` (e.g., `simp [h]`),
we use `noIndexAtArgs := true`. See comment: https://github.com/leanprover/lean4/issues/2670#issuecomment-1758889365
-/
private def pushArgs (root : Bool) (todo : Array Expr) (e : Expr) (config : WhnfCoreConfig) (noIndexAtArgs : Bool) : MetaM (Key × Array Expr) := do
if hasNoindexAnnotation e then
return (.star, todo)
else
let e ← reduceDT e root config
let fn := e.getAppFn
let push (k : Key) (nargs : Nat) (todo : Array Expr): MetaM (Key × Array Expr) := do
let info ← getFunInfoNArgs fn nargs
let todo ← pushArgsAux info.paramInfo (nargs-1) e todo
let todo ← if noIndexAtArgs then
pure <| pushWildcards nargs todo
else
pushArgsAux info.paramInfo (nargs-1) e todo
return (k, todo)
match fn with
| .lit v =>
Expand Down Expand Up @@ -378,22 +400,24 @@ private def pushArgs (root : Bool) (todo : Array Expr) (e : Expr) (config : Whnf
| _ =>
return (.other, todo)

partial def mkPathAux (root : Bool) (todo : Array Expr) (keys : Array Key) (config : WhnfCoreConfig) : MetaM (Array Key) := do
@[inherit_doc pushArgs]
partial def mkPathAux (root : Bool) (todo : Array Expr) (keys : Array Key) (config : WhnfCoreConfig) (noIndexAtArgs : Bool) : MetaM (Array Key) := do
if todo.isEmpty then
return keys
else
let e := todo.back
let todo := todo.pop
let (k, todo) ← pushArgs root todo e config
mkPathAux false todo (keys.push k) config
let (k, todo) ← pushArgs root todo e config noIndexAtArgs
mkPathAux false todo (keys.push k) config noIndexAtArgs

private def initCapacity := 8

def mkPath (e : Expr) (config : WhnfCoreConfig) : MetaM (Array Key) := do
@[inherit_doc pushArgs]
def mkPath (e : Expr) (config : WhnfCoreConfig) (noIndexAtArgs := false) : MetaM (Array Key) := do
withReducible do
let todo : Array Expr := .mkEmpty initCapacity
let keys : Array Key := .mkEmpty initCapacity
mkPathAux (root := true) (todo.push e) keys config
mkPathAux (root := true) (todo.push e) keys config noIndexAtArgs

private partial def createNodes (keys : Array Key) (v : α) (i : Nat) : Trie α :=
if h : i < keys.size then
Expand Down Expand Up @@ -447,22 +471,21 @@ def insertCore [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : DiscrTr
let c := insertAux keys v 1 c
{ root := d.root.insert k c }

def insert [BEq α] (d : DiscrTree α) (e : Expr) (v : α) (config : WhnfCoreConfig) : MetaM (DiscrTree α) := do
let keys ← mkPath e config
def insert [BEq α] (d : DiscrTree α) (e : Expr) (v : α) (config : WhnfCoreConfig) (noIndexAtArgs := false) : MetaM (DiscrTree α) := do
let keys ← mkPath e config noIndexAtArgs
return d.insertCore keys v

/--
Inserts a value into a discrimination tree,
but only if its key is not of the form `#[*]` or `#[=, *, *, *]`.
-/
def insertIfSpecific [BEq α] (d : DiscrTree α) (e : Expr) (v : α) (config : WhnfCoreConfig) : MetaM (DiscrTree α) := do
let keys ← mkPath e config
def insertIfSpecific [BEq α] (d : DiscrTree α) (e : Expr) (v : α) (config : WhnfCoreConfig) (noIndexAtArgs := false) : MetaM (DiscrTree α) := do
let keys ← mkPath e config noIndexAtArgs
return if keys == #[Key.star] || keys == #[Key.const `Eq 3, Key.star, Key.star, Key.star] then
d
else
d.insertCore keys v


private def getKeyArgs (e : Expr) (isMatch root : Bool) (config : WhnfCoreConfig) : MetaM (Key × Array Expr) := do
let e ← reduceDT e root config
unless root do
Expand Down
10 changes: 5 additions & 5 deletions src/Lean/Meta/Tactic/Simp/SimpTheorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,15 @@ private def checkTypeIsProp (type : Expr) : MetaM Unit :=
unless (← isProp type) do
throwError "invalid 'simp', proposition expected{indentExpr type}"

private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array Name) (proof : Expr) (post : Bool) (prio : Nat) : MetaM SimpTheorem := do
private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array Name) (proof : Expr) (post : Bool) (prio : Nat) (noIndexAtArgs : Bool) : MetaM SimpTheorem := do
assert! origin != .fvar ⟨.anonymous⟩
let type ← instantiateMVars (← inferType e)
withNewMCtxDepth do
let (_, _, type) ← withReducible <| forallMetaTelescopeReducing type
let type ← whnfR type
let (keys, perm) ←
match type.eq? with
| some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs simpDtConfig, ← isPerm lhs rhs)
| some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs simpDtConfig noIndexAtArgs, ← isPerm lhs rhs)
| none => throwError "unexpected kind of 'simp' theorem{indentExpr type}"
return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := (← isRflProof proof) }

Expand All @@ -341,10 +341,10 @@ private def mkSimpTheoremsFromConst (declName : Name) (post : Bool) (inv : Bool)
let mut r := #[]
for (val, type) in (← preprocess val type inv (isGlobal := true)) do
let auxName ← mkAuxLemma cinfo.levelParams type val
r := r.push <| (← mkSimpTheoremCore (.decl declName post inv) (mkConst auxName (cinfo.levelParams.map mkLevelParam)) #[] (mkConst auxName) post prio)
r := r.push <| (← mkSimpTheoremCore (.decl declName post inv) (mkConst auxName (cinfo.levelParams.map mkLevelParam)) #[] (mkConst auxName) post prio (noIndexAtArgs := false))
return r
else
return #[← mkSimpTheoremCore (.decl declName post inv) (mkConst declName (cinfo.levelParams.map mkLevelParam)) #[] (mkConst declName) post prio]
return #[← mkSimpTheoremCore (.decl declName post inv) (mkConst declName (cinfo.levelParams.map mkLevelParam)) #[] (mkConst declName) post prio (noIndexAtArgs := false)]

inductive SimpEntry where
| thm : SimpTheorem → SimpEntry
Expand Down Expand Up @@ -455,7 +455,7 @@ private def preprocessProof (val : Expr) (inv : Bool) : MetaM (Array Expr) := do
/-- Auxiliary method for creating simp theorems from a proof term `val`. -/
def mkSimpTheorems (id : Origin) (levelParams : Array Name) (proof : Expr) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM (Array SimpTheorem) :=
withReducible do
(← preprocessProof proof inv).mapM fun val => mkSimpTheoremCore id val levelParams val post prio
(← preprocessProof proof inv).mapM fun val => mkSimpTheoremCore id val levelParams val post prio (noIndexAtArgs := true)

/-- Auxiliary method for adding a local simp theorem to a `SimpTheorems` datastructure. -/
def SimpTheorems.add (s : SimpTheorems) (id : Origin) (levelParams : Array Name) (proof : Expr) (inv := false) (post := true) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
Expand Down
23 changes: 23 additions & 0 deletions tests/lean/run/2670.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
example {α β : Type} {f : α × β → β → β} (h : ∀ p : α × β, f p p.2 = p.2)
(a : α) (b : β) : f (a, b) b = b := by
simp [h] -- should not fail

example {α β : Type} {f : α × β → β → β}
(a : α) (b : β) (h : f (a,b) (a,b).2 = (a,b).2) : f (a, b) b = b := by
simp [h] -- should not fail

def enumFromTR' (n : Nat) (l : List α) : List (Nat × α) :=
let arr := l.toArray
(arr.foldr (fun a (n, acc) => (n-1, (n-1, a) :: acc)) (n + arr.size, [])).2

open List in
theorem enumFrom_eq_enumFromTR' : @enumFrom = @enumFromTR' := by
funext α n l; simp [enumFromTR', -Array.size_toArray]
let f := fun (a : α) (n, acc) => (n-1, (n-1, a) :: acc)
let rec go : ∀ l n, l.foldr f (n + l.length, []) = (n, enumFrom n l)
| [], n => rfl
| a::as, n => by
rw [← show _ + as.length = n + (a::as).length from Nat.succ_add .., List.foldr, go as]
simp [enumFrom, f]
rw [Array.foldr_eq_foldr_data]
simp [go] -- Should close the goal
4 changes: 2 additions & 2 deletions tests/lean/run/ac_rfl.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ instance : LawfulCommIdentity HMul.hMul 1 where right_id := Nat.mul_one
set_option linter.unusedVariables false in
theorem le_ext : ∀ {x y : Nat} (h : ∀ z, z ≤ x ↔ z ≤ y), x = y
| 0, 0, _ => rfl
| x+1, y+1, h => congrArg (· + 1) <| le_ext fun z => have := h (z + 1); by simp_all
| x+1, y+1, h => congrArg (· + 1) <| le_ext fun z => have := h (z + 1); by simp at this; assumption
| 0, y+1, h => have := h 1; by clear h; simp_all
| x+1, 0, h => have := h 1; by simp_all
| x+1, 0, h => have := h 1; by simp at this

theorem le_or_le : ∀ (a b : Nat), a ≤ b ∨ b ≤ a
| x+1, y+1 => by simp [le_or_le x y]
Expand Down
13 changes: 7 additions & 6 deletions tests/lean/run/meta3.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def add := mkAppN (mkConst `Add.add [levelZero]) #[nat, mkConst `Nat.add]
def tst1 : MetaM Unit :=
do let d : DiscrTree Nat := {};
let mvar ← mkFreshExprMVar nat;
let d ← d.insert (mkAppN add #[mvar, mkNatLit 10]) 1 {};
let d ← d.insert (mkAppN add #[mkNatLit 0, mkNatLit 10]) 2 {};
let d ← d.insert (mkAppN (mkConst `Nat.add) #[mkNatLit 0, mkNatLit 20]) 3 {};
let d ← d.insert (mkAppN add #[mvar, mkNatLit 20]) 4 {};
let d ← d.insert mvar 5 {};
let d ← d.insert (mkAppN add #[mvar, mkNatLit 10]) 1 {}
let d ← d.insert (mkAppN add #[mkNatLit 0, mkNatLit 10]) 2 {}
let d ← d.insert (mkAppN (mkConst `Nat.add) #[mkNatLit 0, mkNatLit 20]) 3 {}
let d ← d.insert (mkAppN add #[mvar, mkNatLit 20]) 4 {}
let d ← d.insert mvar 5 {}
print (format d);
let vs ← d.getMatch (mkAppN add #[mkNatLit 1, mkNatLit 10]) {};
print (format vs);
Expand All @@ -56,4 +56,5 @@ do let d : DiscrTree Nat := {};
print (format vs);
pure ()

#eval run #[`Init.Data.Nat] tst1
set_option trace.Meta.debug true in
run_meta tst1

0 comments on commit 90b5a00

Please sign in to comment.