Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Nov 23, 2023
1 parent 146ae3d commit 18c705e
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 49 deletions.
4 changes: 4 additions & 0 deletions src/analysis/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
frule_call = CC.abstract_call_gf_by_type(interp′,
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)
if frule_call.rt !== Const(nothing)
@static if VERSION v"1.11.0-DEV.945"

Check warning on line 29 in src/analysis/forward.jl

View check run for this annotation

Codecov / codecov/patch

src/analysis/forward.jl#L29

Added line #L29 was not covered by tests
return CallMeta(primal_call.rt, primal_call.exct, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call))
else
return CallMeta(primal_call.rt, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call))
end
else
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
end
Expand Down
7 changes: 0 additions & 7 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,6 @@ end
Base.lastindex(x::Core.Compiler.InstructionStream) =
Core.Compiler.length(x)

# Solves an error after https://github.com/JuliaLang/julia/pull/46961
# as does https://github.com/FluxML/IRTools.jl/pull/101
if isdefined(Core.Compiler, :CallInfo)
Base.convert(::Type{Core.Compiler.CallInfo}, ::Nothing) = Core.Compiler.NoCallInfo()
end


"""
find_end_of_phi_block(ir::IRCode, start_search_idx::Int)
Expand Down
8 changes: 4 additions & 4 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Core.IR
using Core.Compiler:
Argument, BasicBlock, CFG, CodeInfo, GotoIfNot, GotoNode, IRCode, IncrementalCompact,
Instruction, MethodInstance, NewInstruction, NewvarNode, OldSSAValue, PhiNode,
ReturnNode, SSAValue, SlotNumber, StmtRange,
BasicBlock, CallInfo, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction,
NoCallInfo, OldSSAValue, StmtRange,
bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete,
construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!,
insert_node_here!, effect_free_and_nothrow, non_dce_finish!, quoted, retrieve_code_info,
Expand Down Expand Up @@ -266,7 +266,7 @@ function optic_transform!(ci, mi, nargs, N)

meta = Expr[]
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
Any[nothing for i = 1:length(code)],
CallInfo[NoCallInfo() for i = 1:length(code)],
ci.codelocs, UInt8[0 for i = 1:length(code)]), cfg, Core.LineInfoNode[ci.linetable...],
Any[Any for i = 1:2], meta, sptypes(sparams))

Expand Down
171 changes: 133 additions & 38 deletions src/stage2/abstractinterpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using .CC: Const, isconstType, argtypes_to_type, tuple_tfunc, Const,
getfield_tfunc, _methods_by_ftype, VarTable, nfields_tfunc,
ArgInfo, singleton_type, CallMeta, MethodMatchInfo, specialize_method,
PartialOpaque, UnionSplitApplyCallInfo, typeof_tfunc, apply_type_tfunc, instanceof_tfunc,
StmtInfo
StmtInfo, NoCallInfo
using Core: PartialStruct
using Base.Meta

Expand Down Expand Up @@ -41,7 +41,11 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
else
rt2 = obtype
end
@static if VERSION v"1.11.0-DEV.945"

Check warning on line 44 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L44

Added line #L44 was not covered by tests
return CallMeta(rt2, call.exct, call.effects, RecurseInfo(call.info))
else
return CallMeta(rt2, call.effects, RecurseInfo(call.info))
end
end

# Check if there is a rrule for this function
Expand All @@ -56,7 +60,12 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
end
call = abstract_call_gf_by_type(lower_level(interp), ChainRules.rrule, ArgInfo(nothing, rrule_argtypes), rrule_atype, sv, -1)
if call.rt != Const(nothing)
return CallMeta(getfield_tfunc(call.rt, Const(1)), call.effects, RRuleInfo(call.rt, call.info))
newrt = getfield_tfunc(call.rt, Const(1))
@static if VERSION v"1.11.0-DEV.945"

Check warning on line 64 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
return CallMeta(newrt, call.exct, call.effects, RRuleInfo(call.rt, call.info))
else
return CallMeta(newrt, call.exct, call.effects, RRuleInfo(call.rt, call.info))

Check warning on line 67 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L67

Added line #L67 was not covered by tests
end
end
end
end
Expand All @@ -74,26 +83,39 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
return ret
end

function abstract_accum(interp::AbstractInterpreter, args::Vector{Any}, sv::InferenceState)
args = filter(x->!(widenconst(x) <: Union{ZeroTangent, NoTangent}), args)
function abstract_accum(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
argtypes = filter(@nospecialize(x)->!(widenconst(x) <: Union{ZeroTangent, NoTangent}), argtypes)

Check warning on line 87 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L86-L87

Added lines #L86 - L87 were not covered by tests

if length(args) == 0
return CallMeta(ZeroTangent, Effects(), nothing)
if length(argtypes) == 0
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(ZeroTangent, Any, Effects(), NoCallInfo())

Check warning on line 91 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L89-L91

Added lines #L89 - L91 were not covered by tests
else
return CallMeta(ZeroTangent, Effects(), NoCallInfo())

Check warning on line 93 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L93

Added line #L93 was not covered by tests
end
end

if length(args) == 1
return CallMeta(args[1], Effects(), nothing)
if length(argtypes) == 1
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(argtypes[1], Any, Effects(), NoCallInfo())

Check warning on line 99 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L97-L99

Added lines #L97 - L99 were not covered by tests
else
return CallMeta(argtypes[1], Effects(), NoCallInfo())

Check warning on line 101 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L101

Added line #L101 was not covered by tests
end
end

rtype = reduce(tmerge, args)
rtype = reduce(tmerge, argtypes)

Check warning on line 105 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L105

Added line #L105 was not covered by tests
if widenconst(rtype) <: Tuple
targs = Any[]
for i = 1:nfields_tfunc(rtype).val
push!(targs, abstract_accum(interp, Any[getfield_tfunc(arg, Const(i)) for arg in args], sv).rt)
push!(targs, abstract_accum(interp, Any[getfield_tfunc(arg, Const(i)) for arg in argtypes], sv).rt)
end
rt = tuple_tfunc(targs)
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), NoCallInfo())

Check warning on line 113 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L109-L113

Added lines #L109 - L113 were not covered by tests
else
return CallMeta(rt, Effects(), NoCallInfo())

Check warning on line 115 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L115

Added line #L115 was not covered by tests
end
return CallMeta(tuple_tfunc(targs), nothing)
end
call = abstract_call(change_level(interp, 0), nothing, Any[typeof(accum), args...],
call = abstract_call(change_level(interp, 0), nothing, Any[typeof(accum), argtypes...],

Check warning on line 118 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L118

Added line #L118 was not covered by tests
sv::InferenceState)
return call
end
Expand Down Expand Up @@ -249,7 +271,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
ft = argextype(inst.args[1], primal, primal.sptypes)
f = singleton_type(ft)
if isa(f, Core.Builtin)
call = CallMeta(backwards_tfunc(f, primal, inst, Δ), nothing)
rt = backwards_tfunc(f, primal, inst, Δ)
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(rt, Any, Effects(), NoCallInfo())

Check warning on line 276 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L274-L276

Added lines #L274 - L276 were not covered by tests
else
call = CallMeta(rt, Effects(), NoCallInfo())

Check warning on line 278 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L278

Added line #L278 was not covered by tests
end
else
bail!(inst)
continue
Expand All @@ -265,7 +292,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
arg = getfield_tfunc(Δ, Const(1))
call = abstract_call(interp, nothing, Any[clos, arg], sv)
# No derivative wrt the functor
call = CallMeta(tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]), ReifyInfo(call.info))
rt = tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...])
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(rt, Any, Effects(), ReifyInfo(call.info))

Check warning on line 297 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L295-L297

Added lines #L295 - L297 were not covered by tests
else
call = CallMeta(rt, Effects(), ReifyInfo(call.info))

Check warning on line 299 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L299

Added line #L299 was not covered by tests
end
else
(level, close) = derive_closure_type(call_info)
call = abstract_call(change_level(interp, level), ArgInfo(nothing, Any[close, Δ]), sv)
Expand All @@ -274,13 +306,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp

if isa(info, UnionSplitApplyCallInfo)
argts = Any[argextype(inst.args[i], primal, primal.sptypes) for i = 4:length(inst.args)]
call = CallMeta(repackage_apply_rt(info, call.rt, argts),
UnionSplitApplyCallInfo([ApplyCallInfo(call.info)]))
rt = repackage_apply_rt(info, call.rt, argts)
newinfo = UnionSplitApplyCallInfo([ApplyCallInfo(call.info)])
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(rt, Any, Effects(), newinfo)

Check warning on line 312 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L309-L312

Added lines #L309 - L312 were not covered by tests
else
call = CallMeta(rt, Effects(), newinfo)

Check warning on line 314 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L314

Added line #L314 was not covered by tests
end
end

if isa(call_info, ReifyInfo)
new_rt = tuple_tfunc(Any[derive_closure_type(call.info)[2]; call.rt])
call = CallMeta(new_rt, RecurseInfo(call.info))
newinfo = RecurseInfo(call.info)
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(new_rt, Any, Effects(), newinfo)

Check warning on line 322 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L320-L322

Added lines #L320 - L322 were not covered by tests
else
call = CallMeta(new_rt, Effects(), newinfo)

Check warning on line 324 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L324

Added line #L324 was not covered by tests
end
end

if call.rt === Union{}
Expand Down Expand Up @@ -312,15 +354,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
accum_call = abstract_accum(interp, this_arg_typs, sv)
if accum_call.rt == Union{}
@show accum_call.rt
return CallMeta(Union{}, false)
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(Union{}, Any, Effects(), NoCallInfo())

Check warning on line 358 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L357-L358

Added lines #L357 - L358 were not covered by tests
else
return CallMeta(Union{}, Effects(), NoCallInfo())

Check warning on line 360 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L360

Added line #L360 was not covered by tests
end
end
push!(arg_accums, accum_call)
tup_push!(tup_elemns, accum_call.rt)
end
end

rt = tuple_tfunc(Any[tup_elemns...])
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), CompClosInfo(cc, ssa_infos))

Check warning on line 370 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L369-L370

Added lines #L369 - L370 were not covered by tests
else
return CallMeta(rt, Effects(), CompClosInfo(cc, ssa_infos))
end
end

function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospecialize(cc_Δ), sv::InferenceState)
Expand Down Expand Up @@ -389,7 +439,11 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe

if isa(inst, ReturnNode)
rt = accum_arg(inst.val)
return CallMeta(rt, CompClosInfo(cc, ssa_infos))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), CompClosInfo(cc, ssa_infos))

Check warning on line 443 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L442-L443

Added lines #L442 - L443 were not covered by tests
else
return CallMeta(rt, Effects(), CompClosInfo(cc, ssa_infos))

Check warning on line 445 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L445

Added line #L445 was not covered by tests
end
end

args = Any[]
Expand Down Expand Up @@ -451,7 +505,12 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
arg = getfield_tfunc(Δ, Const(2))
call = abstract_call(interp, nothing, Any[clos, arg], sv)
# No derivative wrt the functor
call = CallMeta(tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]), ReifyInfo(call.info))
newrt = tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...])
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(newrt, Any, Effects(), ReifyInfo(call.info))

Check warning on line 510 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L508-L510

Added lines #L508 - L510 were not covered by tests
else
call = CallMeta(newrt, Effects(), ReifyInfo(call.info))

Check warning on line 512 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L512

Added line #L512 was not covered by tests
end
#error()
else
(level, clos) = derive_closure_type(call_info)
Expand All @@ -461,11 +520,20 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe

if isa(call_info, ReifyInfo)
new_rt = tuple_tfunc(Any[call.rt; derive_closure_type(call.info)[2]])
call = CallMeta(new_rt, RecurseInfo())
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(new_rt, Any, Effects(), RecurseInfo())

Check warning on line 524 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L523-L524

Added lines #L523 - L524 were not covered by tests
else
call = CallMeta(new_rt, Effects(), RecurseInfo())

Check warning on line 526 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L526

Added line #L526 was not covered by tests
end
end

if isa(info, UnionSplitApplyCallInfo)
call = CallMeta(call.rt, UnionSplitApplyCallInfo([ApplyCallInfo(call.info)]))
newinfo = UnionSplitApplyCallInfo([ApplyCallInfo(call.info)])
@static if VERSION v"1.11.0-DEV.945"
call = CallMeta(call.rt, call.exct, Effects(), newinfo)

Check warning on line 533 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L531-L533

Added lines #L531 - L533 were not covered by tests
else
call = CallMeta(call.rt, Effects(), newinfo)

Check warning on line 535 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L535

Added line #L535 was not covered by tests
end
end

accums[i] = call.rt
Expand All @@ -485,13 +553,16 @@ function infer_comp_closure(interp::ADInterpreter, cc::AbstractCompClosure, @nos
end

function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecialize(Δ), sv::InferenceState)
@show ("enter", pc)

if pc.seq == 1
call = abstract_call(change_level(interp, pc.order), nothing, Any[pc.dual, Δ], sv)
rt = call.rt
@show (pc, Δ, rt)
return CallMeta(call.rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(call.rt, call.exct, Effects(), newinfo)

Check warning on line 562 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L560-L562

Added lines #L560 - L562 were not covered by tests
else
return CallMeta(call.rt, Effects(), newinfo)

Check warning on line 564 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L564

Added line #L564 was not covered by tests
end
elseif pc.seq == 2
ni = change_level(interp, pc.order)
mi′ = specialize_method(pc.info_below.results.matches[1], true)
Expand All @@ -500,8 +571,12 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
call = infer_comp_closure(ni, cc, Δ, sv)
rt = getfield_tfunc(call.rt, Const(2))
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)

Check warning on line 576 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L574-L576

Added lines #L574 - L576 were not covered by tests
else
return CallMeta(rt, Effects(), newinfo)

Check warning on line 578 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L578

Added line #L578 was not covered by tests
end
elseif pc.seq == 3
ni = change_level(interp, pc.order)
mi′ = specialize_method(pc.info_carried.info.results.matches[1], true)
Expand All @@ -511,41 +586,62 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
Any[clos, tuple_tfunc(Any[Δ, pc.dual])], sv)
rt = tuple_tfunc(Any[tuple_type_fields(call.rt)[2:end]...])
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)

Check warning on line 591 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L589-L591

Added lines #L589 - L591 were not covered by tests
else
return CallMeta(rt, Effects(), newinfo)

Check warning on line 593 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L593

Added line #L593 was not covered by tests
end
elseif mod(pc.seq, 4) == 0
info = pc.info_below
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)

# Add back gradient w.r.t. rrule
Δ = tuple_tfunc(Any[NoTangent, tuple_type_fields(Δ)...])
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, Δ], sv)
rt = getfield_tfunc(call.rt, Const(1))
@show (pc, Δ, rt)
return CallMeta(rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(2)), call.info, pc.info_carried)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(2)), call.info, pc.info_carried))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)

Check warning on line 605 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L603-L605

Added lines #L603 - L605 were not covered by tests
else
return CallMeta(rt, Effects(), newinfo)

Check warning on line 607 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L607

Added line #L607 was not covered by tests
end
elseif mod(pc.seq, 4) == 1
info = pc.info_carried
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, tuple_tfunc(Any[pc.dual, Δ])], sv)
rt = call.rt
@show (pc, Δ, rt)
return CallMeta(call.rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)

Check warning on line 617 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L615-L617

Added lines #L615 - L617 were not covered by tests
else
return CallMeta(rt, Effects(), newinfo)

Check warning on line 619 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L619

Added line #L619 was not covered by tests
end
elseif mod(pc.seq, 4) == 2
info = pc.info_below
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, Δ], sv)
rt = getfield_tfunc(call.rt, Const(2))
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)

Check warning on line 629 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L627-L629

Added lines #L627 - L629 were not covered by tests
else
return CallMeta(rt, Effects(), newinfo)

Check warning on line 631 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L631

Added line #L631 was not covered by tests
end
elseif mod(pc.seq, 4) == 3
info = pc.info_carried
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, tuple_tfunc(Any[Δ, pc.dual])], sv)
rt = tuple_tfunc(Any[tuple_type_fields(call.rt)[2:end]...])
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)

Check warning on line 641 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L639-L641

Added lines #L639 - L641 were not covered by tests
else
return CallMeta(rt, Effects(), newinfo)

Check warning on line 643 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L643

Added line #L643 was not covered by tests
end
end
error()
end
Expand All @@ -556,8 +652,7 @@ function CC.abstract_call_opaque_closure(interp::ADInterpreter,
if isa(closure.source, AbstractCompClosure)
(;argtypes) = arginfo
if length(argtypes) !== 2
error()
return CallMeta(Union{}, false)
error("bad argtypes")

Check warning on line 655 in src/stage2/abstractinterpret.jl

View check run for this annotation

Codecov / codecov/patch

src/stage2/abstractinterpret.jl#L655

Added line #L655 was not covered by tests
end
return infer_comp_closure(interp, closure.source, argtypes[2], sv)
elseif isa(closure.source, PrimClosure)
Expand Down

0 comments on commit 18c705e

Please sign in to comment.