Skip to content

Commit

Permalink
feat: same equational lemmas for recusive and non-recursive functions
Browse files Browse the repository at this point in the history
This is part of #3983.

After #4154 introduced equational lemmas for non-recursive functions and #5055
unififed the lemmas for structural and wf recursive funcitons, this now
disables the special handling of recursive functions in
`findMatchToSplit?`, so that the equational lemmas should be the same no
matter how the function was defined.

The new option `eqns.deepRecursiveSplit` can be disabled to get the old
behavior.

This can break existing code, as there now can be extra equational
lemmas:

* Explicit uses of `f.eq_2` might have to be adjusted if the numbering
  changed.

* Uses of `rw [f]` or `simp [f]` may no longer apply if they previously
  matched (and introduced a `match` statement), when the equational
  lemmas got more fine-grained.

  In this case either case analysis on the parameters before rewriting
  helps, or setting the option `opt.deepRecursiveSplit false` while
  defining the function
  • Loading branch information
nomeata committed Aug 22, 2024
1 parent 6d646c6 commit 17ed410
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
19 changes: 10 additions & 9 deletions src/Lean/Elab/PreDefinition/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def simpIf? (mvarId : MVarId) : MetaM (Option MVarId) := do
let mvarId' ← simpIfTarget mvarId (useDecide := true)
if mvarId != mvarId' then return some mvarId' else return none

private def findMatchToSplit? (env : Environment) (e : Expr) (declNames : Array Name) (exceptionSet : ExprSet) : Option Expr :=
private def findMatchToSplit? (deepRecursiveSplit : Bool) (env : Environment) (e : Expr)
(declNames : Array Name) (exceptionSet : ExprSet) : Option Expr :=
e.findExt? fun e => Id.run do
if e.hasLooseBVars || exceptionSet.contains e then
return Expr.FindStep.visit
Expand All @@ -78,9 +79,11 @@ private def findMatchToSplit? (env : Environment) (e : Expr) (declNames : Array
-- For non-recursive functions (`declNames` empty), we split here
if declNames.isEmpty then
return Expr.FindStep.found
-- For recursive functions we only split when at least one alternatives contains a `declNames`
-- For recursive functions, the “new” behavior is to likewise split
if deepRecursiveSplit then
return Expr.FindStep.found
-- Else, the “old” behavior is split only when at least one alternative contains a `declNames`
-- application with loose bound variables.
-- (We plan to disable this by default and treat recursive and non-recursie functions the same)
for i in [info.getFirstAltPos : info.getFirstAltPos + info.numAlts] do
let alt := args[i]!
if Option.isSome <| alt.find? fun e => declNames.any e.isAppOf && e.hasLooseBVars then
Expand All @@ -97,7 +100,8 @@ private def findMatchToSplit? (env : Environment) (e : Expr) (declNames : Array
partial def splitMatch? (mvarId : MVarId) (declNames : Array Name) : MetaM (Option (List MVarId)) := commitWhenSome? do
let target ← mvarId.getType'
let rec go (badCases : ExprSet) : MetaM (Option (List MVarId)) := do
if let some e := findMatchToSplit? (← getEnv) target declNames badCases then
if let some e := findMatchToSplit? (eqns.deepRecursiveSplit.get (← getOptions)) (← getEnv)
target declNames badCases then
try
Meta.Split.splitMatch mvarId e
catch _ =>
Expand All @@ -107,9 +111,6 @@ partial def splitMatch? (mvarId : MVarId) (declNames : Array Name) : MetaM (Opti
return none
go {}

structure Context where
declNames : Array Name

private def lhsDependsOn (type : Expr) (fvarId : FVarId) : MetaM Bool :=
forallTelescope type fun _ type => do
if let some (_, lhs, _) ← matchEq? type then
Expand Down Expand Up @@ -234,10 +235,10 @@ private def shouldUseSimpMatch (e : Expr) : MetaM Bool := do
return (← (find e).run) matches .error _

partial def mkEqnTypes (declNames : Array Name) (mvarId : MVarId) : MetaM (Array Expr) := do
let (_, eqnTypes) ← go mvarId |>.run { declNames } |>.run #[]
let (_, eqnTypes) ← go mvarId |>.run #[]
return eqnTypes
where
go (mvarId : MVarId) : ReaderT Context (StateRefT (Array Expr) MetaM) Unit := do
go (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := do
trace[Elab.definition.eqns] "mkEqnTypes step\n{MessageData.ofGoal mvarId}"

if let some mvarId ← expandRHS? mvarId then
Expand Down
11 changes: 10 additions & 1 deletion src/Lean/Meta/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ register_builtin_option eqns.nonrecursive : Bool := {
descr := "Create fine-grained equational lemmas even for non-recursive definitions."
}

register_builtin_option eqns.deepRecursiveSplit : Bool := {
defValue := true
descr := "Create equational lemmas for recursive functions like for non-recursive \
functions. If disabled, match statements in recursive function definitions \
that do not contain recursive calls do not cause further splits in the \
equational lemmas. This was the behavior before Lean 4.12, and the purpose of \
this option is to help migrating old code."
}


/--
These options affect the generation of equational theorems in a significant way. For these, their
Expand All @@ -26,7 +35,7 @@ This is implemented by
* eagerly realizing the equations when they are set to a non-default vaule
* when realizing them lazily, reset the options to their default
-/
def eqnAffectingOptions : Array (Lean.Option Bool) := #[eqns.nonrecursive]
def eqnAffectingOptions : Array (Lean.Option Bool) := #[eqns.nonrecursive, eqns.deepRecursiveSplit]

/--
Environment extension for storing which declarations are recursive.
Expand Down

0 comments on commit 17ed410

Please sign in to comment.