Skip to content

Commit

Permalink
wip: adjustments to the latest master
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Mar 23, 2024
1 parent a444b7f commit 8c4c348
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 179 deletions.
6 changes: 3 additions & 3 deletions src/Diffractor.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
module Diffractor

export ∂⃖, gradient

using StructArrays
using PrecompileTools

export ∂⃖, gradient

const CC = Core.Compiler
using Core.IR

@static if VERSION v"1.11.0-DEV.1498"
import .CC: get_inference_world
Expand Down Expand Up @@ -33,7 +34,6 @@ end
include("stage2/tfuncs.jl")
include("stage2/forward.jl")

include("codegen/forward.jl")
include("analysis/forward.jl")
include("codegen/forward_demand.jl")
include("codegen/reverse.jl")
Expand Down
117 changes: 0 additions & 117 deletions src/codegen/forward.jl

This file was deleted.

45 changes: 32 additions & 13 deletions src/codegen/reverse.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# Codegen shared by both stage1 and stage2

function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...)
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci, revs...)
if interp !== nothing
cis.inferred = true
@static if VERSION v"1.12.0-DEV.15"
rettype = Any # ci.rettype # TODO revisit
else
ci.inferred = true
rettype = ci.rettype
end
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
typ, Union{}, cis.rettype, @__MODULE__, cis, lno.line, lno.file, meth_nargs, isva, ()).source
return Expr(:new_opaque_closure, typ, Union{}, Any,
ocm, revs...)
typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source
return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...)
else
oc_nargs = Int64(meth_nargs)
Expr(:new_opaque_closure, typ, Union{}, Any,
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, cis), revs...)
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...)
end
end

Expand Down Expand Up @@ -107,8 +111,12 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
opaque_ci.slotnames = [Symbol("#oc#"), ci.slotnames...]
opaque_ci.slotflags = UInt8[0, ci.slotflags...]
end
opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]]
opaque_ci.inferred = false
@static if VERSION v"1.12.0-DEV.173"
opaque_ci.debuginfo = ci.debuginfo
else
opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]]
opaque_ci.inferred = false
end
opaque_ci
end

Expand Down Expand Up @@ -393,12 +401,17 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
code = opaque_ci.code = expand_switch(code, bb_ranges, slot_map)
end

opaque_ci.codelocs = Int32[0 for i=1:length(code)]
@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
end
opaque_ci.ssavaluetypes = length(code)
opaque_ci.ssaflags = UInt8[0 for i=1:length(code)]
opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)]
end


for nc = 2:2:n_closures
fwds = Any[nothing for i = 1:length(ir.stmts)]

Expand Down Expand Up @@ -475,9 +488,15 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
end
end

opaque_ci.codelocs = Int32[0 for i=1:length(code)]
@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
end
opaque_ci.ssavaluetypes = length(code)
opaque_ci.ssaflags = UInt8[0 for i=1:length(code)]
opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)]
end

# TODO: This is absolutely aweful, but the best we can do given the data structures we have
Expand Down
32 changes: 17 additions & 15 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Utilities that should probably go into Core.Compiler
using Core.Compiler: IRCode, CFG, BasicBlock, BBIdxIter
# Utilities that should probably go into CC
using .CC: IRCode, CFG, BasicBlock, BBIdxIter

function Base.push!(cfg::CFG, bb::BasicBlock)
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
Expand All @@ -8,38 +8,40 @@ function Base.push!(cfg::CFG, bb::BasicBlock)
end

if VERSION < v"1.11.0-DEV.258"
Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa)
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
end

Base.copy(ir::IRCode) = Core.Compiler.copy(ir)
Base.copy(ir::IRCode) = CC.copy(ir)

Core.Compiler.NewInstruction(@nospecialize node) =
CC.NewInstruction(@nospecialize node) =
NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED)

Base.setproperty!(x::Core.Compiler.Instruction, f::Symbol, v) =
Core.Compiler.setindex!(x, v, f)
Base.setproperty!(x::CC.Instruction, f::Symbol, v) = CC.setindex!(x, v, f)

Base.getproperty(x::Core.Compiler.Instruction, f::Symbol) =
Core.Compiler.getindex(x, f)
Base.getproperty(x::CC.Instruction, f::Symbol) = CC.getindex(x, f)

function Base.setindex!(ir::IRCode, ni::NewInstruction, i::Int)
stmt = ir.stmts[i]
stmt.inst = ni.stmt
stmt.type = ni.type
stmt.flag = something(ni.flag, 0) # fixes 1.9?
stmt.line = something(ni.line, 0)
@static if VERSION v"1.12.0-DEV.173"
stmt.line = something(ni.line, CC.NoLineUpdate)
else
stmt.line = something(ni.line, 0)
end
return ni
end

function Base.push!(ir::IRCode, ni::NewInstruction)
# TODO: This should be a check in insert_node!
@assert length(ir.new_nodes.stmts) == 0
@static if isdefined(Core.Compiler, :add!)
@static if isdefined(CC, :add!)
# Julia 1.7 & 1.8
ir[Core.Compiler.add!(ir.stmts)] = ni
ir[CC.add!(ir.stmts)] = ni
else
# Re-named in https://github.com/JuliaLang/julia/pull/47051
ir[Core.Compiler.add_new_idx!(ir.stmts)] = ni
ir[CC.add_new_idx!(ir.stmts)] = ni
end
ir
end
Expand All @@ -54,8 +56,8 @@ function Base.iterate(it::Iterators.Reverse{BBIdxIter},
return (bb, idx - 1), (bb, idx - 1)
end

Base.lastindex(x::Core.Compiler.InstructionStream) =
Core.Compiler.length(x)
Base.lastindex(x::CC.InstructionStream) =
CC.length(x)

"""
find_end_of_phi_block(ir::IRCode, start_search_idx::Int)
Expand Down
5 changes: 3 additions & 2 deletions src/stage1/hacks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Updated copy of the same code in Base, but with bugs fixed
using Core.Compiler: count_added_node!, NewSSAValue, add_pending!,
StmtRange, BasicBlock
using Core.Compiler:
NewSSAValue, OldSSAValue, StmtRange, BasicBlock,
count_added_node!, add_pending!

# Re-named in https://github.com/JuliaLang/julia/pull/47051
const add! = Core.Compiler.add_inst!
Expand Down
28 changes: 19 additions & 9 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Core.IR
using Core.Compiler:
BasicBlock, CallInfo, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction,
NoCallInfo, OldSSAValue, StmtRange,
BasicBlock, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction, NoCallInfo, StmtRange,
bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete,
construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!,
insert_node_here!, non_dce_finish!, quoted, retrieve_code_info,
Expand Down Expand Up @@ -255,13 +254,15 @@ function sptypes(sparams)
VarState[Core.Compiler.VarState.(sparams, false)...]
end

function optic_transform(ci, args...)
function optic_transform(ci::CodeInfo, args...)
newci = copy(ci)
optic_transform!(newci, args...)
return newci
end

function optic_transform!(ci, mi, nargs, N)
const SSAFlagType = @static VERSION v"1.11.0-DEV.377" ? UInt32 : UInt8

function optic_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int)
code = ci.code
sparams = mi.sparam_vals

Expand All @@ -270,11 +271,20 @@ function optic_transform!(ci, mi, nargs, N)
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
ci.slottypes = ci.slottypes === nothing ? Any[Any for _ in 1:length(ci.slotflags)] : Any[Any, Any, ci.slottypes...]

type = Any[]
info = CallInfo[NoCallInfo() for i = 1:length(code)]
flag = SSAFlagType[zero(SSAFlagType) for i = 1:length(code)]
argtypes = Any[Any for i = 1:2]
meta = Expr[]
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
CallInfo[NoCallInfo() for i = 1:length(code)],
ci.codelocs, UInt8[0 for i = 1:length(code)]), cfg, Core.LineInfoNode[ci.linetable...],
Any[Any for i = 1:2], meta, sptypes(sparams))
@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(code))
stmts = Core.Compiler.InstructionStream(code, type, info, debuginfo.codelocs, flag)
ir = IRCode(stmts, cfg, debuginfo, argtypes, meta, sptypes(sparams))
else
linetable = Core.LineInfoNode[ci.linetable...]
stmts = Core.Compiler.InstructionStream(code, type, info, ci.codelocs, flag)
ir = IRCode(stmts, cfg, debuginfo, argtypes, meta, sptypes(sparams))
end

# SSA conversion
meth = mi.def::Method
Expand All @@ -300,7 +310,7 @@ function optic_transform!(ci, mi, nargs, N)
Core.Compiler.replace_code_newstyle!(ci, ir)

ci.ssavaluetypes = length(ci.code)
ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)]
ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(ci.code)]
ci.method_for_inference_limit_heuristics = meth
ci.edges = MethodInstance[mi]

Expand Down
Loading

0 comments on commit 8c4c348

Please sign in to comment.