From dba4df7f2aac704eccda71ac64c525345dd3cbda Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Tue, 8 Mar 2022 20:30:11 +0900 Subject: [PATCH] `AbstractInterpreter`: enable selective pure/concrete eval for external `AbstractInterpreter` with overlayed method table Built on top of #44511 and #44561, and solves . This commit allows external `AbstractInterpreter` to selectively use pure/concrete evals even if it uses an overlayed method table. More specifically, such `AbstractInterpreter` can use pure/concrete evals as far as any callees used in a call in question doesn't come from the overlayed method table: ```julia @test Base.return_types((), MTOverlayInterp()) do isbitstype(Int) ? nothing : missing end == Any[Nothing] Base.@assume_effects :terminates_globally function issue41694(x) res = 1 1 < x < 20 || throw("bad") while x > 1 res *= x x -= 1 end return res end @test Base.return_types((), MTOverlayInterp()) do issue41694(3) == 6 ? nothing : missing end == Any[Nothing] ``` In order to check if a call is tainted by any overlayed call, our effect system now additionally tracks `overlayed::Bool` property. This effect property is required to prevents concrete-eval in the following kind of situation: ```julia strangesin(x) = sin(x) @overlay OverlayedMT strangesin(x::Float64) = iszero(x) ? nothing : cos(x) Base.@assume_effects :total totalcall(f, args...) = f(args...) @test Base.return_types(; interp=MTOverlayInterp()) do # we need to disable partial pure/concrete evaluation when tainted by any overlayed call if totalcall(strangesin, 1.0) == cos(1.0) return nothing else return missing end end |> only === Nothing ``` --- base/compiler/abstractinterpretation.jl | 117 +++++++++++++++--------- base/compiler/inferencestate.jl | 2 +- base/compiler/methodtable.jl | 60 +++++++----- base/compiler/ssair/show.jl | 1 + base/compiler/tfuncs.jl | 12 ++- base/compiler/typeinfer.jl | 6 +- base/compiler/types.jl | 36 +++++--- test/compiler/AbstractInterpreter.jl | 46 ++++++++-- 8 files changed, 182 insertions(+), 98 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 466ce93746173e..770e1b0d708761 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -50,12 +50,27 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # At this point we are guaranteed to end up throwing on this path, # which is all that's required for :consistent-cy. Of course, we don't # know anything else about this statement. - tristate_merge!(sv, Effects(Effects(), consistent=ALWAYS_TRUE)) + overlayed = true + if isoverlayed(method_table(interp)) + if !sv.ipo_effects.overlayed + # as we may want to concrete-evaluate this frame, try an additional effort to + # assert that this call isn't overlayed rather than just handling it conservatively + matches = find_matching_methods(arginfo.argtypes, atype, method_table(interp), + InferenceParams(interp).MAX_UNION_SPLITTING, max_methods) + if !isa(matches, FailedMethodMatch) && matches.overlayed + overlayed = false + end + end + else + overlayed = false + end + tristate_merge!(sv, Effects(; consistent=ALWAYS_TRUE, overlayed)) return CallMeta(Any, false) end argtypes = arginfo.argtypes - matches = find_matching_methods(argtypes, atype, method_table(interp), InferenceParams(interp).MAX_UNION_SPLITTING, max_methods) + matches = find_matching_methods(argtypes, atype, method_table(interp), + InferenceParams(interp).MAX_UNION_SPLITTING, max_methods) if isa(matches, FailedMethodMatch) add_remark!(interp, sv, matches.reason) tristate_merge!(sv, Effects()) @@ -72,6 +87,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), any_const_result = false const_results = Union{InferenceResult,Nothing,ConstResult}[] multiple_matches = napplicable > 1 + if matches.overlayed + # currently we don't have a good way to execute the overlayed method definition, + # so we should give up pure/concrete eval when any of the matched methods is overlayed + f = nothing + tristate_merge!(sv, Effects(EFFECTS_TOTAL; overlayed=true)) + end val = pure_eval_call(interp, f, applicable, arginfo, sv) val !== nothing && return CallMeta(val, MethodResultPure(info)) # TODO: add some sort of edge(s) @@ -102,7 +123,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] this_arginfo = ArgInfo(fargs, this_argtypes) - const_call_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv) + const_call_result = abstract_call_method_with_const_args(interp, result, + f, this_arginfo, match, sv) effects = result.edge_effects const_result = nothing if const_call_result !== nothing @@ -144,7 +166,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # this is in preparation for inlining, or improving the return result this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] this_arginfo = ArgInfo(fargs, this_argtypes) - const_call_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv) + const_call_result = abstract_call_method_with_const_args(interp, result, + f, this_arginfo, match, sv) effects = result.edge_effects const_result = nothing if const_call_result !== nothing @@ -189,11 +212,11 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end if seen != napplicable - tristate_merge!(sv, Effects()) + tristate_merge!(sv, Effects(; overlayed=false)) # already accounted for method overlay above elseif isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) : (!_all(b->b, matches.fullmatches) || any_ambig(matches)) # Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature. - tristate_merge!(sv, Effects(EFFECTS_TOTAL, nothrow=TRISTATE_UNKNOWN)) + tristate_merge!(sv, Effects(EFFECTS_TOTAL; nothrow=TRISTATE_UNKNOWN)) end rettype = from_interprocedural!(rettype, sv, arginfo, conditionals) @@ -228,6 +251,7 @@ struct MethodMatches valid_worlds::WorldRange mt::Core.MethodTable fullmatch::Bool + overlayed::Bool end any_ambig(info::MethodMatchInfo) = info.results.ambig any_ambig(m::MethodMatches) = any_ambig(m.info) @@ -239,6 +263,7 @@ struct UnionSplitMethodMatches valid_worlds::WorldRange mts::Vector{Core.MethodTable} fullmatches::Vector{Bool} + overlayed::Bool end any_ambig(m::UnionSplitMethodMatches) = _any(any_ambig, m.info.matches) @@ -253,16 +278,19 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth valid_worlds = WorldRange() mts = Core.MethodTable[] fullmatches = Bool[] + overlayed = false for i in 1:length(split_argtypes) arg_n = split_argtypes[i]::Vector{Any} sig_n = argtypes_to_type(arg_n) mt = ccall(:jl_method_table_for, Any, (Any,), sig_n) mt === nothing && return FailedMethodMatch("Could not identify method table for call") mt = mt::Core.MethodTable - matches = findall(sig_n, method_table; limit = max_methods) - if matches === missing + result = findall(sig_n, method_table; limit = max_methods) + if result === missing return FailedMethodMatch("For one of the union split cases, too many methods matched") end + matches, overlayedᵢ = result + overlayed |= overlayedᵢ push!(infos, MethodMatchInfo(matches)) for m in matches push!(applicable, m) @@ -288,25 +316,28 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth UnionSplitInfo(infos), valid_worlds, mts, - fullmatches) + fullmatches, + overlayed) else mt = ccall(:jl_method_table_for, Any, (Any,), atype) if mt === nothing return FailedMethodMatch("Could not identify method table for call") end mt = mt::Core.MethodTable - matches = findall(atype, method_table; limit = max_methods) - if matches === missing + result = findall(atype, method_table; limit = max_methods) + if result === missing # this means too many methods matched # (assume this will always be true, so we don't compute / update valid age in this case) return FailedMethodMatch("Too many methods matched") end + matches, overlayed = result fullmatch = _any(match->(match::MethodMatch).fully_covers, matches) return MethodMatches(matches.matches, MethodMatchInfo(matches), matches.valid_worlds, mt, - fullmatch) + fullmatch, + overlayed) end end @@ -613,11 +644,11 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp edge_effects = Effects(edge_effects, terminates=ALWAYS_TRUE) elseif is_effect_overridden(method, :terminates_globally) # this edge is known to terminate - edge_effects = Effects(edge_effects, terminates=ALWAYS_TRUE) + edge_effects = Effects(edge_effects; terminates=ALWAYS_TRUE) elseif edgecycle # Some sort of recursion was detected. Even if we did not limit types, # we cannot guarantee that the call will terminate - edge_effects = Effects(edge_effects, terminates=TRISTATE_UNKNOWN) + edge_effects = Effects(edge_effects; terminates=TRISTATE_UNKNOWN) end return MethodCallResult(rt, edgecycle, edgelimited, edge, edge_effects) end @@ -640,8 +671,8 @@ end function pure_eval_eligible(interp::AbstractInterpreter, @nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo, sv::InferenceState) - return !isoverlayed(method_table(interp)) && - f !== nothing && + # XXX we need to check that this pure function doesn't call any overlayed method + return f !== nothing && length(applicable) == 1 && is_method_pure(applicable[1]::MethodMatch) && is_all_const_arg(arginfo) @@ -677,8 +708,8 @@ end function concrete_eval_eligible(interp::AbstractInterpreter, @nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState) - return !isoverlayed(method_table(interp)) && - f !== nothing && + isoverlayed(method_table(interp)) && result.edge_effects.overlayed && return false + return f !== nothing && result.edge !== nothing && is_total_or_error(result.edge_effects) && is_all_const_arg(arginfo) @@ -1477,7 +1508,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)::Type nargtype = Tuple{ft, nargtype.parameters...} argtype = Tuple{ft, argtype.parameters...} - match, valid_worlds = findsup(types, method_table(interp)) + match, valid_worlds, overlayed = findsup(types, method_table(interp)) match === nothing && return CallMeta(Any, false) update_valid_age!(sv, valid_worlds) method = match.method @@ -1495,7 +1526,8 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn # t, a = ti.parameters[i], argtypes′[i] # argtypes′[i] = t ⊑ a ? t : a # end - const_call_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), arginfo, match, sv) + const_call_result = abstract_call_method_with_const_args(interp, result, + overlayed ? nothing : singleton_type(ft′), arginfo, match, sv) const_result = nothing if const_call_result !== nothing if const_call_result.rt ⊑ rt @@ -1526,7 +1558,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), call = abstract_invoke(interp, arginfo, sv) if call.info === false if call.rt === Bottom - tristate_merge!(sv, Effects(EFFECTS_TOTAL, nothrow=ALWAYS_FALSE)) + tristate_merge!(sv, Effects(EFFECTS_TOTAL; nothrow=ALWAYS_FALSE)) else tristate_merge!(sv, Effects()) end @@ -1553,12 +1585,12 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), end end end - tristate_merge!(sv, Effects()) # TODO + tristate_merge!(sv, Effects(; overlayed=false)) # TODO return CallMeta(Any, false) elseif f === TypeVar # Manually look through the definition of TypeVar to # make sure to be able to get `PartialTypeVar`s out. - tristate_merge!(sv, Effects()) # TODO + tristate_merge!(sv, Effects(; overlayed=false)) # TODO (la < 2 || la > 4) && return CallMeta(Union{}, false) n = argtypes[2] ub_var = Const(Any) @@ -1571,17 +1603,17 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), end return CallMeta(typevar_tfunc(n, lb_var, ub_var), false) elseif f === UnionAll - tristate_merge!(sv, Effects()) # TODO + tristate_merge!(sv, Effects(; overlayed=false)) # TODO return CallMeta(abstract_call_unionall(argtypes), false) elseif f === Tuple && la == 2 - tristate_merge!(sv, Effects()) # TODO + tristate_merge!(sv, Effects(; overlayed=false)) # TODO aty = argtypes[2] ty = isvarargtype(aty) ? unwrapva(aty) : widenconst(aty) if !isconcretetype(ty) return CallMeta(Tuple, false) end elseif is_return_type(f) - tristate_merge!(sv, Effects()) # TODO + tristate_merge!(sv, Effects(; overlayed=false)) # TODO return return_type_tfunc(interp, argtypes, sv) elseif la == 2 && istopfunction(f, :!) # handle Conditional propagation through !Bool @@ -1643,8 +1675,8 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt)) const_result = nothing if !result.edgecycle - const_call_result = abstract_call_method_with_const_args(interp, result, nothing, - arginfo, match, sv) + const_call_result = abstract_call_method_with_const_args(interp, result, + nothing, arginfo, match, sv) if const_call_result !== nothing if const_call_result.rt ⊑ rt (; rt, const_result) = const_call_result @@ -1833,9 +1865,9 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), at = tmeet(at, ft) if at === Bottom t = Bottom - tristate_merge!(sv, Effects( - ALWAYS_TRUE, # N.B depends on !ismutabletype(t) above - ALWAYS_TRUE, ALWAYS_FALSE, ALWAYS_TRUE)) + tristate_merge!(sv, Effects(EFFECTS_TOTAL; + # consistent = ALWAYS_TRUE, # N.B depends on !ismutabletype(t) above + nothrow = ALWAYS_FALSE)) @goto t_computed elseif !isa(at, Const) allconst = false @@ -1863,7 +1895,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), else is_nothrow = false end - tristate_merge!(sv, Effects(EFFECTS_TOTAL, + tristate_merge!(sv, Effects(EFFECTS_TOTAL; consistent = !ismutabletype(t) ? ALWAYS_TRUE : ALWAYS_FALSE, nothrow = is_nothrow ? ALWAYS_TRUE : ALWAYS_FALSE)) elseif ehead === :splatnew @@ -1882,7 +1914,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), t = PartialStruct(t, at.fields::Vector{Any}) end end - tristate_merge!(sv, Effects(EFFECTS_TOTAL, + tristate_merge!(sv, Effects(EFFECTS_TOTAL; consistent = ismutabletype(t) ? ALWAYS_FALSE : ALWAYS_TRUE, nothrow = is_nothrow ? ALWAYS_TRUE : ALWAYS_FALSE)) elseif ehead === :new_opaque_closure @@ -1924,20 +1956,21 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), effects.effect_free ? ALWAYS_TRUE : TRISTATE_UNKNOWN, effects.nothrow ? ALWAYS_TRUE : TRISTATE_UNKNOWN, effects.terminates_globally ? ALWAYS_TRUE : TRISTATE_UNKNOWN, + #=overlayed=#false )) else - tristate_merge!(sv, Effects()) + tristate_merge!(sv, Effects(; overlayed=false)) end elseif ehead === :cfunction - tristate_merge!(sv, Effects()) + tristate_merge!(sv, Effects(; overlayed=false)) t = e.args[1] isa(t, Type) || (t = Any) abstract_eval_cfunction(interp, e, vtypes, sv) elseif ehead === :method - tristate_merge!(sv, Effects()) + tristate_merge!(sv, Effects(; overlayed=false)) t = (length(e.args) == 1) ? Any : Nothing elseif ehead === :copyast - tristate_merge!(sv, Effects()) + tristate_merge!(sv, Effects(; overlayed=false)) t = abstract_eval_value(interp, e.args[1], vtypes, sv) if t isa Const && t.val isa Expr # `copyast` makes copies of Exprs @@ -2005,9 +2038,9 @@ function abstract_eval_global(M::Module, s::Symbol, frame::InferenceState) ty = abstract_eval_global(M, s) isa(ty, Const) && return ty if isdefined(M,s) - tristate_merge!(frame, Effects(EFFECTS_TOTAL, consistent=ALWAYS_FALSE)) + tristate_merge!(frame, Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE)) else - tristate_merge!(frame, Effects(EFFECTS_TOTAL, consistent=ALWAYS_FALSE, nothrow=ALWAYS_FALSE)) + tristate_merge!(frame, Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE, nothrow=ALWAYS_FALSE)) end return ty end @@ -2104,7 +2137,7 @@ function handle_control_backedge!(frame::InferenceState, from::Int, to::Int) elseif is_effect_overridden(frame, :terminates_locally) # this backedge is known to terminate else - tristate_merge!(frame, Effects(EFFECTS_TOTAL, terminates=TRISTATE_UNKNOWN)) + tristate_merge!(frame, Effects(EFFECTS_TOTAL; terminates=TRISTATE_UNKNOWN)) end end return nothing @@ -2262,11 +2295,11 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) if isa(lhs, SlotNumber) changes = StateUpdate(lhs, VarState(t, false), changes, false) elseif isa(lhs, GlobalRef) - tristate_merge!(frame, Effects(EFFECTS_TOTAL, + tristate_merge!(frame, Effects(EFFECTS_TOTAL; effect_free=ALWAYS_FALSE, nothrow=TRISTATE_UNKNOWN)) elseif !isa(lhs, SSAValue) - tristate_merge!(frame, Effects()) + tristate_merge!(frame, Effects(; overlayed=false)) end elseif hd === :method stmt = stmt::Expr diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 12de1b6705aa98..db6ab574e38592 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -134,7 +134,7 @@ mutable struct InferenceState #=parent=#nothing, #=cached=#cache === :global, #=inferred=#false, #=dont_work_on_me=#false, - #=ipo_effects=#Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, inbounds_taints_consistency), + #=ipo_effects=#Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, false, inbounds_taints_consistency), interp) result.result = frame cache !== :no && push!(get_inference_cache(interp), result) diff --git a/base/compiler/methodtable.jl b/base/compiler/methodtable.jl index f68cdd52d1b066..da493cf9a9ef57 100644 --- a/base/compiler/methodtable.jl +++ b/base/compiler/methodtable.jl @@ -40,15 +40,18 @@ end getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch """ - findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> MethodLookupResult or missing + findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> + (matches::MethodLookupResult, overlayed::Bool) or missing -Find all methods in the given method table `view` that are applicable to the -given signature `sig`. If no applicable methods are found, an empty result is -returned. If the number of applicable methods exceeded the specified limit, -`missing` is returned. +Find all methods in the given method table `view` that are applicable to the given signature `sig`. +If no applicable methods are found, an empty result is returned. +If the number of applicable methods exceeded the specified limit, `missing` is returned. +`overlayed` indicates if any matching method is defined in an overlayed method table. """ function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32))) - return _findall(sig, nothing, table.world, limit) + result = _findall(sig, nothing, table.world, limit) + result === missing && return missing + return result, false end function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32))) @@ -57,7 +60,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int nr = length(result) if nr ≥ 1 && result[nr].fully_covers # no need to fall back to the internal method table - return result + return result, true end # fall back to the internal method table fallback_result = _findall(sig, nothing, table.world, limit) @@ -68,7 +71,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int WorldRange( max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world), min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)), - result.ambig | fallback_result.ambig) + result.ambig | fallback_result.ambig), !isempty(result) end function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int) @@ -83,31 +86,38 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, end """ - findsup(sig::Type, view::MethodTableView) -> Tuple{MethodMatch, WorldRange} or nothing - -Find the (unique) method `m` such that `sig <: m.sig`, while being more -specific than any other method with the same property. In other words, find -the method which is the least upper bound (supremum) under the specificity/subtype -relation of the queried `signature`. If `sig` is concrete, this is equivalent to -asking for the method that will be called given arguments whose types match the -given signature. This query is also used to implement `invoke`. - -Such a method `m` need not exist. It is possible that no method is an -upper bound of `sig`, or it is possible that among the upper bounds, there -is no least element. In both cases `nothing` is returned. + findsup(sig::Type, view::MethodTableView) -> + (match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing + +Find the (unique) method such that `sig <: match.method.sig`, while being more +specific than any other method with the same property. In other words, find the method +which is the least upper bound (supremum) under the specificity/subtype relation of +the queried `sig`nature. If `sig` is concrete, this is equivalent to asking for the method +that will be called given arguments whose types match the given signature. +Note that this query is also used to implement `invoke`. + +Such a matching method `match` doesn't necessarily exist. +It is possible that no method is an upper bound of `sig`, or +it is possible that among the upper bounds, there is no least element. +In both cases `nothing` is returned. + +`overlayed` indicates if the matching method is defined in an overlayed method table. """ function findsup(@nospecialize(sig::Type), table::InternalMethodTable) - return _findsup(sig, nothing, table.world) + return (_findsup(sig, nothing, table.world)..., false) end function findsup(@nospecialize(sig::Type), table::OverlayMethodTable) match, valid_worlds = _findsup(sig, table.mt, table.world) - match !== nothing && return match, valid_worlds + match !== nothing && return match, valid_worlds, true # fall back to the internal method table fallback_match, fallback_valid_worlds = _findsup(sig, nothing, table.world) - return fallback_match, WorldRange( - max(valid_worlds.min_world, fallback_valid_worlds.min_world), - min(valid_worlds.max_world, fallback_valid_worlds.max_world)) + return ( + fallback_match, + WorldRange( + max(valid_worlds.min_world, fallback_valid_worlds.min_world), + min(valid_worlds.max_world, fallback_valid_worlds.max_world)), + false) end function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt) diff --git a/base/compiler/ssair/show.jl b/base/compiler/ssair/show.jl index 1e98dda0390402..76cbcbd4d5d7dd 100644 --- a/base/compiler/ssair/show.jl +++ b/base/compiler/ssair/show.jl @@ -803,6 +803,7 @@ function Base.show(io::IO, e::Core.Compiler.Effects) print(io, ',') printstyled(io, string(tristate_letter(e.terminates), 't'); color=tristate_color(e.terminates)) print(io, ')') + e.overlayed && printstyled(io, '′'; color=:red) end @specialize diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 452a2b554f307d..df1861e20c2065 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -1789,11 +1789,11 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt) if (f === Core.getfield || f === Core.isdefined) && length(argtypes) >= 3 # consistent if the argtype is immutable if isvarargtype(argtypes[2]) - return Effects(Effects(), effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE) + return Effects(; effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE, overlayed=false) end s = widenconst(argtypes[2]) if isType(s) || !isa(s, DataType) || isabstracttype(s) - return Effects(Effects(), effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE) + return Effects(; effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE, overlayed=false) end s = s::DataType ipo_consistent = !ismutabletype(s) @@ -1826,7 +1826,9 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt) ipo_consistent ? ALWAYS_TRUE : ALWAYS_FALSE, effect_free ? ALWAYS_TRUE : ALWAYS_FALSE, nothrow ? ALWAYS_TRUE : TRISTATE_UNKNOWN, - ALWAYS_TRUE) + #=terminates=#ALWAYS_TRUE, + #=overlayed=#false, + ) end function builtin_nothrow(@nospecialize(f), argtypes::Array{Any, 1}, @nospecialize(rt)) @@ -2007,7 +2009,9 @@ function intrinsic_effects(f::IntrinsicFunction, argtypes::Vector{Any}) ipo_consistent ? ALWAYS_TRUE : ALWAYS_FALSE, effect_free ? ALWAYS_TRUE : ALWAYS_FALSE, nothrow ? ALWAYS_TRUE : TRISTATE_UNKNOWN, - ALWAYS_TRUE) + #=terminates=#ALWAYS_TRUE, + #=overlayed=#false, + ) end # TODO: this function is a very buggy and poor model of the return_type function diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index a047222cbfee0d..1c54345b17de5a 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -431,7 +431,7 @@ function rt_adjust_effects(@nospecialize(rt), ipo_effects::Effects) # but we don't currently model idempontency using dataflow, so we don't notice. # Fix that up here to improve precision. if !ipo_effects.inbounds_taints_consistency && rt === Union{} - return Effects(ipo_effects, consistent=ALWAYS_TRUE) + return Effects(ipo_effects; consistent=ALWAYS_TRUE) end return ipo_effects end @@ -755,11 +755,11 @@ function merge_call_chain!(parent::InferenceState, ancestor::InferenceState, chi # and ensure that walking the parent list will get the same result (DAG) from everywhere # Also taint the termination effect, because we can no longer guarantee the absence # of recursion. - tristate_merge!(parent, Effects(EFFECTS_TOTAL, terminates=TRISTATE_UNKNOWN)) + tristate_merge!(parent, Effects(EFFECTS_TOTAL; terminates=TRISTATE_UNKNOWN)) while true add_cycle_backedge!(child, parent, parent.currpc) union_caller_cycle!(ancestor, child) - tristate_merge!(child, Effects(EFFECTS_TOTAL, terminates=TRISTATE_UNKNOWN)) + tristate_merge!(child, Effects(EFFECTS_TOTAL; terminates=TRISTATE_UNKNOWN)) child = parent child === ancestor && break parent = child.parent::InferenceState diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 65ce341dd55e10..282582c016d970 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -38,6 +38,7 @@ struct Effects effect_free::TriState nothrow::TriState terminates::TriState + overlayed::Bool # This effect is currently only tracked in inference and modified # :consistent before caching. We may want to track it in the future. inbounds_taints_consistency::Bool @@ -46,27 +47,33 @@ function Effects( consistent::TriState, effect_free::TriState, nothrow::TriState, - terminates::TriState) + terminates::TriState, + overlayed::Bool) return Effects( consistent, effect_free, nothrow, terminates, + overlayed, false) end -Effects() = Effects(TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN) -function Effects(e::Effects; +const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, false) +const EFFECTS_UNKNOWN = Effects(TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, true) + +function Effects(e::Effects = EFFECTS_UNKNOWN; consistent::TriState = e.consistent, effect_free::TriState = e.effect_free, nothrow::TriState = e.nothrow, terminates::TriState = e.terminates, + overlayed::Bool = e.overlayed, inbounds_taints_consistency::Bool = e.inbounds_taints_consistency) return Effects( consistent, effect_free, nothrow, terminates, + overlayed, inbounds_taints_consistency) end @@ -82,20 +89,20 @@ is_removable_if_unused(effects::Effects) = effects.terminates === ALWAYS_TRUE && effects.nothrow === ALWAYS_TRUE -const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE) - function encode_effects(e::Effects) - return e.consistent.state | - (e.effect_free.state << 2) | - (e.nothrow.state << 4) | - (e.terminates.state << 6) + return (e.consistent.state << 1) | + (e.effect_free.state << 3) | + (e.nothrow.state << 5) | + (e.terminates.state << 7) | + (e.overlayed) end function decode_effects(e::UInt8) return Effects( - TriState(e & 0x3), - TriState((e >> 2) & 0x3), - TriState((e >> 4) & 0x3), - TriState((e >> 6) & 0x3), + TriState((e >> 1) & 0x03), + TriState((e >> 3) & 0x03), + TriState((e >> 5) & 0x03), + TriState((e >> 7) & 0x03), + e & 0x01 ≠ 0x00, false) end @@ -109,6 +116,7 @@ function tristate_merge(old::Effects, new::Effects) old.nothrow, new.nothrow), tristate_merge( old.terminates, new.terminates), + old.overlayed | new.overlayed, old.inbounds_taints_consistency | new.inbounds_taints_consistency) end @@ -158,7 +166,7 @@ mutable struct InferenceResult arginfo#=::Union{Nothing,Tuple{ArgInfo,InferenceState}}=# = nothing) argtypes, overridden_by_const = matching_cache_argtypes(linfo, arginfo) return new(linfo, argtypes, overridden_by_const, Any, nothing, - WorldRange(), Effects(), Effects(), nothing) + WorldRange(), Effects(; overlayed=false), Effects(; overlayed=false), nothing) end end diff --git a/test/compiler/AbstractInterpreter.jl b/test/compiler/AbstractInterpreter.jl index f1fe4b06dcb63c..213fd4b786dc26 100644 --- a/test/compiler/AbstractInterpreter.jl +++ b/test/compiler/AbstractInterpreter.jl @@ -41,25 +41,53 @@ import Base.Experimental: @MethodTable, @overlay @MethodTable(OverlayedMT) CC.method_table(interp::MTOverlayInterp) = CC.OverlayMethodTable(CC.get_world_counter(interp), OverlayedMT) -@overlay OverlayedMT sin(x::Float64) = 1 -@test Base.return_types((Int,), MTOverlayInterp()) do x - sin(x) -end == Any[Int] +strangesin(x) = sin(x) +@overlay OverlayedMT strangesin(x::Float64) = iszero(x) ? nothing : cos(x) +@test Base.return_types((Float64,), MTOverlayInterp()) do x + strangesin(x) +end |> only === Union{Float64,Nothing} @test Base.return_types((Any,), MTOverlayInterp()) do x - Base.@invoke sin(x::Float64) -end == Any[Int] + Base.@invoke strangesin(x::Float64) +end |> only === Union{Float64,Nothing} # fallback to the internal method table @test Base.return_types((Int,), MTOverlayInterp()) do x cos(x) -end == Any[Float64] +end |> only === Float64 @test Base.return_types((Any,), MTOverlayInterp()) do x Base.@invoke cos(x::Float64) -end == Any[Float64] +end |> only === Float64 # not fully covered overlay method match overlay_match(::Any) = nothing @overlay OverlayedMT overlay_match(::Int) = missing @test Base.return_types((Any,), MTOverlayInterp()) do x overlay_match(x) -end == Any[Union{Nothing,Missing}] +end |> only === Union{Nothing,Missing} + +# partial pure/concrete evaluation +@test Base.return_types((), MTOverlayInterp()) do + isbitstype(Int) ? nothing : missing +end |> only === Nothing +Base.@assume_effects :terminates_globally function issue41694(x) + res = 1 + 1 < x < 20 || throw("bad") + while x > 1 + res *= x + x -= 1 + end + return res +end +@test Base.return_types((), MTOverlayInterp()) do + issue41694(3) == 6 ? nothing : missing +end |> only === Nothing + +# disable partial pure/concrete evaluation when tainted by any overlayed call +Base.@assume_effects :total totalcall(f, args...) = f(args...) +@test Base.return_types((), MTOverlayInterp()) do + if totalcall(strangesin, 1.0) == cos(1.0) + return nothing + else + return missing + end +end |> only === Nothing