Skip to content

Commit

Permalink
feat: add Simp.Config.index (#4202)
Browse files Browse the repository at this point in the history
The `simp` tactic uses a discrimination tree to select candidate
theorems that will be used to rewrite an expression. This indexing data
structure minimizes the number of theorems that need to be tried and
improves performance. However, indexing modulo reducibility is
challenging, and a theorem that could be applied, when taking reduction
into account, may be missed. For example, suppose we have a `simp`
theorem `foo : forall x y, f x (x, y).2 = y`, and we are trying to
simplify the expression `f a b <= b`. `foo` will not be tried by `simp`
because the second argument of `f a b` is not a projection of a pair.
However, `f a b` is definitionally equal to `f a (a, b).2` since we can
reduce `(a, b).2`.

In Lean 3, we had a much simpler indexing data structure where only the
head symbol was taken into account. For the theorem `foo`, the head
symbol is `f`. Thus, the theorem would be considered by `simp`.

This commit adds the option `Simp.Config.index`. When `simp (config := {
index := false })`, only the head symbol is considered when retrieving
theorems, as in Lean 3. Moreover, if `set_option diagnostics true`,
`simp` will check whether every applied theorem would also have been
applied if `index := true`, and report them. This feature can help users
diagnose tricky issues in code that has been ported from libraries
developed using Lean 3 and then ported to Lean 4. In the following
example, it will report that `foo` is a problematic theorem.

```lean
opaque f : Nat → Nat → Nat

@[simp] theorem foo : f x (x, y).2 = y := by sorry

example : f a b ≤ b := by
  set_option diagnostics true in
  simp (config := { index := false })
```

In the example above, the following diagnostic message is produced.
```lean
[simp] theorems with bad keys
    foo, key: [f, *, Prod.1, Prod.mk, Nat, Nat, *, *]
```

With the information above, users can annotate theorems such as `foo`
using `no_index` for problematic subterms.
Example:
```lean
opaque f : Nat → Nat → Nat

@[simp] theorem foo : f x (no_index (x, y).2) = y := by sorry

example : f a b ≤ b := by
  simp -- `foo` is still applied
```

cc @semorrison 
cc @PatrickMassot
  • Loading branch information
leodemoura authored May 17, 2024
1 parent 1382e9f commit ee0bcc8
Show file tree
Hide file tree
Showing 6 changed files with 938 additions and 29 deletions.
5 changes: 5 additions & 0 deletions src/Init/MetaTypes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ structure Config where
That is, given a local context containing entry `x : t := e`, the free variable `x` reduces to `e`.
-/
zetaDelta : Bool := false
/--
When `index` (default : `true`) is `false`, `simp` will only use the root symbol
to find candidate `simp` theorems. It approximates Lean 3 `simp` behavior.
-/
index : Bool := true
deriving Inhabited, BEq

-- Configuration object for `simp_all`
Expand Down
49 changes: 49 additions & 0 deletions src/Lean/Meta/DiscrTree.lean
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,55 @@ where
else
return result

/--
Return the root symbol for `e`, and the number of arguments after `reduceDT`.
-/
def getMatchKeyRootFor (e : Expr) (config : WhnfCoreConfig) : MetaM (Key × Nat) := do
let e ← reduceDT e (root := true) config
let numArgs := e.getAppNumArgs
let key := match e.getAppFn with
| .lit v => .lit v
| .fvar fvarId => .fvar fvarId numArgs
| .mvar _ => .other
| .proj s i _ .. => .proj s i numArgs
| .forallE .. => .arrow
| .const c _ =>
-- This method is used by the simplifier only, we do **not** support
-- (← getConfig).isDefEqStuckEx
.const c numArgs
| _ => .other
return (key, numArgs)

/--
Get all results under key `k`.
-/
private partial def getAllValuesForKey (d : DiscrTree α) (k : Key) (result : Array α) : Array α :=
match d.root.find? k with
| none => result
| some trie => go trie result
where
go (trie : Trie α) (result : Array α) : Array α := Id.run do
match trie with
| .node vs cs =>
let mut result := result ++ vs
for (_, trie) in cs do
result := go trie result
return result

/--
A liberal version of `getMatch` which only takes the root symbol of `e` into account.
We use this method to simulate Lean 3's indexing.
The natural number in the result is the number of arguments in `e` after `reduceDT`.
-/
def getMatchLiberal (d : DiscrTree α) (e : Expr) (config : WhnfCoreConfig) : MetaM (Array α × Nat) := do
withReducible do
let result := getStarResult d
let (k, numArgs) ← getMatchKeyRootFor e config
match k with
| .star => return (result, numArgs)
| _ => return (getAllValuesForKey d k result, numArgs)

partial def getUnify (d : DiscrTree α) (e : Expr) (config : WhnfCoreConfig) : MetaM (Array α) :=
withReducible do
let (k, args) ← getUnifyKeyArgs e (root := true) config
Expand Down
33 changes: 24 additions & 9 deletions src/Lean/Meta/Tactic/Simp/Diagnostics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ import Lean.Meta.Tactic.Simp.Types

namespace Lean.Meta.Simp

private def originToKey (thmId : Origin) : MetaM MessageData := do
match thmId with
| .decl declName _ _ =>
if (← getEnv).contains declName then
pure m!"{MessageData.ofConst (← mkConstWithLevelParams declName)}"
else
pure m!"{declName} (builtin simproc)"
| .fvar fvarId => pure m!"{mkFVar fvarId}"
| _ => pure thmId.key

def mkSimpDiagSummary (counters : PHashMap Origin Nat) (usedCounters? : Option (PHashMap Origin Nat) := none) : MetaM DiagSummary := do
let threshold := diagnostics.threshold.get (← getOptions)
let entries := collectAboveThreshold counters threshold (fun _ => true) (lt := (· < ·))
Expand All @@ -17,31 +27,36 @@ def mkSimpDiagSummary (counters : PHashMap Origin Nat) (usedCounters? : Option (
else
let mut data := #[]
for (thmId, counter) in entries do
let key ← match thmId with
| .decl declName _ _ =>
if (← getEnv).contains declName then
pure m!"{MessageData.ofConst (← mkConstWithLevelParams declName)}"
else
pure m!"{declName} (builtin simproc)"
| .fvar fvarId => pure m!"{mkFVar fvarId}"
| _ => pure thmId.key
let key ← originToKey thmId
let usedMsg ← if let some usedCounters := usedCounters? then
if let some c := usedCounters.find? thmId then pure s!", succeeded: {c}" else pure s!" {crossEmoji}" -- not used
else
pure ""
data := data.push m!"{if data.isEmpty then " " else "\n"}{key} ↦ {counter}{usedMsg}"
return { data, max := entries[0]!.2 }

private def mkTheoremsWithBadKeySummary (thms : PArray SimpTheorem) : MetaM DiagSummary := do
if thms.isEmpty then
return {}
else
let mut data := #[]
for thm in thms do
data := data.push m!"{if data.isEmpty then " " else "\n"}{← originToKey thm.origin}, key: {thm.keys.map (·.format)}"
pure ()
return { data }

def reportDiag (diag : Simp.Diagnostics) : MetaM Unit := do
if (← isDiagnosticsEnabled) then
let used ← mkSimpDiagSummary diag.usedThmCounter
let tried ← mkSimpDiagSummary diag.triedThmCounter diag.usedThmCounter
let congr ← mkDiagSummary diag.congrThmCounter
unless used.isEmpty && tried.isEmpty && congr.isEmpty do
let thmsWithBadKeys ← mkTheoremsWithBadKeySummary diag.thmsWithBadKeys
unless used.isEmpty && tried.isEmpty && congr.isEmpty && thmsWithBadKeys.isEmpty do
let m := MessageData.nil
let m := appendSection m `simp "used theorems" used
let m := appendSection m `simp "tried theorems" tried
let m := appendSection m `simp "tried congruence theorems" congr
let m := appendSection m `simp "theorems with bad keys" thmsWithBadKeys (resultSummary := false)
let m := m ++ "use `set_option diagnostics.threshold <num>` to control threshold for reporting counters"
logInfo m

Expand Down
65 changes: 54 additions & 11 deletions src/Lean/Meta/Tactic/Simp/Rewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,62 @@ def tryTheorem? (e : Expr) (thm : SimpTheorem) : SimpM (Option Result) := do
Remark: the parameter tag is used for creating trace messages. It is irrelevant otherwise.
-/
def rewrite? (e : Expr) (s : SimpTheoremTree) (erased : PHashSet Origin) (tag : String) (rflOnly : Bool) : SimpM (Option Result) := do
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
if candidates.isEmpty then
trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}"
return none
if (← getConfig).index then
rewriteUsingIndex?
else
let candidates := candidates.insertionSort fun e₁ e₂ => e₁.1.priority > e₂.1.priority
for (thm, numExtraArgs) in candidates do
unless inErasedSet thm || (rflOnly && !thm.rfl) do
if let some result ← tryTheoremWithExtraArgs? e thm numExtraArgs then
trace[Debug.Meta.Tactic.simp] "rewrite result {e} => {result.expr}"
return some result
return none
rewriteNoIndex?
where
/-- For `(← getConfig).index := true`, use discrimination tree structure when collecting `simp` theorem candidates. -/
rewriteUsingIndex? : SimpM (Option Result) := do
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
if candidates.isEmpty then
trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}"
return none
else
let candidates := candidates.insertionSort fun e₁ e₂ => e₁.1.priority > e₂.1.priority
for (thm, numExtraArgs) in candidates do
unless inErasedSet thm || (rflOnly && !thm.rfl) do
if let some result ← tryTheoremWithExtraArgs? e thm numExtraArgs then
trace[Debug.Meta.Tactic.simp] "rewrite result {e} => {result.expr}"
return some result
return none

/--
For `(← getConfig).index := false`, Lean 3 style `simp` theorem retrieval.
Only the root symbol is taken into account. Most of the structure of the discrimination tree is ignored.
-/
rewriteNoIndex? : SimpM (Option Result) := do
let (candidates, numArgs) ← s.getMatchLiberal e (getDtConfig (← getConfig))
if candidates.isEmpty then
trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}"
return none
else
let candidates := candidates.insertionSort fun e₁ e₂ => e₁.priority > e₂.priority
for thm in candidates do
unless inErasedSet thm || (rflOnly && !thm.rfl) do
let result? ← withNewMCtxDepth do
let val ← thm.getValue
let type ← inferType val
let (xs, bis, type) ← forallMetaTelescopeReducing type
let type ← whnf (← instantiateMVars type)
let lhs := type.appFn!.appArg!
let lhsNumArgs := lhs.getAppNumArgs
tryTheoremCore lhs xs bis val type e thm (numArgs - lhsNumArgs)
if let some result := result? then
trace[Debug.Meta.Tactic.simp] "rewrite result {e} => {result.expr}"
diagnoseWhenNoIndex thm
return some result
return none

diagnoseWhenNoIndex (thm : SimpTheorem) : SimpM Unit := do
if (← isDiagnosticsEnabled) then
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
for (candidate, _) in candidates do
if unsafe ptrEq thm candidate then
return ()
-- `thm` would not have been applied if `index := true`
recordTheoremWithBadKeys thm

inErasedSet (thm : SimpTheorem) : Bool :=
erased.contains thm.origin

Expand Down
33 changes: 24 additions & 9 deletions src/Lean/Meta/Tactic/Simp/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ structure Diagnostics where
triedThmCounter : PHashMap Origin Nat := {}
/-- Number of times each congr theorem has been tried. -/
congrThmCounter : PHashMap Name Nat := {}
/--
When using `Simp.Config.index := false`, and `set_option diagnostics true`,
for every theorem used by `simp`, we check whether the theorem would be
also applied if `index := true`, and we store it here if it would not have
been tried.
-/
thmsWithBadKeys : PArray SimpTheorem := {}
deriving Inhabited

structure State where
Expand Down Expand Up @@ -325,14 +332,14 @@ Save current cache, reset it, execute `x`, and then restore original cache.
withReader (fun r => { MethodsRef.toMethods r with discharge?, wellBehavedDischarge }.toMethodsRef) x

def recordTriedSimpTheorem (thmId : Origin) : SimpM Unit := do
modifyDiag fun { usedThmCounter, triedThmCounter, congrThmCounter } =>
let cNew := if let some c := triedThmCounter.find? thmId then c + 1 else 1
{ usedThmCounter, triedThmCounter := triedThmCounter.insert thmId cNew, congrThmCounter }
modifyDiag fun s =>
let cNew := if let some c := s.triedThmCounter.find? thmId then c + 1 else 1
{ s with triedThmCounter := s.triedThmCounter.insert thmId cNew }

def recordSimpTheorem (thmId : Origin) : SimpM Unit := do
modifyDiag fun { usedThmCounter, triedThmCounter, congrThmCounter } =>
let cNew := if let some c := usedThmCounter.find? thmId then c + 1 else 1
{ usedThmCounter := usedThmCounter.insert thmId cNew, triedThmCounter, congrThmCounter }
modifyDiag fun s =>
let cNew := if let some c := s.usedThmCounter.find? thmId then c + 1 else 1
{ s with usedThmCounter := s.usedThmCounter.insert thmId cNew }
/-
If `thmId` is an equational theorem (e.g., `foo.eq_1`), we should record `foo` instead.
See issue #3547.
Expand All @@ -353,9 +360,17 @@ def recordSimpTheorem (thmId : Origin) : SimpM Unit := do
{ s with usedTheorems := s.usedTheorems.insert thmId n }

def recordCongrTheorem (declName : Name) : SimpM Unit := do
modifyDiag fun { usedThmCounter, triedThmCounter, congrThmCounter } =>
let cNew := if let some c := congrThmCounter.find? declName then c + 1 else 1
{ congrThmCounter := congrThmCounter.insert declName cNew, triedThmCounter, usedThmCounter }
modifyDiag fun s =>
let cNew := if let some c := s.congrThmCounter.find? declName then c + 1 else 1
{ s with congrThmCounter := s.congrThmCounter.insert declName cNew }

def recordTheoremWithBadKeys (thm : SimpTheorem) : SimpM Unit := do
modifyDiag fun s =>
-- check whether it is already there
if unsafe s.thmsWithBadKeys.any fun thm' => ptrEq thm thm' then
s
else
{ s with thmsWithBadKeys := s.thmsWithBadKeys.push thm }

def Result.getProof (r : Result) : MetaM Expr := do
match r.proof? with
Expand Down
Loading

0 comments on commit ee0bcc8

Please sign in to comment.