From 8c4c3485dc1861bd9478d1722ed7b530ef6ae31c Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Sat, 23 Mar 2024 23:06:10 +0900 Subject: [PATCH 1/3] wip: adjustments to the latest master --- src/Diffractor.jl | 6 +- src/codegen/forward.jl | 117 ----------------------------- src/codegen/reverse.jl | 45 +++++++---- src/stage1/compiler_utils.jl | 32 ++++---- src/stage1/hacks.jl | 5 +- src/stage1/recurse.jl | 28 ++++--- src/stage1/recurse_fwd.jl | 141 +++++++++++++++++++++++++++++++++++ src/stage2/interpreter.jl | 52 ++++++++++--- src/stage2/lattice.jl | 2 +- test/stage2_fwd.jl | 12 ++- 10 files changed, 261 insertions(+), 179 deletions(-) delete mode 100644 src/codegen/forward.jl diff --git a/src/Diffractor.jl b/src/Diffractor.jl index dae4e1e3..e4fef34b 100644 --- a/src/Diffractor.jl +++ b/src/Diffractor.jl @@ -1,11 +1,12 @@ module Diffractor +export ∂⃖, gradient + using StructArrays using PrecompileTools -export ∂⃖, gradient - const CC = Core.Compiler +using Core.IR @static if VERSION ≥ v"1.11.0-DEV.1498" import .CC: get_inference_world @@ -33,7 +34,6 @@ end include("stage2/tfuncs.jl") include("stage2/forward.jl") - include("codegen/forward.jl") include("analysis/forward.jl") include("codegen/forward_demand.jl") include("codegen/reverse.jl") diff --git a/src/codegen/forward.jl b/src/codegen/forward.jl deleted file mode 100644 index 9a310379..00000000 --- a/src/codegen/forward.jl +++ /dev/null @@ -1,117 +0,0 @@ -function fwd_transform(ci, args...) - newci = copy(ci) - fwd_transform!(newci, args...) - return newci -end - -function fwd_transform!(ci, mi, nargs, N) - new_code = Any[] - new_codelocs = Any[] - ssa_mapping = Int[] - loc_mapping = Int[] - - emit!(@nospecialize stmt) = stmt - function emit!(stmt::Expr) - stmt.head ∈ (:call, :(=), :new, :isdefined) || return stmt - push!(new_code, stmt) - push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end]) - return SSAValue(length(new_code)) - end - - function mapstmt!(@nospecialize stmt) - if isexpr(stmt, :(=)) - return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2]))) - elseif isexpr(stmt, :call) - args = map(stmt.args) do stmt - emit!(mapstmt!(stmt)) - end - return Expr(:call, ∂☆{N}(), args...) - elseif isexpr(stmt, :new) - args = map(stmt.args) do stmt - emit!(mapstmt!(stmt)) - end - return Expr(:call, ∂☆new{N}(), args...) - elseif isexpr(stmt, :splatnew) - args = map(stmt.args) do stmt - emit!(mapstmt!(stmt)) - end - return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...) - elseif isa(stmt, SSAValue) - return SSAValue(ssa_mapping[stmt.id]) - elseif isa(stmt, Core.SlotNumber) - return SlotNumber(2 + stmt.id) - elseif isa(stmt, Argument) - return SlotNumber(2 + stmt.n) - elseif isa(stmt, NewvarNode) - return NewvarNode(SlotNumber(2 + stmt.slot.id)) - elseif isa(stmt, ReturnNode) - return ReturnNode(emit!(mapstmt!(stmt.val))) - elseif isa(stmt, GotoNode) - return stmt - elseif isa(stmt, GotoIfNot) - return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest) - elseif isexpr(stmt, :static_parameter) - return ZeroBundle{N}(mi.sparam_vals[stmt.args[1]::Int]) - elseif isexpr(stmt, :foreigncall) - return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?") - elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds) || isexpr(stmt, :loopinfo) || - isexpr(stmt, :code_coverage_effect) - # Can't trust that meta annotations are still valid in the AD'd - # version. - return nothing - elseif isexpr(stmt, :isdefined) - return Expr(:call, zero_bundle{N}(), emit!(stmt)) - # Always disable `@inbounds`, as we don't actually know if the AD'd - # code is truly `@inbounds` or not. - elseif isexpr(stmt, :boundscheck) - return DNEBundle{N}(true) - else - # Fallback case, for literals. - # If it is an Expr, then it is not a literal - if isa(stmt, Expr) - error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt") - end - return Expr(:call, zero_bundle{N}(), stmt) - end - end - - meth = mi.def::Method - for i = 1:meth.nargs - if meth.isva && i == meth.nargs - args = map(i:(nargs+1)) do j::Int - emit!(Expr(:call, getfield, SlotNumber(2), j)) - end - emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...))) - else - emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, getfield, SlotNumber(2), i))) - end - end - - for (stmt, codeloc) in zip(ci.code, ci.codelocs) - push!(loc_mapping, length(new_code)+1) - push!(new_codelocs, codeloc) - push!(new_code, mapstmt!(stmt)) - push!(ssa_mapping, length(new_code)) - end - - # Rewrite control flow - for (i, stmt) in enumerate(new_code) - if isa(stmt, GotoNode) - new_code[i] = GotoNode(loc_mapping[stmt.label]) - elseif isa(stmt, GotoIfNot) - new_code[i] = GotoIfNot(stmt.cond, loc_mapping[stmt.dest]) - end - end - - ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...] - ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...] - ci.slottypes = ci.slottypes === nothing ? nothing : Any[Any, Any, ci.slottypes...] - ci.code = new_code - ci.codelocs = new_codelocs - ci.ssavaluetypes = length(new_code) - ci.ssaflags = UInt8[0 for i=1:length(new_code)] - ci.method_for_inference_limit_heuristics = meth - ci.edges = MethodInstance[mi] - - return ci -end diff --git a/src/codegen/reverse.jl b/src/codegen/reverse.jl index f8925f9f..011db4cf 100644 --- a/src/codegen/reverse.jl +++ b/src/codegen/reverse.jl @@ -1,16 +1,20 @@ # Codegen shared by both stage1 and stage2 -function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...) +function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci, revs...) if interp !== nothing - cis.inferred = true + @static if VERSION ≥ v"1.12.0-DEV.15" + rettype = Any # ci.rettype # TODO revisit + else + ci.inferred = true + rettype = ci.rettype + end ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any), - typ, Union{}, cis.rettype, @__MODULE__, cis, lno.line, lno.file, meth_nargs, isva, ()).source - return Expr(:new_opaque_closure, typ, Union{}, Any, - ocm, revs...) + typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source + return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...) else oc_nargs = Int64(meth_nargs) Expr(:new_opaque_closure, typ, Union{}, Any, - Expr(:opaque_closure_method, name, oc_nargs, isva, lno, cis), revs...) + Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...) end end @@ -107,8 +111,12 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I opaque_ci.slotnames = [Symbol("#oc#"), ci.slotnames...] opaque_ci.slotflags = UInt8[0, ci.slotflags...] end - opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]] - opaque_ci.inferred = false + @static if VERSION ≥ v"1.12.0-DEV.173" + opaque_ci.debuginfo = ci.debuginfo + else + opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]] + opaque_ci.inferred = false + end opaque_ci end @@ -393,12 +401,17 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I code = opaque_ci.code = expand_switch(code, bb_ranges, slot_map) end - opaque_ci.codelocs = Int32[0 for i=1:length(code)] + @static if VERSION ≥ v"1.12.0-DEV.173" + debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code)) + debuginfo.def = :var"N/A" + opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code)) + else + opaque_ci.codelocs = Int32[0 for i=1:length(code)] + end opaque_ci.ssavaluetypes = length(code) - opaque_ci.ssaflags = UInt8[0 for i=1:length(code)] + opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)] end - for nc = 2:2:n_closures fwds = Any[nothing for i = 1:length(ir.stmts)] @@ -475,9 +488,15 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I end end - opaque_ci.codelocs = Int32[0 for i=1:length(code)] + @static if VERSION ≥ v"1.12.0-DEV.173" + debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code)) + debuginfo.def = :var"N/A" + opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code)) + else + opaque_ci.codelocs = Int32[0 for i=1:length(code)] + end opaque_ci.ssavaluetypes = length(code) - opaque_ci.ssaflags = UInt8[0 for i=1:length(code)] + opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)] end # TODO: This is absolutely aweful, but the best we can do given the data structures we have diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index dd3daf5c..52c8de3f 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -1,5 +1,5 @@ -# Utilities that should probably go into Core.Compiler -using Core.Compiler: IRCode, CFG, BasicBlock, BBIdxIter +# Utilities that should probably go into CC +using .CC: IRCode, CFG, BasicBlock, BBIdxIter function Base.push!(cfg::CFG, bb::BasicBlock) @assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start @@ -8,38 +8,40 @@ function Base.push!(cfg::CFG, bb::BasicBlock) end if VERSION < v"1.11.0-DEV.258" - Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa) + Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa) end -Base.copy(ir::IRCode) = Core.Compiler.copy(ir) +Base.copy(ir::IRCode) = CC.copy(ir) -Core.Compiler.NewInstruction(@nospecialize node) = +CC.NewInstruction(@nospecialize node) = NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED) -Base.setproperty!(x::Core.Compiler.Instruction, f::Symbol, v) = - Core.Compiler.setindex!(x, v, f) +Base.setproperty!(x::CC.Instruction, f::Symbol, v) = CC.setindex!(x, v, f) -Base.getproperty(x::Core.Compiler.Instruction, f::Symbol) = - Core.Compiler.getindex(x, f) +Base.getproperty(x::CC.Instruction, f::Symbol) = CC.getindex(x, f) function Base.setindex!(ir::IRCode, ni::NewInstruction, i::Int) stmt = ir.stmts[i] stmt.inst = ni.stmt stmt.type = ni.type stmt.flag = something(ni.flag, 0) # fixes 1.9? - stmt.line = something(ni.line, 0) + @static if VERSION ≥ v"1.12.0-DEV.173" + stmt.line = something(ni.line, CC.NoLineUpdate) + else + stmt.line = something(ni.line, 0) + end return ni end function Base.push!(ir::IRCode, ni::NewInstruction) # TODO: This should be a check in insert_node! @assert length(ir.new_nodes.stmts) == 0 - @static if isdefined(Core.Compiler, :add!) + @static if isdefined(CC, :add!) # Julia 1.7 & 1.8 - ir[Core.Compiler.add!(ir.stmts)] = ni + ir[CC.add!(ir.stmts)] = ni else # Re-named in https://github.com/JuliaLang/julia/pull/47051 - ir[Core.Compiler.add_new_idx!(ir.stmts)] = ni + ir[CC.add_new_idx!(ir.stmts)] = ni end ir end @@ -54,8 +56,8 @@ function Base.iterate(it::Iterators.Reverse{BBIdxIter}, return (bb, idx - 1), (bb, idx - 1) end -Base.lastindex(x::Core.Compiler.InstructionStream) = - Core.Compiler.length(x) +Base.lastindex(x::CC.InstructionStream) = + CC.length(x) """ find_end_of_phi_block(ir::IRCode, start_search_idx::Int) diff --git a/src/stage1/hacks.jl b/src/stage1/hacks.jl index 48ac232d..18fba64c 100644 --- a/src/stage1/hacks.jl +++ b/src/stage1/hacks.jl @@ -1,6 +1,7 @@ # Updated copy of the same code in Base, but with bugs fixed -using Core.Compiler: count_added_node!, NewSSAValue, add_pending!, - StmtRange, BasicBlock +using Core.Compiler: + NewSSAValue, OldSSAValue, StmtRange, BasicBlock, + count_added_node!, add_pending! # Re-named in https://github.com/JuliaLang/julia/pull/47051 const add! = Core.Compiler.add_inst! diff --git a/src/stage1/recurse.jl b/src/stage1/recurse.jl index 3824fe1a..9048376a 100644 --- a/src/stage1/recurse.jl +++ b/src/stage1/recurse.jl @@ -1,7 +1,6 @@ using Core.IR using Core.Compiler: - BasicBlock, CallInfo, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction, - NoCallInfo, OldSSAValue, StmtRange, + BasicBlock, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction, NoCallInfo, StmtRange, bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete, construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!, insert_node_here!, non_dce_finish!, quoted, retrieve_code_info, @@ -255,13 +254,15 @@ function sptypes(sparams) VarState[Core.Compiler.VarState.(sparams, false)...] end -function optic_transform(ci, args...) +function optic_transform(ci::CodeInfo, args...) newci = copy(ci) optic_transform!(newci, args...) return newci end -function optic_transform!(ci, mi, nargs, N) +const SSAFlagType = @static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8 + +function optic_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int) code = ci.code sparams = mi.sparam_vals @@ -270,11 +271,20 @@ function optic_transform!(ci, mi, nargs, N) ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...] ci.slottypes = ci.slottypes === nothing ? Any[Any for _ in 1:length(ci.slotflags)] : Any[Any, Any, ci.slottypes...] + type = Any[] + info = CallInfo[NoCallInfo() for i = 1:length(code)] + flag = SSAFlagType[zero(SSAFlagType) for i = 1:length(code)] + argtypes = Any[Any for i = 1:2] meta = Expr[] - ir = IRCode(Core.Compiler.InstructionStream(code, Any[], - CallInfo[NoCallInfo() for i = 1:length(code)], - ci.codelocs, UInt8[0 for i = 1:length(code)]), cfg, Core.LineInfoNode[ci.linetable...], - Any[Any for i = 1:2], meta, sptypes(sparams)) + @static if VERSION ≥ v"1.12.0-DEV.173" + debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(code)) + stmts = Core.Compiler.InstructionStream(code, type, info, debuginfo.codelocs, flag) + ir = IRCode(stmts, cfg, debuginfo, argtypes, meta, sptypes(sparams)) + else + linetable = Core.LineInfoNode[ci.linetable...] + stmts = Core.Compiler.InstructionStream(code, type, info, ci.codelocs, flag) + ir = IRCode(stmts, cfg, debuginfo, argtypes, meta, sptypes(sparams)) + end # SSA conversion meth = mi.def::Method @@ -300,7 +310,7 @@ function optic_transform!(ci, mi, nargs, N) Core.Compiler.replace_code_newstyle!(ci, ir) ci.ssavaluetypes = length(ci.code) - ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)] + ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(ci.code)] ci.method_for_inference_limit_heuristics = meth ci.edges = MethodInstance[mi] diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index fa8a99fe..68f3c1ff 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -73,6 +73,147 @@ function ∂☆builtin((f_bundle, args...)) throw(DomainError(f, "No `ChainRulesCore.frule` found for the built-in function `$sig`")) end +function fwd_transform(ci::CodeInfo, args...) + newci = copy(ci) + fwd_transform!(newci, args...) + return newci +end + +function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int) + new_code = Any[] + @static if VERSION ≥ v"1.12.0-DEV.173" + debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(ci.code)) + new_codelocs = Int32[] + else + new_codelocs = Any[] + end + ssa_mapping = Int[] + loc_mapping = Int[] + + emit!(@nospecialize stmt) = stmt + function emit!(stmt::Expr) + stmt.head ∈ (:call, :(=), :new, :isdefined) || return stmt + push!(new_code, stmt) + @static if VERSION ≥ v"1.12.0-DEV.173" + if isempty(new_codelocs) + push!(new_codelocs, 0, 0, 0) + else + append!(new_codelocs, new_codelocs[end-2:end]) + end + else + push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end]) + end + return SSAValue(length(new_code)) + end + + function mapstmt!(@nospecialize stmt) + if isexpr(stmt, :(=)) + return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2]))) + elseif isexpr(stmt, :call) + args = map(stmt.args) do stmt + emit!(mapstmt!(stmt)) + end + return Expr(:call, ∂☆{N}(), args...) + elseif isexpr(stmt, :new) + args = map(stmt.args) do stmt + emit!(mapstmt!(stmt)) + end + return Expr(:call, ∂☆new{N}(), args...) + elseif isexpr(stmt, :splatnew) + args = map(stmt.args) do stmt + emit!(mapstmt!(stmt)) + end + return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...) + elseif isa(stmt, SSAValue) + return SSAValue(ssa_mapping[stmt.id]) + elseif isa(stmt, Core.SlotNumber) + return SlotNumber(2 + stmt.id) + elseif isa(stmt, Argument) + return SlotNumber(2 + stmt.n) + elseif isa(stmt, NewvarNode) + return NewvarNode(SlotNumber(2 + stmt.slot.id)) + elseif isa(stmt, ReturnNode) + return ReturnNode(emit!(mapstmt!(stmt.val))) + elseif isa(stmt, GotoNode) + return stmt + elseif isa(stmt, GotoIfNot) + return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest) + elseif isexpr(stmt, :static_parameter) + return ZeroBundle{N}(mi.sparam_vals[stmt.args[1]::Int]) + elseif isexpr(stmt, :foreigncall) + return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?") + elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds) || isexpr(stmt, :loopinfo) || + isexpr(stmt, :code_coverage_effect) + # Can't trust that meta annotations are still valid in the AD'd + # version. + return nothing + elseif isexpr(stmt, :isdefined) + return Expr(:call, zero_bundle{N}(), emit!(stmt)) + # Always disable `@inbounds`, as we don't actually know if the AD'd + # code is truly `@inbounds` or not. + elseif isexpr(stmt, :boundscheck) + return DNEBundle{N}(true) + else + # Fallback case, for literals. + # If it is an Expr, then it is not a literal + if isa(stmt, Expr) + error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt") + end + return Expr(:call, zero_bundle{N}(), stmt) + end + end + + meth = mi.def::Method + for i = 1:meth.nargs + if meth.isva && i == meth.nargs + args = map(i:(nargs+1)) do j::Int + emit!(Expr(:call, getfield, SlotNumber(2), j)) + end + emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...))) + else + emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, getfield, SlotNumber(2), i))) + end + end + + for (i, stmt) = enumerate(ci.code) + push!(loc_mapping, length(new_code)+1) + @static if VERSION ≥ v"1.12.0-DEV.173" + append!(new_codelocs, debuginfo.codelocs[3i-2:3i]) + else + push!(new_codelocs, ci.codelocs[i]) + end + push!(new_code, mapstmt!(stmt)) + push!(ssa_mapping, length(new_code)) + end + + # Rewrite control flow + for (i, stmt) in enumerate(new_code) + if isa(stmt, GotoNode) + new_code[i] = GotoNode(loc_mapping[stmt.label]) + elseif isa(stmt, GotoIfNot) + new_code[i] = GotoIfNot(stmt.cond, loc_mapping[stmt.dest]) + end + end + + ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...] + ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...] + ci.slottypes = ci.slottypes === nothing ? nothing : Any[Any, Any, ci.slottypes...] + ci.code = new_code + @static if VERSION ≥ v"1.12.0-DEV.173" + empty!(debuginfo.codelocs) + append!(debuginfo.codelocs, new_codelocs) + ci.debuginfo = Core.DebugInfo(debuginfo, length(new_code)) + else + ci.codelocs = new_codelocs + end + ci.ssavaluetypes = length(new_code) + ci.ssaflags = UInt8[0 for i=1:length(new_code)] + ci.method_for_inference_limit_heuristics = meth + ci.edges = MethodInstance[mi] + + return ci +end + function perform_fwd_transform(world::UInt, source::LineNumberNode, @nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N} if all(x->x <: ZeroBundle, args) diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 8534b6cb..dd4a3711 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -260,6 +260,10 @@ CC.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interp CC.lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing CC.unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing +@static if VERSION ≥ v"1.11.0-DEV.1552" +CC.cache_owner(ei::ADInterpreter) = ei.opt +end + function CC.code_cache(ei::ADInterpreter) while ei.current_level > lastindex(ei.opt) push!(ei.opt, Dict{MethodInstance, Any}()) @@ -291,21 +295,17 @@ function CC.finish(state::InferenceState, interp::ADInterpreter) return res end -function CC.transform_result_for_cache(interp::ADInterpreter, - linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) - return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) -end - -function CC.inlining_policy(interp::ADInterpreter, - @nospecialize(src), @nospecialize(info::CC.CallInfo), - stmt_flag::(@static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8), - mi::MethodInstance, argtypes::Vector{Any}) +const StmtFlag = @static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8 +function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.CallInfo), + stmt_flag::StmtFlag) # Disallow inlining things away that have an frule if isa(info, FRuleCallInfo) return nothing end - if isa(src, CC.SemiConcreteResult) - return src + @static if VERSION < v"1.11.0-DEV.879" + if isa(src, CC.SemiConcreteResult) + return src + end end @assert isa(src, Cthulhu.OptimizedSource) || isnothing(src) if isa(src, Cthulhu.OptimizedSource) @@ -314,12 +314,40 @@ function CC.inlining_policy(interp::ADInterpreter, end return nothing end + return missing +end + +@static if VERSION ≥ v"1.12.0-DEV.45" +function CC.transform_result_for_cache(interp::ADInterpreter, + ::MethodInstance, ::WorldRange, result::InferenceResult, ::Bool) + return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) +end +function CC.src_inlining_policy(interp::ADInterpreter, + @nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag) + ret = diffractor_inlining_policy(src, info, stmt_flag) + ret === nothing && return false + ret !== missing && return true + return @invoke CC.src_inlining_policy(interp::AbstractInterpreter, + src::Any, info::CC.CallInfo, stmt_flag::StmtFlag) +end +else +function CC.transform_result_for_cache(interp::ADInterpreter, + linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) + return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) +end +function CC.inlining_policy(interp::ADInterpreter, + @nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag, + mi::MethodInstance, argtypes::Vector{Any}) + ret = diffractor_inlining_policy(src, info, stmt_flag) + ret === nothing && return nothing + ret !== missing && return ret # the default inlining policy may try additional effor to find the source in a local cache return @invoke CC.inlining_policy(interp::AbstractInterpreter, nothing, info::CC.CallInfo, - stmt_flag::(@static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8), + stmt_flag::StmtFlag, mi::MethodInstance, argtypes::Vector{Any}) end +end #= function CC.optimize(interp::ADInterpreter, opt::OptimizationState, diff --git a/src/stage2/lattice.jl b/src/stage2/lattice.jl index 8683a795..663b5ffe 100644 --- a/src/stage2/lattice.jl +++ b/src/stage2/lattice.jl @@ -1,4 +1,4 @@ -using Core.Compiler: CodeInfo, CallInfo, CallMeta +using Core.Compiler: CallInfo, CallMeta import Core.Compiler: widenconst struct CompClosure; opaque; end # TODO: Is this a YAKC? diff --git a/test/stage2_fwd.jl b/test/stage2_fwd.jl index 6bc48d08..4514084e 100644 --- a/test/stage2_fwd.jl +++ b/test/stage2_fwd.jl @@ -6,8 +6,7 @@ module stage2_fwd @test sin′(1.0) == cos(1.0) end let sin′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(mysin), Float64}, 2) - # This broke some time between 1.10 and 1.11-DEV.10001 - @test isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" + @test isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test sin′′(1.0) == -sin(1.0) end @@ -15,27 +14,26 @@ module stage2_fwd self_minus(a) = myminus(a, a) ChainRulesCore.@scalar_rule myminus(x, y) (true, -1) let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}) - # This broke some time between 1.10 and 1.11-DEV.10001 - @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" + @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test self_minus′(1.0) == 0. end ChainRulesCore.@scalar_rule myminus(x, y) (true, true) # frule for `x - y` let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}) # This broke some time between 1.10 and 1.11-DEV.10001 - @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" + @test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test self_minus′(1.0) == 2. end myminus2(a, b) = a - b self_minus2(a) = myminus2(a, a) let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64}) - @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" + @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test self_minus2′(1.0) == 0. end ChainRulesCore.@scalar_rule myminus2(x, y) (true, true) # frule for `x - y` let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64}) # This broke some time between 1.10 and 1.11-DEV.10001 - @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) broken=VERSION>=v"1.11-" + @test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64}) @test self_minus2′(1.0) == 2. end From 4d2fc4b7a3de211c52c73b8f420adff9458d3d2e Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Tue, 26 Mar 2024 19:53:18 +0900 Subject: [PATCH 2/3] fix --- src/Diffractor.jl | 2 +- src/stage1/recurse.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Diffractor.jl b/src/Diffractor.jl index e4fef34b..5cb37459 100644 --- a/src/Diffractor.jl +++ b/src/Diffractor.jl @@ -48,4 +48,4 @@ end include("AbstractDifferentiation.jl") end -end +end # module Diffractor diff --git a/src/stage1/recurse.jl b/src/stage1/recurse.jl index 9048376a..39131d80 100644 --- a/src/stage1/recurse.jl +++ b/src/stage1/recurse.jl @@ -283,7 +283,7 @@ function optic_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int) else linetable = Core.LineInfoNode[ci.linetable...] stmts = Core.Compiler.InstructionStream(code, type, info, ci.codelocs, flag) - ir = IRCode(stmts, cfg, debuginfo, argtypes, meta, sptypes(sparams)) + ir = IRCode(stmts, cfg, linetable, argtypes, meta, sptypes(sparams)) end # SSA conversion From 058aa5748c489cfb5e7f37005a12c8241140cbfb Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Wed, 27 Mar 2024 06:14:45 +0000 Subject: [PATCH 3/3] More test fixes --- src/codegen/reverse.jl | 8 ++++++-- src/higher_fwd_rules.jl | 6 +++--- src/jet.jl | 17 +++++++++++++---- src/stage1/forward.jl | 2 +- src/stage1/generated.jl | 15 +++++++++++++++ test/gradcheck.jl | 24 ++++++++++++------------ 6 files changed, 50 insertions(+), 22 deletions(-) diff --git a/src/codegen/reverse.jl b/src/codegen/reverse.jl index 011db4cf..fff17b77 100644 --- a/src/codegen/reverse.jl +++ b/src/codegen/reverse.jl @@ -8,8 +8,12 @@ function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci, ci.inferred = true rettype = ci.rettype end - ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any), - typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source + @static if VERSION ≥ v"1.12.0-DEV.15" + ocm = Core.OpaqueClosure(ci; rettype, nargs=meth_nargs, isva, sig=typ).source + else + ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any), + typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source + end return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...) else oc_nargs = Int64(meth_nargs) diff --git a/src/higher_fwd_rules.jl b/src/higher_fwd_rules.jl index 8486b8bd..9c924aeb 100644 --- a/src/higher_fwd_rules.jl +++ b/src/higher_fwd_rules.jl @@ -4,17 +4,17 @@ using Base.Iterators function njet(::Val{N}, ::typeof(sin), x₀) where {N} (s, c) = sincos(x₀) - Jet(x₀, s, tuple(take(cycle((c, -s, -c, s)), N)...)) + Jet(convert(typeof(s), x₀), s, tuple(take(cycle((c, -s, -c, s)), N)...)) end function njet(::Val{N}, ::typeof(cos), x₀) where {N} (s, c) = sincos(x₀) - Jet(x₀, s, tuple(take(cycle((-s, -c, s, c)), N)...)) + Jet(convert(typeof(s), x₀), s, tuple(take(cycle((-s, -c, s, c)), N)...)) end function njet(::Val{N}, ::typeof(exp), x₀) where {N} exped = exp(x₀) - Jet(x₀, exped, tuple(take(repeated(exped), N)...)) + Jet(convert(typeof(exped), x₀), exped, tuple(take(repeated(exped), N)...)) end jeval(j, x) = j(x) diff --git a/src/jet.jl b/src/jet.jl index 05aa3069..9e9e7fe5 100644 --- a/src/jet.jl +++ b/src/jet.jl @@ -104,7 +104,7 @@ function Base.show(io::IO, j::Jet) end function domain_check(j::Jet, x) - if j.a !== x + if j.a !== convert(typeof(j.a), x) throw(DomainError("Evaluation is only valid at a")) end end @@ -153,11 +153,17 @@ function ChainRulesCore.rrule(j::Jet, x) end function ChainRulesCore.rrule(::typeof(map), ::typeof(*), a, b) - map(*, a, b), Δ->(NoTangent(), NoTangent(), map(*, Δ, b), map(*, a, Δ)) + map(*, a, b), Δ->let Δ=unthunk(Δ) + isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent(), NoTangent()) + (NoTangent(), NoTangent(), map(*, Δ, b), map(*, a, Δ)) + end end ChainRulesCore.rrule(::typeof(map), ::typeof(integrate), js::Array{<:Jet}) = - map(integrate, js), Δ->(NoTangent(), NoTangent(), map(deriv, Δ)) + map(integrate, js), Δ->let Δ=unthunk(Δ) + isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent()) + (NoTangent(), NoTangent(), map(deriv, Δ)) + end struct derivBack js @@ -177,7 +183,10 @@ end function ChainRulesCore.rrule(::typeof(mapev), js::Array{<:Jet}, xs::AbstractArray) mapev(js, xs), let djs=map(deriv, js) - Δ->(NoTangent(), NoTangent(), map(*, unthunk(Δ), mapev(djs, xs))) + function (Δ) + isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent()) + (NoTangent(), NoTangent(), map(*, unthunk(Δ), mapev(djs, xs))) + end end end diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index f63cf51c..873adf67 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -126,7 +126,7 @@ function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...) end function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...) - bundles = map(bundle, partials, args) + bundles = map(bundle, args, partials) result = ∂☆internal{1}()(bundles...) primal(result), first_partial(result) end diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index c1624046..07f5d59b 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -12,6 +12,10 @@ function generate_lambda_ex(world::UInt, source::LineNumberNode, return stub(world, source, body) end +struct NonTransformableError + args +end + function perform_optic_transform(world::UInt, source::LineNumberNode, @nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N} @assert N >= 1 @@ -28,6 +32,17 @@ function perform_optic_transform(world::UInt, source::LineNumberNode, mi = Core.Compiler.specialize_method(match) ci = Core.Compiler.retrieve_code_info(mi, world) + if ci === nothing + # Failed to retrieve source - likely a generated function that errors. + # To aid the user in debugging, run the original call in the forward pass and if that + # does not error, do our own error message afterwards. + return generate_lambda_ex(world, source, + Core.svec(:ff, :args), Core.svec(), + quote + args[1](args[2:end]...) + throw($(NonTransformableError)(args)) + end) + end return optic_transform(ci, mi, length(args)-1, N) end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index d003c82d..bfada096 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -36,7 +36,7 @@ gradcheck(f, dims...) = gradcheck(f, rand.(Float64, dims)...) @test gradcheck(dot, randn(3), rand(3)) # given multiple vectors @test gradcheck(dot, 3, 3) # given multiple random vectors -jacobicheck(f, xs::AbstractArray...) = f(xs...) isa Number ? gradcheck(f, xs...) : +jacobicheck(f, xs::AbstractArray...) = f(xs...) isa Number ? gradcheck(f, xs...) : gradcheck((xs...) -> sum(sin, f(xs...)), xs...) jacobicheck(f, dims...) = jacobicheck(f, randn.(Float64, dims)...) @test jacobicheck(identity, [1,2,3]) # one given array @@ -104,26 +104,26 @@ end @test_broken jacobicheck(x -> sum(x, dims = (2, 3)), (3,4,5)) @test_broken jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) @test_broken gradcheck(X -> sum(sum(x -> x^2, X; dims=1)), randn(10)) # issue #681 - + # Non-differentiable sum of booleans @test gradient(sum, [true, false, true]) == (NoTangent(),) @test gradient(x->sum(x .== 0.0), [1.2, 0.2, 0.0, -1.1, 100.0]) == (NoTangent(),) - + # https://github.com/FluxML/Zygote.jl/issues/314 @test gradient((x,y) -> sum(yi -> yi*x, y), 1, [1,1]) == (2, [1, 1]) @test gradient((x,y) -> prod(yi -> yi*x, y), 1, [1,1]) == (2, [1, 1]) - + # AssertionError: Base.issingletontype(typeof(f)) @test_broken gradient((x,y) -> sum(map(yi -> yi*x, y)), 1, [1,1]) == (2, [1, 1]) @test_broken gradient((x,y) -> prod(map(yi -> yi*x, y)), 1, [1,1]) == (2, [1, 1]) - + @test gradcheck(x -> prod(x), (3,4)) @test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2) # MethodError: no method matching copy(::Nothing) @test_broken jacobicheck(x -> prod(x, dims = (2, 3)), (3,4,5)) end - + @testset "cumsum" begin @test jacobicheck(x -> cumsum(x), (4,)) @@ -263,7 +263,7 @@ end @testset "circshift" begin for D in 1:5 x0 = zeros(ntuple(d->5, D)) - g = gradient(x -> x[1], x0)[1] + g = gradient(x -> x[1], x0)[1] shift = ntuple(_ -> rand(-5:5), D) @test gradient(x -> circshift(x, shift)[1], x0)[1] == circshift(g, map(-, shift)) end @@ -374,12 +374,12 @@ end @test_broken gradient(x -> sum(map(first, x)), [(1,2), (3,4)]) == ([(1.0, nothing), (1.0, nothing)],) T = Tangent{Tuple{Int64, Int64}} @test gradient(x -> sum(first, x), [(1,2), (3,4)]) == (T[T(1.0, ZeroTangent()), T(1.0, ZeroTangent())],) - + @test gradient(x -> map(+, x, (1,2,3))[1], (4,5,6)) == (Tangent{Tuple{Int,Int,Int}}(1.0, ZeroTangent(), ZeroTangent()),) # MethodError: no method matching copy(::Nothing) @test_broken gradient(x -> map(+, x, [1,2,3])[1], (4,5,6)) == ((1.0, 0.0, 0.0),) @test_broken gradient(x -> map(+, x, (1,2,3))[1], [4,5,6]) == ([1,0,0],) - + # mismatched lengths, should zip # MethodError: no method matching copy(::Nothing) @test_broken gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),) @@ -413,7 +413,7 @@ end end @test_broken gradient(x -> sum(map(*,x,(1,2,3))), rand(5)) == ([1,2,3,0,0],) @test_broken gradient(x -> sum(map(*,x,[1,2,3])), Tuple(rand(5))) == ((1.0, 2.0, 3.0, nothing, nothing),) - + # mixed shapes # MethodError: no method matching length(::InplaceableThunk{...}) @test_broken gradient((x,y) -> sum(map(*,x,y)), [1,2,3,4], [1 2; 3 4]) == ([1,3,2,4], [1 3; 2 4]) @@ -549,7 +549,7 @@ end catdim = (x...) -> cat(x..., dims = dim) @test_broken jacobicheck(catdim, rand(4,1)) @test_broken jacobicheck(catdim, rand(5), rand(5,1)) - @test_broken jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5)) + @test_broken jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5)) catdimval = (x...) -> cat(x...; dims = Val(dim)) @test_broken jacobicheck(catdimval, rand(5), rand(5)) @@ -620,7 +620,7 @@ end @test_broken jacobicheck(+, A, B, A) @test jacobicheck(-, A) # in typeassert, expected Int64, got a value of type Nothing - @test_broken jacobicheck(-, A, B) + @test jacobicheck(-, A, B) end end