Skip to content

Commit

Permalink
fix: fold raw Nat literals at dsimp (#3624)
Browse files Browse the repository at this point in the history
closes #2916

Remark: this PR also renames `Expr.natLit?` ==> `Expr.rawNatLit?`.
Motivation: consistent naming convention: `Expr.isRawNatLit`.
  • Loading branch information
leodemoura authored Mar 6, 2024
1 parent 46cc00d commit 5302b78
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/Lean/Expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def isRawNatLit : Expr → Bool
| lit (Literal.natVal _) => true
| _ => false

def natLit? : Expr → Option Nat
def rawNatLit? : Expr → Option Nat
| lit (Literal.natVal v) => v
| _ => none

Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Reduce.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ partial def reduce (e : Expr) (explicitOnly skipTypes skipProofs := true) : Meta
else
args ← args.modifyM i visit
if f.isConstOf ``Nat.succ && args.size == 1 && args[0]!.isRawNatLit then
return mkRawNatLit (args[0]!.natLit?.get! + 1)
return mkRawNatLit (args[0]!.rawNatLit?.get! + 1)
else
return mkAppN f args
| Expr.lam .. => lambdaTelescope e fun xs b => do mkLambdaFVars xs (← visit b)
Expand Down
48 changes: 33 additions & 15 deletions src/Lean/Meta/Tactic/Simp/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ def Config.updateArith (c : Config) : CoreM Config := do
def isOfNatNatLit (e : Expr) : Bool :=
e.isAppOfArity ``OfNat.ofNat 3 && e.appFn!.appArg!.isRawNatLit

/--
If `e` is a raw Nat literal and `OfNat.ofNat` is not in the list of declarations to unfold,
return an `OfNat.ofNat`-application.
-/
def foldRawNatLit (e : Expr) : SimpM Expr := do
match e.rawNatLit? with
| some n =>
/- If `OfNat.ofNat` is marked to be unfolded, we do not pack orphan nat literals as `OfNat.ofNat` applications
to avoid non-termination. See issue #788. -/
if (← readThe Simp.Context).isDeclToUnfold ``OfNat.ofNat then
return e
else
return toExpr n
| none => return e

private def reduceProjFn? (e : Expr) : SimpM (Option Expr) := do
matchConst e.getAppFn (fun _ => pure none) fun cinfo _ => do
match (← getProjectionFnInfo? cinfo.name) with
Expand Down Expand Up @@ -179,7 +194,7 @@ private def reduceStep (e : Expr) : SimpM Expr := do
trace[Meta.Tactic.simp.rewrite] "unfold {mkConst e.getAppFn.constName!}, {e} ==> {e'}"
recordSimpTheorem (.decl e.getAppFn.constName!)
return e'
| none => return e
| none => foldRawNatLit e

private partial def reduce (e : Expr) : SimpM Expr := withIncRecDepth do
let e' ← reduceStep e
Expand Down Expand Up @@ -233,17 +248,6 @@ def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := do
else
f

def simpLit (e : Expr) : SimpM Result := do
match e.natLit? with
| some n =>
/- If `OfNat.ofNat` is marked to be unfolded, we do not pack orphan nat literals as `OfNat.ofNat` applications
to avoid non-termination. See issue #788. -/
if (← readThe Simp.Context).isDeclToUnfold ``OfNat.ofNat then
return { expr := e }
else
return { expr := (← mkNumeral (mkConst ``Nat) n) }
| none => return { expr := e }

def simpProj (e : Expr) : SimpM Result := do
match (← reduceProj? e) with
| some e => return { expr := e }
Expand Down Expand Up @@ -406,13 +410,27 @@ private def dsimpReduce : DSimproc := fun e => do
eNew ← reduceFVar (← getConfig) (← getSimpTheorems) eNew
if eNew != e then return .visit eNew else return .done e

/--
Auliliary `dsimproc` for not visiting `OfNat.ofNat` application subterms.
This is the `dsimp` equivalent of the approach used at `visitApp`.
Recall that we fold orphan raw Nat literals.
-/
private def doNotVisitOfNat : DSimproc := fun e => do
if isOfNatNatLit e then
if (← readThe Simp.Context).isDeclToUnfold ``OfNat.ofNat then
return .continue e
else
return .done e
else
return .continue e

@[export lean_dsimp]
private partial def dsimpImpl (e : Expr) : SimpM Expr := do
let cfg ← getConfig
unless cfg.dsimp do
return e
let m ← getMethods
let pre := m.dpre
let pre := m.dpre >> doNotVisitOfNat
let post := m.dpost >> dsimpReduce
transform (usedLetOnly := cfg.zeta) e (pre := pre) (post := post)

Expand Down Expand Up @@ -533,7 +551,7 @@ def congr (e : Expr) : SimpM Result := do

def simpApp (e : Expr) : SimpM Result := do
if isOfNatNatLit e then
-- Recall that we expand "orphan" kernel nat literals `n` into `ofNat n`
-- Recall that we expand "orphan" kernel Nat literals `n` into `OfNat.ofNat n`
return { expr := e }
else
congr e
Expand All @@ -549,7 +567,7 @@ def simpStep (e : Expr) : SimpM Result := do
| .const .. => simpConst e
| .bvar .. => unreachable!
| .sort .. => return { expr := e }
| .lit .. => simpLit e
| .lit .. => return { expr := e }
| .mvar .. => return { expr := (← instantiateMVars e) }
| .fvar .. => return { expr := (← reduceFVar (← getConfig) (← getSimpTheorems) e) }

Expand Down
77 changes: 77 additions & 0 deletions tests/lean/run/2916.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
set_option pp.coercions false -- Show `OfNat.ofNat` when present for clarity

/--
warning: declaration uses 'sorry'
---
info: x : Nat
⊢ OfNat.ofNat 2 = x
-/
#guard_msgs in
example : nat_lit 2 = x := by
simp only
trace_state
sorry

/--
warning: declaration uses 'sorry'
---
info: x : Nat
⊢ OfNat.ofNat 2 = x
-/
#guard_msgs in
example : nat_lit 2 = x := by
dsimp only -- dsimp made no progress
trace_state
sorry

/--
warning: declaration uses 'sorry'
---
info: α : Nat → Type
f : (n : Nat) → α n
x : α (OfNat.ofNat 2)
⊢ f (OfNat.ofNat 2) = x
-/
#guard_msgs in
example (α : Nat → Type) (f : (n : Nat) → α n) (x : α 2) : f (nat_lit 2) = x := by
simp only
trace_state
sorry

/--
info: x : Nat
f : Nat → Nat
h : f (OfNat.ofNat 2) = x
⊢ f (OfNat.ofNat 2) = x
---
info: x : Nat
f : Nat → Nat
h : f (OfNat.ofNat 2) = x
⊢ f 2 = x
-/
#guard_msgs in
example (f : Nat → Nat) (h : f 2 = x) : f 2 = x := by
trace_state
simp [OfNat.ofNat]
trace_state
assumption

/--
warning: declaration uses 'sorry'
---
info: α : Nat → Type
f : (n : Nat) → α n
x : α (OfNat.ofNat 2)
⊢ f (OfNat.ofNat 2) = x
---
info: α : Nat → Type
f : (n : Nat) → α n
x : α (OfNat.ofNat 2)
⊢ f 2 = x
-/
#guard_msgs in
example (α : Nat → Type) (f : (n : Nat) → α n) (x : α 2) : f 2 = x := by
trace_state
simp [OfNat.ofNat]
trace_state
sorry
4 changes: 2 additions & 2 deletions tests/lean/run/maze.lean
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def extractXY : Lean.Expr → Lean.MetaM Coords
let sizeArgs := Lean.Expr.getAppArgs e'
let x ← Lean.Meta.whnf sizeArgs[0]!
let y ← Lean.Meta.whnf sizeArgs[1]!
let numCols := (Lean.Expr.natLit? x).get!
let numRows := (Lean.Expr.natLit? y).get!
let numCols := (Lean.Expr.rawNatLit? x).get!
let numRows := (Lean.Expr.rawNatLit? y).get!
return Coords.mk numCols numRows

partial def extractWallList : Lean.Expr → Lean.MetaM (List Coords)
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/run/meta2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ check t;
(match t.arrayLit? with
| some (_, xs) => do
checkM $ pure $ xs.length == 2;
(match (xs.get! 0).natLit?, (xs.get! 1).natLit? with
(match (xs.get! 0).rawNatLit?, (xs.get! 1).rawNatLit? with
| some 1, some 2 => pure ()
| _, _ => throwError "nat lits expected")
| none => throwError "array lit expected")
Expand Down

0 comments on commit 5302b78

Please sign in to comment.