Skip to content

Commit

Permalink
different fix
Browse files Browse the repository at this point in the history
  • Loading branch information
llllvvuu committed Jun 10, 2024
1 parent 8890d9b commit 80affa4
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 32 deletions.
53 changes: 30 additions & 23 deletions src/Lean/Meta/ExprDefEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ Note that `fun xs => v` is the term `fun (x1 : Nat) (x2 : b = x1) => v` which ha
and may not even be type correct.
The issue here is that we are not capturing the `let`-declarations.
This method collects let-declarations `y` occurring before `xs.back` s.t.
This method collects let-declarations `y` occurring between `xs[0]` and `xs.back` s.t.
some `x` in `xs` depends on `y`.
`ys` is the `xs` with these extra let-declarations included.
Expand Down Expand Up @@ -444,9 +444,9 @@ where
return check (← getLCtx)

/-- Traverse `e` and stores in the state `NameHashSet` any let-declaration with index greater than `(← read)`.
The context `Nat` is the position of `ys[0]` in the local context. -/
collectLetDeclsFrom (e : Expr) : StateRefT (Nat × FVarIdHashSet) MetaM Unit := do
let rec visit (e : Expr) : MonadCacheT Expr Unit (StateRefT (Nat × FVarIdHashSet) MetaM) Unit :=
The context `Nat` is the position of `xs[0]` in the local context. -/
collectLetDeclsFrom (e : Expr) : ReaderT Nat (StateRefT FVarIdHashSet MetaM) Unit := do
let rec visit (e : Expr) : MonadCacheT Expr Unit (ReaderT Nat (StateRefT FVarIdHashSet MetaM)) Unit :=
checkCache e fun _ => do
match e with
| .forallE _ d b _ => visit d; visit b
Expand All @@ -456,50 +456,57 @@ where
| .mdata _ b => visit b
| .proj _ _ b => visit b
| .fvar fvarId =>
match ← fvarId.getDecl with
| .ldecl (index := index) (fvarId := fvarId) .. =>
modify fun s => (s.1.min index, s.2.insert fvarId)
| _ => pure ()
let localDecl ← fvarId.getDecl
if localDecl.isLet && localDecl.index > (← read) then
modify fun s => s.insert localDecl.fvarId
| _ => pure ()
visit (← instantiateMVars e) |>.run

/--
Traverse all declarations between `ys[0]` ... `xs.back` backwards.
Auxiliary definition for traversing all declarations between `xs[0]` ... `xs.back` backwards.
The `Nat` argument is the current position in the local context being visited, and it is less than
or equal to the position of `xs.back` in the local context.
The `Nat` context `(← get).1` is the position of `ys[0]` in the local context.
The `Nat` context `(← read)` is the position of `xs[0]` in the local context.
-/
collectLetDeps : Nat → StateRefT (Nat × FVarIdHashSet) MetaM Unit
collectLetDepsAux : Nat → ReaderT Nat (StateRefT FVarIdHashSet MetaM) Unit
| 0 => return ()
| i+1 => do
if i+2 == (← get).1 then
if i+1 == (← read) then
return ()
else
match (← getLCtx).getAt? (i+1) with
| none => collectLetDeps i
| none => collectLetDepsAux i
| some localDecl =>
if (← get).2.contains localDecl.fvarId then
if (← get).contains localDecl.fvarId then
collectLetDeclsFrom localDecl.type
match localDecl.value? with
| some val => collectLetDeclsFrom val
| _ => pure ()
collectLetDeps i
collectLetDepsAux i

/-- Computes the array `ys` containing let-decls that
some `x` in `xs` depends on, either directly or transitively. -/
addLetDeps : MetaM (Array Expr) := do
/-- Computes the set `ys`. It is a set of `FVarId`s, -/
collectLetDeps : MetaM FVarIdHashSet := do
let lctx ← getLCtx
let start := lctx.getFVar! xs[0]! |>.index
let stop := lctx.getFVar! xs.back |>.index
let s := (start, xs.foldl (init := {}) fun s x => s.insert x.fvarId!)
let (_, s) ← collectLetDeps stop |>.run s
let stop := lctx.getFVar! xs.back |>.index
let s := xs.foldl (init := {}) fun s x => s.insert x.fvarId!
let (_, s) ← collectLetDepsAux stop |>.run start |>.run s
return s

/-- Computes the array `ys` containing let-decls between `xs[0]` and `xs.back` that
some `x` in `xs` depends on. -/
addLetDeps : MetaM (Array Expr) := do
let lctx ← getLCtx
let s ← collectLetDeps
/- Convert `s` into the array `ys` -/
let start := lctx.getFVar! xs[0]! |>.index
let stop := lctx.getFVar! xs.back |>.index
let mut ys := #[]
for i in [s.1:stop+1] do
for i in [start:stop+1] do
match lctx.getAt? i with
| none => pure ()
| some localDecl =>
if s.2.contains localDecl.fvarId then
if s.contains localDecl.fvarId then
ys := ys.push localDecl.toExpr
return ys

Expand Down
8 changes: 6 additions & 2 deletions src/Lean/MetavarContext.lean
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,12 @@ partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLi
let wasMVar := f.isMVar
let f ← instantiateExprMVars f
if wasMVar && f.isLambda then
/- Some of the arguments in args are irrelevant after we beta reduce. -/
instantiateExprMVars (f.betaRev args.reverse)
/- Some of the arguments in `args` are irrelevant after we beta
reduce. Also, it may be a bug to not instantiate them, since they
may depend on free variables that are not in the context (see
issue #4375). So we pass `useZeta := true` to ensure that they are
instantiated. -/
instantiateExprMVars (f.betaRev args.reverse (useZeta := true))
else
instArgs f
match f with
Expand Down
12 changes: 5 additions & 7 deletions tests/lean/4375.lean
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
example: Nat :=
let n: Nat := Nat.zero
let a: Nat := n
have: True := trivial
let b: Nat := Nat.zero
have: a = b := Eq.refl a
(fun (_ : Nat) => 0) Nat.zero
example: True → Nat :=
let a : Nat := Nat.zero
fun (_ : True) =>
let b : Nat := Nat.zero
(fun (_ : a = b) => 0) (Eq.refl a)

0 comments on commit 80affa4

Please sign in to comment.