From c86d43a50e2baf365279531f459516615031db78 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 4 Mar 2022 18:28:31 +0900 Subject: [PATCH] `AbstractInterpreter`: implement `findsup` for `OverlayMethodTable` --- base/compiler/abstractinterpretation.jl | 3 +- base/compiler/methodtable.jl | 61 +++++++++++++++---------- base/reflection.jl | 12 ++--- src/gf.c | 23 +++++----- stdlib/Test/src/Test.jl | 2 +- test/compiler/AbstractInterpreter.jl | 18 ++++++++ 6 files changed, 75 insertions(+), 44 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index a9f20ceaf569b..a191120e306b7 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1484,7 +1484,8 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn argtype = Tuple{ft, argtype.parameters...} result = findsup(types, method_table(interp)) result === nothing && return CallMeta(Any, false) - method, valid_worlds = result + match, valid_worlds = result + method = match.method update_valid_age!(sv, valid_worlds) (ti, env::SimpleVector) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector (; rt, edge) = result = abstract_call_method(interp, method, ti, env, false, sv) diff --git a/base/compiler/methodtable.jl b/base/compiler/methodtable.jl index 4086023a725f0..841a020b43a8d 100644 --- a/base/compiler/methodtable.jl +++ b/base/compiler/methodtable.jl @@ -40,7 +40,7 @@ end getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch """ - findall(sig::Type, view::MethodTableView; limit=typemax(Int)) + findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> MethodLookupResult or missing Find all methods in the given method table `view` that are applicable to the given signature `sig`. If no applicable methods are found, an empty result is @@ -48,37 +48,43 @@ returned. If the number of applicable methods exceeded the specified limit, `missing` is returned. """ function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=typemax(Int)) - _min_val = RefValue{UInt}(typemin(UInt)) - _max_val = RefValue{UInt}(typemax(UInt)) - _ambig = RefValue{Int32}(0) - ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig) - if ms === false - return missing - end - return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0) + return _findall(sig, nothing, table.world, limit) end function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=typemax(Int)) + result = _findall(sig, table.mt, table.world, limit) + result === missing && return missing + if !isempty(result) + if all(match->match.fully_covers, result) + # no need to fall back to the internal method table + return result + else + # merge the match results with the internal method table + fallback_result = _findall(sig, nothing, table.world, limit) + return MethodLookupResult( + vcat(result.matches, fallback_result.matches), + WorldRange(min(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world), + max(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)), + result.ambig | fallback_result.ambig) + end + end + # fall back to the internal method table + return _findall(sig, nothing, table.world, limit) +end + +function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int) _min_val = RefValue{UInt}(typemin(UInt)) _max_val = RefValue{UInt}(typemax(UInt)) _ambig = RefValue{Int32}(0) - ms = _methods_by_ftype(sig, table.mt, limit, table.world, false, _min_val, _max_val, _ambig) + ms = _methods_by_ftype(sig, mt, limit, world, false, _min_val, _max_val, _ambig) if ms === false return missing - elseif isempty(ms) - # fall back to the internal method table - _min_val[] = typemin(UInt) - _max_val[] = typemax(UInt) - ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig) - if ms === false - return missing - end end return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0) end """ - findsup(sig::Type, view::MethodTableView)::Union{Tuple{MethodMatch, WorldRange}, Nothing} + findsup(sig::Type, view::MethodTableView) -> Tuple{MethodMatch, WorldRange} or nothing Find the (unique) method `m` such that `sig <: m.sig`, while being more specific than any other method with the same property. In other words, find @@ -92,12 +98,21 @@ upper bound of `sig`, or it is possible that among the upper bounds, there is no least element. In both cases `nothing` is returned. """ function findsup(@nospecialize(sig::Type), table::InternalMethodTable) + return _findsup(sig, nothing, table.world) +end + +function findsup(@nospecialize(sig::Type), table::OverlayMethodTable) + result = _findsup(sig, table.mt, table.world) + result === nothing || return result + return _findsup(sig, nothing, table.world) # fall back to the internal method table +end + +function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt) min_valid = RefValue{UInt}(typemin(UInt)) max_valid = RefValue{UInt}(typemax(UInt)) - result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}), - sig, table.world, min_valid, max_valid)::Union{MethodMatch, Nothing} - result === nothing && return nothing - (result.method, WorldRange(min_valid[], max_valid[])) + result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}), + sig, mt, world, min_valid, max_valid)::Union{MethodMatch, Nothing} + return result === nothing ? result : (result, WorldRange(min_valid[], max_valid[])) end isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface") diff --git a/base/reflection.jl b/base/reflection.jl index 5490cae9511c8..0ae3b087fc022 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -1347,15 +1347,11 @@ end print_statement_costs(args...; kwargs...) = print_statement_costs(stdout, args...; kwargs...) function _which(@nospecialize(tt::Type), world=get_world_counter()) - min_valid = RefValue{UInt}(typemin(UInt)) - max_valid = RefValue{UInt}(typemax(UInt)) - match = ccall(:jl_gf_invoke_lookup_worlds, Any, - (Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}), - tt, world, min_valid, max_valid) - if match === nothing + result = Core.Compiler._findsup(tt, nothing, world) + if result === nothing error("no unique matching method found for the specified argument types") end - return match::Core.MethodMatch + return first(result)::Core.MethodMatch end """ @@ -1478,7 +1474,7 @@ true function hasmethod(@nospecialize(f), @nospecialize(t); world::UInt=get_world_counter()) t = to_tuple_type(t) t = signature_type(f, t) - return ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), t, world) !== nothing + return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), t, nothing, world) !== nothing end function hasmethod(@nospecialize(f), @nospecialize(t), kwnames::Tuple{Vararg{Symbol}}; world::UInt=get_world_counter()) diff --git a/src/gf.c b/src/gf.c index c59a0587166c1..f6643e014d785 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1217,7 +1217,7 @@ static jl_method_instance_t *cache_method( return newmeth; } -static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid); +static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, jl_value_t *mt, size_t world, size_t *min_valid, size_t *max_valid); static jl_method_instance_t *jl_mt_assoc_by_type(jl_methtable_t *mt JL_PROPAGATES_ROOT, jl_datatype_t *tt, size_t world) { @@ -1237,7 +1237,7 @@ static jl_method_instance_t *jl_mt_assoc_by_type(jl_methtable_t *mt JL_PROPAGATE size_t min_valid = 0; size_t max_valid = ~(size_t)0; - jl_method_match_t *matc = _gf_invoke_lookup((jl_value_t*)tt, world, &min_valid, &max_valid); + jl_method_match_t *matc = _gf_invoke_lookup((jl_value_t*)tt, jl_nothing, world, &min_valid, &max_valid); jl_method_instance_t *nf = NULL; if (matc) { JL_GC_PUSH1(&matc); @@ -2549,36 +2549,37 @@ JL_DLLEXPORT jl_value_t *jl_apply_generic(jl_value_t *F, jl_value_t **args, uint return _jl_invoke(F, args, nargs, mfunc, world); } -static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid) +static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, jl_value_t *mt, size_t world, size_t *min_valid, size_t *max_valid) { jl_value_t *unw = jl_unwrap_unionall((jl_value_t*)types); if (jl_is_tuple_type(unw) && jl_tparam0(unw) == jl_bottom_type) return NULL; - jl_methtable_t *mt = jl_method_table_for(unw); - if ((jl_value_t*)mt == jl_nothing) + if (mt == jl_nothing) + mt = (jl_value_t*)jl_method_table_for(unw); + if (mt == jl_nothing) mt = NULL; - jl_value_t *matches = ml_matches(mt, (jl_tupletype_t*)types, 1, 0, 0, world, 1, min_valid, max_valid, NULL); + jl_value_t *matches = ml_matches((jl_methtable_t*)mt, (jl_tupletype_t*)types, 1, 0, 0, world, 1, min_valid, max_valid, NULL); if (matches == jl_false || jl_array_len(matches) != 1) return NULL; jl_method_match_t *matc = (jl_method_match_t*)jl_array_ptr_ref(matches, 0); return matc; } -JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world) +JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, jl_value_t *mt, size_t world) { // Deprecated: Use jl_gf_invoke_lookup_worlds for future development size_t min_valid = 0; size_t max_valid = ~(size_t)0; - jl_method_match_t *matc = _gf_invoke_lookup(types, world, &min_valid, &max_valid); + jl_method_match_t *matc = _gf_invoke_lookup(types, mt, world, &min_valid, &max_valid); if (matc == NULL) return jl_nothing; return (jl_value_t*)matc->method; } -JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, size_t world, size_t *min_world, size_t *max_world) +JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, jl_value_t *mt, size_t world, size_t *min_world, size_t *max_world) { - jl_method_match_t *matc = _gf_invoke_lookup(types, world, min_world, max_world); + jl_method_match_t *matc = _gf_invoke_lookup(types, mt, world, min_world, max_world); if (matc == NULL) return jl_nothing; return (jl_value_t*)matc; @@ -2599,7 +2600,7 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t *gf, jl_value_t **args, jl_value_t *types = NULL; JL_GC_PUSH1(&types); types = jl_argtype_with_function(gf, types0); - jl_method_t *method = (jl_method_t*)jl_gf_invoke_lookup(types, world); + jl_method_t *method = (jl_method_t*)jl_gf_invoke_lookup(types, jl_nothing, world); JL_GC_PROMISE_ROOTED(method); if ((jl_value_t*)method == jl_nothing) { diff --git a/stdlib/Test/src/Test.jl b/stdlib/Test/src/Test.jl index 5693d65c7f913..95cd1ecccd9c3 100644 --- a/stdlib/Test/src/Test.jl +++ b/stdlib/Test/src/Test.jl @@ -1774,7 +1774,7 @@ function detect_unbound_args(mods...; params = tuple_sig.parameters[1:(end - 1)] tuple_sig = Base.rewrap_unionall(Tuple{params...}, m.sig) world = Base.get_world_counter() - mf = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), tuple_sig, world) + mf = ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), tuple_sig, nothing, world) if mf !== nothing && mf !== m && mf.sig <: tuple_sig continue end diff --git a/test/compiler/AbstractInterpreter.jl b/test/compiler/AbstractInterpreter.jl index 6ef2adf7177fa..f1fe4b06dcb63 100644 --- a/test/compiler/AbstractInterpreter.jl +++ b/test/compiler/AbstractInterpreter.jl @@ -45,3 +45,21 @@ CC.method_table(interp::MTOverlayInterp) = CC.OverlayMethodTable(CC.get_world_co @test Base.return_types((Int,), MTOverlayInterp()) do x sin(x) end == Any[Int] +@test Base.return_types((Any,), MTOverlayInterp()) do x + Base.@invoke sin(x::Float64) +end == Any[Int] + +# fallback to the internal method table +@test Base.return_types((Int,), MTOverlayInterp()) do x + cos(x) +end == Any[Float64] +@test Base.return_types((Any,), MTOverlayInterp()) do x + Base.@invoke cos(x::Float64) +end == Any[Float64] + +# not fully covered overlay method match +overlay_match(::Any) = nothing +@overlay OverlayedMT overlay_match(::Int) = missing +@test Base.return_types((Any,), MTOverlayInterp()) do x + overlay_match(x) +end == Any[Union{Nothing,Missing}]