diff --git a/Project.toml b/Project.toml index d3e2cdb..7d341c1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Salsa" uuid = "1fbf2c77-44e2-4d5d-8131-0fa618a5c278" authors = ["Nathan Daly ", "Todd J. Green "] -version = "2.1.0" +version = "2.1.1" [deps] ExceptionUnwrapping = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" diff --git a/src/inspect.jl b/src/inspect.jl index 007e638..8cd2251 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -57,8 +57,8 @@ function build_graph(rt::Salsa.Runtime; module_boxes=false) println(io, "{") println(io, "rank=sink;") end - for (input_key, name) in inputs - println(io, """$(vertex_name(input_key)) [label="$name"]""") + for (input_key_func, name) in inputs + println(io, """$(vertex_name(input_key_func)) [label="$name"]""") end if !module_boxes println(io, "}") @@ -76,17 +76,19 @@ function build_graph(rt::Salsa.Runtime; module_boxes=false) end function _build_graph(io::IO, st::DefaultStorage, seen::_IdSet, modules_map::Dict, out_edges::Dict, out_inputs::Dict) - for ((F,k),v) in st.inputs_map + for (input_key,v) in st.inputs_map + F = key_function(input_key) m = F.name.module out_inputs[F] = "@input $(nameof(F.instance))" push!(get!(modules_map, m, Set([])), F) end - for (derived_key,derived_map) in st.derived_function_maps - _build_graph(io, st, derived_key, derived_map, seen, modules_map, out_edges) + for (derived_key_t,derived_map) in st.derived_function_maps + _build_graph(io, st, derived_key_t, derived_map, seen, modules_map, out_edges) end end +key_function(::Union{Salsa.InputKey{F}, Salsa.DerivedKey{F}}) where F = F function vertex_name(::Salsa.InputKey{F}) where F return vertex_name(F) end @@ -94,17 +96,17 @@ function vertex_name(x::Any)::String return "v$(objectid(x))" end -function _build_graph(io, st::DefaultStorage, derived_key::Salsa.DerivedKey{F,TT}, derived_map::Dict, +function _build_graph(io, st::DefaultStorage, derived_key_t::Type{<:Salsa.DerivedKey{F,TT}}, derived_map::Dict, seen::_IdSet{Any}, modules_map::Dict{Module,Set}, edges::Dict{Pair, Int}) where {F,TT} - in(derived_key, seen) && return - push!(seen, derived_key) + in(derived_key_t, seen) && return + push!(seen, derived_key_t) m = methods(F.instance).mt.module - push!(get!(modules_map, m, Set([])), derived_key) - key_str = _derived_key_as_call_str(derived_key) - println(io, "$(vertex_name(derived_key)) [shape=rect,label=\"$key_str\"]") + push!(get!(modules_map, m, Set([])), F) + key_str = _derived_key_as_call_str(derived_key_t) + println(io, "$(vertex_name(F)) [shape=rect,label=\"$key_str\"]") for (k,v) in derived_map for d in v.dependencies - edge = (derived_key) => (d.key) + edge = (F) => (key_function(d)) count = get!(edges, edge, 0) + 1 edges[edge] = count end @@ -112,7 +114,7 @@ function _build_graph(io, st::DefaultStorage, derived_key::Salsa.DerivedKey{F,TT #_build_graph(io, s.leaf_node, seen) end -function _derived_key_as_call_str(key::Salsa.DerivedKey{F,TT})::String where {F,TT} +function _derived_key_as_call_str(::Type{<:Salsa.DerivedKey{F,TT}})::String where {F,TT} f = isdefined(F, :instance) ? nameof(F.instance) : nameof(F) argsexprs = [ :rt,