Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: apply tmerge limit elementwise to the Union #50927

Merged
merged 3 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3123,11 +3123,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
ssavaluetypes[currpc] = Any
continue
end
if !isempty(frame.ssavalue_uses[currpc])
record_ssa_assign!(𝕃ᵢ, currpc, type, frame)
else
ssavaluetypes[currpc] = type
end
record_ssa_assign!(𝕃ᵢ, currpc, type, frame)
end # for currpc in bbstart:bbend

# Case 1: Fallthrough termination
Expand Down
7 changes: 2 additions & 5 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,8 @@ _topmod(sv::InferenceState) = _topmod(frame_module(sv))
function record_ssa_assign!(𝕃ᵢ::AbstractLattice, ssa_id::Int, @nospecialize(new), frame::InferenceState)
ssavaluetypes = frame.ssavaluetypes
old = ssavaluetypes[ssa_id]
if old === NOT_FOUND || !⊑(𝕃ᵢ, new, old)
# typically, we expect that old ⊑ new (that output information only
# gets less precise with worse input information), but to actually
# guarantee convergence we need to use tmerge here to ensure that is true
ssavaluetypes[ssa_id] = old === NOT_FOUND ? new : tmerge(𝕃ᵢ, old, new)
if old === NOT_FOUND || !is_lattice_equal(𝕃ᵢ, new, old)
ssavaluetypes[ssa_id] = new
W = frame.ip
for r in frame.ssavalue_uses[ssa_id]
if was_reached(frame, r)
Expand Down
112 changes: 87 additions & 25 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,17 @@ union_count_abstract(x::Union) = union_count_abstract(x.a) + union_count_abstrac
union_count_abstract(@nospecialize(x)) = !isdispatchelem(x)

function issimpleenoughtype(@nospecialize t)
ut = unwrap_unionall(t)
ut isa DataType && ut.name.wrapper == t && return true
return unionlen(t) + union_count_abstract(t) <= MAX_TYPEUNION_LENGTH &&
unioncomplexity(t) <= MAX_TYPEUNION_COMPLEXITY
end

# We may want to apply a stricter limit than issimpleenoughtype to
# tupleelements individually, to try to keep the whole tuple under the limit,
# even after complicated recursion and other operations on it elsewhere
const issimpleenoughtupleelem = issimpleenoughtype

# A simplified type_more_complex query over the extended lattice
# (assumes typeb ⊑ typea)
@nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb))
Expand Down Expand Up @@ -692,6 +699,33 @@ end
return tmerge_types_slow(typea, typeb)
end

@nospecializeinfer @noinline function tname_intersect(aname::Core.TypeName, bname::Core.TypeName)
aname === bname && return aname
if !isabstracttype(aname.wrapper) && !isabstracttype(bname.wrapper)
return nothing # fast path
end
Any.name === aname && return aname
a = unwrap_unionall(aname.wrapper)
heighta = 0
while a !== Any
heighta += 1
a = a.super
end
b = unwrap_unionall(bname.wrapper)
heightb = 0
while b !== Any
b.name === aname && return aname
heightb += 1
b = b.super
end
a = unwrap_unionall(aname.wrapper)
while heighta > heightb
a = a.super
heighta -= 1
end
return a.name === bname ? bname : nothing
end

@nospecializeinfer @noinline function tmerge_types_slow(@nospecialize(typea::Type), @nospecialize(typeb::Type))
# collect the list of types from past tmerge calls returning Union
# and then reduce over that list
Expand All @@ -715,74 +749,95 @@ end
# see if any of the union elements have the same TypeName
# in which case, simplify this tmerge by replacing it with
# the widest possible version of itself (the wrapper)
simplify = falses(length(types))
for i in 1:length(types)
typenames[i] === Any.name && continue
ti = types[i]
for j in (i + 1):length(types)
if typenames[i] === typenames[j]
typenames[j] === Any.name && continue
ijname = tname_intersect(typenames[i], typenames[j])
if !(ijname === nothing)
tj = types[j]
if ti <: tj
types[i] = Union{}
typenames[i] = Any.name
simplify[i] = false
simplify[j] = true
break
elseif tj <: ti
types[j] = Union{}
typenames[j] = Any.name
simplify[j] = false
simplify[i] = true
else
if typenames[i] === Tuple.name
if ijname === Tuple.name
# try to widen Tuple slower: make a single non-concrete Tuple containing both
# converge the Tuple element-wise if they are the same length
# see 4ee2b41552a6bc95465c12ca66146d69b354317b, be59686f7613a2ccfd63491c7b354d0b16a95c05,
widen = tuplemerge(unwrap_unionall(ti)::DataType, unwrap_unionall(tj)::DataType)
widen = rewrap_unionall(rewrap_unionall(widen, ti), tj)
simplify[j] = false
else
wr = typenames[i].wrapper
wr = ijname.wrapper
uw = unwrap_unionall(wr)::DataType
ui = unwrap_unionall(ti)::DataType
while ui.name !== ijname
ui = ui.super
end
uj = unwrap_unionall(tj)::DataType
merged = wr
while uj.name !== ijname
uj = uj.super
end
p = Vector{Any}(undef, length(uw.parameters))
usep = true
widen = wr
for k = 1:length(uw.parameters)
ui_k = ui.parameters[k]
if ui_k === uj.parameters[k] && !has_free_typevars(ui_k)
merged = merged{ui_k}
p[k] = ui_k
usep = true
else
merged = merged{uw.parameters[k]}
p[k] = uw.parameters[k]
end
end
widen = rewrap_unionall(merged, wr)
if usep
widen = rewrap_unionall(wr{p...}, wr)
end
simplify[j] = !usep
end
types[i] = Union{}
typenames[i] = Any.name
simplify[i] = false
types[j] = widen
break
end
end
end
end
u = Union{types...}
# don't let type unions get too big, if the above didn't reduce it enough
if issimpleenoughtype(u)
return u
end
# don't let the slow widening of Tuple cause the whole type to grow too fast
# don't let elements of the union get too big, if the above didn't reduce something enough
# Specifically widen Tuple{..., Union{lots of stuff}...} to Tuple{..., Any, ...}
# Don't let Val{<:Val{<:Val}} keep nesting abstract levels either
for i in 1:length(types)
simplify[i] || continue
ti = types[i]
issimpleenoughtype(ti) && continue
if typenames[i] === Tuple.name
ti = types[i]
tip = (unwrap_unionall(types[i])::DataType).parameters
# otherwise we need to do a simple version of tuplemerge for one element now
tip = (unwrap_unionall(ti)::DataType).parameters
lt = length(tip)
p = Vector{Any}(undef, lt)
for j = 1:lt
ui = tip[j]
p[j] = (unioncomplexity(ui)==0) ? ui : isvarargtype(ui) ? Vararg : Any
p[j] = issimpleenoughtupleelem(unwrapva(ui)) ? ui : isvarargtype(ui) ? Vararg : Any
end
types[i] = rewrap_unionall(Tuple{p...}, ti)
else
# this element is not simple enough yet, make it so now
types[i] = typenames[i].wrapper
end
end
u = Union{types...}
if issimpleenoughtype(u)
return u
end
return Any
return u
end

# the inverse of switchtupleunion, with limits on max element union size
Expand All @@ -804,7 +859,7 @@ function tuplemerge(a::DataType, b::DataType)
p = Vector{Any}(undef, lt + vt)
for i = 1:lt
ui = Union{ap[i], bp[i]}
p[i] = issimpleenoughtype(ui) ? ui : Any
p[i] = issimpleenoughtupleelem(ui) ? ui : Any
end
# merge the remaining tail into a single, simple Tuple{Vararg{T}} (#22120)
if vt
Expand All @@ -822,8 +877,10 @@ function tuplemerge(a::DataType, b::DataType)
# or (equivalently?) iteratively took super-types until reaching a common wrapper
# e.g. consider the results of `tuplemerge(Tuple{Complex}, Tuple{Number, Int})` and of
# `tuplemerge(Tuple{Int}, Tuple{String}, Tuple{Int, String})`
if !(ti <: tail)
if tail <: ti
# c.f. tname_intersect in the algorithm above
hasfree = has_free_typevars(ti)
if hasfree || !(ti <: tail)
if !hasfree && tail <: ti
tail = ti # widen to ti
else
uw = unwrap_unionall(tail)
Expand Down Expand Up @@ -851,11 +908,16 @@ function tuplemerge(a::DataType, b::DataType)
end
end
end
tail === Any && return Tuple # short-circuit loop
tail === Any && return Tuple # short-circuit loops
end
end
@assert !(tail === Union{})
p[lt + 1] = Vararg{tail}
if !issimpleenoughtupleelem(tail) || tail === Any
p[lt + 1] = Vararg
lt == 0 && return Tuple
else
p[lt + 1] = Vararg{tail}
end
end
return Tuple{p...}
end
4 changes: 2 additions & 2 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ end
unioncomplexity(@nospecialize x) = _unioncomplexity(x)::Int
function _unioncomplexity(@nospecialize x)
if isa(x, DataType)
x.name === Tuple.name || isvarargtype(x) || return 0
x.name === Tuple.name || return 0
c = 0
for ti in x.parameters
c = max(c, unioncomplexity(ti))
Expand All @@ -302,7 +302,7 @@ function _unioncomplexity(@nospecialize x)
elseif isa(x, UnionAll)
return max(unioncomplexity(x.body), unioncomplexity(x.var.ub))
elseif isa(x, TypeofVararg)
return isdefined(x, :T) ? unioncomplexity(x.T) : 0
return isdefined(x, :T) ? unioncomplexity(x.T) + 1 : 1
else
return 0
end
Expand Down
84 changes: 58 additions & 26 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,48 +156,48 @@ Base.ndims(g::e43296) = ndims(typeof(g))
@test Core.Compiler.unioncomplexity(Tuple{Union{Int8, Int16, Int32, Int64}}) == 3
@test Core.Compiler.unioncomplexity(Union{Int8, Int16, Int32, T} where T) == 3
@test Core.Compiler.unioncomplexity(Tuple{Val{T}, Union{Int8, Int16}, Int8} where T<:Union{Int8, Int16, Int32, Int64}) == 3
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Tuple{Union{Int8, Int16}}}}) == 1
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Symbol}}) == 0
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Union{Symbol, Tuple{Vararg{Symbol}}}}}) == 1
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple{Vararg{Symbol}}}}}}}}) == 2
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple{Vararg{Symbol}}}}}}}}}}}) == 3
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Tuple{Union{Int8, Int16}}}}) == 2
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Symbol}}) == 1
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Union{Symbol, Tuple{Vararg{Symbol}}}}}) == 3
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple{Vararg{Symbol}}}}}}}}) == 5
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple{Vararg{Symbol}}}}}}}}}}}) == 7


# PR 22120
function tmerge_test(a, b, r, commutative=true)
function tuplemerge_test(a, b, r, commutative=true)
@test r == Core.Compiler.tuplemerge(a, b)
if commutative
@test r == Core.Compiler.tuplemerge(b, a)
else
@test_broken r == Core.Compiler.tuplemerge(b, a)
end
end
tmerge_test(Tuple{Int}, Tuple{String}, Tuple{Union{Int, String}})
tmerge_test(Tuple{Int}, Tuple{String, String}, Tuple)
tmerge_test(Tuple{Vararg{Int}}, Tuple{String}, Tuple)
tmerge_test(Tuple{Int}, Tuple{Int, Int},
tuplemerge_test(Tuple{Int}, Tuple{String}, Tuple{Union{Int, String}})
tuplemerge_test(Tuple{Int}, Tuple{String, String}, Tuple)
tuplemerge_test(Tuple{Vararg{Int}}, Tuple{String}, Tuple)
tuplemerge_test(Tuple{Int}, Tuple{Int, Int},
Tuple{Vararg{Int}})
tmerge_test(Tuple{Integer}, Tuple{Int, Int},
tuplemerge_test(Tuple{Integer}, Tuple{Int, Int},
Tuple{Vararg{Integer}})
tmerge_test(Tuple{}, Tuple{Int, Int},
tuplemerge_test(Tuple{}, Tuple{Int, Int},
Tuple{Vararg{Int}})
tmerge_test(Tuple{}, Tuple{Complex},
tuplemerge_test(Tuple{}, Tuple{Complex},
Tuple{Vararg{Complex}})
tmerge_test(Tuple{ComplexF32}, Tuple{ComplexF32, ComplexF64},
tuplemerge_test(Tuple{ComplexF32}, Tuple{ComplexF32, ComplexF64},
Tuple{Vararg{Complex}})
tmerge_test(Tuple{Vararg{ComplexF32}}, Tuple{Vararg{ComplexF64}},
tuplemerge_test(Tuple{Vararg{ComplexF32}}, Tuple{Vararg{ComplexF64}},
Tuple{Vararg{Complex}})
tmerge_test(Tuple{}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Union{ComplexF32, ComplexF64}}})
tmerge_test(Tuple{ComplexF32}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{ComplexF32}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Union{ComplexF32, ComplexF64}}})
tmerge_test(Tuple{ComplexF32, ComplexF32, ComplexF32}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{ComplexF32, ComplexF32, ComplexF32}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Union{ComplexF32, ComplexF64}}})
tmerge_test(Tuple{}, Tuple{Union{ComplexF64, ComplexF32}, Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{}, Tuple{Union{ComplexF64, ComplexF32}, Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Union{ComplexF32, ComplexF64}}})
tmerge_test(Tuple{ComplexF64, ComplexF64, ComplexF32}, Tuple{Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{ComplexF64, ComplexF64, ComplexF32}, Tuple{Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Complex}}, false)
tmerge_test(Tuple{}, Tuple{Complex, Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{}, Tuple{Complex, Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Complex}})
@test Core.Compiler.tmerge(Tuple{}, Union{Nothing, Tuple{ComplexF32, ComplexF32}}) ==
Union{Nothing, Tuple{}, Tuple{ComplexF32, ComplexF32}}
Expand All @@ -214,9 +214,19 @@ tmerge_test(Tuple{}, Tuple{Complex, Vararg{Union{ComplexF32, ComplexF64}}},
@test Core.Compiler.tmerge(Union{}, Base.BitIntegerType) === Base.BitIntegerType
@test Core.Compiler.tmerge(Core.Compiler.fallback_ipo_lattice, Core.Compiler.InterConditional(1, Int, Union{}), Core.Compiler.InterConditional(2, String, Union{})) === Core.Compiler.Const(true)
# test issue behind https://github.com/JuliaLang/julia/issues/50458
@test Core.Compiler.tmerge(Nothing, Tuple{Base.BitInteger, Int}) == Union{Nothing, Tuple{Any, Int}}
@test Core.Compiler.tmerge(Nothing, Tuple{Union{Char, String, SubString{String}, Symbol}, Int}) == Union{Nothing, Tuple{Any, Int}}
@test Core.Compiler.tmerge(Nothing, Tuple{Base.BitInteger, Int}) == Union{Nothing, Tuple{Base.BitInteger, Int}}
@test Core.Compiler.tmerge(Union{Nothing, Tuple{Int, Int}}, Tuple{Base.BitInteger, Int}) == Union{Nothing, Tuple{Any, Int}}
@test Core.Compiler.tmerge(Nothing, Tuple{Union{Char, String, SubString{String}, Symbol}, Int}) == Union{Nothing, Tuple{Union{Char, String, SubString{String}, Symbol}, Int}}
@test Core.Compiler.tmerge(Union{Nothing, Tuple{Char, Int}}, Tuple{Union{Char, String, SubString{String}, Symbol}, Int}) == Union{Nothing, Tuple{Union{Char, String, SubString{String}, Symbol}, Int}}
@test Core.Compiler.tmerge(Nothing, Tuple{Integer, Int}) == Union{Nothing, Tuple{Integer, Int}}
@test Core.Compiler.tmerge(Union{Nothing, Tuple{Int, Int}}, Tuple{Integer, Int}) == Union{Nothing, Tuple{Integer, Int}}
@test Core.Compiler.tmerge(Union{Nothing, AbstractVector{Int}}, Vector) == Union{Nothing, AbstractVector}
@test Core.Compiler.tmerge(Union{Nothing, AbstractVector{Int}}, Matrix) == Union{Nothing, AbstractArray}
@test Core.Compiler.tmerge(Union{Nothing, AbstractVector{Int}}, Matrix{Int}) == Union{Nothing, AbstractArray{Int}}
@test Core.Compiler.tmerge(Union{Nothing, AbstractVector{Int}}, Array) == Union{Nothing, AbstractArray}
@test Core.Compiler.tmerge(Union{Nothing, AbstractArray{Int}}, Vector) == Union{Nothing, AbstractArray}
@test Core.Compiler.tmerge(Union{Nothing, AbstractVector}, Matrix{Int}) == Union{Nothing, AbstractArray}
@test Core.Compiler.tmerge(Union{Nothing, AbstractFloat}, Integer) == Union{Nothing, AbstractFloat, Integer}

# test that recursively more complicated types don't widen all the way to Any when there is a useful valid type upper bound
# Specificially test with base types of a trivial type, a simple union, a complicated union, and a tuple.
Expand Down Expand Up @@ -2886,7 +2896,7 @@ end
# issue #27316 - inference shouldn't hang on these
f27316(::Vector) = nothing
f27316(::Any) = f27316(Any[][1]), f27316(Any[][1])
let expected = NTuple{2, Union{Nothing, NTuple{2, Union{Nothing, Tuple{Any, Any}}}}}
let expected = NTuple{2, Union{Nothing, Tuple{Any, Any}}}
@test Tuple{Nothing, Nothing} <: only(Base.return_types(f27316, Tuple{Int})) == expected # we may be able to improve this bound in the future
end
function g27316()
Expand Down Expand Up @@ -3501,8 +3511,20 @@ function pickvarnames(x::Vector{Any})
end
@test pickvarnames(:a) === :a
@test pickvarnames(Any[:a, :b]) === (:a, :b)
@test only(Base.return_types(pickvarnames, (Vector{Any},))) == Tuple{Vararg{Union{Symbol, Tuple}}}
@test only(Base.code_typed(pickvarnames, (Vector{Any},), optimize=false))[2] == Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple}}}}}}
@test only(Base.return_types(pickvarnames, (Vector{Any},))) == Tuple
@test only(Base.code_typed(pickvarnames, (Vector{Any},), optimize=false))[2] == Tuple{Vararg{Union{Symbol, Tuple}}}

# make sure this converges in a reasonable amount of time
function pickvarnames2(x::Vector{Any})
varnames = ()
for a in x
varnames = (varnames..., pickvarnames(a) )
end
return varnames
end
@test only(Base.return_types(pickvarnames2, (Vector{Any},))) == Tuple{Vararg{Union{Symbol, Tuple}}}
@test only(Base.code_typed(pickvarnames2, (Vector{Any},), optimize=false))[2] == Tuple{Vararg{Union{Symbol, Tuple}}}


@test map(>:, [Int], [Int]) == [true]

Expand Down Expand Up @@ -4597,6 +4619,16 @@ end
a = Core.Compiler.tmerge(Core.Compiler.JLTypeLattice(), Val{<:a}, a)
@test_broken a != Val{<:Val{Union{}}}
@test_broken a == Val{<:Val} || a == Val

a = Tuple{Vararg{Tuple{}}}
a = Core.Compiler.tmerge(Core.Compiler.JLTypeLattice(), Tuple{a}, a)
@test a == Tuple{Vararg{Tuple{Vararg{Tuple{}}}}}
a = Core.Compiler.tmerge(Core.Compiler.JLTypeLattice(), Tuple{a}, a)
@test a == Tuple{Vararg{Tuple{Vararg{Tuple{Vararg{Tuple{}}}}}}}
a = Core.Compiler.tmerge(Core.Compiler.JLTypeLattice(), Tuple{a}, a)
@test a == Tuple{Vararg{Tuple{Vararg{Tuple{Vararg{Tuple{Vararg{Tuple{}}}}}}}}}
a = Core.Compiler.tmerge(Core.Compiler.JLTypeLattice(), Tuple{a}, a)
@test a == Tuple
end

# Test that a function-wise `@max_methods` works as expected
Expand Down
Loading