Skip to content

Commit

Permalink
Compiler support for optimizing PersistentDict (JuliaLang#51993)
Browse files Browse the repository at this point in the history
This is part of the work to address JuliaLang#51352 by attempting to allow the
compiler to perform SRAO on persistent data structures like
`PersistentDict` as if they were regular immutable data structures.
These sorts of data structures have very complicated internals (with
lots of mutation, memory sharing, etc.), but a relatively simple
interface. As such, it is unlikely that our compiler will have
sufficient power to optimize this interface by analyzing the
implementation.

We thus need to come up with some other mechanism that gives the
compiler license to perform the requisite optimization. One way would be
to just hardcode `PersistentDict` into the compiler, optimizing it like
any of the other builtin datatypes. However, this is of course very
unsatisfying. At the other end of the spectrum would be something like a
generic rewrite rule system (e-graphs anyone?) that would let the
PersistentDict implementation declare its interface to the compiler and
the compiler would use this for optimization (in a perfect world, the
actual rewrite would then be checked using some sort of formal methods).
I think that would be interesting, but we're very far from even being
able to design something like that (at least in Base - experiments with
external AbstractInterpreters in this direction are encouraged).

This PR tries to come up with a reasonable middle ground, where the
compiler gets some knowledge of the protocol hardcoded without having to
know about the implementation details of the data structure.

The basic ideas is that `Core` provides some magic generic functions
that implementations can extend. Semantically, they are not special.
They dispatch as usual, and implementations are expected to work
properly even in the absence of any compiler optimizations.

However, the compiler is semantically permitted to perform structural
optimization using these magic generic functions. In the concrete case,
this PR introduces the `KeyValue` interface which consists of two
generic functions, `get` and `set`. The core optimization is that the
compiler is allowed to rewrite any occurrence of `get(set(x, k, v), k)`
into `v` without additional legality checks. In particular, the compiler
performs no type checks, conversions, etc. The higher level
implementation code is expected to do all that.

This approach closely matches the general direction we've been taking in
external AbstractInterpreters for embedding additional semantics and
optimization opportunities into Julia code (although we generally use
methods there, rather than full generic functions), so I think we have
some evidence that this sort of approach works reasonably well.

Nevertheless, this is certainly an experiment and the interface is
explicitly declared unstable.

## Current Status

This is fully working and implemented, but the optimization currently
bails on anything but the simplest cases. Filling all those cases in is
not particularly hard, but should be done along with a more invasive
refactoring of SROA, so we should figure out the general direction here
first and then we can finish all that up in a follow-up cleanup.

## Obligatory benchmark
Before:
```
julia> using BenchmarkTools

julia> function foo()
           a = Base.PersistentDict(:a => 1)
           return a[:a]
       end
foo (generic function with 1 method)

julia> @benchmark foo()
BenchmarkTools.Trial: 10000 samples with 993 evaluations.
 Range (min … max):  32.940 ns …  28.754 μs  ┊ GC (min … max):  0.00% … 99.76%
 Time  (median):     49.647 ns               ┊ GC (median):     0.00%
 Time  (mean ± σ):   57.519 ns ± 333.275 ns  ┊ GC (mean ± σ):  10.81% ±  2.22%

        ▃█▅               ▁▃▅▅▃▁                ▁▃▂   ▂
  ▁▂▄▃▅▇███▇▃▁▂▁▁▁▁▁▁▁▁▂▂▅██████▅▂▁▁▁▁▁▁▁▁▁▁▂▃▃▇███▇▆███▆▄▃▃▂▂ ▃
  32.9 ns         Histogram: frequency by time         68.6 ns <

 Memory estimate: 128 bytes, allocs estimate: 4.

julia> @code_typed foo()
CodeInfo(
1 ─ %1  = invoke Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}(Base.HashArrayMappedTries.undef::UndefInitializer, 1::Int64)::Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}
│   %2  = %new(Base.HashArrayMappedTries.HAMT{Symbol, Int64}, %1, 0x00000000)::Base.HashArrayMappedTries.HAMT{Symbol, Int64}
│   %3  = %new(Base.HashArrayMappedTries.Leaf{Symbol, Int64}, :a, 1)::Base.HashArrayMappedTries.Leaf{Symbol, Int64}
│   %4  = Base.getfield(%2, :data)::Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}
│   %5  = $(Expr(:boundscheck, true))::Bool
└──       goto JuliaLang#5 if not %5
2 ─ %7  = Base.sub_int(1, 1)::Int64
│   %8  = Base.bitcast(UInt64, %7)::UInt64
│   %9  = Base.getfield(%4, :size)::Tuple{Int64}
│   %10 = $(Expr(:boundscheck, true))::Bool
│   %11 = Base.getfield(%9, 1, %10)::Int64
│   %12 = Base.bitcast(UInt64, %11)::UInt64
│   %13 = Base.ult_int(%8, %12)::Bool
└──       goto JuliaLang#4 if not %13
3 ─       goto JuliaLang#5
4 ─ %16 = Core.tuple(1)::Tuple{Int64}
│         invoke Base.throw_boundserror(%4::Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}, %16::Tuple{Int64})::Union{}
└──       unreachable
5 ┄ %19 = Base.getfield(%4, :ref)::MemoryRef{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}
│   %20 = Base.memoryref(%19, 1, false)::MemoryRef{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}
│         Base.memoryrefset!(%20, %3, :not_atomic, false)::MemoryRef{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}
└──       goto JuliaLang#6
6 ─ %23 = Base.getfield(%2, :bitmap)::UInt32
│   %24 = Base.or_int(%23, 0x00010000)::UInt32
│         Base.setfield!(%2, :bitmap, %24)::UInt32
└──       goto JuliaLang#7
7 ─ %27 = %new(Base.PersistentDict{Symbol, Int64}, %2)::Base.PersistentDict{Symbol, Int64}
└──       goto JuliaLang#8
8 ─ %29 = invoke Base.getindex(%27::Base.PersistentDict{Symbol, Int64}, 🅰️:Symbol)::Int64
└──       return %29
```

After:
```
julia> using BenchmarkTools

julia> function foo()
           a = Base.PersistentDict(:a => 1)
           return a[:a]
       end
foo (generic function with 1 method)

julia> @benchmark foo()
BenchmarkTools.Trial: 10000 samples with 1000 evaluations.
 Range (min … max):  2.459 ns … 11.320 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     2.460 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.469 ns ±  0.183 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▂    █                                              ▁    █ ▂
  █▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁█ █
  2.46 ns      Histogram: log(frequency) by time     2.47 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @code_typed foo()
CodeInfo(
1 ─     return 1
```
  • Loading branch information
Keno authored and mkitti committed Dec 9, 2023
1 parent 7dad3ee commit ddcce82
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 65 deletions.
3 changes: 2 additions & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,6 @@ function _hasmethod(@nospecialize(tt)) # this function has a special tfunc
return Intrinsics.not_int(ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), tt, nothing, world) === nothing)
end


# for backward compat
arrayref(inbounds::Bool, A::Array, i::Int...) = Main.Base.getindex(A, i...)
const_arrayref(inbounds::Bool, A::Array, i::Int...) = Main.Base.getindex(A, i...)
Expand All @@ -969,4 +968,6 @@ export arrayref, arrayset, arraysize, const_arrayref
# For convenience
EnterNode(old::EnterNode, new_dest::Int) = EnterNode(new_dest)

include(Core, "optimized_generics.jl")

ccall(:jl_set_istopmod, Cvoid, (Any, Bool), Core, true)
82 changes: 81 additions & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ function is_known_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,I
return singleton_type(ft) === func
end

function is_known_invoke_or_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,IncrementalCompact})
isinvoke = isexpr(x, :invoke)
(isinvoke || isexpr(x, :call)) || return false
ft = argextype(x.args[isinvoke ? 2 : 1], ir)
return singleton_type(ft) === func
end

struct SSAUse
kind::Symbol
idx::Int
Expand Down Expand Up @@ -819,6 +826,76 @@ function lift_svec_ref!(compact::IncrementalCompact, idx::Int, stmt::Expr)
return
end

function lift_leaves_keyvalue(compact::IncrementalCompact, @nospecialize(key),
leaves::Vector{Any}, 𝕃ₒ::AbstractLattice)
# For every leaf, the lifted value
lifted_leaves = LiftedLeaves()
for i = 1:length(leaves)
leaf = leaves[i]
cache_key = leaf
if isa(leaf, AnySSAValue)
(def, leaf) = walk_to_def(compact, leaf)
if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, compact)
@assert isexpr(def, :invoke)
if length(def.args) in (5, 6)
collection = def.args[end-2]
set_key = def.args[end-1]
set_val_idx = length(def.args)
elseif length(def.args) == 4
collection = def.args[end-1]
# Key is deleted
# TODO: Model this
return nothing
elseif length(def.args) == 3
collection = def.args[end]
# The whole collection is deleted
# TODO: Model this
return nothing
else
return nothing
end
if set_key === key || (egal_tfunc(𝕃ₒ, argextype(key, compact), argextype(set_key, compact)) == Const(true))
lift_arg!(compact, leaf, cache_key, def, set_val_idx, lifted_leaves)
continue
end
# TODO: Continue walking the chain
return nothing
end
end
return nothing
end
return lifted_leaves
end

function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr, 𝕃ₒ::AbstractLattice)
collection = stmt.args[end-1]
key = stmt.args[end]

leaves, visited_philikes = collect_leaves(compact, collection, Any, 𝕃ₒ, phi_or_ifelse_predecessors)
isempty(leaves) && return

lifted_leaves = lift_leaves_keyvalue(compact, key, leaves, 𝕃ₒ)
lifted_leaves === nothing && return

result_t = Union{}
for v in values(lifted_leaves)
v === nothing && return
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
end

lifted_val = perform_lifting!(compact,
visited_philikes, key, result_t, lifted_leaves, collection, nothing)

compact[idx] = lifted_val === nothing ? nothing : Expr(:call, Core.tuple, lifted_val.val)
if lifted_val !== nothing
if !(𝕃ₒ, compact[SSAValue(idx)][:type], result_t)
compact[SSAValue(idx)][:flag] |= IR_FLAG_REFINED
end
end

return
end

# TODO: We could do the whole lifing machinery here, but really all
# we want to do is clean this up when it got inserted by inlining,
# which always targets simple `svec` call or `_compute_sparams`,
Expand Down Expand Up @@ -1004,7 +1081,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
for ((_, idx), stmt) in compact
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
isa(stmt, Expr) || continue
is_setfield = is_isdefined = is_finalizer = false
is_setfield = is_isdefined = is_finalizer = is_keyvalue_get = false
field_ordering = :unspecified
if is_known_call(stmt, setfield!, compact)
4 <= length(stmt.args) <= 5 || continue
Expand Down Expand Up @@ -1094,6 +1171,9 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
lift_comparison!(isa, compact, idx, stmt, 𝕃ₒ)
elseif is_known_call(stmt, Core.ifelse, compact)
fold_ifelse!(compact, idx, stmt)
elseif is_known_invoke_or_call(stmt, Core.OptimizedGenerics.KeyValue.get, compact)
2 == (length(stmt.args) - (isexpr(stmt, :invoke) ? 2 : 1)) || continue
lift_keyvalue_get!(compact, idx, stmt, 𝕃ₒ)
elseif isexpr(stmt, :new)
refine_new_effects!(𝕃ₒ, compact, idx, stmt)
end
Expand Down
129 changes: 69 additions & 60 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -887,10 +887,35 @@ _similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} =

include("hamt.jl")
using .HashArrayMappedTries
using Core.OptimizedGenerics: KeyValue
const HAMT = HashArrayMappedTries

struct PersistentDict{K,V} <: AbstractDict{K,V}
trie::HAMT.HAMT{K,V}
# Serves as a marker for an empty initialization
@noinline function KeyValue.set(::Type{PersistentDict{K, V}}) where {K, V}
new{K, V}(HAMT.HAMT{K,V}())
end
@noinline function KeyValue.set(::Type{PersistentDict{K, V}}, ::Nothing, key, val) where {K, V}
new{K, V}(HAMT.HAMT{K, V}(key => val))
end
@noinline function KeyValue.set(dict::PersistentDict{K, V}, key, val) where {K, V}
trie = dict.trie
h = HAMT.HashState(key)
found, present, trie, i, bi, top, hs = HAMT.path(trie, key, h, #=persistent=# true)
HAMT.insert!(found, present, trie, i, bi, hs, val)
return new{K, V}(top)
end
@noinline function KeyValue.set(dict::PersistentDict{K, V}, key) where {K, V}
trie = dict.trie
h = HAMT.HashState(key)
found, present, trie, i, bi, top, _ = HAMT.path(trie, key, h, #=persistent=# true)
if found && present
deleteat!(trie.data, i)
HAMT.unset!(trie, bi)
end
return new{K, V}(top)
end
end

"""
Expand Down Expand Up @@ -925,19 +950,27 @@ Base.PersistentDict{Symbol, Int64} with 1 entry:
"""
PersistentDict

PersistentDict{K,V}() where {K,V} = PersistentDict(HAMT.HAMT{K,V}())
PersistentDict{K,V}(KV::Pair) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV))
PersistentDict(KV::Pair{K,V}) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV))
PersistentDict{K,V}() where {K, V} = KeyValue.set(PersistentDict{K,V})
function PersistentDict{K,V}(KV::Pair) where {K,V}
KeyValue.set(
PersistentDict{K, V},
nothing,
KV...)
end
function PersistentDict(KV::Pair{K,V}) where {K,V}
KeyValue.set(
PersistentDict{K, V},
nothing,
KV...)
end
PersistentDict(dict::PersistentDict, pair::Pair) = PersistentDict(dict, pair...)
PersistentDict{K,V}(dict::PersistentDict{K,V}, pair::Pair) where {K,V} = PersistentDict(dict, pair...)


function PersistentDict(dict::PersistentDict{K,V}, key, val) where {K,V}
key = convert(K, key)
val = convert(V, val)
trie = dict.trie
h = HAMT.HashState(key)
found, present, trie, i, bi, top, hs = HAMT.path(trie, key, h, #=persistent=# true)
HAMT.insert!(found, present, trie, i, bi, hs, val)
return PersistentDict(top)
return KeyValue.set(dict, key, val)
end

function PersistentDict{K,V}(KV::Pair, rest::Pair...) where {K,V}
Expand All @@ -959,84 +992,60 @@ end
eltype(::PersistentDict{K,V}) where {K,V} = Pair{K,V}

function in(key_val::Pair{K,V}, dict::PersistentDict{K,V}, valcmp=(==)) where {K,V}
trie = dict.trie
if HAMT.islevel_empty(trie)
return false
end

key, val = key_val

h = HAMT.HashState(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
return valcmp(val, leaf.val) && return true
end
return false
found = KeyValue.get(dict, key)
found === nothing && return false
return valcmp(val, only(found))
end

function haskey(dict::PersistentDict{K}, key::K) where K
trie = dict.trie
h = HAMT.HashState(key)
found, present, _, _, _, _, _ = HAMT.path(trie, key, h)
return found && present
return KeyValue.get(dict, key) !== nothing
end

function getindex(dict::PersistentDict{K,V}, key::K) where {K,V}
trie = dict.trie
if HAMT.islevel_empty(trie)
throw(KeyError(key))
end
h = HAMT.HashState(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
return leaf.val
end
throw(KeyError(key))
found = KeyValue.get(dict, key)
found === nothing && throw(KeyError(key))
return only(found)
end

function get(dict::PersistentDict{K,V}, key::K, default) where {K,V}
trie = dict.trie
if HAMT.islevel_empty(trie)
return default
end
h = HAMT.HashState(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
return leaf.val
end
return default
found = KeyValue.get(dict, key)
found === nothing && return default
return only(found)
end

function get(default::Callable, dict::PersistentDict{K,V}, key::K) where {K,V}
@noinline function KeyValue.get(dict::PersistentDict{K, V}, key) where {K, V}
trie = dict.trie
if HAMT.islevel_empty(trie)
return default
return nothing
end
h = HAMT.HashState(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
return leaf.val
return (leaf.val,)
end
return default()
return nothing
end

iterate(dict::PersistentDict, state=nothing) = HAMT.iterate(dict.trie, state)
@noinline function KeyValue.get(default, dict::PersistentDict, key)
found = KeyValue.get(dict, key)
found === nothing && return default()
return only(found)
end

function get(default::Callable, dict::PersistentDict{K,V}, key::K) where {K,V}
found = KeyValue.get(dict, key)
found === nothing && return default()
return only(found)
end

function delete(dict::PersistentDict{K}, key::K) where K
trie = dict.trie
h = HAMT.HashState(key)
found, present, trie, i, bi, top, _ = HAMT.path(trie, key, h, #=persistent=# true)
if found && present
deleteat!(trie.data, i)
HAMT.unset!(trie, bi)
end
return PersistentDict(top)
return KeyValue.set(dict, key)
end

iterate(dict::PersistentDict, state=nothing) = HAMT.iterate(dict.trie, state)

length(dict::PersistentDict) = HAMT.length(dict.trie)
isempty(dict::PersistentDict) = HAMT.isempty(dict.trie)
empty(::PersistentDict, ::Type{K}, ::Type{V}) where {K, V} = PersistentDict{K, V}()
12 changes: 9 additions & 3 deletions base/hamt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,18 @@ mutable struct HAMT{K, V}
HAMT{K,V}(data, bitmap) where {K,V} = new{K,V}(data, bitmap)
HAMT{K, V}() where {K, V} = new{K,V}(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 0), zero(BITMAP))
end
function HAMT{K,V}((k,v)::Pair) where {K, V}
k = convert(K, k)
v = convert(V, v)

@Base.assume_effects :nothrow function init_hamt(K, V, k, v)
# For a single element we can't have a hash-collision
trie = HAMT{K,V}(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 1), zero(BITMAP))
trie.data[1] = Leaf{K,V}(k,v)
return trie
end

function HAMT{K,V}((k,v)::Pair) where {K, V}
k = convert(K, k)
v = convert(V, v)
trie = init_hamt(K, V, k, v)
bi = BitmapIndex(HashState(k))
set!(trie, bi)
return trie
Expand Down
57 changes: 57 additions & 0 deletions base/optimized_generics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module OptimizedGenerics

# This file defines interfaces that are recognized and optimized by the compiler
# They are intended to be used by data structure implementations that wish to
# opt into some level of compiler optimizations. These interfaces are
# EXPERIMENTAL and currently intended for use by Base only. They are subject
# to change or removal without notice. It is undefined behavior to add methods
# to these generics that do not conform to the specified interface.
#
# The intended way to use these generics is that data structures will provide
# appropriate implementations for a generic. In the absence of compiler
# optimizations, these behave like regular methods. However, the compiler is
# semantically allowed to perform certain structural optimizations on
# appropriate combinations of these intrinsics without proving correctness.

# Compiler-recognized generics for immutable key-value stores (dicts, etc.)
"""
module KeyValue
Implements a key-value like interface where the compiler has liberty to perform
the following transformations. The core optimization semantically allowed for
the compiler is:
get(set(x, key, val), key) -> (val,)
where the compiler will recursively look through `x`. Keys are compared by
egality.
Implementations must observe the following constraints:
1. It is undefined behavior for `get` not to return the exact (by egality) val
stored for a given `key`.
"""
module KeyValue
"""
set(collection, [key [, val]])
set(T, collection, key, val)
Set the `key` in `collection` to `val`. If `val` is omitted, deletes the
value from the collection. If `key` is omitted as well, deletes all elements
of the collection.
"""
function set end

"""
get(collection, key)
Retrieve the value corresponding to `key` in `collection` as a single
element tuple or `nothing` if no value corresponding to the key was found.
`key`s are compared by egal.
"""
function get end
end

end
Loading

0 comments on commit ddcce82

Please sign in to comment.