Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adjustments to the latest master #284

Merged
merged 3 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/Diffractor.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
module Diffractor

export ∂⃖, gradient

using StructArrays
using PrecompileTools

export ∂⃖, gradient

const CC = Core.Compiler
using Core.IR

@static if VERSION ≥ v"1.11.0-DEV.1498"
import .CC: get_inference_world
Expand Down Expand Up @@ -33,7 +34,6 @@ end
include("stage2/tfuncs.jl")
include("stage2/forward.jl")

include("codegen/forward.jl")
include("analysis/forward.jl")
include("codegen/forward_demand.jl")
include("codegen/reverse.jl")
Expand All @@ -48,4 +48,4 @@ end
include("AbstractDifferentiation.jl")
end

end
end # module Diffractor
117 changes: 0 additions & 117 deletions src/codegen/forward.jl

This file was deleted.

51 changes: 37 additions & 14 deletions src/codegen/reverse.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
# Codegen shared by both stage1 and stage2

function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...)
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci, revs...)
if interp !== nothing
cis.inferred = true
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
typ, Union{}, cis.rettype, @__MODULE__, cis, lno.line, lno.file, meth_nargs, isva, ()).source
return Expr(:new_opaque_closure, typ, Union{}, Any,
ocm, revs...)
@static if VERSION ≥ v"1.12.0-DEV.15"

Check warning on line 5 in src/codegen/reverse.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/reverse.jl#L5

Added line #L5 was not covered by tests
rettype = Any # ci.rettype # TODO revisit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? is inference broken?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ci no longer has rettype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need further tweak on this code, but maybe later.

else
ci.inferred = true
rettype = ci.rettype

Check warning on line 9 in src/codegen/reverse.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/reverse.jl#L8-L9

Added lines #L8 - L9 were not covered by tests
end
@static if VERSION ≥ v"1.12.0-DEV.15"

Check warning on line 11 in src/codegen/reverse.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/reverse.jl#L11

Added line #L11 was not covered by tests
ocm = Core.OpaqueClosure(ci; rettype, nargs=meth_nargs, isva, sig=typ).source
else
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),

Check warning on line 14 in src/codegen/reverse.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/reverse.jl#L14

Added line #L14 was not covered by tests
typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source
end
return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...)

Check warning on line 17 in src/codegen/reverse.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/reverse.jl#L17

Added line #L17 was not covered by tests
else
oc_nargs = Int64(meth_nargs)
Expr(:new_opaque_closure, typ, Union{}, Any,
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, cis), revs...)
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...)
end
end

Expand Down Expand Up @@ -107,8 +115,12 @@
opaque_ci.slotnames = [Symbol("#oc#"), ci.slotnames...]
opaque_ci.slotflags = UInt8[0, ci.slotflags...]
end
opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]]
opaque_ci.inferred = false
@static if VERSION ≥ v"1.12.0-DEV.173"

Check warning on line 118 in src/codegen/reverse.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/reverse.jl#L118

Added line #L118 was not covered by tests
opaque_ci.debuginfo = ci.debuginfo
else
opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]]
opaque_ci.inferred = false
end
opaque_ci
end

Expand Down Expand Up @@ -393,12 +405,17 @@
code = opaque_ci.code = expand_switch(code, bb_ranges, slot_map)
end

opaque_ci.codelocs = Int32[0 for i=1:length(code)]
@static if VERSION ≥ v"1.12.0-DEV.173"

Check warning on line 408 in src/codegen/reverse.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/reverse.jl#L408

Added line #L408 was not covered by tests
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
end
opaque_ci.ssavaluetypes = length(code)
opaque_ci.ssaflags = UInt8[0 for i=1:length(code)]
opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)]
end


for nc = 2:2:n_closures
fwds = Any[nothing for i = 1:length(ir.stmts)]

Expand Down Expand Up @@ -475,9 +492,15 @@
end
end

opaque_ci.codelocs = Int32[0 for i=1:length(code)]
@static if VERSION ≥ v"1.12.0-DEV.173"

Check warning on line 495 in src/codegen/reverse.jl

View check run for this annotation

Codecov / codecov/patch

src/codegen/reverse.jl#L495

Added line #L495 was not covered by tests
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
debuginfo.def = :var"N/A"
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
else
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
end
opaque_ci.ssavaluetypes = length(code)
opaque_ci.ssaflags = UInt8[0 for i=1:length(code)]
opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)]
end

# TODO: This is absolutely aweful, but the best we can do given the data structures we have
Expand Down
6 changes: 3 additions & 3 deletions src/higher_fwd_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ using Base.Iterators

function njet(::Val{N}, ::typeof(sin), x₀) where {N}
(s, c) = sincos(x₀)
Jet(x₀, s, tuple(take(cycle((c, -s, -c, s)), N)...))
Jet(convert(typeof(s), x₀), s, tuple(take(cycle((c, -s, -c, s)), N)...))
end

function njet(::Val{N}, ::typeof(cos), x₀) where {N}
(s, c) = sincos(x₀)
Jet(x₀, s, tuple(take(cycle((-s, -c, s, c)), N)...))
Jet(convert(typeof(s), x₀), s, tuple(take(cycle((-s, -c, s, c)), N)...))
end

function njet(::Val{N}, ::typeof(exp), x₀) where {N}
exped = exp(x₀)
Jet(x₀, exped, tuple(take(repeated(exped), N)...))
Jet(convert(typeof(exped), x₀), exped, tuple(take(repeated(exped), N)...))
end

jeval(j, x) = j(x)
Expand Down
17 changes: 13 additions & 4 deletions src/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
end

function domain_check(j::Jet, x)
if j.a !== x
if j.a !== convert(typeof(j.a), x)
throw(DomainError("Evaluation is only valid at a"))
end
end
Expand Down Expand Up @@ -153,11 +153,17 @@
end

function ChainRulesCore.rrule(::typeof(map), ::typeof(*), a, b)
map(*, a, b), Δ->(NoTangent(), NoTangent(), map(*, Δ, b), map(*, a, Δ))
map(*, a, b), Δ->let Δ=unthunk(Δ)
isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent(), NoTangent())
(NoTangent(), NoTangent(), map(*, Δ, b), map(*, a, Δ))
end
end

ChainRulesCore.rrule(::typeof(map), ::typeof(integrate), js::Array{<:Jet}) =
map(integrate, js), Δ->(NoTangent(), NoTangent(), map(deriv, Δ))
map(integrate, js), Δ->let Δ=unthunk(Δ)
isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent())
(NoTangent(), NoTangent(), map(deriv, Δ))

Check warning on line 165 in src/jet.jl

View check run for this annotation

Codecov / codecov/patch

src/jet.jl#L163-L165

Added lines #L163 - L165 were not covered by tests
end

struct derivBack
js
Expand All @@ -177,7 +183,10 @@

function ChainRulesCore.rrule(::typeof(mapev), js::Array{<:Jet}, xs::AbstractArray)
mapev(js, xs), let djs=map(deriv, js)
Δ->(NoTangent(), NoTangent(), map(*, unthunk(Δ), mapev(djs, xs)))
function (Δ)
isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent())
(NoTangent(), NoTangent(), map(*, unthunk(Δ), mapev(djs, xs)))
end
end
end

Expand Down
32 changes: 17 additions & 15 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Utilities that should probably go into Core.Compiler
using Core.Compiler: IRCode, CFG, BasicBlock, BBIdxIter
# Utilities that should probably go into CC
using .CC: IRCode, CFG, BasicBlock, BBIdxIter

function Base.push!(cfg::CFG, bb::BasicBlock)
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
Expand All @@ -8,38 +8,40 @@
end

if VERSION < v"1.11.0-DEV.258"
Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa)
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
end

Base.copy(ir::IRCode) = Core.Compiler.copy(ir)
Base.copy(ir::IRCode) = CC.copy(ir)

Core.Compiler.NewInstruction(@nospecialize node) =
CC.NewInstruction(@nospecialize node) =
NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED)

Base.setproperty!(x::Core.Compiler.Instruction, f::Symbol, v) =
Core.Compiler.setindex!(x, v, f)
Base.setproperty!(x::CC.Instruction, f::Symbol, v) = CC.setindex!(x, v, f)

Base.getproperty(x::Core.Compiler.Instruction, f::Symbol) =
Core.Compiler.getindex(x, f)
Base.getproperty(x::CC.Instruction, f::Symbol) = CC.getindex(x, f)

function Base.setindex!(ir::IRCode, ni::NewInstruction, i::Int)
stmt = ir.stmts[i]
stmt.inst = ni.stmt
stmt.type = ni.type
stmt.flag = something(ni.flag, 0) # fixes 1.9?
stmt.line = something(ni.line, 0)
@static if VERSION ≥ v"1.12.0-DEV.173"

Check warning on line 28 in src/stage1/compiler_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/stage1/compiler_utils.jl#L28

Added line #L28 was not covered by tests
stmt.line = something(ni.line, CC.NoLineUpdate)
else
stmt.line = something(ni.line, 0)
end
return ni
end

function Base.push!(ir::IRCode, ni::NewInstruction)
# TODO: This should be a check in insert_node!
@assert length(ir.new_nodes.stmts) == 0
@static if isdefined(Core.Compiler, :add!)
@static if isdefined(CC, :add!)

Check warning on line 39 in src/stage1/compiler_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/stage1/compiler_utils.jl#L39

Added line #L39 was not covered by tests
# Julia 1.7 & 1.8
ir[Core.Compiler.add!(ir.stmts)] = ni
ir[CC.add!(ir.stmts)] = ni
else
# Re-named in https://github.com/JuliaLang/julia/pull/47051
ir[Core.Compiler.add_new_idx!(ir.stmts)] = ni
ir[CC.add_new_idx!(ir.stmts)] = ni
end
ir
end
Expand All @@ -54,8 +56,8 @@
return (bb, idx - 1), (bb, idx - 1)
end

Base.lastindex(x::Core.Compiler.InstructionStream) =
Core.Compiler.length(x)
Base.lastindex(x::CC.InstructionStream) =
CC.length(x)

"""
find_end_of_phi_block(ir::IRCode, start_search_idx::Int)
Expand Down
2 changes: 1 addition & 1 deletion src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...)
end

function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
bundles = map(bundle, partials, args)
bundles = map(bundle, args, partials)
result = ∂☆internal{1}()(bundles...)
primal(result), first_partial(result)
end
Expand Down
Loading
Loading