Skip to content

Commit

Permalink
Merge pull request #24337 from JuliaLang/jn/inference-edges
Browse files Browse the repository at this point in the history
inference: improve edge computations
  • Loading branch information
vtjnash authored Nov 6, 2017
2 parents 6b3dd9d + 86aefe8 commit 28dc9d3
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 104 deletions.
183 changes: 88 additions & 95 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1859,19 +1859,17 @@ function abstract_call_gf_by_type(@nospecialize(f), @nospecialize(atype), sv::In
applicable = _methods_by_ftype(argtype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid)
if applicable === false
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
return Any
end
end
update_valid_age!(min_valid[1], max_valid[1], sv)
applicable = applicable::Array{Any,1}
napplicable = length(applicable)
fullmatch = false
rettype = Bottom
for i in 1:napplicable
match = applicable[i]::SimpleVector
method = match[3]::Method
if !fullmatch && (argtype <: method.sig)
fullmatch = true
end
sig = match[1]
sigtuple = unwrap_unionall(sig)::DataType
splitunions = false
Expand All @@ -1891,116 +1889,106 @@ function abstract_call_gf_by_type(@nospecialize(f), @nospecialize(atype), sv::In
rettype === Any && break
end
end
if !(fullmatch || rettype === Any)
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge(ftname.mt, argtype, sv)
update_valid_age!(min_valid[1], max_valid[1], sv)
if !(rettype === Any)
fullmatch = false
for i in napplicable:-1:1
match = applicable[i]::SimpleVector
method = match[3]::Method
if atype <: method.sig
fullmatch = true
break
end
end
if !fullmatch
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge(ftname.mt, atype, sv)
end
end
#print("=> ", rettype, "\n")
return rettype
end

function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(sig), sparams::SimpleVector, sv::InferenceState)
limited = sv.limited
# If we are operating without inference limits,
# see if we need to enable those.
# The limit will be imposed if we recur on the same method.
topmost = nothing
if !limited && !istopfunction(_topmod(sv), f, :promote_typeof)
# since promote_typeof signature may be used with many arguments, here we'll just assume it is defined non-recursively
# otherwise: limit argument type tuple growth of all other functions
# look through the parents list to find the topmost
# function call to the same method
cyclei = 0
infstate = sv
while infstate !== nothing
infstate = infstate::InferenceState
if method === infstate.linfo.def
if infstate.linfo.specTypes == sig
# avoid widening when detecting self-recursion
# TODO: merge call cycle and return right away
# TODO: this'll improve convergence speed and give better results,
# but is it correct and valid?
limited = false
break
# Limit argument type tuple growth of functions:
# look through the parents list to see if there's a call to the same method
# and from the same method.
# Returns the topmost occurrence of that repeated edge.
cyclei = 0
infstate = sv
while !(infstate === nothing)
infstate = infstate::InferenceState
if method === infstate.linfo.def
if infstate.linfo.specTypes == sig
# avoid widening when detecting self-recursion
# TODO: merge call cycle and return right away
topmost = nothing
break
end
if topmost === nothing
# inspect the parent of this edge,
# to see if they are the same Method as sv
# in which case we'll need to ensure it is convergent
# otherwise, we don't
for parent in infstate.callers_in_cycle
# check in the cycle list first
# all items in here are mutual parents of all others
if parent.linfo.def === sv.linfo.def
topmost = infstate
break
end
end
if topmost === nothing
# inspect the parent of this edge,
# to see if they are the same Method as sv
# in which case we'll need to ensure it is convergent
# otherwise, we don't
for parent in infstate.callers_in_cycle
# check in the cycle list first
# all items in here are mutual parents of all others
if parent.linfo.def === sv.linfo.def
limited = true
let parent = infstate.parent
# then check the parent link
if topmost === nothing && parent !== nothing
parent = parent::InferenceState
if parent.cached && parent.linfo.def === sv.linfo.def
topmost = infstate
break
end
end
let parent = infstate.parent
# then check the parent link
if topmost === nothing && parent !== nothing
parent = parent::InferenceState
if parent.cached && parent.linfo.def === sv.linfo.def
limited = true
topmost = infstate
end
end
end
end
end
# iterate through the cycle before walking to the parent
if cyclei < length(infstate.callers_in_cycle)
cyclei += 1
infstate = infstate.callers_in_cycle[cyclei]
else
cyclei = 0
infstate = infstate.parent
end
end
# iterate through the cycle before walking to the parent
if cyclei < length(infstate.callers_in_cycle)
cyclei += 1
infstate = infstate.callers_in_cycle[cyclei]
else
cyclei = 0
infstate = infstate.parent
end
end

if limited
if !(topmost === nothing)
topmost = topmost::InferenceState
sigtuple = unwrap_unionall(sig)::DataType
msig = unwrap_unionall(method.sig)::DataType
spec_len = length(msig.parameters) + 1
ls = length(sigtuple.parameters)
if method === sv.linfo.def
# direct self-recursion permits much greater use of reducers
# without using non-local state (just the total edge)
# Under direct self-recursion, permit much greater use of reducers.
# here we assume that complexity(specTypes) :>= complexity(sig)
comparison = sv.linfo.specTypes
l_comparison = length(unwrap_unionall(comparison).parameters)
spec_len = max(spec_len, l_comparison)
else
comparison = method.sig
end
# see if the type is too big, and limit it if required
# see if the type is actually too big (relative to the caller), and limit it if required
newsig = limit_type_size(sig, comparison, sv.linfo.specTypes, spec_len)

if newsig !== sig
if !sv.limited
# continue inference, but limit parameter complexity to ensure (quick) convergence
topmost = topmost::InferenceState
infstate = sv
while infstate !== topmost.parent
# TODO: avoid this non-local mutation
infstate.limited = true
if infstate.parent !== nothing
infstate.optimize = false
end
for infstate_cycle in infstate.callers_in_cycle
infstate_cycle.limited = true
if infstate_cycle.parent !== nothing
infstate_cycle.optimize = false
end
end
infstate = infstate.parent
infstate === nothing && break
# continue inference, but note that we've limited parameter complexity
# on this call (to ensure convergence), so that we don't cache this result
infstate = sv
topmost = topmost::InferenceState
while !(infstate.parent === topmost.parent)
infstate.limited = true
for infstate_cycle in infstate.callers_in_cycle
infstate_cycle.limited = true
end
# TODO: break here and restart from "topmost" call-site
infstate = infstate.parent
end
sig = newsig
sparams = svec()
Expand Down Expand Up @@ -3084,13 +3072,6 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV
return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any, UInt), method, atypes, sparams, world)
end

function typeinf_active(linfo::MethodInstance, sv::InferenceState)
for infstate in sv.callers_in_cycle
linfo === infstate.linfo && return infstate
end
return nothing
end

function add_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int)
update_valid_age!(frame, caller)
backedge = (caller, currpc)
Expand Down Expand Up @@ -3184,8 +3165,15 @@ function typeinf_edge(method::Method, @nospecialize(atypes), sparams::SimpleVect
end
end
end
frame = resolve_call_cycle!(code, caller)
if !caller.cached && caller.parent === nothing
# this caller exists to return to the user
# (if we asked resolve_call_cyle, it might instead detect that there is a cycle that it can't merge)
frame = nothing
else
frame = resolve_call_cycle!(code, caller)
end
if frame === nothing
# completely new
code.inInference = true
frame = InferenceState(code, #=optimize=#true, #=cached=#true, caller.params) # always optimize and cache edge targets
if frame === nothing
Expand Down Expand Up @@ -3389,6 +3377,11 @@ function typeinf_work(frame::InferenceState)
elseif hd === :return
pc´ = n + 1
rt = abstract_eval(stmt.args[1], s[pc], frame)
if !isa(rt, Const) && !isa(rt, Type)
# only propagate information we know we can store
# and is valid inter-procedurally
rt = widenconst(rt)
end
if tchanged(rt, frame.bestguess)
# new (wider) return type for frame
frame.bestguess = tmerge(frame.bestguess, rt)
Expand Down Expand Up @@ -3565,7 +3558,11 @@ function optimize(me::InferenceState)

# run optimization passes on fulltree
force_noinline = true
if me.optimize
if me.limited && me.parent !== nothing
# a top parent will be cached still, but not this intermediate work
me.cached = false
me.linfo.inInference = false
elseif me.optimize
opt = OptimizationState(me)
# This pass is required for the AST to be valid in codegen
# if any `SSAValue` is created by type inference. Ref issue #6068
Expand Down Expand Up @@ -3593,10 +3590,6 @@ function optimize(me::InferenceState)
reindex_labels!(opt)
me.min_valid = opt.min_valid
me.max_valid = opt.max_valid
elseif me.cached && me.parent !== nothing
# top parent will be cached still, but not this intermediate work
me.cached = false
me.linfo.inInference = false
end

# convert all type information into the form consumed by the code-generator
Expand Down
13 changes: 6 additions & 7 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,20 @@ function promote end

function _promote(x::T, y::S) where {T,S}
@_inline_meta
(convert(promote_type(T,S),x), convert(promote_type(T,S),y))
R = promote_type(T, S)
return (convert(R, x), convert(R, y))
end
promote_typeof(x) = typeof(x)
promote_typeof(x, xs...) = (@_inline_meta; promote_type(typeof(x), promote_typeof(xs...)))
function _promote(x, y, z)
@_inline_meta
(convert(promote_typeof(x,y,z), x),
convert(promote_typeof(x,y,z), y),
convert(promote_typeof(x,y,z), z))
R = promote_typeof(x, y, z)
return (convert(R, x), convert(R, y), convert(R, z))
end
function _promote(x, y, zs...)
@_inline_meta
(convert(promote_typeof(x,y,zs...), x),
convert(promote_typeof(x,y,zs...), y),
convert(Tuple{Vararg{promote_typeof(x,y,zs...)}}, zs)...)
R = promote_typeof(x, y, zs...)
return (convert(R, x), convert(R, y), convert(Tuple{Vararg{R}}, zs)...)
end
# TODO: promote(x::T, ys::T...) where {T} here to catch all circularities?

Expand Down
6 changes: 4 additions & 2 deletions test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -885,8 +885,10 @@ f21771(::Val{U}) where {U} = Tuple{g21771(U)}

# issue #21653
# ensure that we don't try to resolve cycles using uncached edges
# but which also means we should still be storing the inference result from inferring the cycle
f21653() = f21653()
@test code_typed(f21653, Tuple{}, optimize=false)[1] isa Pair{CodeInfo, typeof(Union{})}
@test which(f21653, ()).specializations.func.rettype === Union{}

# ensure _apply can "see-through" SSAValue to infer precise container types
let f, m
Expand Down Expand Up @@ -993,13 +995,13 @@ copy_dims_out(out) = ()
copy_dims_out(out, dim::Int, tail...) = copy_dims_out((out..., dim), tail...)
copy_dims_out(out, dim::Colon, tail...) = copy_dims_out((out..., dim), tail...)
@test Base.return_types(copy_dims_out, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
@test all(m -> 10 < count_specializations(m) < 25, methods(copy_dims_out))
@test all(m -> 20 < count_specializations(m) < 45, methods(copy_dims_out))

copy_dims_pair(out) = ()
copy_dims_pair(out, dim::Int, tail...) = copy_dims_pair(out => dim, tail...)
copy_dims_pair(out, dim::Colon, tail...) = copy_dims_pair(out => dim, tail...)
@test Base.return_types(copy_dims_pair, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
@test all(m -> 5 < count_specializations(m) < 25, methods(copy_dims_pair))
@test all(m -> 10 < count_specializations(m) < 35, methods(copy_dims_pair))

@test isdefined_tfunc(typeof(NamedTuple()), Const(0)) === Const(false)
@test isdefined_tfunc(typeof(NamedTuple()), Const(1)) === Const(false)
Expand Down

0 comments on commit 28dc9d3

Please sign in to comment.