From 9883872b4dbe39224138f30cfeef739fd1477a2b Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 18 Mar 2024 23:01:07 +0900 Subject: [PATCH] wip: adjustments to the latest master --- src/stage2/interpreter.jl | 56 ++++++++++++++++++++++++++++++--------- test/runtests.jl | 2 +- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 8534b6cb..de09698f 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -260,6 +260,10 @@ CC.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interp CC.lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing CC.unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing +@static if VERSION ≥ v"1.11.0-DEV.1552" +CC.cache_owner(ei::ADInterpreter) = ei.opt +end + function CC.code_cache(ei::ADInterpreter) while ei.current_level > lastindex(ei.opt) push!(ei.opt, Dict{MethodInstance, Any}()) @@ -291,21 +295,17 @@ function CC.finish(state::InferenceState, interp::ADInterpreter) return res end -function CC.transform_result_for_cache(interp::ADInterpreter, - linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) - return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) -end - -function CC.inlining_policy(interp::ADInterpreter, - @nospecialize(src), @nospecialize(info::CC.CallInfo), - stmt_flag::(@static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8), - mi::MethodInstance, argtypes::Vector{Any}) +const StmtFlag = @static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8 +function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.CallInfo), + stmt_flag::StmtFlag) # Disallow inlining things away that have an frule if isa(info, FRuleCallInfo) return nothing end - if isa(src, CC.SemiConcreteResult) - return src + @static if VERSION < v"1.11.0-DEV.879" + if isa(src, CC.SemiConcreteResult) + return src + end end @assert isa(src, Cthulhu.OptimizedSource) || isnothing(src) if isa(src, Cthulhu.OptimizedSource) @@ -314,12 +314,44 @@ function CC.inlining_policy(interp::ADInterpreter, end return nothing end + return missing +end + +@static if VERSION ≥ v"1.12.0-DEV.45" +function CC.transform_result_for_cache(interp::ADInterpreter, + ::MethodInstance, ::WorldRange, result::InferenceResult, ::Bool) + return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) +end +function CC.src_inlining_policy(interp::ADInterpreter, + @nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag) + ret = diffractor_inlining_policy(src, info, stmt_flag) + ret === nothing && return false + ret !== missing && return true + return CC.src_inlining_policy(interp::AbstractInterpreter, + src::Any, info::CC.CallInfo, stmt_flag::StmtFlag) +end +CC.retrieve_ir_for_inlining(cached_result::CodeInstance, src::Cthulhu.OptimizedSource) = + CC.retrieve_ir_for_inlining(cached_result.def, src.ir, true) +CC.retrieve_ir_for_inlining(mi::MethodInstance, src::Cthulhu.OptimizedSource, preserve_local_sources::Bool) = + CC.retrieve_ir_for_inlining(mi, src.ir, preserve_local_sources) +else +function CC.transform_result_for_cache(interp::ADInterpreter, + linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) + return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) +end +function CC.inlining_policy(interp::ADInterpreter, + @nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag, + mi::MethodInstance, argtypes::Vector{Any}) + ret = diffractor_inlining_policy(src, info, stmt_flag) + ret === nothing && return nothing + ret !== missing && return ret # the default inlining policy may try additional effor to find the source in a local cache return @invoke CC.inlining_policy(interp::AbstractInterpreter, nothing, info::CC.CallInfo, - stmt_flag::(@static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8), + stmt_flag::StmtFlag, mi::MethodInstance, argtypes::Vector{Any}) end +end #= function CC.optimize(interp::ADInterpreter, opt::OptimizationState, diff --git a/test/runtests.jl b/test/runtests.jl index 01cbc825..7df9f8a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,7 @@ const bwd = Diffractor.PrimeDerivativeBack @testset verbose=true "Diffractor.jl" begin # overall testset, ensures all tests run @testset "$file" for file in ( - "extra_rules.jl" + "extra_rules.jl", "stage2_fwd.jl", "tangent.jl", "forward_diff_no_inf.jl",