Skip to content

Commit

Permalink
fix: fixes #2042
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura authored and kim-em committed Nov 4, 2023
1 parent 8cf9d13 commit 3f8c00b
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 14 deletions.
8 changes: 8 additions & 0 deletions src/Lean/Expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,14 @@ lambda expression. See docstring for `betaRev` for examples.
def beta (f : Expr) (args : Array Expr) : Expr :=
betaRev f args.reverse

/--
Count the number of lambdas at the head of the given expression.
-/
def getNumHeadLambdas : Expr → Nat
| .lam _ _ b _ => getNumHeadLambdas b + 1
| .mdata _ b => getNumHeadLambdas b
| _ => 0

/--
Return true if the given expression is the function of an expression that is target for (head) beta reduction.
If `useZeta = true`, then `let`-expressions are visited. That is, it assumes
Expand Down
45 changes: 31 additions & 14 deletions src/Lean/Meta/Tactic/Simp/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,23 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do
if eNew == e then return none
trace[Meta.Tactic.simp.ground] "{e}\n---->\n{eNew}"
return some eNew
let rec unfoldDeclToUnfold? : SimpM (Option Expr) := do
let options ← getOptions
let cfg ← getConfig
-- Support for issue #2042
if cfg.unfoldPartialApp -- If we are unfolding partial applications, ignore issue #2042
-- When smart unfolding is enabled, and `f` supports it, we don't need to worry about issue #2042
|| (smartUnfolding.get options && (← getEnv).contains (mkSmartUnfoldingNameFor fName)) then
withDefault <| unfoldDefinition? e
else
-- `We are not unfolding partial applications, and `fName` does not have smart unfolding support.
-- Thus, we must check whether the arity of the function >= number of arguments.
let some cinfo := (← getEnv).find? fName | return none
let some value := cinfo.value? | return none
let arity := value.getNumHeadLambdas
-- Partially applied function, return `none`. See issue #2042
if arity > e.getAppNumArgs then return none
withDefault <| unfoldDefinition? e
if let some eNew ← unfoldGround? then
return some eNew
else if (← isProjectionFn fName) then
Expand All @@ -210,7 +227,7 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do
else
return none
else if ctx.isDeclToUnfold fName then
withDefault <| unfoldDefinition? e
unfoldDeclToUnfold?
else
return none

Expand Down Expand Up @@ -352,18 +369,18 @@ where

simpStep (e : Expr) : M Result := do
match e with
| Expr.mdata m e => let r ← simp e; return { r with expr := mkMData m r.expr }
| Expr.proj .. => simpProj e
| Expr.app .. => simpApp e
| Expr.lam .. => simpLambda e
| Expr.forallE .. => simpForall e
| Expr.letE .. => simpLet e
| Expr.const .. => simpConst e
| Expr.bvar .. => unreachable!
| Expr.sort .. => return { expr := e }
| Expr.lit .. => simpLit e
| Expr.mvar .. => return { expr := (← instantiateMVars e) }
| Expr.fvar .. => return { expr := (← reduceFVar (← getConfig) (← getSimpTheorems) e) }
| .mdata m e => let r ← simp e; return { r with expr := mkMData m r.expr }
| .proj .. => simpProj e
| .app .. => simpApp e
| .lam .. => simpLambda e
| .forallE .. => simpForall e
| .letE .. => simpLet e
| .const .. => simpConst e
| .bvar .. => unreachable!
| .sort .. => return { expr := e }
| .lit .. => simpLit e
| .mvar .. => return { expr := (← instantiateMVars e) }
| .fvar .. => return { expr := (← reduceFVar (← getConfig) (← getSimpTheorems) e) }

simpLit (e : Expr) : M Result := do
match e.natLit? with
Expand Down Expand Up @@ -766,7 +783,7 @@ where
return { expr := (← dsimp e) }

simpLet (e : Expr) : M Result := do
let Expr.letE n t v b _ := e | unreachable!
let .letE n t v b _ := e | unreachable!
if (← getConfig).zeta then
return { expr := b.instantiate1 v }
else
Expand Down
21 changes: 21 additions & 0 deletions tests/lean/run/2042.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
@[simp] def foo (a : Nat) : Nat :=
2 * a

example : foo = fun a => a + a :=
by
fail_if_success simp -- should not unfold `foo` into a lambda
funext x
simp -- unfolds `foo`
trace_state
simp_arith

@[simp] def boo : Nat → Nat
| a => 2 * a

example : boo = fun a => a + a :=
by
fail_if_success simp -- should not unfold `boo` into a lambda
funext x
simp -- unfolds `boo`
trace_state
simp_arith

0 comments on commit 3f8c00b

Please sign in to comment.