Skip to content

Commit

Permalink
Demonstrate backtraces using juliajit
Browse files Browse the repository at this point in the history
  • Loading branch information
gbaraldi committed May 22, 2023
1 parent fcba321 commit 1a136e2
Showing 1 changed file with 51 additions and 49 deletions.
100 changes: 51 additions & 49 deletions src/compiler/orcv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import ..Compiler: API, cpu_name, cpu_features
export get_trampoline

struct CompilerInstance
jit::LLVM.LLJIT
jit::LLVM.JuliaOJIT
lctm::Union{LLVM.LazyCallThroughManager, Nothing}
ism::Union{LLVM.IndirectStubsManager, Nothing}
end
Expand Down Expand Up @@ -60,57 +60,59 @@ function __init__()
optlevel = LLVM.API.LLVMCodeGenLevelAggressive
end

tempTM = LLVM.JITTargetMachine(LLVM.triple(), cpu_name(), cpu_features(); optlevel)
tempTM = LLVM.JITTargetMachine(LLVM.triple(), cpu_name(), cpu_features(); optlevel)
LLVM.asm_verbosity!(tempTM, true)
tm[] = tempTM

tempTM = LLVM.JITTargetMachine(LLVM.triple(), cpu_name(), cpu_features(); optlevel)
LLVM.asm_verbosity!(tempTM, true)

if haskey(ENV, "ENABLE_GDBLISTENER")
ollc = LLVM.ObjectLinkingLayerCreator() do es, triple
oll = ObjectLinkingLayer(es)
register!(oll, GDBRegistrationListener())
return oll
end

GC.@preserve ollc begin
builder = LLJITBuilder()
LLVM.linkinglayercreator!(builder, ollc)
tmb = TargetMachineBuilder(tempTM)
LLVM.targetmachinebuilder!(builder, tmb)
lljit = LLJIT(builder)
end
else
lljit = LLJIT(;tm=tempTM)
end

jd_main = JITDylib(lljit)

prefix = LLVM.get_prefix(lljit)
# if haskey(ENV, "ENABLE_GDBLISTENER")
# ollc = LLVM.ObjectLinkingLayerCreator() do es, triple
# oll = ObjectLinkingLayer(es)
# register!(oll, GDBRegistrationListener())
# return oll
# end

# GC.@preserve ollc begin
# builder = LLJITBuilder()
# LLVM.linkinglayercreator!(builder, ollc)
# tmb = TargetMachineBuilder(tempTM)
# LLVM.targetmachinebuilder!(builder, tmb)
# lljit = LLJIT(builder)
# end
# else
# lljit = LLJIT(;tm=tempTM)

# end
jljit = JuliaOJIT()

jd_main = JITDylib(jljit)

prefix = LLVM.get_prefix(jljit)
dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix)
LLVM.add!(jd_main, dg)

if Sys.iswindows() && Int === Int64
# TODO can we check isGNU?
define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms"))
define_absolute_symbol(jd_main, mangle(jljit, "___chkstk_ms"))
end

es = ExecutionSession(lljit)
es = ExecutionSession(jljit)
try
lctm = LLVM.LocalLazyCallThroughManager(triple(lljit), es)
ism = LLVM.LocalIndirectStubsManager(triple(lljit))
jit[] = CompilerInstance(lljit, lctm, ism)
lctm = LLVM.LocalLazyCallThroughManager(triple(jljit), es)
ism = LLVM.LocalIndirectStubsManager(triple(jljit))
jit[] = CompilerInstance(jljit, lctm, ism)
catch err
@warn "OrcV2 initialization failed with" err
jit[] = CompilerInstance(lljit, nothing, nothing)
jit[] = CompilerInstance(jljit, nothing, nothing)
end

atexit() do
ci = jit[]
dispose(ci)
dispose(tm[])
end
# atexit() do
# ci = jit[]
# dispose(ci)
# dispose(tm[])
# end
end

function move_to_threadsafe(ir)
Expand All @@ -126,25 +128,25 @@ function move_to_threadsafe(ir)
end
end

function add_trampoline!(jd, (lljit, lctm, ism), entry, target)
function add_trampoline!(jd, (jljit, lctm, ism), entry, target)
flags = LLVM.API.LLVMJITSymbolFlags(
LLVM.API.LLVMJITSymbolGenericFlagsCallable |
LLVM.API.LLVMJITSymbolGenericFlagsExported, 0)

alias = LLVM.API.LLVMOrcCSymbolAliasMapPair(
mangle(lljit, entry),
mangle(jljit, entry),
LLVM.API.LLVMOrcCSymbolAliasMapEntry(
mangle(lljit, target), flags))
mangle(jljit, target), flags))

mu = LLVM.reexports(lctm, ism, jd, [alias])
LLVM.define(jd, mu)
LLVM.lookup(lljit, entry)

LLVM.lookup(jljit, entry)
end

function get_trampoline(job)
compiler = jit[]
lljit = compiler.jit
jljit = compiler.jit
lctm = compiler.lctm
ism = compiler.ism

Expand All @@ -156,17 +158,17 @@ function get_trampoline(job)
needs_augmented_primal = mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient

# We could also use one dylib per job
jd = JITDylib(lljit)
jd = JITDylib(jljit)

adjoint_sym = String(gensym(:adjoint))
_adjoint_sym = String(gensym(:adjoint))
adjoint_addr = add_trampoline!(jd, (lljit, lctm, ism),
adjoint_addr = add_trampoline!(jd, (jljit, lctm, ism),
_adjoint_sym, adjoint_sym)

if needs_augmented_primal
primal_sym = String(gensym(:augmented_primal))
_primal_sym = String(gensym(:augmented_primal))
primal_addr = add_trampoline!(jd, (lljit, lctm, ism),
primal_addr = add_trampoline!(jd, (jljit, lctm, ism),
_primal_sym, primal_sym)
else
primal_sym = nothing
Expand All @@ -193,7 +195,7 @@ function get_trampoline(job)
end

tsm = move_to_threadsafe(mod)
il = LLVM.IRTransformLayer(lljit)
il = LLVM.IRCompileLayer(jljit)
LLVM.emit(il, mr, tsm)

return nothing
Expand All @@ -207,11 +209,11 @@ function get_trampoline(job)

symbols = [
LLVM.API.LLVMOrcCSymbolFlagsMapPair(
mangle(lljit, adjoint_sym), flags),
mangle(jljit, adjoint_sym), flags),
]
if needs_augmented_primal
push!(symbols, LLVM.API.LLVMOrcCSymbolFlagsMapPair(
mangle(lljit, primal_sym), flags),)
mangle(jljit, primal_sym), flags),)
end

mu = LLVM.CustomMaterializationUnit(adjoint_sym, symbols,
Expand All @@ -221,10 +223,10 @@ function get_trampoline(job)
end

function add!(mod)
lljit = jit[].jit
jd = LLVM.JITDylib(lljit)
jljit = jit[].jit
jd = LLVM.JITDylib(jljit)
tsm = move_to_threadsafe(mod)
LLVM.add!(lljit, jd, tsm)
LLVM.add!(jljit, jd, tsm)
return nothing
end

Expand Down

0 comments on commit 1a136e2

Please sign in to comment.