diff --git a/src/Lean/Meta/Injective.lean b/src/Lean/Meta/Injective.lean index 7fefdd21234f..ad098581e410 100644 --- a/src/Lean/Meta/Injective.lean +++ b/src/Lean/Meta/Injective.lean @@ -52,13 +52,13 @@ def elimOptParam (type : Expr) : CoreM Expr := do Instead of checking the type of every subterm, we only need to check the type of free variables, since free variables introduced in the constructor may only appear in the type of other free variables introduced after them. -/ -def occursOrInType (e : Expr) (t : Expr) : MetaM Bool := do - let_fun f (s : Expr) := do - if !s.isFVar then - return s == e - let ty ← inferType s - return s == e || e.occurs ty - return (← t.findM? f).isSome +def occursOrInType (lctx : LocalContext) (e : Expr) (t : Expr) : Bool := + t.find? go |>.isSome +where + go s := Id.run do + let .fvar fvarId := s | s == e + let some decl := lctx.find? fvarId | s == e + return s == e || e.occurs decl.type private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useEq : Bool) : MetaM (Option Expr) := do let us := ctorVal.levelParams.map mkLevelParam @@ -87,7 +87,7 @@ private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useE match (← whnf type) with | Expr.forallE n d b _ => let arg1 := args1.get ⟨i, h⟩ - if ← occursOrInType arg1 resultType then + if occursOrInType (← getLCtx) arg1 resultType then mkArgs2 (i + 1) (b.instantiate1 arg1) (args2.push arg1) args2New else withLocalDecl n (if useEq then BinderInfo.default else BinderInfo.implicit) d fun arg2 => diff --git a/src/Lean/Util/FindExpr.lean b/src/Lean/Util/FindExpr.lean index e9788ea76bf1..a93765a3cb8c 100644 --- a/src/Lean/Util/FindExpr.lean +++ b/src/Lean/Util/FindExpr.lean @@ -11,17 +11,17 @@ namespace Lean namespace Expr namespace FindImpl -unsafe abbrev FindM (m) := StateT (PtrSet Expr) m +unsafe abbrev FindM := StateT (PtrSet Expr) Id -@[inline] unsafe def checkVisited [Monad m] (e : Expr) : OptionT (FindM m) Unit := do +@[inline] unsafe def checkVisited (e : Expr) : OptionT FindM Unit := do if (← get).contains e then failure modify fun s => s.insert e -unsafe def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : OptionT (FindM m) Expr := +unsafe def findM? (p : Expr → Bool) (e : Expr) : OptionT FindM Expr := let rec visit (e : Expr) := do checkVisited e - if ← p e then + if p e then pure e else match e with | .forallE _ d b _ => visit d <|> visit b @@ -33,35 +33,29 @@ unsafe def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : OptionT (FindM m) | _ => failure visit e -unsafe def findUnsafeM? {m} [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) := - findM? p e |>.run' mkPtrSet - -@[inline] unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr := findUnsafeM? (m := Id) p e +unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr := + Id.run <| findM? p e |>.run' mkPtrSet end FindImpl -@[implemented_by FindImpl.findUnsafeM?] -/- This is a reference implementation for the unsafe one above -/ -def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) := do - if ← p e then - return some e - else match e with - | .forallE _ d b _ => findM? p d <||> findM? p b - | .lam _ d b _ => findM? p d <||> findM? p b - | .mdata _ b => findM? p b - | .letE _ t v b _ => findM? p t <||> findM? p v <||> findM? p b - | .app f a => findM? p f <||> findM? p a - | .proj _ _ b => findM? p b - | _ => pure none - @[implemented_by FindImpl.findUnsafe?] -def find? (p : Expr → Bool) (e : Expr) : Option Expr := findM? (m := Id) p e +def find? (p : Expr → Bool) (e : Expr) : Option Expr := + /- This is a reference implementation for the unsafe one above -/ + if p e then + some e + else match e with + | .forallE _ d b _ => find? p d <|> find? p b + | .lam _ d b _ => find? p d <|> find? p b + | .mdata _ b => find? p b + | .letE _ t v b _ => find? p t <|> find? p v <|> find? p b + | .app f a => find? p f <|> find? p a + | .proj _ _ b => find? p b + | _ => none /-- Return true if `e` occurs in `t` -/ def occurs (e : Expr) (t : Expr) : Bool := (t.find? fun s => s == e).isSome - /-- Return type for `findExt?` function argument. -/ @@ -72,7 +66,7 @@ inductive FindStep where namespace FindExtImpl -unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT (FindImpl.FindM Id) Expr := +unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT FindImpl.FindM Expr := visit e where visitApp (e : Expr) :=