Skip to content

Commit

Permalink
feat: monadic generalization of FindExpr (#3970)
Browse files Browse the repository at this point in the history
Not certain this is a good idea. Motivated by code duplication
introduced in #3398.
  • Loading branch information
kim-em authored Apr 24, 2024
1 parent 4fe0259 commit 706a4cf
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions src/Lean/Util/FindExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ namespace Lean
namespace Expr
namespace FindImpl

unsafe abbrev FindM := StateT (PtrSet Expr) Id
unsafe abbrev FindM (m) := StateT (PtrSet Expr) m

@[inline] unsafe def checkVisited (e : Expr) : OptionT FindM Unit := do
@[inline] unsafe def checkVisited [Monad m] (e : Expr) : OptionT (FindM m) Unit := do
if (← get).contains e then
failure
modify fun s => s.insert e

unsafe def findM? (p : Expr → Bool) (e : Expr) : OptionT FindM Expr :=
unsafe def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : OptionT (FindM m) Expr :=
let rec visit (e : Expr) := do
checkVisited e
if p e then
if p e then
pure e
else match e with
| .forallE _ d b _ => visit d <|> visit b
Expand All @@ -33,29 +33,35 @@ unsafe def findM? (p : Expr → Bool) (e : Expr) : OptionT FindM Expr :=
| _ => failure
visit e

unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr :=
Id.run <| findM? p e |>.run' mkPtrSet
unsafe def findUnsafeM? {m} [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) :=
findM? p e |>.run' mkPtrSet

@[inline] unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr := findUnsafeM? (m := Id) p e

end FindImpl

@[implemented_by FindImpl.findUnsafe?]
def find? (p : Expr → Bool) (e : Expr) : Option Expr :=
/- This is a reference implementation for the unsafe one above -/
if p e then
some e
@[implemented_by FindImpl.findUnsafeM?]
/- This is a reference implementation for the unsafe one above -/
def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) := do
if p e then
return some e
else match e with
| .forallE _ d b _ => find? p d <|> find? p b
| .lam _ d b _ => find? p d <|> find? p b
| .mdata _ b => find? p b
| .letE _ t v b _ => find? p t <|> find? p v <|> find? p b
| .app f a => find? p f <|> find? p a
| .proj _ _ b => find? p b
| _ => none
| .forallE _ d b _ => findM? p d <||> findM? p b
| .lam _ d b _ => findM? p d <||> findM? p b
| .mdata _ b => findM? p b
| .letE _ t v b _ => findM? p t <||> findM? p v <||> findM? p b
| .app f a => findM? p f <||> findM? p a
| .proj _ _ b => findM? p b
| _ => pure none

@[implemented_by FindImpl.findUnsafe?]
def find? (p : Expr → Bool) (e : Expr) : Option Expr := findM? (m := Id) p e

/-- Return true if `e` occurs in `t` -/
def occurs (e : Expr) (t : Expr) : Bool :=
(t.find? fun s => s == e).isSome


/--
Return type for `findExt?` function argument.
-/
Expand All @@ -66,7 +72,7 @@ inductive FindStep where

namespace FindExtImpl

unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT FindImpl.FindM Expr :=
unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT (FindImpl.FindM Id) Expr :=
visit e
where
visitApp (e : Expr) :=
Expand Down

0 comments on commit 706a4cf

Please sign in to comment.