Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-em committed Oct 23, 2023
1 parent d83bd03 commit 3d66565
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 41 deletions.
48 changes: 24 additions & 24 deletions Std/Lean/Meta/DiscrTree.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ namespace Key
Compare two `Key`s. The ordering is total but otherwise arbitrary. (It uses
`Name.quickCmp` internally.)
-/
protected def cmp : Key s → Key s → Ordering
protected def cmp : Key → Key → Ordering
| .lit v₁, .lit v₂ => compare v₁ v₂
| .fvar n₁ a₁, .fvar n₂ a₂ => n₁.name.quickCmp n₂.name |>.then <| compare a₁ a₂
| .const n₁ a₁, .const n₂ a₂ => n₁.quickCmp n₂ |>.then <| compare a₁ a₂
| .proj s₁ i₁ a₁, .proj s₂ i₂ a₂ =>
s₁.quickCmp s₂ |>.then <| compare i₁ i₂ |>.then <| compare a₁ a₂
| k₁, k₂ => compare k₁.ctorIdx k₂.ctorIdx

instance : Ord (Key s) :=
instance : Ord Key :=
⟨Key.cmp⟩

end Key
Expand All @@ -36,8 +36,8 @@ namespace Trie

-- This is just a partial function, but Lean doesn't realise that its type is
-- inhabited.
private unsafe def foldMUnsafe [Monad m] (initialKeys : Array (Key s))
(f : σ → Array (Key s) → α → m σ) (init : σ) : Trie α s → m σ
private unsafe def foldMUnsafe [Monad m] (initialKeys : Array Key)
(f : σ → Array Key → α → m σ) (init : σ) : Trie α → m σ
| Trie.node vs children => do
let s ← vs.foldlM (init := init) fun s v => f s initialKeys v
children.foldlM (init := s) fun s (k, t) =>
Expand All @@ -47,22 +47,22 @@ private unsafe def foldMUnsafe [Monad m] (initialKeys : Array (Key s))
Monadically fold the keys and values stored in a `Trie`.
-/
@[implemented_by foldMUnsafe]
opaque foldM [Monad m] (initalKeys : Array (Key s))
(f : σ → Array (Key s) → α → m σ) (init : σ) (t : Trie α s) : m σ :=
opaque foldM [Monad m] (initalKeys : Array Key)
(f : σ → Array Key → α → m σ) (init : σ) (t : Trie α) : m σ :=
pure init

/--
Fold the keys and values stored in a `Trie`.
-/
@[inline]
def fold (initialKeys : Array (Key s)) (f : σ → Array (Key s) → α → σ)
(init : σ) (t : Trie α s) : σ :=
def fold (initialKeys : Array Key) (f : σ → Array Key → α → σ)
(init : σ) (t : Trie α) : σ :=
Id.run <| t.foldM initialKeys (init := init) fun s k a => return f s k a

-- This is just a partial function, but Lean doesn't realise that its type is
-- inhabited.
private unsafe def foldValuesMUnsafe [Monad m] (f : σ → α → m σ) (init : σ) :
Trie α s → m σ
Trie α → m σ
| node vs children => do
let s ← vs.foldlM (init := init) f
children.foldlM (init := s) fun s (_, c) => c.foldValuesMUnsafe (init := s) f
Expand All @@ -71,32 +71,32 @@ private unsafe def foldValuesMUnsafe [Monad m] (f : σ → α → m σ) (init :
Monadically fold the values stored in a `Trie`.
-/
@[implemented_by foldValuesMUnsafe]
opaque foldValuesM [Monad m] (f : σ → α → m σ) (init : σ) (t : Trie α s) : m σ := pure init
opaque foldValuesM [Monad m] (f : σ → α → m σ) (init : σ) (t : Trie α) : m σ := pure init

/--
Fold the values stored in a `Trie`.
-/
@[inline]
def foldValues (f : σ → α → σ) (init : σ) (t : Trie α s) : σ :=
def foldValues (f : σ → α → σ) (init : σ) (t : Trie α) : σ :=
Id.run <| t.foldValuesM (init := init) f

/--
The number of values stored in a `Trie`.
-/
partial def size : Trie α s → Nat
partial def size : Trie α → Nat
| Trie.node vs children =>
children.foldl (init := vs.size) fun n (_, c) => n + size c

/--
Merge two `Trie`s. Duplicate values are preserved.
-/
partial def mergePreservingDuplicates : Trie α s → Trie α s → Trie α s
partial def mergePreservingDuplicates : Trie α → Trie α → Trie α
| node vs₁ cs₁, node vs₂ cs₂ =>
node (vs₁ ++ vs₂) (mergeChildren cs₁ cs₂)
where
/-- Auxiliary definition for `mergePreservingDuplicates`. -/
mergeChildren (cs₁ cs₂ : Array (Key s × Trie α s)) :
Array (Key s × Trie α s) :=
mergeChildren (cs₁ cs₂ : Array (Key × Trie α)) :
Array (Key × Trie α) :=
Array.mergeSortedMergingDuplicates
(ord := ⟨compareOn (·.fst)⟩) cs₁ cs₂
(fun (k₁, t₁) (_, t₂) => (k₁, mergePreservingDuplicates t₁ t₂))
Expand All @@ -108,57 +108,57 @@ end Trie
Monadically fold over the keys and values stored in a `DiscrTree`.
-/
@[inline]
def foldM [Monad m] (f : σ → Array (Key s) → α → m σ) (init : σ)
(t : DiscrTree α s) : m σ :=
def foldM [Monad m] (f : σ → Array Key → α → m σ) (init : σ)
(t : DiscrTree α) : m σ :=
t.root.foldlM (init := init) fun s k t => t.foldM #[k] (init := s) f

/--
Fold over the keys and values stored in a `DiscrTree`
-/
@[inline]
def fold (f : σ → Array (Key s) → α → σ) (init : σ) (t : DiscrTree α s) : σ :=
def fold (f : σ → Array Key → α → σ) (init : σ) (t : DiscrTree α) : σ :=
Id.run <| t.foldM (init := init) fun s keys a => return f s keys a

/--
Monadically fold over the values stored in a `DiscrTree`.
-/
@[inline]
def foldValuesM [Monad m] (f : σ → α → m σ) (init : σ) (t : DiscrTree α s) :
def foldValuesM [Monad m] (f : σ → α → m σ) (init : σ) (t : DiscrTree α) :
m σ :=
t.root.foldlM (init := init) fun s _ t => t.foldValuesM (init := s) f

/--
Fold over the values stored in a `DiscrTree`.
-/
@[inline]
def foldValues (f : σ → α → σ) (init : σ) (t : DiscrTree α s) : σ :=
def foldValues (f : σ → α → σ) (init : σ) (t : DiscrTree α) : σ :=
Id.run <| t.foldValuesM (init := init) f

/--
Extract the values stored in a `DiscrTree`.
-/
@[inline]
def values (t : DiscrTree α s) : Array α :=
def values (t : DiscrTree α) : Array α :=
t.foldValues (init := #[]) fun as a => as.push a

/--
Extract the keys and values stored in a `DiscrTree`.
-/
@[inline]
def toArray (t : DiscrTree α s) : Array (Array (Key s) × α) :=
def toArray (t : DiscrTree α) : Array (Array Key × α) :=
t.fold (init := #[]) fun as keys a => as.push (keys, a)

/--
Get the number of values stored in a `DiscrTree`. O(n) in the size of the tree.
-/
@[inline]
def size (t : DiscrTree α s) : Nat :=
def size (t : DiscrTree α) : Nat :=
t.root.foldl (init := 0) fun n _ t => n + t.size

/--
Merge two `DiscrTree`s. Duplicate values are preserved.
-/
@[inline]
def mergePreservingDuplicates (t u : DiscrTree α s) : DiscrTree α s :=
def mergePreservingDuplicates (t u : DiscrTree α) : DiscrTree α :=
⟨t.root.mergeWith u.root fun _ trie₁ trie₂ =>
trie₁.mergePreservingDuplicates trie₂⟩
14 changes: 9 additions & 5 deletions Std/Tactic/Ext/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,35 @@ structure ExtTheorem where
/-- Priority of the extensionality theorem. -/
priority : Nat
/-- Key in the discrimination tree. -/
keys : Array (DiscrTree.Key true)
keys : Array DiscrTree.Key
deriving Inhabited, Repr, BEq, Hashable

/-- The state of the `ext` extension environment -/
structure ExtTheorems where
/-- The tree of `ext` extensions. -/
tree : DiscrTree ExtTheorem true := {}
tree : DiscrTree ExtTheorem := {}
/-- Erased `ext`s. -/
erased : PHashSet Name := {}
deriving Inhabited


/-- Discrimation tree settings for the `ext` extension. -/
def extExt.config : WhnfCoreConfig := {}

/-- The environment extension to track `@[ext]` lemmas. -/
initialize extExtension :
SimpleScopedEnvExtension ExtTheorem ExtTheorems ←
registerSimpleScopedEnvExtension {
addEntry := fun { tree, erased } thm =>
{ tree := tree.insertCore thm.keys thm, erased := erased.erase thm.declName }
{ tree := tree.insertCore thm.keys thm extExt.config, erased := erased.erase thm.declName }
initial := {}
}

/-- Get the list of `@[ext]` lemmas corresponding to the key `ty`,
ordered from high priority to low. -/
@[inline] def getExtLemmas (ty : Expr) : MetaM (Array ExtTheorem) := do
let extTheorems := extExtension.getState (← getEnv)
let arr ← extTheorems.tree.getMatch ty
let arr ← extTheorems.tree.getMatch ty extExt.config
let erasedArr := arr.filter fun thm => !extTheorems.erased.contains thm.declName
-- Using insertion sort because it is stable and the list of matches should be mostly sorted.
-- Most ext lemmas have default priority.
Expand Down Expand Up @@ -97,7 +101,7 @@ initialize registerBuiltinAttribute {
"@[ext] attribute only applies to structures or lemmas proving x = y, got {declTy}"
let some (ty, lhs, rhs) := declTy.eq? | failNotEq
unless lhs.isMVar && rhs.isMVar do failNotEq
let keys ← withReducible <| DiscrTree.mkPath ty
let keys ← withReducible <| DiscrTree.mkPath ty extExt.config
let priority ← liftCommandElabM do Elab.liftMacroM do
evalPrio (prio.getD (← `(prio| default)))
extExtension.add {declName, keys, priority} kind
Expand Down
2 changes: 1 addition & 1 deletion Std/Tactic/Instances.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ elab (name := instancesCmd) tk:"#instances " stx:term : command => runTermElabM
let some className ← isClass? type
| throwErrorAt stx "type class instance expected{indentExpr type}"
let globalInstances ← getGlobalInstancesIndex
let result ← globalInstances.getUnify type
let result ← globalInstances.getUnify type tcDtConfig
let erasedInstances ← getErasedInstances
let mut msgs := #[]
for e in result.insertionSort fun e₁ e₂ => e₁.priority < e₂.priority do
Expand Down
5 changes: 3 additions & 2 deletions Std/Tactic/Lint/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def isSimpTheorem (declName : Name) : MetaM Bool := do

open Lean.Meta.DiscrTree in
/-- Returns the list of elements in the discrimination tree. -/
partial def _root_.Lean.Meta.DiscrTree.elements (d : DiscrTree α s) : Array α :=
partial def _root_.Lean.Meta.DiscrTree.elements (d : DiscrTree α) : Array α :=
d.root.foldl (init := #[]) fun arr _ => trieElements arr
where
/-- Returns the list of elements in the trie. -/
Expand Down Expand Up @@ -229,7 +229,8 @@ Some commutativity lemmas are simp lemmas:"
unless ← isDefEq rhs lhs' do return none
unless ← withNewMCtxDepth (isDefEq rhs lhs') do return none
-- make sure that the discrimination tree will actually find this match (see #69)
if (← (← DiscrTree.empty.insert (s := true) rhs ()).getMatch lhs').isEmpty then return none
if (← (← DiscrTree.empty.insert rhs () simpDtConfig).getMatch lhs' simpDtConfig).isEmpty then
return none
-- ensure that the second application makes progress:
if ← isDefEq lhs' rhs' then return none
pure m!"should not be marked simp"
11 changes: 7 additions & 4 deletions Std/Tactic/Relation/Rfl.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ namespace Std.Tactic

open Lean Meta

/-- Discrimation tree settings for the `refl` extension. -/
def reflExt.config : WhnfCoreConfig := {}

/-- Environment extensions for `refl` lemmas -/
initialize reflExt :
SimpleScopedEnvExtension (Name × Array (DiscrTree.Key true)) (DiscrTree Name true) ←
SimpleScopedEnvExtension (Name × Array (DiscrTree.Key)) (DiscrTree Name) ←
registerSimpleScopedEnvExtension {
addEntry := fun dt (n, ks) => dt.insertCore ks n
addEntry := fun dt (n, ks) => dt.insertCore ks n reflExt.config
initial := {}
}

Expand All @@ -35,7 +38,7 @@ initialize registerBuiltinAttribute {
"@[refl] attribute only applies to lemmas proving x ∼ x, got {declTy}"
let .app (.app rel lhs) rhs := targetTy | fail
unless ← withNewMCtxDepth <| isDefEq lhs rhs do fail
let key ← DiscrTree.mkPath rel
let key ← DiscrTree.mkPath rel reflExt.config
reflExt.add (decl, key) kind
}

Expand All @@ -52,7 +55,7 @@ def _root_.Lean.MVarId.applyRfl (goal : MVarId) : MetaM Unit := do
indentExpr (← goal.getType)}"
let s ← saveState
let mut ex? := none
for lem in ← (reflExt.getState (← getEnv)).getMatch rel do
for lem in ← (reflExt.getState (← getEnv)).getMatch rel reflExt.config do
try
let gs ← goal.apply (← mkConstWithFreshMVarLevels lem)
if gs.isEmpty then return () else
Expand Down
11 changes: 7 additions & 4 deletions Std/Tactic/Relation/Symm.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ open Lean Meta

namespace Std.Tactic

/-- Discrimation tree settings for the `symm` extension. -/
def symmExt.config : WhnfCoreConfig := {}

/-- Environment extensions for symm lemmas -/
initialize symmExt :
SimpleScopedEnvExtension (Name × Array (DiscrTree.Key true)) (DiscrTree Name true) ←
SimpleScopedEnvExtension (Name × Array (DiscrTree.Key)) (DiscrTree Name) ←
registerSimpleScopedEnvExtension {
addEntry := fun dt (n, ks) => dt.insertCore ks n
addEntry := fun dt (n, ks) => dt.insertCore ks n symmExt.config
initial := {}
}

Expand All @@ -37,7 +40,7 @@ initialize registerBuiltinAttribute {
let some _ := xs.back? | fail
let targetTy ← reduce targetTy
let .app (.app rel _) _ := targetTy | fail
let key ← withReducible <| DiscrTree.mkPath rel
let key ← withReducible <| DiscrTree.mkPath rel symmExt.config
symmExt.add (decl, key) kind
}

Expand Down Expand Up @@ -70,7 +73,7 @@ where
go (tgt : Expr) {α} (k : Expr → Array Expr → Expr → MetaM α) : MetaM α := do
let .app (.app rel _) _ := tgt
| throwError "symmetry lemmas only apply to binary relations, not{indentExpr tgt}"
for lem in ← (symmExt.getState (← getEnv)).getMatch rel do
for lem in ← (symmExt.getState (← getEnv)).getMatch rel symmExt.config do
try
let lem ← mkConstWithFreshMVarLevels lem
let (args, _, body) ← withReducible <| forallMetaTelescopeReducing (← inferType lem)
Expand Down
2 changes: 1 addition & 1 deletion lean-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
leanprover/lean4:v4.2.0-rc4
leanprover/lean4-pr-releases:pr-release-2734

0 comments on commit 3d66565

Please sign in to comment.