Skip to content

Commit

Permalink
feat: isolate fixed prefix at well-founded recursion
Browse files Browse the repository at this point in the history
closes #1017
  • Loading branch information
leodemoura committed Feb 18, 2022
1 parent 75e771b commit e61d0be
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 30 deletions.
10 changes: 10 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,13 @@ example (a : α) (x : Fam α) : α :=
```

* We now use `PSum` (instead of `Sum`) when compiling mutually recursive definitions using well-founded recursion.

* Better support for parametric well-founded relations. See [issue #1017](https://github.com/leanprover/lean4/issues/1017). This change affects the low-level `termination_by'` hint because the fixed prefix of the function parameters in not "packed" anymore when constructing the well-founded relation type. For example, in the following definition, `as` is part of the fixed prefix, and is not packed anymore. In previous versions, the `termination_by'` term would be written as `measure fun ⟨as, i, _⟩ => as.size - i`
```lean
def sum (as : Array Nat) (i : Nat) (s : Nat) : Nat :=
if h : i < as.size then
sum as (i+1) (s + as.get ⟨i, h⟩)
else
s
termination_by' measure fun ⟨i, _⟩ => as.size - i
```
20 changes: 9 additions & 11 deletions src/Lean/Elab/PreDefinition/WF/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ namespace Lean.Elab
open WF
open Meta

private def isOnlyOneUnaryDef (preDefs : Array PreDefinition) : MetaM Bool := do
private def isOnlyOneUnaryDef (preDefs : Array PreDefinition) (fixedPrefixSize : Nat) : MetaM Bool := do
if preDefs.size == 1 then
lambdaTelescope preDefs[0].value fun xs _ => return xs.size == 1
lambdaTelescope preDefs[0].value fun xs _ => return xs.size == fixedPrefixSize + 1
else
return false

private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonRec : PreDefinition) : TermElabM Unit := do
if (← isOnlyOneUnaryDef preDefs) then
private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonRec : PreDefinition) (fixedPrefixSize : Nat) : TermElabM Unit := do
if (← isOnlyOneUnaryDef preDefs fixedPrefixSize) then
return ()
let Expr.forallE _ domain _ _ := preDefNonRec.type | unreachable!
let us := preDefNonRec.levelParams.map mkLevelParam
for fidx in [:preDefs.size] do
let preDef := preDefs[fidx]
let value ← lambdaTelescope preDef.value fun xs _ => do
let packedArgs : Array Expr := xs[fixedPrefixSize:]
let mkProd (type : Expr) : MetaM Expr := do
mkUnaryArg type xs
mkUnaryArg type packedArgs
let rec mkSum (i : Nat) (type : Expr) : MetaM Expr := do
if i == preDefs.size - 1 then
mkProd type
Expand All @@ -42,8 +42,9 @@ private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonR
else
let r ← mkSum (i+1) args[1]
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0] args[1] r
let Expr.forallE _ domain _ _ := (← instantiateForall preDefNonRec.type xs[:fixedPrefixSize]) | unreachable!
let arg ← mkSum 0 domain
mkLambdaFVars xs (mkApp (mkConst preDefNonRec.declName us) arg)
mkLambdaFVars xs (mkApp (mkAppN (mkConst preDefNonRec.declName us) xs[:fixedPrefixSize]) arg)
trace[Elab.definition.wf] "{preDef.declName} := {value}"
addNonRec { preDef with value } (applyAttrAfterCompilation := false)

Expand Down Expand Up @@ -80,14 +81,11 @@ def getFixedPrefix (preDefs : Array PreDefinition) : TermElabM Nat :=
def wfRecursion (preDefs : Array PreDefinition) (wf? : Option TerminationWF) (decrTactic? : Option Syntax) : TermElabM Unit := do
let fixedPrefixSize ← getFixedPrefix preDefs
trace[Elab.definition.wf] "fixed prefix: {fixedPrefixSize}"
let fixedPrefixSize := 0 -- TODO: remove after we convert all code in this module to use the fixedPrefix
let unaryPreDef ← withoutModifyingEnv do
for preDef in preDefs do
addAsAxiom preDef
let unaryPreDefs ← packDomain fixedPrefixSize preDefs
-- unaryPreDefs.forM fun d => do trace[Elab.definition.wf] "packDomain result {d.declName} := {d.value}"; check d.value
packMutual fixedPrefixSize unaryPreDefs
-- trace[Elab.definition.wf] "after packMutual {unaryPreDef.declName} := {unaryPreDef.value}"
let preDefNonRec ← forallBoundedTelescope unaryPreDef.type fixedPrefixSize fun prefixArgs type => do
let packedArgType := type.bindingDomain!
let wfRel ← elabWFRel preDefs unaryPreDef.declName fixedPrefixSize packedArgType wf?
Expand All @@ -100,7 +98,7 @@ def wfRecursion (preDefs : Array PreDefinition) (wf? : Option TerminationWF) (de
trace[Elab.definition.wf] ">> {preDefNonRec.declName} :=\n{preDefNonRec.value}"
let preDefs ← preDefs.mapM fun d => eraseRecAppSyntax d
addNonRec preDefNonRec
addNonRecPreDefs preDefs preDefNonRec
addNonRecPreDefs preDefs preDefNonRec fixedPrefixSize
registerEqnsInfo preDefs preDefNonRec.declName
for preDef in preDefs do
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
Expand Down
10 changes: 8 additions & 2 deletions src/Lean/Elab/PreDefinition/WF/Rel.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,22 @@ private partial def unpackMutual (preDefs : Array PreDefinition) (mvarId : MVarI
private partial def unpackUnary (preDef : PreDefinition) (prefixSize : Nat) (mvarId : MVarId) (fvarId : FVarId) (element : TerminationByElement) : TermElabM MVarId := do
let varNames ← lambdaTelescope preDef.value fun xs body => do
let mut varNames ← xs.mapM fun x => return (← getLocalDecl x.fvarId!).userName
if element.vars.size > varNames.size - prefixSize then
if element.vars.size > varNames.size then
throwErrorAt element.vars[varNames.size] "too many variable names"
for i in [:element.vars.size] do
let varStx := element.vars[i]
if varStx.isIdent then
varNames := varNames.set! (varNames.size - element.vars.size + i) varStx.getId
return varNames
let mut mvarId := mvarId
for localDecl in (← Term.getMVarDecl mvarId).lctx, varName in varNames[:prefixSize] do
unless localDecl.userName == varName do
mvarId ← rename mvarId localDecl.fvarId varName
let numPackedArgs := varNames.size - prefixSize
let rec go (i : Nat) (mvarId : MVarId) (fvarId : FVarId) : TermElabM MVarId := do
trace[Elab.definition.wf] "i: {i}, varNames: {varNames}, goal: {mvarId}"
if i < numPackedArgs - 1 then
let #[s] ← cases mvarId fvarId #[{ varNames := [varNames[i]] }] | unreachable!
let #[s] ← cases mvarId fvarId #[{ varNames := [varNames[prefixSize + i]] }] | unreachable!
go (i+1) s.mvarId s.fields[1].fvarId!
else
rename mvarId fvarId varNames.back
Expand All @@ -52,6 +57,7 @@ def elabWFRel (preDefs : Array PreDefinition) (unaryPreDefName : Name) (fixedPre
let α := argType
let u ← getLevel α
let expectedType := mkApp (mkConst ``WellFoundedRelation [u]) α
trace[Elab.definition.wf] "elabWFRel start: {(← mkFreshTypeMVar).mvarId!}"
match wf? with
| some (TerminationWF.core wfStx) => withDeclName unaryPreDefName do
let wfRel ← instantiateMVars (← withSynthesize <| elabTermEnsuringType wfStx expectedType)
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/letRecMissingAnnotation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ def sum (as : Array Nat) : Nat :=
else
s
go 0 0
termination_by' measure (fun ⟨a, i, _⟩ => a.size - i)
termination_by' measure (fun ⟨i, _⟩ => as.size - i)
55 changes: 55 additions & 0 deletions tests/lean/run/1017.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
namespace Stream

variable [Stream ρ τ] (s : ρ)

def take (s : ρ) : Nat → List τ × ρ
| 0 => ([], s)
| n+1 =>
match next? s with
| none => ([], s)
| some (x,rest) =>
let (L,rest) := take rest n
(x::L, rest)

def isEmpty : Bool :=
Option.isNone (next? s)

def lengthBoundedBy (n : Nat) : Prop :=
isEmpty (take s n).2

def hasNext : ρ → ρ → Prop
:= λ s1 s2 => ∃ x, next? s1 = some ⟨x,s2⟩

def isFinite : Prop :=
∃ n, lengthBoundedBy s n

instance hasNextWF : WellFoundedRelation {s : ρ // isFinite s} where
rel := λ s1 s2 => hasNext s2.val s1.val
wf := ⟨λ ⟨s,h⟩ => ⟨⟨s,h⟩, by
simp
cases h; case intro w h =>
induction w generalizing s
case zero =>
intro ⟨s',h'⟩ h_next
simp [hasNext] at h_next
cases h_next; case intro x h_next =>
simp [lengthBoundedBy, isEmpty, Option.isNone, take, h_next] at h
case succ n ih =>
intro ⟨s',h'⟩ h_next
simp [hasNext] at h_next
cases h_next; case intro x h_next =>
simp [lengthBoundedBy, take, h_next] at h
have := ih s' h
exact Acc.intro (⟨s',h'⟩ : {s : ρ // isFinite s}) this
⟩⟩

def mwe [Stream ρ τ] (acc : α) : {l : ρ // isFinite l} → α
| ⟨l,h⟩ =>
match h:next? l with
| none => acc
| some (x,xs) =>
have h_next : hasNext l xs := by exists x; simp [hasNext, h]
mwe acc ⟨xs, by sorry
termination_by _ l => l

end Stream
6 changes: 3 additions & 3 deletions tests/lean/run/mutwf1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ end
termination_by'
invImage
(fun
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
| PSum.inl ⟨n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, n⟩ => (n, 0))
(Prod.lex sizeOfWFRel sizeOfWFRel)

#print f
Expand Down
12 changes: 6 additions & 6 deletions tests/lean/run/mutwf3.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ end
termination_by'
invImage
(fun
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
| PSum.inl ⟨n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, n⟩ => (n, 0))
(Prod.lex sizeOfWFRel sizeOfWFRel)
decreasing_by
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]
Expand Down Expand Up @@ -51,9 +51,9 @@ end
termination_by'
invImage
(fun
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
| PSum.inl ⟨n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, n⟩ => (n, 0))
(Prod.lex sizeOfWFRel sizeOfWFRel)

#print f._unary._mutual
Expand Down
6 changes: 6 additions & 0 deletions tests/lean/run/renameFixedPrefix.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def f (as : Array Nat) (hsz : as.size > 0) (i : Nat) : Nat :=
if h : i < as.size then
as.get ⟨i, h⟩ + f as hsz (i + 1)
else
0
termination_by f a h i => a.size - i
4 changes: 2 additions & 2 deletions tests/lean/run/wfEqns2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ end
termination_by'
invImage
(fun
| PSum.inl ⟨_, n⟩ => (n, 0)
| PSum.inr ⟨_, n⟩ => (n, 1))
| PSum.inl n => (n, 0)
| PSum.inr n => (n, 1))
(Prod.lex sizeOfWFRel sizeOfWFRel)
decreasing_by
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]
Expand Down
6 changes: 3 additions & 3 deletions tests/lean/run/wfEqns4.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ end
termination_by'
invImage
(fun
| PSum.inl ⟨_, n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, _, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, _, n⟩ => (n, 0))
| PSum.inl ⟨n, _, _⟩ => (n, 2)
| PSum.inr <| PSum.inl ⟨_, n, _⟩ => (n, 1)
| PSum.inr <| PSum.inr ⟨_, _, n⟩ => (n, 0))
(Prod.lex sizeOfWFRel sizeOfWFRel)
decreasing_by
simp [invImage, InvImage, Prod.lex, sizeOfWFRel, measure, Nat.lt_wfRel, WellFoundedRelation.rel]
Expand Down
9 changes: 9 additions & 0 deletions tests/lean/run/wfSum.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def sum (as : Array Nat) : Nat :=
go 0 0
where
go (i : Nat) (s : Nat) : Nat :=
if h : i < as.size then
go (i+1) (s + as.get ⟨i, h⟩)
else
s
termination_by' measure (fun ⟨i, s⟩ => as.size - i)
4 changes: 2 additions & 2 deletions tests/lean/substBadMotive.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Ex3
def heapifyDown (a : Array α) (i : Fin a.size) : Array α :=
have : i < i := sorry
heapifyDown a ⟨i.1, a.size_swap i i ▸ i.2-- Error, failed to compute motive, `subst` is not applicable here
termination_by' measure fun ⟨_, a, i⟩ => i.1
termination_by' measure fun i => i.1
decreasing_by assumption
end Ex3

Expand All @@ -34,6 +34,6 @@ def heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : Arra
heapifyDown lt a' ⟨j.1.1, a.size_swap i j ▸ j.1.2-- Error, failed to compute motive, `subst` is not applicable here
else
a
termination_by' measure fun_, _, a, i⟩ => a.size - i.1
termination_by' measure fun ⟨a, i⟩ => a.size - i.1
decreasing_by assumption
end Ex4

0 comments on commit e61d0be

Please sign in to comment.