Skip to content

Commit

Permalink
fix: PackMutual: Eta-Expand as needed
Browse files Browse the repository at this point in the history
The `packMutual` code ought to reliably replace all recursive calls to
the functions in `preDefs`, even when they are under- or over-applied.
Therefore eta-expand if need rsp. keep extra arguments around.

Needs a tweak to `Meta.transform` to avoid mistaking the `f` in
`f x1 x2` as a zero-arity application.

Includes a test case.

This fixes #2628 and #2883.
  • Loading branch information
nomeata committed Nov 17, 2023
1 parent cafff2a commit 8ecf2e1
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 23 deletions.
55 changes: 35 additions & 20 deletions src/Lean/Elab/PreDefinition/WF/PackMutual.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,36 +90,51 @@ private partial def packValues (x : Expr) (codomain : Expr) (preDefValues : Arra
go mvar.mvarId! x.fvarId! 0
instantiateMVars mvar

/--
Pass the first `n` arguments of `e` to the continuation, and apply the result to the
remaining arguments. If `e` does not have enough arguments, it is eta-expanded as needed.
-/
def withAppN (n : Nat) (e : Expr) (k : Array Expr → MetaM Expr) : MetaM Expr := do
let args := e.getAppArgs
if n ≤ args.size then
let e' ← k args[:n]
return mkAppN e' args[n:]
else
withDefault do -- TODO: Copied from `etaExpand`. Needed? Harmful?
let missing := n - args.size
forallBoundedTelescope (← inferType e) missing fun xs _ => do
if xs.size < missing then
throwError "Failed to eta-expand partial application"
let e' ← k (args ++ xs)
mkLambdaFVars xs e'

/--
Auxiliary function for replacing nested `preDefs` recursive calls in `e` with the new function `newFn`.
See: `packMutual`
-/
private partial def post (fixedPrefix : Nat) (preDefs : Array PreDefinition) (domain : Expr) (newFn : Name) (e : Expr) : MetaM TransformStep := do
if e.getAppNumArgs < fixedPrefix + 1 then
return TransformStep.done e
let f := e.getAppFn
if !f.isConst then
return TransformStep.done e
let declName := f.constName!
let us := f.constLevels!
if let some fidx := preDefs.findIdx? (·.declName == declName) then
let args := e.getAppArgs
let fixedArgs := args[:fixedPrefix]
let arg := args[fixedPrefix]!
let remaining := args[fixedPrefix+1:]
let rec mkNewArg (i : Nat) (type : Expr) : MetaM Expr := do
if i == preDefs.size - 1 then
return arg
else
(← whnfD type).withApp fun f args => do
assert! args.size == 2
if i == fidx then
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! arg
else
let r ← mkNewArg (i+1) args[1]!
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r
return TransformStep.done <|
mkAppN (mkApp (mkAppN (mkConst newFn us) fixedArgs) (← mkNewArg 0 domain)) remaining
let e' ← withAppN (fixedPrefix + 1) e fun args => do
let fixedArgs := args[:fixedPrefix]
let arg := args[fixedPrefix]!
let rec mkNewArg (i : Nat) (type : Expr) : MetaM Expr := do
if i == preDefs.size - 1 then
return arg
else
(← whnfD type).withApp fun f args => do
assert! args.size == 2
if i == fidx then
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! arg
else
let r ← mkNewArg (i+1) args[1]!
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r
return mkApp (mkAppN (mkConst newFn us) fixedArgs) (← mkNewArg 0 domain)
return TransformStep.done e'
return TransformStep.done e

partial def withFixedPrefix (fixedPrefix : Nat) (preDefs : Array PreDefinition) (k : Array Expr → Array Expr → Array Expr → MetaM α) : MetaM α :=
Expand Down Expand Up @@ -185,7 +200,7 @@ def packMutual (fixedPrefix : Nat) (preDefsOriginal : Array PreDefinition) (preD
let newFn := preDefs[0]!.declName ++ `_mutual
let preDefNew := { preDefs[0]! with declName := newFn, type, value }
addAsAxiom preDefNew
let value ← transform value (post := post fixedPrefix preDefs domain newFn)
let value ← transform value (constIsApp := true) (post := post fixedPrefix preDefs domain newFn)
let value ← mkLambdaFVars (ys.push x) value
return { preDefNew with value }

Expand Down
13 changes: 11 additions & 2 deletions src/Lean/Meta/Transform.lean
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,18 @@ namespace Meta

/--
Similar to `Core.transform`, but terms provided to `pre` and `post` do not contain loose bound variables.
So, it is safe to use any `MetaM` method at `pre` and `post`. -/
So, it is safe to use any `MetaM` method at `pre` and `post`.
If `constIsApp := true`, then for an expression `mkAppN (.const f) args`, the subexpression
`.const f` is not visited again. Put differently: every `.const f` is visited once, with its
arguments.
-/
partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m] [MonadTrace m] [MonadRef m] [MonadOptions m] [AddMessageContext m]
(input : Expr)
(pre : Expr → m TransformStep := fun _ => return .continue)
(post : Expr → m TransformStep := fun e => return .done e)
(usedLetOnly := false)
(constIsApp := false)
: m Expr := do
let _ : STWorld IO.RealWorld m := ⟨⟩
let _ : MonadLiftT (ST IO.RealWorld) m := { monadLift := fun x => liftM (m := MetaM) (liftM (m := ST IO.RealWorld) x) }
Expand Down Expand Up @@ -109,7 +115,10 @@ partial def transform {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m]
| e => visitPost (← mkLetFVars (usedLetOnly := usedLetOnly) fvars (← visit (e.instantiateRev fvars)))
let visitApp (e : Expr) : MonadCacheT ExprStructEq Expr m Expr :=
e.withApp fun f args => do
visitPost (mkAppN (← visit f) (← args.mapM visit))
if constIsApp && f.isConst then
visitPost (mkAppN f (← args.mapM visit))
else
visitPost (mkAppN (← visit f) (← args.mapM visit))
match (← pre e) with
| .done e => pure e
| .visit e => visit e
Expand Down
72 changes: 72 additions & 0 deletions tests/lean/run/issue2628.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/-!
Test that PackMutual isn't confused when a recursive call has more arguments than is packed
into the unary argument, which can happen if the retturn type is a function type.
-/

namespace Ex1
mutual
def foo : Nat → Nat
| .zero => 0
| .succ n => (id bar) n
def bar : Nat → Nat
| .zero => 0
| .succ n => foo n
end
termination_by foo n => n; bar n => n
decreasing_by sorry

end Ex1

-- Same for n-ary functions

opaque id' : ∀ {α}, α → α := id

namespace Ex2

mutual
def foo : Nat → Nat → Nat
| .zero, _m => 0
| .succ n, .zero => (id' (bar n)) .zero
| .succ n, m => (id' bar) n m
def bar : Nat → Nat → Nat
| .zero, _m => 0
| .succ n, m => foo n m
end
termination_by foo n m => (n,m); bar n m => (n,m)
decreasing_by sorry

end Ex2

-- With extra arguments

namespace Ex3
mutual
def foo : Nat → Nat → Nat
| .zero => fun _ => 0
| .succ n => fun m => (id bar) n m
def bar : Nat → Nat → Nat
| .zero => fun _ => 0
| .succ n => fun m => foo n m
end
termination_by foo n => n; bar n => n
decreasing_by sorry

end Ex3

-- n-ary and with extra arguments

namespace Ex4

mutual
def foo : Nat → Nat → Nat → Nat
| .zero, _m => fun _ => 0
| .succ n, .zero => fun k => (id' (bar n)) .zero k
| .succ n, m => fun k => (id' bar) n m k
def bar : Nat → Nat → Nat → Nat
| .zero, _m => fun _ => 0
| .succ n, m => fun k => foo n m k
end
termination_by foo n m => (n,m); bar n m => (n,m)
decreasing_by sorry

end Ex4
2 changes: 1 addition & 1 deletion tests/lean/run/issue2883.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/-!
Test that PackMutual isn't confused when a recursive call has more arguments than is packed
into the unary argument, which can happen if the retturn type is a function type.
into the unary argument, which can happen if the return type is a function type.
-/

mutual
Expand Down

0 comments on commit 8ecf2e1

Please sign in to comment.