From 4846fbbb499fa33fc1e0cee34cb558a23f323056 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Thu, 17 Feb 2022 20:08:04 +0900 Subject: [PATCH] `AbstractInterpreter`: make it easier to overload pure/concrete-eval These changes are necessary to fix #44174 nicely. --- base/compiler/abstractinterpretation.jl | 144 +++++++++++++----------- base/compiler/methodtable.jl | 11 +- 2 files changed, 88 insertions(+), 67 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 6a9837547834b..4fdaa6257686c 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -65,13 +65,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), const_results = Union{InferenceResult,Nothing,ConstResult}[] multiple_matches = napplicable > 1 - if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch) - val = pure_eval_call(f, argtypes) - if val !== nothing - # TODO: add some sort of edge(s) - return CallMeta(val, MethodResultPure(info)) - end - end + val = pure_eval_call(interp, f, applicable, arginfo, sv) + val !== nothing && return CallMeta(val, MethodResultPure(info)) # TODO: add some sort of edge(s) fargs = arginfo.fargs for i in 1:napplicable @@ -619,27 +614,85 @@ struct MethodCallResult end end +function pure_eval_eligible(interp::AbstractInterpreter, + @nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo, sv::InferenceState) + return !isoverlayed(method_table(interp, sv)) && + f !== nothing && + length(applicable) == 1 && + is_method_pure(applicable[1]::MethodMatch) && + is_all_const_arg(arginfo) +end + +function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector) + if isdefined(method, :generator) + method.generator.expand_early || return false + mi = specialize_method(method, sig, sparams) + isa(mi, MethodInstance) || return false + staged = get_staged(mi) + (staged isa CodeInfo && (staged::CodeInfo).pure) || return false + return true + end + return method.pure +end +is_method_pure(match::MethodMatch) = is_method_pure(match.method, match.spec_types, match.sparams) + +function pure_eval_call(interp::AbstractInterpreter, + @nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo, sv::InferenceState) + pure_eval_eligible(interp, f, applicable, arginfo, sv) || return nothing + return _pure_eval_call(f, arginfo) +end +function _pure_eval_call(@nospecialize(f), arginfo::ArgInfo) + args = collect_const_args(arginfo) + try + value = Core._apply_pure(f, args) + return Const(value) + catch + return nothing + end +end + +function concrete_eval_eligible(interp::AbstractInterpreter, + @nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState) + return !isoverlayed(method_table(interp, sv)) && + f !== nothing && + result.edge !== nothing && + is_total_or_error(result.edge_effects) && + is_all_const_arg(arginfo) +end + function is_all_const_arg((; argtypes)::ArgInfo) - for a in argtypes - if !isa(a, Const) && !isconstType(a) && !issingletontype(a) - return false - end + for i = 2:length(argtypes) + a = widenconditional(argtypes[i]) + isa(a, Const) || isconstType(a) || issingletontype(a) || return false end return true end -function concrete_eval_const_proven_total_or_error(interp::AbstractInterpreter, - @nospecialize(f), (; argtypes)::ArgInfo, _::InferenceState) - args = Any[ (a = widenconditional(argtypes[i]); - isa(a, Const) ? a.val : - isconstType(a) ? (a::DataType).parameters[1] : - (a::DataType).instance) for i in 2:length(argtypes) ] +function collect_const_args((; argtypes)::ArgInfo) + return Any[ let a = widenconditional(argtypes[i]) + isa(a, Const) ? a.val : + isconstType(a) ? (a::DataType).parameters[1] : + (a::DataType).instance + end for i in 2:length(argtypes) ] +end + +function concrete_eval_call(interp::AbstractInterpreter, + @nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState) + concrete_eval_eligible(interp, f, result, arginfo, sv) || return nothing + args = collect_const_args(arginfo) try value = Core._call_in_world_total(get_world_counter(interp), f, args...) - return Const(value) - catch e - return nothing + if is_inlineable_constant(value) || call_result_unused(sv) + # If the constant is not inlineable, still do the const-prop, since the + # code that led to the creation of the Const may be inlineable in the same + # circumstance and may be optimizable. + return ConstCallResults(Const(value), ConstResult(result.edge, value), EFFECTS_TOTAL) + end + catch + # The evaulation threw. By :consistent-cy, we're guaranteed this would have happened at runtime + return ConstCallResults(Union{}, ConstResult(result.edge), result.edge_effects) end + return nothing end function const_prop_enabled(interp::AbstractInterpreter, sv::InferenceState, match::MethodMatch) @@ -671,19 +724,10 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul if !const_prop_enabled(interp, sv, match) return nothing end - if f !== nothing && result.edge !== nothing && is_total_or_error(result.edge_effects) && is_all_const_arg(arginfo) - rt = concrete_eval_const_proven_total_or_error(interp, f, arginfo, sv) + val = concrete_eval_call(interp, f, result, arginfo, sv) + if val !== nothing add_backedge!(result.edge, sv) - if rt === nothing - # The evaulation threw. By :consistent-cy, we're guaranteed this would have happened at runtime - return ConstCallResults(Union{}, ConstResult(result.edge), result.edge_effects) - end - if is_inlineable_constant(rt.val) || call_result_unused(sv) - # If the constant is not inlineable, still do the const-prop, since the - # code that led to the creation of the Const may be inlineable in the same - # circumstance and may be optimizable. - return ConstCallResults(rt, ConstResult(result.edge, rt.val), EFFECTS_TOTAL) - end + return val end mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv) mi === nothing && return nothing @@ -1218,36 +1262,6 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:: return CallMeta(res, retinfo) end -function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector) - if isdefined(method, :generator) - method.generator.expand_early || return false - mi = specialize_method(method, sig, sparams) - isa(mi, MethodInstance) || return false - staged = get_staged(mi) - (staged isa CodeInfo && (staged::CodeInfo).pure) || return false - return true - end - return method.pure -end -is_method_pure(match::MethodMatch) = is_method_pure(match.method, match.spec_types, match.sparams) - -function pure_eval_call(@nospecialize(f), argtypes::Vector{Any}) - for i = 2:length(argtypes) - a = widenconditional(argtypes[i]) - if !(isa(a, Const) || isconstType(a)) - return nothing - end - end - args = Any[ (a = widenconditional(argtypes[i]); - isa(a, Const) ? a.val : (a::DataType).parameters[1]) for i in 2:length(argtypes) ] - try - value = Core._apply_pure(f, args) - return Const(value) - catch - return nothing - end -end - function argtype_by_index(argtypes::Vector{Any}, i::Int) n = length(argtypes) na = argtypes[n] @@ -1586,8 +1600,10 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), elseif max_methods > 1 && istopfunction(f, :copyto!) max_methods = 1 elseif la == 3 && istopfunction(f, :typejoin) - val = pure_eval_call(f, argtypes) - return CallMeta(val === nothing ? Type : val, MethodResultPure()) + if is_all_const_arg(arginfo) + val = _pure_eval_call(f, arginfo) + return CallMeta(val === nothing ? Type : val, MethodResultPure()) + end end atype = argtypes_to_type(argtypes) return abstract_call_gf_by_type(interp, f, arginfo, atype, sv, max_methods) diff --git a/base/compiler/methodtable.jl b/base/compiler/methodtable.jl index 93020ae6a2639..70beb259cb6a5 100644 --- a/base/compiler/methodtable.jl +++ b/base/compiler/methodtable.jl @@ -84,9 +84,9 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int _min_val[] = typemin(UInt) _max_val[] = typemax(UInt) ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig) - end - if ms === false - return missing + if ms === false + return missing + end end return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0) end @@ -123,3 +123,8 @@ end # This query is not cached findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table) + +isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface") +isoverlayed(::InternalMethodTable) = false +isoverlayed(::OverlayMethodTable) = true +isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)