diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 44e1f7deb2cf..9b6ef5d9b66e 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -1332,6 +1332,14 @@ lambda expression. See docstring for `betaRev` for examples. def beta (f : Expr) (args : Array Expr) : Expr := betaRev f args.reverse +/-- +Count the number of lambdas at the head of the given expression. +-/ +def getNumHeadLambdas : Expr → Nat + | .lam _ _ b _ => getNumHeadLambdas b + 1 + | .mdata _ b => getNumHeadLambdas b + | _ => 0 + /-- Return true if the given expression is the function of an expression that is target for (head) beta reduction. If `useZeta = true`, then `let`-expressions are visited. That is, it assumes diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 3a5498261960..b7aa54fe4a42 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -194,6 +194,23 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do if eNew == e then return none trace[Meta.Tactic.simp.ground] "{e}\n---->\n{eNew}" return some eNew + let rec unfoldDeclToUnfold? : SimpM (Option Expr) := do + let options ← getOptions + let cfg ← getConfig + -- Support for issue #2042 + if cfg.unfoldPartialApp -- If we are unfolding partial applications, ignore issue #2042 + -- When smart unfolding is enabled, and `f` supports it, we don't need to worry about issue #2042 + || (smartUnfolding.get options && (← getEnv).contains (mkSmartUnfoldingNameFor fName)) then + withDefault <| unfoldDefinition? e + else + -- `We are not unfolding partial applications, and `fName` does not have smart unfolding support. + -- Thus, we must check whether the arity of the function >= number of arguments. + let some cinfo := (← getEnv).find? fName | return none + let some value := cinfo.value? | return none + let arity := value.getNumHeadLambdas + -- Partially applied function, return `none`. See issue #2042 + if arity > e.getAppNumArgs then return none + withDefault <| unfoldDefinition? e if let some eNew ← unfoldGround? then return some eNew else if (← isProjectionFn fName) then @@ -210,7 +227,7 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do else return none else if ctx.isDeclToUnfold fName then - withDefault <| unfoldDefinition? e + unfoldDeclToUnfold? else return none @@ -352,18 +369,18 @@ where simpStep (e : Expr) : M Result := do match e with - | Expr.mdata m e => let r ← simp e; return { r with expr := mkMData m r.expr } - | Expr.proj .. => simpProj e - | Expr.app .. => simpApp e - | Expr.lam .. => simpLambda e - | Expr.forallE .. => simpForall e - | Expr.letE .. => simpLet e - | Expr.const .. => simpConst e - | Expr.bvar .. => unreachable! - | Expr.sort .. => return { expr := e } - | Expr.lit .. => simpLit e - | Expr.mvar .. => return { expr := (← instantiateMVars e) } - | Expr.fvar .. => return { expr := (← reduceFVar (← getConfig) (← getSimpTheorems) e) } + | .mdata m e => let r ← simp e; return { r with expr := mkMData m r.expr } + | .proj .. => simpProj e + | .app .. => simpApp e + | .lam .. => simpLambda e + | .forallE .. => simpForall e + | .letE .. => simpLet e + | .const .. => simpConst e + | .bvar .. => unreachable! + | .sort .. => return { expr := e } + | .lit .. => simpLit e + | .mvar .. => return { expr := (← instantiateMVars e) } + | .fvar .. => return { expr := (← reduceFVar (← getConfig) (← getSimpTheorems) e) } simpLit (e : Expr) : M Result := do match e.natLit? with @@ -768,7 +785,7 @@ where return { expr := (← dsimp e) } simpLet (e : Expr) : M Result := do - let Expr.letE n t v b _ := e | unreachable! + let .letE n t v b _ := e | unreachable! if (← getConfig).zeta then return { expr := b.instantiate1 v } else diff --git a/tests/lean/run/2042.lean b/tests/lean/run/2042.lean new file mode 100644 index 000000000000..01fa29d12e15 --- /dev/null +++ b/tests/lean/run/2042.lean @@ -0,0 +1,21 @@ +@[simp] def foo (a : Nat) : Nat := + 2 * a + +example : foo = fun a => a + a := +by + fail_if_success simp -- should not unfold `foo` into a lambda + funext x + simp -- unfolds `foo` + trace_state + simp_arith + +@[simp] def boo : Nat → Nat + | a => 2 * a + +example : boo = fun a => a + a := +by + fail_if_success simp -- should not unfold `boo` into a lambda + funext x + simp -- unfolds `boo` + trace_state + simp_arith