Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adjustments to the latest master #284

Merged
merged 3 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/Diffractor.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
module Diffractor

export ∂⃖, gradient

using StructArrays
using PrecompileTools

export ∂⃖, gradient

const CC = Core.Compiler
using Core.IR

@static if VERSION ≥ v"1.11.0-DEV.1498"
import .CC: get_inference_world
Expand Down Expand Up @@ -33,7 +34,6 @@ end
include("stage2/tfuncs.jl")
include("stage2/forward.jl")

include("codegen/forward.jl")
include("analysis/forward.jl")
include("codegen/forward_demand.jl")
include("codegen/reverse.jl")
Expand All @@ -48,4 +48,4 @@ end
include("AbstractDifferentiation.jl")
end

end
end # module Diffractor
117 changes: 0 additions & 117 deletions src/codegen/forward.jl

This file was deleted.

45 changes: 32 additions & 13 deletions src/codegen/reverse.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# Codegen shared by both stage1 and stage2

function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...)
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci, revs...)
if interp !== nothing
cis.inferred = true
@static if VERSION v"1.12.0-DEV.15"
rettype = Any # ci.rettype # TODO revisit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? is inference broken?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ci no longer has rettype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need further tweak on this code, but maybe later.

else
ci.inferred = true
rettype = ci.rettype
end
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature of this c function changed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, this should ideally use the Core.OpaqueClosure constructor

typ, Union{}, cis.rettype, @__MODULE__, cis, lno.line, lno.file, meth_nargs, isva, ()).source
return Expr(:new_opaque_closure, typ, Union{}, Any,
ocm, revs...)
typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source
return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...)
else
oc_nargs = Int64(meth_nargs)
Expr(:new_opaque_closure, typ, Union{}, Any,
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, cis), revs...)
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...)
end
end

Expand Down Expand Up @@ -107,8 +111,12 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
opaque_ci.slotnames = [Symbol("#oc#"), ci.slotnames...]
opaque_ci.slotflags = UInt8[0, ci.slotflags...]
end
opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]]
opaque_ci.inferred = false
@static if VERSION v"1.12.0-DEV.173"
opaque_ci.debuginfo = ci.debuginfo
else
opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]]
opaque_ci.inferred = false
end
opaque_ci
end

Expand Down Expand Up @@ -393,12 +401,17 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
code = opaque_ci.code = expand_switch(code, bb_ranges, slot_map)
end

opaque_ci.codelocs = Int32[0 for i=1:length(code)]
@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
end
opaque_ci.ssavaluetypes = length(code)
opaque_ci.ssaflags = UInt8[0 for i=1:length(code)]
opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)]
end


for nc = 2:2:n_closures
fwds = Any[nothing for i = 1:length(ir.stmts)]

Expand Down Expand Up @@ -475,9 +488,15 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
end
end

opaque_ci.codelocs = Int32[0 for i=1:length(code)]
@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
end
opaque_ci.ssavaluetypes = length(code)
opaque_ci.ssaflags = UInt8[0 for i=1:length(code)]
opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)]
end

# TODO: This is absolutely aweful, but the best we can do given the data structures we have
Expand Down
32 changes: 17 additions & 15 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Utilities that should probably go into Core.Compiler
using Core.Compiler: IRCode, CFG, BasicBlock, BBIdxIter
# Utilities that should probably go into CC
using .CC: IRCode, CFG, BasicBlock, BBIdxIter

function Base.push!(cfg::CFG, bb::BasicBlock)
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
Expand All @@ -8,38 +8,40 @@ function Base.push!(cfg::CFG, bb::BasicBlock)
end

if VERSION < v"1.11.0-DEV.258"
Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa)
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
end

Base.copy(ir::IRCode) = Core.Compiler.copy(ir)
Base.copy(ir::IRCode) = CC.copy(ir)

Core.Compiler.NewInstruction(@nospecialize node) =
CC.NewInstruction(@nospecialize node) =
NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED)

Base.setproperty!(x::Core.Compiler.Instruction, f::Symbol, v) =
Core.Compiler.setindex!(x, v, f)
Base.setproperty!(x::CC.Instruction, f::Symbol, v) = CC.setindex!(x, v, f)

Base.getproperty(x::Core.Compiler.Instruction, f::Symbol) =
Core.Compiler.getindex(x, f)
Base.getproperty(x::CC.Instruction, f::Symbol) = CC.getindex(x, f)

function Base.setindex!(ir::IRCode, ni::NewInstruction, i::Int)
stmt = ir.stmts[i]
stmt.inst = ni.stmt
stmt.type = ni.type
stmt.flag = something(ni.flag, 0) # fixes 1.9?
stmt.line = something(ni.line, 0)
@static if VERSION v"1.12.0-DEV.173"
stmt.line = something(ni.line, CC.NoLineUpdate)
else
stmt.line = something(ni.line, 0)
end
return ni
end

function Base.push!(ir::IRCode, ni::NewInstruction)
# TODO: This should be a check in insert_node!
@assert length(ir.new_nodes.stmts) == 0
@static if isdefined(Core.Compiler, :add!)
@static if isdefined(CC, :add!)
# Julia 1.7 & 1.8
ir[Core.Compiler.add!(ir.stmts)] = ni
ir[CC.add!(ir.stmts)] = ni
else
# Re-named in https://github.com/JuliaLang/julia/pull/47051
ir[Core.Compiler.add_new_idx!(ir.stmts)] = ni
ir[CC.add_new_idx!(ir.stmts)] = ni
end
ir
end
Expand All @@ -54,8 +56,8 @@ function Base.iterate(it::Iterators.Reverse{BBIdxIter},
return (bb, idx - 1), (bb, idx - 1)
end

Base.lastindex(x::Core.Compiler.InstructionStream) =
Core.Compiler.length(x)
Base.lastindex(x::CC.InstructionStream) =
CC.length(x)

"""
find_end_of_phi_block(ir::IRCode, start_search_idx::Int)
Expand Down
5 changes: 3 additions & 2 deletions src/stage1/hacks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Updated copy of the same code in Base, but with bugs fixed
using Core.Compiler: count_added_node!, NewSSAValue, add_pending!,
StmtRange, BasicBlock
using Core.Compiler:
NewSSAValue, OldSSAValue, StmtRange, BasicBlock,
count_added_node!, add_pending!

# Re-named in https://github.com/JuliaLang/julia/pull/47051
const add! = Core.Compiler.add_inst!
Expand Down
28 changes: 19 additions & 9 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Core.IR
using Core.Compiler:
BasicBlock, CallInfo, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction,
NoCallInfo, OldSSAValue, StmtRange,
BasicBlock, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction, NoCallInfo, StmtRange,
bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete,
construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!,
insert_node_here!, non_dce_finish!, quoted, retrieve_code_info,
Expand Down Expand Up @@ -255,13 +254,15 @@ function sptypes(sparams)
VarState[Core.Compiler.VarState.(sparams, false)...]
end

function optic_transform(ci, args...)
function optic_transform(ci::CodeInfo, args...)
newci = copy(ci)
optic_transform!(newci, args...)
return newci
end

function optic_transform!(ci, mi, nargs, N)
const SSAFlagType = @static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8

function optic_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int)
code = ci.code
sparams = mi.sparam_vals

Expand All @@ -270,11 +271,20 @@ function optic_transform!(ci, mi, nargs, N)
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
ci.slottypes = ci.slottypes === nothing ? Any[Any for _ in 1:length(ci.slotflags)] : Any[Any, Any, ci.slottypes...]

type = Any[]
info = CallInfo[NoCallInfo() for i = 1:length(code)]
flag = SSAFlagType[zero(SSAFlagType) for i = 1:length(code)]
argtypes = Any[Any for i = 1:2]
meta = Expr[]
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
CallInfo[NoCallInfo() for i = 1:length(code)],
ci.codelocs, UInt8[0 for i = 1:length(code)]), cfg, Core.LineInfoNode[ci.linetable...],
Any[Any for i = 1:2], meta, sptypes(sparams))
@static if VERSION ≥ v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(code))
stmts = Core.Compiler.InstructionStream(code, type, info, debuginfo.codelocs, flag)
ir = IRCode(stmts, cfg, debuginfo, argtypes, meta, sptypes(sparams))
else
linetable = Core.LineInfoNode[ci.linetable...]
stmts = Core.Compiler.InstructionStream(code, type, info, ci.codelocs, flag)
ir = IRCode(stmts, cfg, linetable, argtypes, meta, sptypes(sparams))
end

# SSA conversion
meth = mi.def::Method
Expand All @@ -300,7 +310,7 @@ function optic_transform!(ci, mi, nargs, N)
Core.Compiler.replace_code_newstyle!(ci, ir)

ci.ssavaluetypes = length(ci.code)
ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)]
ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(ci.code)]
ci.method_for_inference_limit_heuristics = meth
ci.edges = MethodInstance[mi]

Expand Down
Loading
Loading