Skip to content

Commit

Permalink
inference: accelerate type-limits under wide-recursion
Browse files Browse the repository at this point in the history
when we hit union-splitting, we need to ensure type limits are very aggressive
and preferably also independent of the height of the recursion chain

fix #31572
  • Loading branch information
vtjnash committed Apr 18, 2019
1 parent 1dc8236 commit 7f89529
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 26 deletions.
55 changes: 31 additions & 24 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
nonbot = 0 # the index of the only non-Bottom inference result if > 0
seen = 0 # number of signatures actually inferred
istoplevel = sv.linfo.def isa Module
multiple_matches = napplicable > 1
for i in 1:napplicable
match = applicable[i]::SimpleVector
method = match[3]::Method
Expand All @@ -80,7 +81,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
if splitunions
splitsigs = switchtupleunion(sig)
for sig_n in splitsigs
rt, edgecycle1, edge = abstract_call_method(method, sig_n, svec(), sv)
rt, edgecycle1, edge = abstract_call_method(method, sig_n, svec(), multiple_matches, sv)
if edge !== nothing
push!(edges, edge)
end
Expand All @@ -89,7 +90,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
this_rt === Any && break
end
else
this_rt, edgecycle1, edge = abstract_call_method(method, sig, match[2]::SimpleVector, sv)
this_rt, edgecycle1, edge = abstract_call_method(method, sig, match[2]::SimpleVector, multiple_matches, sv)
edgecycle |= edgecycle1::Bool
if edge !== nothing
push!(edges, edge)
Expand Down Expand Up @@ -227,7 +228,7 @@ function abstract_call_method_with_const_args(@nospecialize(rettype), @nospecial
return result
end

function abstract_call_method(method::Method, @nospecialize(sig), sparams::SimpleVector, sv::InferenceState)
function abstract_call_method(method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
return Any, false, nothing
end
Expand Down Expand Up @@ -266,30 +267,36 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
if topmost === nothing && method2 === inf_method2
# inspect the parent of this edge,
# to see if they are the same Method as sv
# in which case we'll need to ensure it is convergent
# otherwise, we don't
for parent in infstate.callers_in_cycle
# check in the cycle list first
# all items in here are mutual parents of all others
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
break
end
end
let parent = infstate.parent
# then check the parent link
if topmost === nothing && parent !== nothing
parent = parent::InferenceState
if hardlimit
topmost = infstate
edgecycle = true
else
# if this is a soft limit,
# also inspect the parent of this edge,
# to see if they are the same Method as sv
# in which case we'll need to ensure it is convergent
# otherwise, we don't
for parent in infstate.callers_in_cycle
# check in the cycle list first
# all items in here are mutual parents of all others
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
break
end
end
let parent = infstate.parent
# then check the parent link
if topmost === nothing && parent !== nothing
parent = parent::InferenceState
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
end
end
end
end
Expand Down Expand Up @@ -321,7 +328,7 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
comparison = method.sig
end
# see if the type is actually too big (relative to the caller), and limit it if required
newsig = limit_type_size(sig, comparison, sv.linfo.specTypes, sv.params.TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len)
newsig = limit_type_size(sig, comparison, hardlimit ? comparison : sv.linfo.specTypes, sv.params.TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len)

if newsig !== sig
# continue inference, but note that we've limited parameter complexity
Expand Down
41 changes: 39 additions & 2 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1050,13 +1050,13 @@ copy_dims_out(out) = ()
copy_dims_out(out, dim::Int, tail...) = copy_dims_out((out..., dim), tail...)
copy_dims_out(out, dim::Colon, tail...) = copy_dims_out((out..., dim), tail...)
@test Base.return_types(copy_dims_out, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
@test all(m -> 20 < count_specializations(m) < 45, methods(copy_dims_out))
@test all(m -> 4 < count_specializations(m) < 15, methods(copy_dims_out)) # currently about 5

copy_dims_pair(out) = ()
copy_dims_pair(out, dim::Int, tail...) = copy_dims_pair(out => dim, tail...)
copy_dims_pair(out, dim::Colon, tail...) = copy_dims_pair(out => dim, tail...)
@test Base.return_types(copy_dims_pair, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
@test all(m -> 10 < count_specializations(m) < 35, methods(copy_dims_pair))
@test all(m -> 5 < count_specializations(m) < 15, methods(copy_dims_pair)) # currently about 7

@test isdefined_tfunc(typeof(NamedTuple()), Const(0)) === Const(false)
@test isdefined_tfunc(typeof(NamedTuple()), Const(1)) === Const(false)
Expand Down Expand Up @@ -2348,3 +2348,40 @@ function gen_nodes(qty::Integer) :: AbstractNode
end
end
@test count(==('}'), string(I31663.gen_nodes(50))) == 1275

# issue #31572
struct MixedKeyDict{T<:Tuple} #<: AbstractDict{Any,Any}
dicts::T
end
Base.merge(f::Function, d::MixedKeyDict, others::MixedKeyDict...) = _merge(f, (), d.dicts, (d->d.dicts).(others)...)
Base.merge(f, d::MixedKeyDict, others::MixedKeyDict...) = _merge(f, (), d.dicts, (d->d.dicts).(others)...)
function _merge(f, res, d, others...)
ofsametype, remaining = _alloftype(Base.heads(d), ((),), others...)
return _merge(f, (res..., merge(f, ofsametype...)), Base.tail(d), remaining...)
end
_merge(f, res, ::Tuple{}, others...) = _merge(f, res, others...)
_merge(f, res, d) = MixedKeyDict((res..., d...))
_merge(f, res, ::Tuple{}) = MixedKeyDict(res)
function _alloftype(ofdesiredtype::Tuple{Vararg{D}}, accumulated, d::Tuple{D,Vararg}, others...) where D
return _alloftype((ofdesiredtype..., first(d)),
(Base.front(accumulated)..., (last(accumulated)..., Base.tail(d)...), ()),
others...)
end
function _alloftype(ofdesiredtype, accumulated, d, others...)
return _alloftype(ofdesiredtype,
(Base.front(accumulated)..., (last(accumulated)..., first(d))),
Base.tail(d), others...)
end
function _alloftype(ofdesiredtype, accumulated, ::Tuple{}, others...)
return _alloftype(ofdesiredtype,
(accumulated..., ()),
others...)
end
_alloftype(ofdesiredtype, accumulated) = ofdesiredtype, Base.front(accumulated)
let
d = MixedKeyDict((Dict(1 => 3), Dict(4. => 2)))
e = MixedKeyDict((Dict(1 => 7), Dict(5. => 9)))
@test merge(+, d, e).dicts == (Dict(1 => 10), Dict(4.0 => 2, 5.0 => 9))
f = MixedKeyDict((Dict(2 => 7), Dict(5. => 11)))
@test merge(+, d, e, f).dicts == (Dict(1 => 10, 2 => 7), Dict(4.0 => 2, 5.0 => 20))
end

0 comments on commit 7f89529

Please sign in to comment.