Skip to content

Commit

Permalink
fix: cleanup type annotations in congruence theorems (#4185)
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura authored May 15, 2024
1 parent f636168 commit 8204b79
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/Lean/Meta/CongrTheorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ private def setBinderInfosD (ys : Array Expr) (lctx : LocalContext) : LocalConte

partial def mkHCongrWithArity (f : Expr) (numArgs : Nat) : MetaM CongrTheorem := do
let fType ← inferType f
forallBoundedTelescope fType numArgs fun xs _ =>
forallBoundedTelescope fType numArgs fun ys _ => do
forallBoundedTelescope fType numArgs (cleanupAnnotations := true) fun xs _ =>
forallBoundedTelescope fType numArgs (cleanupAnnotations := true) fun ys _ => do
if xs.size != numArgs then
throwError "failed to generate hcongr theorem, insufficient number of arguments"
else
Expand All @@ -80,8 +80,8 @@ where
if i < xs.size then
let x := xs[i]!
let y := ys[i]!
let xType := (← inferType x).consumeTypeAnnotations
let yType := (← inferType y).consumeTypeAnnotations
let xType := (← inferType x).cleanupAnnotations
let yType := (← inferType y).cleanupAnnotations
if xType == yType then
withLocalDeclD ((`e).appendIndexAfter (i+1)) (← mkEq x y) fun h =>
loop (i+1) (eqs.push h) (kinds.push CongrArgKind.eq)
Expand All @@ -98,9 +98,9 @@ where
else if let some (_, lhs, _, _) := type.heq? then
mkHEqRefl lhs
else
forallBoundedTelescope type (some 1) fun a type =>
forallBoundedTelescope type (some 1) (cleanupAnnotations := true) fun a type =>
let a := a[0]!
forallBoundedTelescope type (some 1) fun b motive =>
forallBoundedTelescope type (some 1) (cleanupAnnotations := true) fun b motive =>
let b := b[0]!
let type := type.bindingBody!.instantiate1 a
withLocalDeclD motive.bindingName! motive.bindingDomain! fun eqPr => do
Expand Down Expand Up @@ -159,7 +159,7 @@ private def hasCastLike (kinds : Array CongrArgKind) : Bool :=
kinds.any fun kind => kind matches CongrArgKind.cast || kind matches CongrArgKind.subsingletonInst

private def withNext (type : Expr) (k : Expr → Expr → MetaM α) : MetaM α := do
forallBoundedTelescope type (some 1) fun xs type => k xs[0]! type
forallBoundedTelescope type (some 1) (cleanupAnnotations := true) fun xs type => k xs[0]! type

/--
Test whether we should use `subsingletonInst` kind for instances which depend on `eq`.
Expand All @@ -182,7 +182,7 @@ private def getClassSubobjectMask? (f : Expr) : MetaM (Option (Array Bool)) := d
let .const declName _ := f | return none
let .ctorInfo val ← getConstInfo declName | return none
unless isClass (← getEnv) val.induct do return none
forallTelescopeReducing val.type fun xs _ => do
forallTelescopeReducing val.type (cleanupAnnotations := true) fun xs _ => do
let env ← getEnv
let mut mask := #[]
for i in [:xs.size] do
Expand Down Expand Up @@ -255,7 +255,7 @@ where
mk? (f : Expr) (info : FunInfo) (kinds : Array CongrArgKind) : MetaM (Option CongrTheorem) := do
try
let fType ← inferType f
forallBoundedTelescope fType kinds.size fun lhss _ => do
forallBoundedTelescope fType kinds.size (cleanupAnnotations := true) fun lhss _ => do
if lhss.size != kinds.size then return none
let rec go (i : Nat) (rhss : Array Expr) (eqs : Array (Option Expr)) (hyps : Array Expr) : MetaM CongrTheorem := do
if i == kinds.size then
Expand Down
51 changes: 51 additions & 0 deletions tests/lean/run/congrThm2.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import Lean

open Lean Meta

def genHCongr (declName : Name) : MetaM Unit := do
let info ← getConstInfo declName
let thm ← mkHCongr (mkConst declName <| info.levelParams.map Level.param)
IO.println (← ppExpr thm.type)

def genCongr (declName : Name) : MetaM Unit := do
let info ← getConstInfo declName
let some thm ← mkCongrSimp? (mkConst declName <| info.levelParams.map Level.param) | return ()
IO.println (← ppExpr thm.type)

/--
info: ∀ (coll coll' : Type u),
coll = coll' →
∀ (idx idx' : Type v),
idx = idx' →
∀ (elem elem' : Type w),
elem = elem' →
∀ (valid : coll → idx → Prop) (valid' : coll' → idx' → Prop),
HEq valid valid' →
∀ (self : GetElem coll idx elem valid) (self' : GetElem coll' idx' elem' valid'),
HEq self self' →
∀ (xs : coll) (xs' : coll'),
HEq xs xs' →
∀ (i : idx) (i' : idx'),
HEq i i' → ∀ (h : valid xs i) (h' : valid' xs' i'), HEq h h' → HEq xs[i] xs'[i']
-/
#guard_msgs in
#eval genHCongr ``GetElem.getElem

/--
info: ∀ {coll : Type u} {idx : Type v} {elem : Type w} {valid : coll → idx → Prop} [self : GetElem coll idx elem valid]
(xs xs_1 : coll) (e_xs : xs = xs_1) (i i_1 : idx) (e_i : i = i_1) (h : valid xs i), xs[i] = xs_1[i_1]
-/
#guard_msgs in
#eval genCongr ``GetElem.getElem

def f (x := 0) (_ : x = x := by rfl) := x + 1

/--
info: ∀ (x x' : Nat), x = x' → ∀ (x_1 : x = x) (x'_1 : x' = x'), HEq x_1 x'_1 → HEq (f x x_1) (f x' x'_1)
-/
#guard_msgs in
#eval genHCongr ``f

/-- info: ∀ (x x_1 : Nat) (e_x : x = x_1) (x_2 : x = x), f x x_2 = f x_1 ⋯ -/
#guard_msgs in
#eval genCongr ``f

0 comments on commit 8204b79

Please sign in to comment.