Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow apply_type_tfunc to handle argtypes with Union #56617

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 42 additions & 22 deletions Compiler/src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1350,14 +1350,14 @@ end
T = _fieldtype_tfunc(𝕃, o′, f, isconcretetype(o′))
T === Bottom && return Bottom
PT = Const(Pair)
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T, T), true)[1]
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T, T]), true)[1]
end
@nospecs function replacefield!_tfunc(𝕃::AbstractLattice, o, f, x, v, success_order=Symbol, failure_order=Symbol)
o′ = widenconst(o)
T = _fieldtype_tfunc(𝕃, o′, f, isconcretetype(o′))
T === Bottom && return Bottom
PT = Const(ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T)
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T), true)[1]
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T]), true)[1]
end
@nospecs function setfieldonce!_tfunc(𝕃::AbstractLattice, o, f, v, success_order=Symbol, failure_order=Symbol)
setfield!_tfunc(𝕃, o, f, v) === Bottom && return Bottom
Expand Down Expand Up @@ -1713,8 +1713,12 @@ end
const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K, :_L, :_M,
:_N, :_O, :_P, :_Q, :_R, :_S, :_T, :_U, :_V, :_W, :_X, :_Y, :_Z]

# TODO: handle e.g. apply_type(T, R::Union{Type{Int32},Type{Float64}})
@nospecs function apply_type_tfunc(𝕃::AbstractLattice, headtypetype, args...)
function apply_type_tfunc(𝕃::AbstractLattice, argtypes::Vector{Any};
max_union_splitting::Int=InferenceParams().max_union_splitting)
if isempty(argtypes)
return Bottom
end
headtypetype = argtypes[1]
headtypetype = widenslotwrapper(headtypetype)
if isa(headtypetype, Const)
headtype = headtypetype.val
Expand All @@ -1723,15 +1727,15 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
else
return Any
end
if !isempty(args) && isvarargtype(args[end])
largs = length(argtypes)
if largs > 1 && isvarargtype(argtypes[end])
return isvarargtype(headtype) ? TypeofVararg : Type
end
largs = length(args)
if headtype === Union
largs == 0 && return Const(Bottom)
largs == 1 && return Const(Bottom)
hasnonType = false
for i = 1:largs
ai = args[i]
for i = 2:largs
ai = argtypes[i]
if isa(ai, Const)
if !isa(ai.val, Type)
if isa(ai.val, TypeVar)
Expand All @@ -1750,14 +1754,14 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
end
end
end
if largs == 1 # Union{T} --> T
return tmeet(widenconst(args[1]), Union{Type,TypeVar})
if largs == 2 # Union{T} --> T
return tmeet(widenconst(argtypes[2]), Union{Type,TypeVar})
end
hasnonType && return Type
ty = Union{}
allconst = true
for i = 1:largs
ai = args[i]
for i = 2:largs
ai = argtypes[i]
if isType(ai)
aty = ai.parameters[1]
allconst &= hasuniquerep(aty)
Expand All @@ -1768,6 +1772,18 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
end
return allconst ? Const(ty) : Type{ty}
end
if 1 < unionsplitcost(𝕃, argtypes) ≤ max_union_splitting
rt = Bottom
for split_argtypes = switchtupleunion(𝕃, argtypes)
this_rt = widenconst(_apply_type_tfunc(𝕃, headtype, split_argtypes))
rt = Union{rt, this_rt}
end
return rt
end
return _apply_type_tfunc(𝕃, headtype, argtypes)
end
@nospecs function _apply_type_tfunc(𝕃::AbstractLattice, headtype, argtypes::Vector{Any})
largs = length(argtypes)
istuple = headtype === Tuple
if !istuple && !isa(headtype, UnionAll) && !isvarargtype(headtype)
return Union{}
Expand All @@ -1781,20 +1797,20 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
# first push the tailing vars from headtype into outervars
outer_start, ua = 0, headtype
while isa(ua, UnionAll)
if (outer_start += 1) > largs
if (outer_start += 1) > largs - 1
push!(outervars, ua.var)
end
ua = ua.body
end
if largs > outer_start && isa(headtype, UnionAll) # e.g. !isvarargtype(ua) && !istuple
if largs - 1 > outer_start && isa(headtype, UnionAll) # e.g. !isvarargtype(ua) && !istuple
return Bottom # too many arguments
end
outer_start = outer_start - largs + 1
outer_start = outer_start - largs + 2

varnamectr = 1
ua = headtype
for i = 1:largs
ai = widenslotwrapper(args[i])
for i = 2:largs
ai = widenslotwrapper(argtypes[i])
if isType(ai)
aip1 = ai.parameters[1]
canconst &= !has_free_typevars(aip1)
Expand Down Expand Up @@ -1868,7 +1884,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
# If the names are known, keep the upper bound, but otherwise widen to Tuple.
# This is a widening heuristic to avoid keeping type information
# that's unlikely to be useful.
if !(uw.parameters[1] isa Tuple || (i == 2 && tparams[1] isa Tuple))
if !(uw.parameters[1] isa Tuple || (i == 3 && tparams[1] isa Tuple))
ub = Any
end
else
Expand Down Expand Up @@ -1910,7 +1926,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
# throwing errors.
appl = headtype
if isa(appl, UnionAll)
for _ = 1:largs
for _ = 2:largs
appl = appl::UnionAll
push!(outervars, appl.var)
appl = appl.body
Expand All @@ -1930,6 +1946,8 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
end
return ans
end
@nospecs apply_type_tfunc(𝕃::AbstractLattice, headtypetype, args...) =
apply_type_tfunc(𝕃, Any[i == 0 ? headtypetype : args[i] for i in 0:length(args)])
add_tfunc(apply_type, 1, INT_INF, apply_type_tfunc, 10)

# convert the dispatch tuple type argtype to the real (concrete) type of
Expand Down Expand Up @@ -2016,15 +2034,15 @@ end
T = _memoryref_elemtype(mem)
T === Bottom && return Bottom
PT = Const(Pair)
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T, T), true)[1]
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T, T]), true)[1]
end
@nospecs function memoryrefreplace!_tfunc(𝕃::AbstractLattice, mem, x, v, success_order, failure_order, boundscheck)
memoryrefset!_tfunc(𝕃, mem, v, success_order, boundscheck) === Bottom && return Bottom
hasintersect(widenconst(failure_order), Symbol) || return Bottom
T = _memoryref_elemtype(mem)
T === Bottom && return Bottom
PT = Const(ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T)
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T), true)[1]
return instanceof_tfunc(apply_type_tfunc(𝕃, Any[PT, T]), true)[1]
end
@nospecs function memoryrefsetonce!_tfunc(𝕃::AbstractLattice, mem, v, success_order, failure_order, boundscheck)
memoryrefset!_tfunc(𝕃, mem, v, success_order, boundscheck) === Bottom && return Bottom
Expand Down Expand Up @@ -2666,6 +2684,8 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
end
end
return current_scope_tfunc(interp, sv)
elseif f === Core.apply_type
return apply_type_tfunc(𝕃ᵢ, argtypes; max_union_splitting=InferenceParams(interp).max_union_splitting)
end
fidx = find_tfunc(f)
if fidx === nothing
Expand Down
43 changes: 34 additions & 9 deletions Compiler/test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
include("irutils.jl")

# tests for Compiler correctness and precision
import .Compiler: Const, Conditional, ⊑, ReturnNode, GotoIfNot
using .Compiler: Conditional, ⊑
isdispatchelem(@nospecialize x) = !isa(x, Type) || Compiler.isdispatchelem(x)

using Random, Core.IR
Expand Down Expand Up @@ -1721,7 +1721,7 @@ g_test_constant() = (f_constant(3) == 3 && f_constant(4) == 4 ? true : "BAD")
f_pure_add() = (1 + 1 == 2) ? true : "FAIL"
@test @inferred f_pure_add()

import Core: Const
using Core: Const
mutable struct ARef{T}
@atomic x::T
end
Expand Down Expand Up @@ -1762,7 +1762,7 @@ let getfield_tfunc(@nospecialize xs...) =
@test getfield_tfunc(ARef{Int},Const(:x),Bool,Bool) === Union{}
end

import .Compiler: Const
using Core: Const
mutable struct XY{X,Y}
x::X
y::Y
Expand Down Expand Up @@ -2765,10 +2765,10 @@ end |> only === Int

# `apply_type_tfunc` accuracy for constrained type construction
# https://github.com/JuliaLang/julia/issues/47089
import Core: Const
import .Compiler: apply_type_tfunc
struct Issue47089{A<:Number,B<:Number} end
let 𝕃 = Compiler.fallback_lattice
let apply_type_tfunc = Compiler.apply_type_tfunc
𝕃 = Compiler.fallback_lattice
Const = Core.Const
A = Type{<:Integer}
@test apply_type_tfunc(𝕃, Const(Issue47089), A, A) <: (Type{Issue47089{A,B}} where {A<:Integer, B<:Integer})
@test apply_type_tfunc(𝕃, Const(Issue47089), Const(Int), Const(Int), Const(Int)) === Union{}
Expand Down Expand Up @@ -4554,7 +4554,8 @@ end |> only == Tuple{Int,Int}
end |> only == Int

# form PartialStruct for mutables with `const` field
import .Compiler: Const, ⊑
using Core: Const
using .Compiler: ⊑
mutable struct PartialMutable{S,T}
const s::S
t::T
Expand Down Expand Up @@ -5700,7 +5701,8 @@ let x = 1, _Any = Any
end

# Issue #51927
let 𝕃 = Compiler.fallback_lattice
let apply_type_tfunc = Compiler.apply_type_tfunc
𝕃 = Compiler.fallback_lattice
@test apply_type_tfunc(𝕃, Const(Tuple{Vararg{Any,N}} where N), Int) == Type{NTuple{_A, Any}} where _A
end

Expand Down Expand Up @@ -6074,6 +6076,29 @@ function issue56387(nt::NamedTuple, field::Symbol=:a)
end
@test Base.infer_return_type(issue56387, (typeof((;a=1)),)) == Type{Int}

# `apply_type_tfunc` with `Union` in its arguments
let apply_type_tfunc = Compiler.apply_type_tfunc
𝕃 = Compiler.fallback_lattice
Const = Core.Const
@test apply_type_tfunc(𝕃, Any[Const(Vector), Union{Type{Int},Type{Nothing}}]) == Union{Type{Vector{Int}},Type{Vector{Nothing}}}
end

@test Base.infer_return_type((Bool,Int,)) do b, y
x = b ? 1 : missing
inner = y -> x + y
return inner(y)
end == Union{Int,Missing}

function issue31909(ys)
x = if @noinline rand(Bool)
1
else
missing
end
map(y -> x + y, ys)
end
@test Base.infer_return_type(issue31909, (Vector{Int},)) == Union{Vector{Int},Vector{Missing}}

global setglobal!_refine::Int
@test Base.infer_return_type((Integer,)) do x
setglobal!(@__MODULE__, :setglobal!_refine, x)
Expand All @@ -6098,4 +6123,4 @@ function func_swapglobal!_must_throw(x)
swapglobal!(@__MODULE__, :swapglobal!_must_throw, x)
end
@test Base.infer_return_type(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) === Union{}
@test !Base.Compiler.is_effect_free(Base.infer_effects(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) )
@test !Compiler.is_effect_free(Base.infer_effects(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) )