From 89d0e56aa914443ebe1645b3b558cc398fa95fd3 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Thu, 23 Nov 2023 15:56:21 +0900 Subject: [PATCH] adjust to JuliaLang/julia#51754 --- src/analysis/forward.jl | 4 + src/stage1/compiler_utils.jl | 7 -- src/stage2/abstractinterpret.jl | 171 +++++++++++++++++++++++++------- src/stage2/interpreter.jl | 3 +- 4 files changed, 139 insertions(+), 46 deletions(-) diff --git a/src/analysis/forward.jl b/src/analysis/forward.jl index 8ca1a796..0e3ec12e 100644 --- a/src/analysis/forward.jl +++ b/src/analysis/forward.jl @@ -26,7 +26,11 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize frule_call = CC.abstract_call_gf_by_type(interp′, ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1) if frule_call.rt !== Const(nothing) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(primal_call.rt, primal_call.exct, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call)) + else return CallMeta(primal_call.rt, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call)) + end else CC.add_mt_backedge!(sv, frule_mt, frule_atype) end diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index 651a063d..dd3daf5c 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -57,13 +57,6 @@ end Base.lastindex(x::Core.Compiler.InstructionStream) = Core.Compiler.length(x) -# Solves an error after https://github.com/JuliaLang/julia/pull/46961 -# as does https://github.com/FluxML/IRTools.jl/pull/101 -if isdefined(Core.Compiler, :CallInfo) - Base.convert(::Type{Core.Compiler.CallInfo}, ::Nothing) = Core.Compiler.NoCallInfo() -end - - """ find_end_of_phi_block(ir::IRCode, start_search_idx::Int) diff --git a/src/stage2/abstractinterpret.jl b/src/stage2/abstractinterpret.jl index c655a19b..5a709a7c 100644 --- a/src/stage2/abstractinterpret.jl +++ b/src/stage2/abstractinterpret.jl @@ -3,7 +3,7 @@ using Core.Compiler: Const, isconstType, argtypes_to_type, tuple_tfunc, Const, getfield_tfunc, _methods_by_ftype, VarTable, cache_lookup, nfields_tfunc, ArgInfo, singleton_type, CallMeta, MethodMatchInfo, specialize_method, PartialOpaque, UnionSplitApplyCallInfo, typeof_tfunc, apply_type_tfunc, instanceof_tfunc, - StmtInfo + StmtInfo, NoCallInfo using Core: PartialStruct using Base.Meta @@ -41,7 +41,11 @@ function Core.Compiler.abstract_call_gf_by_type(interp::ADInterpreter, @nospecia else rt2 = obtype end + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(rt2, call.exct, call.effects, RecurseInfo(call.info)) + else return CallMeta(rt2, call.effects, RecurseInfo(call.info)) + end end # Check if there is a rrule for this function @@ -56,7 +60,12 @@ function Core.Compiler.abstract_call_gf_by_type(interp::ADInterpreter, @nospecia end call = abstract_call_gf_by_type(lower_level(interp), ChainRules.rrule, ArgInfo(nothing, rrule_argtypes), rrule_atype, sv, -1) if call.rt != Const(nothing) - return CallMeta(getfield_tfunc(call.rt, Const(1)), call.effects, RRuleInfo(call.rt, call.info)) + newrt = getfield_tfunc(call.rt, Const(1)) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(newrt, call.exct, call.effects, RRuleInfo(call.rt, call.info)) + else + return CallMeta(newrt, call.exct, call.effects, RRuleInfo(call.rt, call.info)) + end end end end @@ -74,26 +83,39 @@ function Core.Compiler.abstract_call_gf_by_type(interp::ADInterpreter, @nospecia return ret end -function abstract_accum(interp::AbstractInterpreter, args::Vector{Any}, sv::InferenceState) - args = filter(x->!(widenconst(x) <: Union{ZeroTangent, NoTangent}), args) +function abstract_accum(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState) + argtypes = filter(@nospecialize(x)->!(widenconst(x) <: Union{ZeroTangent, NoTangent}), argtypes) - if length(args) == 0 - return CallMeta(ZeroTangent, Effects(), nothing) + if length(argtypes) == 0 + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(ZeroTangent, Any, Effects(), NoCallInfo()) + else + return CallMeta(ZeroTangent, Effects(), NoCallInfo()) + end end - if length(args) == 1 - return CallMeta(args[1], Effects(), nothing) + if length(argtypes) == 1 + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(argtypes[1], Any, Effects(), NoCallInfo()) + else + return CallMeta(argtypes[1], Effects(), NoCallInfo()) + end end - rtype = reduce(tmerge, args) + rtype = reduce(tmerge, argtypes) if widenconst(rtype) <: Tuple targs = Any[] for i = 1:nfields_tfunc(rtype).val - push!(targs, abstract_accum(interp, Any[getfield_tfunc(arg, Const(i)) for arg in args], sv).rt) + push!(targs, abstract_accum(interp, Any[getfield_tfunc(arg, Const(i)) for arg in argtypes], sv).rt) + end + rt = tuple_tfunc(targs) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(rt, Any, Effects(), NoCallInfo()) + else + return CallMeta(rt, Effects(), NoCallInfo()) end - return CallMeta(tuple_tfunc(targs), nothing) end - call = abstract_call(change_level(interp, 0), nothing, Any[typeof(accum), args...], + call = abstract_call(change_level(interp, 0), nothing, Any[typeof(accum), argtypes...], sv::InferenceState) return call end @@ -249,7 +271,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp ft = argextype(inst.args[1], primal, primal.sptypes) f = singleton_type(ft) if isa(f, Core.Builtin) - call = CallMeta(backwards_tfunc(f, primal, inst, Δ), nothing) + rt = backwards_tfunc(f, primal, inst, Δ) + @static if VERSION ≥ v"1.11.0-DEV.945" + call = CallMeta(rt, Any, Effects(), NoCallInfo()) + else + call = CallMeta(rt, Effects(), NoCallInfo()) + end else bail!(inst) continue @@ -265,7 +292,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp arg = getfield_tfunc(Δ, Const(1)) call = abstract_call(interp, nothing, Any[clos, arg], sv) # No derivative wrt the functor - call = CallMeta(tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]), ReifyInfo(call.info)) + rt = tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]) + @static if VERSION ≥ v"1.11.0-DEV.945" + call = CallMeta(rt, Any, Effects(), ReifyInfo(call.info)) + else + call = CallMeta(rt, Effects(), ReifyInfo(call.info)) + end else (level, close) = derive_closure_type(call_info) call = abstract_call(change_level(interp, level), ArgInfo(nothing, Any[close, Δ]), sv) @@ -274,13 +306,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp if isa(info, UnionSplitApplyCallInfo) argts = Any[argextype(inst.args[i], primal, primal.sptypes) for i = 4:length(inst.args)] - call = CallMeta(repackage_apply_rt(info, call.rt, argts), - UnionSplitApplyCallInfo([ApplyCallInfo(call.info)])) + rt = repackage_apply_rt(info, call.rt, argts) + newinfo = UnionSplitApplyCallInfo([ApplyCallInfo(call.info)]) + @static if VERSION ≥ v"1.11.0-DEV.945" + call = CallMeta(rt, Any, Effects(), newinfo) + else + call = CallMeta(rt, Effects(), newinfo) + end end if isa(call_info, ReifyInfo) new_rt = tuple_tfunc(Any[derive_closure_type(call.info)[2]; call.rt]) - call = CallMeta(new_rt, RecurseInfo(call.info)) + newinfo = RecurseInfo(call.info) + @static if VERSION ≥ v"1.11.0-DEV.945" + call = CallMeta(new_rt, Any, Effects(), newinfo) + else + call = CallMeta(new_rt, Effects(), newinfo) + end end if call.rt === Union{} @@ -312,7 +354,11 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp accum_call = abstract_accum(interp, this_arg_typs, sv) if accum_call.rt == Union{} @show accum_call.rt - return CallMeta(Union{}, false) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(Union{}, Any, Effects(), NoCallInfo()) + else + return CallMeta(Union{}, Effects(), NoCallInfo()) + end end push!(arg_accums, accum_call) tup_push!(tup_elemns, accum_call.rt) @@ -320,7 +366,11 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp end rt = tuple_tfunc(Any[tup_elemns...]) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(rt, Any, Effects(), CompClosInfo(cc, ssa_infos)) + else return CallMeta(rt, Effects(), CompClosInfo(cc, ssa_infos)) + end end function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospecialize(cc_Δ), sv::InferenceState) @@ -389,7 +439,11 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe if isa(inst, ReturnNode) rt = accum_arg(inst.val) - return CallMeta(rt, CompClosInfo(cc, ssa_infos)) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(rt, Any, Effects(), CompClosInfo(cc, ssa_infos)) + else + return CallMeta(rt, Effects(), CompClosInfo(cc, ssa_infos)) + end end args = Any[] @@ -451,7 +505,12 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe arg = getfield_tfunc(Δ, Const(2)) call = abstract_call(interp, nothing, Any[clos, arg], sv) # No derivative wrt the functor - call = CallMeta(tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]), ReifyInfo(call.info)) + newrt = tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]) + @static if VERSION ≥ v"1.11.0-DEV.945" + call = CallMeta(newrt, Any, Effects(), ReifyInfo(call.info)) + else + call = CallMeta(newrt, Effects(), ReifyInfo(call.info)) + end #error() else (level, clos) = derive_closure_type(call_info) @@ -461,11 +520,20 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe if isa(call_info, ReifyInfo) new_rt = tuple_tfunc(Any[call.rt; derive_closure_type(call.info)[2]]) - call = CallMeta(new_rt, RecurseInfo()) + @static if VERSION ≥ v"1.11.0-DEV.945" + call = CallMeta(new_rt, Any, Effects(), RecurseInfo()) + else + call = CallMeta(new_rt, Effects(), RecurseInfo()) + end end if isa(info, UnionSplitApplyCallInfo) - call = CallMeta(call.rt, UnionSplitApplyCallInfo([ApplyCallInfo(call.info)])) + newinfo = UnionSplitApplyCallInfo([ApplyCallInfo(call.info)]) + @static if VERSION ≥ v"1.11.0-DEV.945" + call = CallMeta(call.rt, call.exct, Effects(), newinfo) + else + call = CallMeta(call.rt, Effects(), newinfo) + end end accums[i] = call.rt @@ -485,13 +553,16 @@ function infer_comp_closure(interp::ADInterpreter, cc::AbstractCompClosure, @nos end function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecialize(Δ), sv::InferenceState) - @show ("enter", pc) - if pc.seq == 1 call = abstract_call(change_level(interp, pc.order), nothing, Any[pc.dual, Δ], sv) rt = call.rt @show (pc, Δ, rt) - return CallMeta(call.rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))) + newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(call.rt, call.exct, Effects(), newinfo) + else + return CallMeta(call.rt, Effects(), newinfo) + end elseif pc.seq == 2 ni = change_level(interp, pc.order) mi′ = specialize_method(pc.info_below.results.matches[1], true) @@ -500,8 +571,12 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ call = infer_comp_closure(ni, cc, Δ, sv) rt = getfield_tfunc(call.rt, Const(2)) @show (pc, Δ, rt) - return CallMeta(rt, - PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried))) + newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried)) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(rt, Any, Effects(), newinfo) + else + return CallMeta(rt, Effects(), newinfo) + end elseif pc.seq == 3 ni = change_level(interp, pc.order) mi′ = specialize_method(pc.info_carried.info.results.matches[1], true) @@ -511,41 +586,62 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ Any[clos, tuple_tfunc(Any[Δ, pc.dual])], sv) rt = tuple_tfunc(Any[tuple_type_fields(call.rt)[2:end]...]) @show (pc, Δ, rt) - return CallMeta(rt, - PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))) + newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(rt, Any, Effects(), newinfo) + else + return CallMeta(rt, Effects(), newinfo) + end elseif mod(pc.seq, 4) == 0 info = pc.info_below clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos) - # Add back gradient w.r.t. rrule Δ = tuple_tfunc(Any[NoTangent, tuple_type_fields(Δ)...]) call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, Δ], sv) rt = getfield_tfunc(call.rt, Const(1)) @show (pc, Δ, rt) - return CallMeta(rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(2)), call.info, pc.info_carried))) + newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(2)), call.info, pc.info_carried)) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(rt, Any, Effects(), newinfo) + else + return CallMeta(rt, Effects(), newinfo) + end elseif mod(pc.seq, 4) == 1 info = pc.info_carried clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos) call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, tuple_tfunc(Any[pc.dual, Δ])], sv) rt = call.rt @show (pc, Δ, rt) - return CallMeta(call.rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))) + newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(rt, Any, Effects(), newinfo) + else + return CallMeta(rt, Effects(), newinfo) + end elseif mod(pc.seq, 4) == 2 info = pc.info_below clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos) call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, Δ], sv) rt = getfield_tfunc(call.rt, Const(2)) @show (pc, Δ, rt) - return CallMeta(rt, - PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried))) + newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried)) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(rt, Any, Effects(), newinfo) + else + return CallMeta(rt, Effects(), newinfo) + end elseif mod(pc.seq, 4) == 3 info = pc.info_carried clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos) call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, tuple_tfunc(Any[Δ, pc.dual])], sv) rt = tuple_tfunc(Any[tuple_type_fields(call.rt)[2:end]...]) @show (pc, Δ, rt) - return CallMeta(rt, - PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))) + newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)) + @static if VERSION ≥ v"1.11.0-DEV.945" + return CallMeta(rt, Any, Effects(), newinfo) + else + return CallMeta(rt, Effects(), newinfo) + end end error() end @@ -556,8 +652,7 @@ function Core.Compiler.abstract_call_opaque_closure(interp::ADInterpreter, if isa(closure.source, AbstractCompClosure) (;argtypes) = arginfo if length(argtypes) !== 2 - error() - return CallMeta(Union{}, false) + error("bad argtypes") end return infer_comp_closure(interp, closure.source, argtypes[2], sv) elseif isa(closure.source, PrimClosure) diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index e4fd5eda..42ed22e4 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -252,7 +252,8 @@ function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::CC.Call end CC.InferenceParams(ei::ADInterpreter) = InferenceParams(ei.native_interpreter) -CC.OptimizationParams(ei::ADInterpreter) = OptimizationParams(ei.native_interpreter) +CC.OptimizationParams(ei::ADInterpreter) = OptimizationParams(ei.native_interpreter; + preserve_local_sources=true) CC.get_world_counter(ei::ADInterpreter) = get_world_counter(ei.native_interpreter) CC.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interpreter)