Skip to content

Commit

Permalink
AbstractInterpreter: enable selective pure/concrete eval for extern…
Browse files Browse the repository at this point in the history
…al `AbstractInterpreter` with overlayed method table

Built on top of #44511 and #44561, and solves <JuliaGPU/GPUCompiler.jl#309>.
This commit allows external `AbstractInterpreter` to selectively use
pure/concrete evals even if it uses an overlayed method table.
More specifically, such `AbstractInterpreter` can use pure/concrete evals
as far as any callees used in a call in question doesn't come from the
overlayed method table:
```julia
@test Base.return_types((), MTOverlayInterp()) do
    isbitstype(Int) ? nothing : missing
end == Any[Nothing]
Base.@assume_effects :terminates_globally function issue41694(x)
    res = 1
    1 < x < 20 || throw("bad")
    while x > 1
        res *= x
        x -= 1
    end
    return res
end
@test Base.return_types((), MTOverlayInterp()) do
    issue41694(3) == 6 ? nothing : missing
end == Any[Nothing]
```

In order to check if a call is tainted by any overlayed call, our effect
system now additionally tracks `overlayed::Bool` property. This effect
property is required to prevents concrete-eval in the following kind of situation:
```julia
strangesin(x) = sin(x)
@overlay OverlayedMT strangesin(x::Float64) = iszero(x) ? nothing : cos(x)
Base.@assume_effects :total totalcall(f, args...) = f(args...)
@test Base.return_types(; interp=MTOverlayInterp()) do
    # we need to disable partial pure/concrete evaluation when tainted by any overlayed call
    if totalcall(strangesin, 1.0) == cos(1.0)
        return nothing
    else
        return missing
    end
end |> only === Nothing
```
  • Loading branch information
aviatesk committed Mar 15, 2022
1 parent b2890d5 commit dba4df7
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 98 deletions.
117 changes: 75 additions & 42 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ mutable struct InferenceState
#=parent=#nothing,
#=cached=#cache === :global,
#=inferred=#false, #=dont_work_on_me=#false,
#=ipo_effects=#Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, inbounds_taints_consistency),
#=ipo_effects=#Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, false, inbounds_taints_consistency),
interp)
result.result = frame
cache !== :no && push!(get_inference_cache(interp), result)
Expand Down
60 changes: 35 additions & 25 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,18 @@ end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

"""
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> MethodLookupResult or missing
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) ->
(matches::MethodLookupResult, overlayed::Bool) or missing
Find all methods in the given method table `view` that are applicable to the
given signature `sig`. If no applicable methods are found, an empty result is
returned. If the number of applicable methods exceeded the specified limit,
`missing` is returned.
Find all methods in the given method table `view` that are applicable to the given signature `sig`.
If no applicable methods are found, an empty result is returned.
If the number of applicable methods exceeded the specified limit, `missing` is returned.
`overlayed` indicates if any matching method is defined in an overlayed method table.
"""
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32)))
return _findall(sig, nothing, table.world, limit)
result = _findall(sig, nothing, table.world, limit)
result === missing && return missing
return result, false
end

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32)))
Expand All @@ -57,7 +60,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
nr = length(result)
if nr 1 && result[nr].fully_covers
# no need to fall back to the internal method table
return result
return result, true
end
# fall back to the internal method table
fallback_result = _findall(sig, nothing, table.world, limit)
Expand All @@ -68,7 +71,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig)
result.ambig | fallback_result.ambig), !isempty(result)
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
Expand All @@ -83,31 +86,38 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
end

"""
findsup(sig::Type, view::MethodTableView) -> Tuple{MethodMatch, WorldRange} or nothing
Find the (unique) method `m` such that `sig <: m.sig`, while being more
specific than any other method with the same property. In other words, find
the method which is the least upper bound (supremum) under the specificity/subtype
relation of the queried `signature`. If `sig` is concrete, this is equivalent to
asking for the method that will be called given arguments whose types match the
given signature. This query is also used to implement `invoke`.
Such a method `m` need not exist. It is possible that no method is an
upper bound of `sig`, or it is possible that among the upper bounds, there
is no least element. In both cases `nothing` is returned.
findsup(sig::Type, view::MethodTableView) ->
(match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing
Find the (unique) method such that `sig <: match.method.sig`, while being more
specific than any other method with the same property. In other words, find the method
which is the least upper bound (supremum) under the specificity/subtype relation of
the queried `sig`nature. If `sig` is concrete, this is equivalent to asking for the method
that will be called given arguments whose types match the given signature.
Note that this query is also used to implement `invoke`.
Such a matching method `match` doesn't necessarily exist.
It is possible that no method is an upper bound of `sig`, or
it is possible that among the upper bounds, there is no least element.
In both cases `nothing` is returned.
`overlayed` indicates if the matching method is defined in an overlayed method table.
"""
function findsup(@nospecialize(sig::Type), table::InternalMethodTable)
return _findsup(sig, nothing, table.world)
return (_findsup(sig, nothing, table.world)..., false)
end

function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
match, valid_worlds = _findsup(sig, table.mt, table.world)
match !== nothing && return match, valid_worlds
match !== nothing && return match, valid_worlds, true
# fall back to the internal method table
fallback_match, fallback_valid_worlds = _findsup(sig, nothing, table.world)
return fallback_match, WorldRange(
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
min(valid_worlds.max_world, fallback_valid_worlds.max_world))
return (
fallback_match,
WorldRange(
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
min(valid_worlds.max_world, fallback_valid_worlds.max_world)),
false)
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)
Expand Down
1 change: 1 addition & 0 deletions base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ function Base.show(io::IO, e::Core.Compiler.Effects)
print(io, ',')
printstyled(io, string(tristate_letter(e.terminates), 't'); color=tristate_color(e.terminates))
print(io, ')')
e.overlayed && printstyled(io, ''; color=:red)
end

@specialize
12 changes: 8 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1789,11 +1789,11 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt)
if (f === Core.getfield || f === Core.isdefined) && length(argtypes) >= 3
# consistent if the argtype is immutable
if isvarargtype(argtypes[2])
return Effects(Effects(), effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE)
return Effects(; effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE, overlayed=false)
end
s = widenconst(argtypes[2])
if isType(s) || !isa(s, DataType) || isabstracttype(s)
return Effects(Effects(), effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE)
return Effects(; effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE, overlayed=false)
end
s = s::DataType
ipo_consistent = !ismutabletype(s)
Expand Down Expand Up @@ -1826,7 +1826,9 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt)
ipo_consistent ? ALWAYS_TRUE : ALWAYS_FALSE,
effect_free ? ALWAYS_TRUE : ALWAYS_FALSE,
nothrow ? ALWAYS_TRUE : TRISTATE_UNKNOWN,
ALWAYS_TRUE)
#=terminates=#ALWAYS_TRUE,
#=overlayed=#false,
)
end

function builtin_nothrow(@nospecialize(f), argtypes::Array{Any, 1}, @nospecialize(rt))
Expand Down Expand Up @@ -2007,7 +2009,9 @@ function intrinsic_effects(f::IntrinsicFunction, argtypes::Vector{Any})
ipo_consistent ? ALWAYS_TRUE : ALWAYS_FALSE,
effect_free ? ALWAYS_TRUE : ALWAYS_FALSE,
nothrow ? ALWAYS_TRUE : TRISTATE_UNKNOWN,
ALWAYS_TRUE)
#=terminates=#ALWAYS_TRUE,
#=overlayed=#false,
)
end

# TODO: this function is a very buggy and poor model of the return_type function
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ function rt_adjust_effects(@nospecialize(rt), ipo_effects::Effects)
# but we don't currently model idempontency using dataflow, so we don't notice.
# Fix that up here to improve precision.
if !ipo_effects.inbounds_taints_consistency && rt === Union{}
return Effects(ipo_effects, consistent=ALWAYS_TRUE)
return Effects(ipo_effects; consistent=ALWAYS_TRUE)
end
return ipo_effects
end
Expand Down Expand Up @@ -755,11 +755,11 @@ function merge_call_chain!(parent::InferenceState, ancestor::InferenceState, chi
# and ensure that walking the parent list will get the same result (DAG) from everywhere
# Also taint the termination effect, because we can no longer guarantee the absence
# of recursion.
tristate_merge!(parent, Effects(EFFECTS_TOTAL, terminates=TRISTATE_UNKNOWN))
tristate_merge!(parent, Effects(EFFECTS_TOTAL; terminates=TRISTATE_UNKNOWN))
while true
add_cycle_backedge!(child, parent, parent.currpc)
union_caller_cycle!(ancestor, child)
tristate_merge!(child, Effects(EFFECTS_TOTAL, terminates=TRISTATE_UNKNOWN))
tristate_merge!(child, Effects(EFFECTS_TOTAL; terminates=TRISTATE_UNKNOWN))
child = parent
child === ancestor && break
parent = child.parent::InferenceState
Expand Down
36 changes: 22 additions & 14 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct Effects
effect_free::TriState
nothrow::TriState
terminates::TriState
overlayed::Bool
# This effect is currently only tracked in inference and modified
# :consistent before caching. We may want to track it in the future.
inbounds_taints_consistency::Bool
Expand All @@ -46,27 +47,33 @@ function Effects(
consistent::TriState,
effect_free::TriState,
nothrow::TriState,
terminates::TriState)
terminates::TriState,
overlayed::Bool)
return Effects(
consistent,
effect_free,
nothrow,
terminates,
overlayed,
false)
end
Effects() = Effects(TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN)

function Effects(e::Effects;
const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, false)
const EFFECTS_UNKNOWN = Effects(TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, true)

function Effects(e::Effects = EFFECTS_UNKNOWN;
consistent::TriState = e.consistent,
effect_free::TriState = e.effect_free,
nothrow::TriState = e.nothrow,
terminates::TriState = e.terminates,
overlayed::Bool = e.overlayed,
inbounds_taints_consistency::Bool = e.inbounds_taints_consistency)
return Effects(
consistent,
effect_free,
nothrow,
terminates,
overlayed,
inbounds_taints_consistency)
end

Expand All @@ -82,20 +89,20 @@ is_removable_if_unused(effects::Effects) =
effects.terminates === ALWAYS_TRUE &&
effects.nothrow === ALWAYS_TRUE

const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE)

function encode_effects(e::Effects)
return e.consistent.state |
(e.effect_free.state << 2) |
(e.nothrow.state << 4) |
(e.terminates.state << 6)
return (e.consistent.state << 1) |
(e.effect_free.state << 3) |
(e.nothrow.state << 5) |
(e.terminates.state << 7) |
(e.overlayed)
end
function decode_effects(e::UInt8)
return Effects(
TriState(e & 0x3),
TriState((e >> 2) & 0x3),
TriState((e >> 4) & 0x3),
TriState((e >> 6) & 0x3),
TriState((e >> 1) & 0x03),
TriState((e >> 3) & 0x03),
TriState((e >> 5) & 0x03),
TriState((e >> 7) & 0x03),
e & 0x01 0x00,
false)
end

Expand All @@ -109,6 +116,7 @@ function tristate_merge(old::Effects, new::Effects)
old.nothrow, new.nothrow),
tristate_merge(
old.terminates, new.terminates),
old.overlayed | new.overlayed,
old.inbounds_taints_consistency | new.inbounds_taints_consistency)
end

Expand Down Expand Up @@ -158,7 +166,7 @@ mutable struct InferenceResult
arginfo#=::Union{Nothing,Tuple{ArgInfo,InferenceState}}=# = nothing)
argtypes, overridden_by_const = matching_cache_argtypes(linfo, arginfo)
return new(linfo, argtypes, overridden_by_const, Any, nothing,
WorldRange(), Effects(), Effects(), nothing)
WorldRange(), Effects(; overlayed=false), Effects(; overlayed=false), nothing)
end
end

Expand Down
46 changes: 37 additions & 9 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,53 @@ import Base.Experimental: @MethodTable, @overlay
@MethodTable(OverlayedMT)
CC.method_table(interp::MTOverlayInterp) = CC.OverlayMethodTable(CC.get_world_counter(interp), OverlayedMT)

@overlay OverlayedMT sin(x::Float64) = 1
@test Base.return_types((Int,), MTOverlayInterp()) do x
sin(x)
end == Any[Int]
strangesin(x) = sin(x)
@overlay OverlayedMT strangesin(x::Float64) = iszero(x) ? nothing : cos(x)
@test Base.return_types((Float64,), MTOverlayInterp()) do x
strangesin(x)
end |> only === Union{Float64,Nothing}
@test Base.return_types((Any,), MTOverlayInterp()) do x
Base.@invoke sin(x::Float64)
end == Any[Int]
Base.@invoke strangesin(x::Float64)
end |> only === Union{Float64,Nothing}

# fallback to the internal method table
@test Base.return_types((Int,), MTOverlayInterp()) do x
cos(x)
end == Any[Float64]
end |> only === Float64
@test Base.return_types((Any,), MTOverlayInterp()) do x
Base.@invoke cos(x::Float64)
end == Any[Float64]
end |> only === Float64

# not fully covered overlay method match
overlay_match(::Any) = nothing
@overlay OverlayedMT overlay_match(::Int) = missing
@test Base.return_types((Any,), MTOverlayInterp()) do x
overlay_match(x)
end == Any[Union{Nothing,Missing}]
end |> only === Union{Nothing,Missing}

# partial pure/concrete evaluation
@test Base.return_types((), MTOverlayInterp()) do
isbitstype(Int) ? nothing : missing
end |> only === Nothing
Base.@assume_effects :terminates_globally function issue41694(x)
res = 1
1 < x < 20 || throw("bad")
while x > 1
res *= x
x -= 1
end
return res
end
@test Base.return_types((), MTOverlayInterp()) do
issue41694(3) == 6 ? nothing : missing
end |> only === Nothing

# disable partial pure/concrete evaluation when tainted by any overlayed call
Base.@assume_effects :total totalcall(f, args...) = f(args...)
@test Base.return_types((), MTOverlayInterp()) do
if totalcall(strangesin, 1.0) == cos(1.0)
return nothing
else
return missing
end
end |> only === Nothing

0 comments on commit dba4df7

Please sign in to comment.