Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AbstractInterpreter: implement findsup for OverlayMethodTable #44448

Merged
merged 1 commit into from
Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,8 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
argtype = Tuple{ft, argtype.parameters...}
result = findsup(types, method_table(interp))
result === nothing && return CallMeta(Any, false)
method, valid_worlds = result
match, valid_worlds = result
method = match.method
update_valid_age!(sv, valid_worlds)
(ti, env::SimpleVector) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector
(; rt, edge) = result = abstract_call_method(interp, method, ti, env, false, sv)
Expand Down
61 changes: 38 additions & 23 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,45 +40,51 @@ end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

"""
findall(sig::Type, view::MethodTableView; limit=typemax(Int))
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> MethodLookupResult or missing

Find all methods in the given method table `view` that are applicable to the
given signature `sig`. If no applicable methods are found, an empty result is
returned. If the number of applicable methods exceeded the specified limit,
`missing` is returned.
"""
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=typemax(Int))
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
return _findall(sig, nothing, table.world, limit)
end

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=typemax(Int))
result = _findall(sig, table.mt, table.world, limit)
result === missing && return missing
if !isempty(result)
if all(match->match.fully_covers, result)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only matters if result[end].fully_covers

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check my understanding.
We can just look at result[end].fully_covers since:

  • jl_matching_methods returns matching methods in order of speciality
    (and so the last match is always the most general case)
  • if the last match fully_covers, it means this call is assured to be
    dispatched to the last match when it's not dispatched with the other matches

# no need to fall back to the internal method table
return result
else
# merge the match results with the internal method table
fallback_result = _findall(sig, nothing, table.world, limit)
return MethodLookupResult(
vcat(result.matches, fallback_result.matches),
WorldRange(min(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
max(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
Comment on lines +66 to +67
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need the max of the min_worlds and the min of the max_worlds

result.ambig | fallback_result.ambig)
end
end
# fall back to the internal method table
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if the result is empty, you are still required to return the valid_worlds data from the (failed) lookup

return _findall(sig, nothing, table.world, limit)
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, table.mt, limit, table.world, false, _min_val, _max_val, _ambig)
ms = _methods_by_ftype(sig, mt, limit, world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
elseif isempty(ms)
# fall back to the internal method table
_min_val[] = typemin(UInt)
_max_val[] = typemax(UInt)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
end
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end

"""
findsup(sig::Type, view::MethodTableView)::Union{Tuple{MethodMatch, WorldRange}, Nothing}
findsup(sig::Type, view::MethodTableView) -> Tuple{MethodMatch, WorldRange} or nothing

Find the (unique) method `m` such that `sig <: m.sig`, while being more
specific than any other method with the same property. In other words, find
Expand All @@ -92,12 +98,21 @@ upper bound of `sig`, or it is possible that among the upper bounds, there
is no least element. In both cases `nothing` is returned.
"""
function findsup(@nospecialize(sig::Type), table::InternalMethodTable)
return _findsup(sig, nothing, table.world)
end

function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
result = _findsup(sig, table.mt, table.world)
result === nothing || return result
return _findsup(sig, nothing, table.world) # fall back to the internal method table
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)
min_valid = RefValue{UInt}(typemin(UInt))
max_valid = RefValue{UInt}(typemax(UInt))
result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
sig, table.world, min_valid, max_valid)::Union{MethodMatch, Nothing}
result === nothing && return nothing
(result.method, WorldRange(min_valid[], max_valid[]))
result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
sig, mt, world, min_valid, max_valid)::Union{MethodMatch, Nothing}
return result === nothing ? result : (result, WorldRange(min_valid[], max_valid[]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is conservatively correct currently (since we return Any), but it would be more correct to return the tuple unconditionally, since the WorldRange applies to the lookup, regardless of the result.

end

isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
Expand Down
12 changes: 4 additions & 8 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1347,15 +1347,11 @@ end
print_statement_costs(args...; kwargs...) = print_statement_costs(stdout, args...; kwargs...)

function _which(@nospecialize(tt::Type), world=get_world_counter())
min_valid = RefValue{UInt}(typemin(UInt))
max_valid = RefValue{UInt}(typemax(UInt))
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
(Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
tt, world, min_valid, max_valid)
if match === nothing
result = Core.Compiler._findsup(tt, nothing, world)
if result === nothing
error("no unique matching method found for the specified argument types")
end
return match::Core.MethodMatch
return first(result)::Core.MethodMatch
end

"""
Expand Down Expand Up @@ -1478,7 +1474,7 @@ true
function hasmethod(@nospecialize(f), @nospecialize(t); world::UInt=get_world_counter())
t = to_tuple_type(t)
t = signature_type(f, t)
return ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), t, world) !== nothing
return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), t, nothing, world) !== nothing
end

function hasmethod(@nospecialize(f), @nospecialize(t), kwnames::Tuple{Vararg{Symbol}}; world::UInt=get_world_counter())
Expand Down
23 changes: 12 additions & 11 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ static jl_method_instance_t *cache_method(
return newmeth;
}

static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid);
static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, jl_value_t *mt, size_t world, size_t *min_valid, size_t *max_valid);

static jl_method_instance_t *jl_mt_assoc_by_type(jl_methtable_t *mt JL_PROPAGATES_ROOT, jl_datatype_t *tt, size_t world)
{
Expand All @@ -1237,7 +1237,7 @@ static jl_method_instance_t *jl_mt_assoc_by_type(jl_methtable_t *mt JL_PROPAGATE

size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
jl_method_match_t *matc = _gf_invoke_lookup((jl_value_t*)tt, world, &min_valid, &max_valid);
jl_method_match_t *matc = _gf_invoke_lookup((jl_value_t*)tt, jl_nothing, world, &min_valid, &max_valid);
jl_method_instance_t *nf = NULL;
if (matc) {
JL_GC_PUSH1(&matc);
Expand Down Expand Up @@ -2549,36 +2549,37 @@ JL_DLLEXPORT jl_value_t *jl_apply_generic(jl_value_t *F, jl_value_t **args, uint
return _jl_invoke(F, args, nargs, mfunc, world);
}

static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid)
static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, jl_value_t *mt, size_t world, size_t *min_valid, size_t *max_valid)
{
jl_value_t *unw = jl_unwrap_unionall((jl_value_t*)types);
if (jl_is_tuple_type(unw) && jl_tparam0(unw) == jl_bottom_type)
return NULL;
jl_methtable_t *mt = jl_method_table_for(unw);
if ((jl_value_t*)mt == jl_nothing)
if (mt == jl_nothing)
mt = (jl_value_t*)jl_method_table_for(unw);
if (mt == jl_nothing)
mt = NULL;
jl_value_t *matches = ml_matches(mt, (jl_tupletype_t*)types, 1, 0, 0, world, 1, min_valid, max_valid, NULL);
jl_value_t *matches = ml_matches((jl_methtable_t*)mt, (jl_tupletype_t*)types, 1, 0, 0, world, 1, min_valid, max_valid, NULL);
if (matches == jl_false || jl_array_len(matches) != 1)
return NULL;
jl_method_match_t *matc = (jl_method_match_t*)jl_array_ptr_ref(matches, 0);
return matc;
}

JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world)
JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, jl_value_t *mt, size_t world)
{
// Deprecated: Use jl_gf_invoke_lookup_worlds for future development
size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
jl_method_match_t *matc = _gf_invoke_lookup(types, world, &min_valid, &max_valid);
jl_method_match_t *matc = _gf_invoke_lookup(types, mt, world, &min_valid, &max_valid);
if (matc == NULL)
return jl_nothing;
return (jl_value_t*)matc->method;
}


JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, size_t world, size_t *min_world, size_t *max_world)
JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, jl_value_t *mt, size_t world, size_t *min_world, size_t *max_world)
{
jl_method_match_t *matc = _gf_invoke_lookup(types, world, min_world, max_world);
jl_method_match_t *matc = _gf_invoke_lookup(types, mt, world, min_world, max_world);
if (matc == NULL)
return jl_nothing;
return (jl_value_t*)matc;
Expand All @@ -2599,7 +2600,7 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t *gf, jl_value_t **args,
jl_value_t *types = NULL;
JL_GC_PUSH1(&types);
types = jl_argtype_with_function(gf, types0);
jl_method_t *method = (jl_method_t*)jl_gf_invoke_lookup(types, world);
jl_method_t *method = (jl_method_t*)jl_gf_invoke_lookup(types, jl_nothing, world);
JL_GC_PROMISE_ROOTED(method);

if ((jl_value_t*)method == jl_nothing) {
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Test/src/Test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1774,7 +1774,7 @@ function detect_unbound_args(mods...;
params = tuple_sig.parameters[1:(end - 1)]
tuple_sig = Base.rewrap_unionall(Tuple{params...}, m.sig)
world = Base.get_world_counter()
mf = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), tuple_sig, world)
mf = ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), tuple_sig, nothing, world)
if mf !== nothing && mf !== m && mf.sig <: tuple_sig
continue
end
Expand Down
18 changes: 18 additions & 0 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,21 @@ CC.method_table(interp::MTOverlayInterp) = CC.OverlayMethodTable(CC.get_world_co
@test Base.return_types((Int,), MTOverlayInterp()) do x
sin(x)
end == Any[Int]
@test Base.return_types((Any,), MTOverlayInterp()) do x
Base.@invoke sin(x::Float64)
end == Any[Int]

# fallback to the internal method table
@test Base.return_types((Int,), MTOverlayInterp()) do x
cos(x)
end == Any[Float64]
@test Base.return_types((Any,), MTOverlayInterp()) do x
Base.@invoke cos(x::Float64)
end == Any[Float64]

# not fully covered overlay method match
overlay_match(::Any) = nothing
@overlay OverlayedMT overlay_match(::Int) = missing
@test Base.return_types((Any,), MTOverlayInterp()) do x
overlay_match(x)
end == Any[Union{Nothing,Missing}]