diff --git a/src/Lean/Util/FindExpr.lean b/src/Lean/Util/FindExpr.lean index a93765a3cb8c..e9788ea76bf1 100644 --- a/src/Lean/Util/FindExpr.lean +++ b/src/Lean/Util/FindExpr.lean @@ -11,17 +11,17 @@ namespace Lean namespace Expr namespace FindImpl -unsafe abbrev FindM := StateT (PtrSet Expr) Id +unsafe abbrev FindM (m) := StateT (PtrSet Expr) m -@[inline] unsafe def checkVisited (e : Expr) : OptionT FindM Unit := do +@[inline] unsafe def checkVisited [Monad m] (e : Expr) : OptionT (FindM m) Unit := do if (← get).contains e then failure modify fun s => s.insert e -unsafe def findM? (p : Expr → Bool) (e : Expr) : OptionT FindM Expr := +unsafe def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : OptionT (FindM m) Expr := let rec visit (e : Expr) := do checkVisited e - if p e then + if ← p e then pure e else match e with | .forallE _ d b _ => visit d <|> visit b @@ -33,29 +33,35 @@ unsafe def findM? (p : Expr → Bool) (e : Expr) : OptionT FindM Expr := | _ => failure visit e -unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr := - Id.run <| findM? p e |>.run' mkPtrSet +unsafe def findUnsafeM? {m} [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) := + findM? p e |>.run' mkPtrSet + +@[inline] unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr := findUnsafeM? (m := Id) p e end FindImpl -@[implemented_by FindImpl.findUnsafe?] -def find? (p : Expr → Bool) (e : Expr) : Option Expr := - /- This is a reference implementation for the unsafe one above -/ - if p e then - some e +@[implemented_by FindImpl.findUnsafeM?] +/- This is a reference implementation for the unsafe one above -/ +def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) := do + if ← p e then + return some e else match e with - | .forallE _ d b _ => find? p d <|> find? p b - | .lam _ d b _ => find? p d <|> find? p b - | .mdata _ b => find? p b - | .letE _ t v b _ => find? p t <|> find? p v <|> find? p b - | .app f a => find? p f <|> find? p a - | .proj _ _ b => find? p b - | _ => none + | .forallE _ d b _ => findM? p d <||> findM? p b + | .lam _ d b _ => findM? p d <||> findM? p b + | .mdata _ b => findM? p b + | .letE _ t v b _ => findM? p t <||> findM? p v <||> findM? p b + | .app f a => findM? p f <||> findM? p a + | .proj _ _ b => findM? p b + | _ => pure none + +@[implemented_by FindImpl.findUnsafe?] +def find? (p : Expr → Bool) (e : Expr) : Option Expr := findM? (m := Id) p e /-- Return true if `e` occurs in `t` -/ def occurs (e : Expr) (t : Expr) : Bool := (t.find? fun s => s == e).isSome + /-- Return type for `findExt?` function argument. -/ @@ -66,7 +72,7 @@ inductive FindStep where namespace FindExtImpl -unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT FindImpl.FindM Expr := +unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT (FindImpl.FindM Id) Expr := visit e where visitApp (e : Expr) :=