Skip to content

Commit

Permalink
fix: remove PersistentHashMap.size
Browse files Browse the repository at this point in the history
It is buggy and was unnecessary overhead.

closes #3029
  • Loading branch information
leodemoura committed Jun 19, 2024
1 parent 0a1a855 commit 9096d6f
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 47 deletions.
12 changes: 6 additions & 6 deletions src/Lean/Attributes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ structure AttributeImpl extends AttributeImplCore where
erase (decl : Name) : AttrM Unit := throwError "attribute cannot be erased"
deriving Inhabited

builtin_initialize attributeMapRef : IO.Ref (PersistentHashMap Name AttributeImpl) ← IO.mkRef {}
builtin_initialize attributeMapRef : IO.Ref (HashMap Name AttributeImpl) ← IO.mkRef {}

/-- Low level attribute registration function. -/
def registerBuiltinAttribute (attr : AttributeImpl) : IO Unit := do
Expand Down Expand Up @@ -317,7 +317,7 @@ inductive AttributeExtensionOLeanEntry where

structure AttributeExtensionState where
newEntries : List AttributeExtensionOLeanEntry := []
map : PersistentHashMap Name AttributeImpl
map : HashMap Name AttributeImpl
deriving Inhabited

abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState
Expand Down Expand Up @@ -348,7 +348,7 @@ private def AttributeExtension.addImported (es : Array (Array AttributeExtension
let map ← es.foldlM
(fun map entries =>
entries.foldlM
(fun (map : PersistentHashMap Name AttributeImpl) entry => do
(fun (map : HashMap Name AttributeImpl) entry => do
let attrImpl ← mkAttributeImplOfEntry ctx.env ctx.opts entry
return map.insert attrImpl.name attrImpl)
map)
Expand All @@ -374,7 +374,7 @@ def isBuiltinAttribute (n : Name) : IO Bool := do

/-- Return the name of all registered attributes. -/
def getBuiltinAttributeNames : IO (List Name) :=
return (← attributeMapRef.get).foldl (init := []) fun r n _ => n::r
return (← attributeMapRef.get).fold (init := []) fun r n _ => n::r

def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
let m ← attributeMapRef.get
Expand All @@ -392,7 +392,7 @@ def isAttribute (env : Environment) (attrName : Name) : Bool :=

def getAttributeNames (env : Environment) : List Name :=
let m := (attributeExtension.getState env).map
m.foldl (fun r n _ => n::r) []
m.fold (fun r n _ => n::r) []

def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
let m := (attributeExtension.getState env).map
Expand Down Expand Up @@ -427,7 +427,7 @@ def Attribute.erase (declName : Name) (attrName : Name) : AttrM Unit := do
def updateEnvAttributesImpl (env : Environment) : IO Environment := do
let map ← attributeMapRef.get
let s := attributeExtension.getState env
let s := map.foldl (init := s) fun s attrName attrImpl =>
let s := map.fold (init := s) fun s attrName attrImpl =>
if s.map.contains attrName then
s
else
Expand Down
41 changes: 23 additions & 18 deletions src/Lean/Data/PersistentHashMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ inductive Node (α : Type u) (β : Type v) : Type (max u v) where
| entries (es : Array (Entry α β (Node α β))) : Node α β
| collision (ks : Array α) (vs : Array β) (h : ks.size = vs.size) : Node α β

partial def Node.isEmpty : Node α β → Bool
| .collision .. => false
| .entries es => es.all fun
| .entry .. => false
| .ref n => n.isEmpty
| .null => true

instance {α β} : Inhabited (Node α β) := ⟨Node.entries #[]⟩

abbrev shift : USize := 5
Expand All @@ -36,17 +43,16 @@ def mkEmptyEntriesArray {α β} : Array (Entry α β (Node α β)) :=
end PersistentHashMap

structure PersistentHashMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
root : PersistentHashMap.Node α β := PersistentHashMap.Node.entries PersistentHashMap.mkEmptyEntriesArray
size : Nat := 0
root : PersistentHashMap.Node α β := PersistentHashMap.Node.entries PersistentHashMap.mkEmptyEntriesArray

abbrev PHashMap (α : Type u) (β : Type v) [BEq α] [Hashable α] := PersistentHashMap α β

namespace PersistentHashMap

def empty [BEq α] [Hashable α] : PersistentHashMap α β := {}

def isEmpty [BEq α] [Hashable α] (m : PersistentHashMap α β) : Bool :=
m.size == 0
def isEmpty {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β Bool
| { root } => root.isEmpty

instance [BEq α] [Hashable α] : Inhabited (PersistentHashMap α β) := ⟨{}⟩

Expand Down Expand Up @@ -130,7 +136,7 @@ partial def insertAux [BEq α] [Hashable α] : Node α β → USize → USize
else Entry.ref $ mkCollisionNode k' v' k v

def insert {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → β → PersistentHashMap α β
| { root := n, size := sz }, k, v => { root := insertAux n (hash k |>.toUSize) 1 k v, size := sz + 1 }
| { root := n }, k, v => { root := insertAux n (hash k |>.toUSize) 1 k v }

partial def findAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option β :=
if h : i < keys.size then
Expand Down Expand Up @@ -225,7 +231,7 @@ def isUnaryNode : Node α β → Option (α × β)
else
none

partial def eraseAux [BEq α] : Node α β → USize → α → Node α β × Bool
partial def eraseAux [BEq α] : Node α β → USize → α → Node α β
| n@(Node.collision keys vals heq), _, k =>
match keys.indexOf? k with
| some idx =>
Expand All @@ -234,28 +240,27 @@ partial def eraseAux [BEq α] : Node α β → USize → α → Node α β × Bo
let vals' := vals.feraseIdx (Eq.ndrec idx heq)
have veq := vals.size_feraseIdx (Eq.ndrec idx heq)
have : keys.size - 1 = vals.size - 1 := by rw [heq]
(Node.collision keys' vals' (keq.trans (this.trans veq.symm)), true)
| none => (n, false)
Node.collision keys' vals' (keq.trans (this.trans veq.symm))
| none => n
| n@(Node.entries entries), h, k =>
let j := (mod2Shift h shift).toNat
let entry := entries.get! j
match entry with
| Entry.null => (n, false)
| Entry.null => n
| Entry.entry k' _ =>
if k == k' then (Node.entries (entries.set! j Entry.null), true) else (n, false)
if k == k' then Node.entries (entries.set! j Entry.null) else n
| Entry.ref node =>
let entries := entries.set! j Entry.null
let (newNode, deleted) := eraseAux node (div2Shift h shift) k
if !deleted then (n, false)
else match isUnaryNode newNode with
| none => (Node.entries (entries.set! j (Entry.ref newNode)), true)
| some (k, v) => (Node.entries (entries.set! j (Entry.entry k v)), true)
let newNode := eraseAux node (div2Shift h shift) k
match isUnaryNode newNode with
| none => Node.entries (entries.set! j (Entry.ref newNode))
| some (k, v) => Node.entries (entries.set! j (Entry.entry k v))

def erase {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → PersistentHashMap α β
| { root := n, size := sz }, k =>
| { root := n }, k =>
let h := hash k |>.toUSize
let (n, del) := eraseAux n h k
{ root := n, size := if del then sz - 1 else sz }
let n := eraseAux n h k
{ root := n }

section
variable {m : Type w → Type w'} [Monad m]
Expand Down
3 changes: 0 additions & 3 deletions src/Lean/Data/PersistentHashSet.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ variable {_ : BEq α} {_ : Hashable α}
@[inline] def contains (s : PersistentHashSet α) (a : α) : Bool :=
s.set.contains a

@[inline] def size (s : PersistentHashSet α) : Nat :=
s.set.size

@[inline] def foldM {β : Type v} {m : Type v → Type v} [Monad m] (f : β → α → m β) (init : β) (s : PersistentHashSet α) : m β :=
s.set.foldlM (init := init) fun d a _ => f d a

Expand Down
6 changes: 0 additions & 6 deletions src/Lean/Data/SMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,6 @@ def switch (m : SMap α β) : SMap α β :=
def fold {σ : Type w} (f : σ → α → β → σ) (init : σ) (m : SMap α β) : σ :=
m.map₂.foldl f $ m.map₁.fold f init

def size (m : SMap α β) : Nat :=
m.map₁.size + m.map₂.size

def stageSizes (m : SMap α β) : Nat × Nat :=
(m.map₁.size, m.map₂.size)

def numBuckets (m : SMap α β) : Nat :=
m.map₁.numBuckets

Expand Down
3 changes: 0 additions & 3 deletions src/Lean/Data/SSet.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ abbrev switch (s : SSet α) : SSet α :=
abbrev fold (f : σ → α → σ) (init : σ) (s : SSet α) : σ :=
SMap.fold (fun d a _ => f d a) init s

abbrev size (s : SSet α) : Nat :=
SMap.size s

def toList (m : SSet α) : List α :=
m.fold (init := []) fun es a => a::es

Expand Down
2 changes: 2 additions & 0 deletions src/Lean/Elab/BuiltinCommand.lean
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def elabCheckCore (ignoreStuckTC : Bool) : CommandElab

@[builtin_command_elab Lean.Parser.Command.check] def elabCheck : CommandElab := elabCheckCore (ignoreStuckTC := true)

/-
@[builtin_command_elab Lean.reduceCmd] def elabReduce : CommandElab
| `(#reduce%$tk $term) => go tk term
| `(#reduce%$tk (proofs := true) $term) => go tk term (skipProofs := false)
Expand All @@ -278,6 +279,7 @@ where
withTheReader Core.Context (fun ctx => { ctx with options := ctx.options.setBool `smartUnfolding false }) do
let e ← withTransparency (mode := TransparencyMode.all) <| reduce e (skipProofs := skipProofs) (skipTypes := skipTypes)
logInfoAt tk e
-/

def hasNoErrorMessages : CommandElabM Bool := do
return !(← get).messages.hasErrors
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Tactic/Cache.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private def dbg_cache (msg : String) : TacticM Unit := do
private def dbg_cache' (cacheRef : IO.Ref Cache) (pos : String.Pos) (mvarId : MVarId) (msg : String) : TacticM Unit := do
if tactic.dbg_cache.get (← getOptions) then
let {line, column} := (← getFileMap).toPosition pos
dbg_trace "{msg}, cache size: {(← cacheRef.get).pre.size}, line: {line}, column: {column}, contains entry: {(← cacheRef.get).pre.find? { mvarId, pos } |>.isSome}"
dbg_trace "{msg}, line: {line}, column: {column}, contains entry: {(← cacheRef.get).pre.find? { mvarId, pos } |>.isSome}"

private def findCache? (cacheRef : IO.Ref Cache) (mvarId : MVarId) (stx : Syntax) (pos : String.Pos) : TacticM (Option Snapshot) := do
let some s := (← cacheRef.get).pre.find? { mvarId, pos } | do dbg_cache "cache key not found"; return none
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Tactic/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def mkSimpOnly (stx : Syntax) (usedSimps : Simp.UsedSimps) : MetaM Syntax := do
let mut localsOrStar := some #[]
let lctx ← getLCtx
let env ← getEnv
for (thm, _) in usedSimps.toArray.qsort (·.2 < ·.2) do
for thm in usedSimps.toArray do
match thm with
| .decl declName post inv => -- global definitions in the environment
if env.contains declName
Expand Down
3 changes: 0 additions & 3 deletions src/Lean/Environment.lean
Original file line number Diff line number Diff line change
Expand Up @@ -984,9 +984,6 @@ def displayStats (env : Environment) : IO Unit := do
IO.println ("direct imports: " ++ toString env.header.imports);
IO.println ("number of imported modules: " ++ toString env.header.regions.size);
IO.println ("number of memory-mapped modules: " ++ toString (env.header.regions.filter (·.isMemoryMapped) |>.size));
IO.println ("number of consts: " ++ toString env.constants.size);
IO.println ("number of imported consts: " ++ toString env.constants.stageSizes.1);
IO.println ("number of local consts: " ++ toString env.constants.stageSizes.2);
IO.println ("number of buckets for imported consts: " ++ toString env.constants.numBuckets);
IO.println ("trust level: " ++ toString env.header.trustLevel);
IO.println ("number of extensions: " ++ toString env.extensions.size);
Expand Down
21 changes: 16 additions & 5 deletions src/Lean/Meta/Tactic/Simp/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,21 @@ structure Context where
def Context.isDeclToUnfold (ctx : Context) (declName : Name) : Bool :=
ctx.simpTheorems.isDeclToUnfold declName

-- We should use `PHashMap` because we backtrack the contents of `UsedSimps`
abbrev UsedSimps := PHashMap Origin Nat
structure UsedSimps where
-- We should use `PHashMap` because we backtrack the contents of `UsedSimps`
-- The natural number tracks the insertion order
map : PHashMap Origin Nat := {}
size : Nat := 0
deriving Inhabited

def UsedSimps.insert (s : UsedSimps) (thmId : Origin) : UsedSimps :=
if s.map.contains thmId then
s
else match s with
| { map, size } => { map := map.insert thmId size, size := size + 1 }

def UsedSimps.toArray (s : UsedSimps) : Array Origin :=
s.map.toArray.qsort (·.2 < ·.2) |>.map (·.1)

structure Diagnostics where
/-- Number of times each simp theorem has been used/applied. -/
Expand Down Expand Up @@ -367,9 +380,7 @@ def recordSimpTheorem (thmId : Origin) : SimpM Unit := do
else
pure thmId
| _ => pure thmId
modify fun s => if s.usedTheorems.contains thmId then s else
let n := s.usedTheorems.size
{ s with usedTheorems := s.usedTheorems.insert thmId n }
modify fun s => { s with usedTheorems := s.usedTheorems.insert thmId }

def recordCongrTheorem (declName : Name) : SimpM Unit := do
modifyDiag fun s =>
Expand Down
2 changes: 1 addition & 1 deletion stage0/src/stdlib_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ options get_default_options() {
// see https://lean-lang.org/lean4/doc/dev/bootstrap.html#further-bootstrapping-complications
#if LEAN_IS_STAGE0 == 1
// switch to `true` for ABI-breaking changes affecting meta code
opts = opts.update({"interpreter", "prefer_native"}, false);
opts = opts.update({"interpreter", "prefer_native"}, true);
// switch to `true` for changing built-in parsers used in quotations
opts = opts.update({"internal", "parseQuotWithCurrentStage"}, false);
// toggling `parseQuotWithCurrentStage` may also require toggling the following option if macros/syntax
Expand Down

0 comments on commit 9096d6f

Please sign in to comment.