Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose the Julia JIT with a C API #49858

Merged
merged 5 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop llvm
llvm-final-gc-lowering llvm-pass-helpers llvm-late-gc-lowering llvm-ptls \
llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces \
llvm-multiversioning llvm-alloc-opt llvm-alloc-helpers cgmemmgr llvm-remove-addrspaces \
llvm-remove-ni llvm-julia-licm llvm-demote-float16 llvm-cpufeatures pipeline
llvm-remove-ni llvm-julia-licm llvm-demote-float16 llvm-cpufeatures pipeline llvm_api
FLAGS += -I$(shell $(LLVM_CONFIG_HOST) --includedir)
CG_LLVM_LIBS := all
ifeq ($(USE_POLLY),1)
Expand Down
23 changes: 23 additions & 0 deletions src/codegen-stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,26 @@ JL_DLLEXPORT void LLVMExtraAddGCInvariantVerifierPass_fallback(void *PM, bool_t
JL_DLLEXPORT void LLVMExtraAddDemoteFloat16Pass_fallback(void *PM) UNAVAILABLE

JL_DLLEXPORT void LLVMExtraAddCPUFeaturesPass_fallback(void *PM) UNAVAILABLE

//LLVM C api to the julia JIT
JL_DLLEXPORT void* JLJITGetLLVMOrcExecutionSession_fallback(void* JIT) UNAVAILABLE
vchuravy marked this conversation as resolved.
Show resolved Hide resolved

JL_DLLEXPORT void* JLJITGetJuliaOJIT_fallback(void) UNAVAILABLE

JL_DLLEXPORT void* JLJITGetExternalJITDylib_fallback(void* JIT) UNAVAILABLE

JL_DLLEXPORT void* JLJITAddObjectFile_fallback(void* JIT, void* JD, void* ObjBuffer) UNAVAILABLE

JL_DLLEXPORT void* JLJITAddLLVMIRModule_fallback(void* JIT, void* JD, void* TSM) UNAVAILABLE

JL_DLLEXPORT void* JLJITLookup_fallback(void* JIT, void* Result, const char *Name) UNAVAILABLE

JL_DLLEXPORT void* JLJITMangleAndIntern_fallback(void* JIT, const char *Name) UNAVAILABLE

JL_DLLEXPORT const char *JLJITGetTripleString_fallback(void* JIT) UNAVAILABLE

JL_DLLEXPORT const char JLJITGetGlobalPrefix_fallback(void* JIT) UNAVAILABLE

JL_DLLEXPORT const char *JLJITGetDataLayoutString_fallback(void* JIT) UNAVAILABLE

JL_DLLEXPORT void* JLJITGetIRCompileLayer_fallback(void* JIT) UNAVAILABLE
44 changes: 41 additions & 3 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,7 @@ JuliaOJIT::JuliaOJIT()
ES(cantFail(orc::SelfExecutorProcessControl::Create())),
GlobalJD(ES.createBareJITDylib("JuliaGlobals")),
JD(ES.createBareJITDylib("JuliaOJIT")),
ExternalJD(ES.createBareJITDylib("JuliaExternal")),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long-term we might want many ExternalJDs and was considering a design where we had one JD per world in Julia.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add as many JDs as we want with the ExecutionSession. The difference is that this one is linked in by default.

ContextPool([](){
auto ctx = std::make_unique<LLVMContext>();
return orc::ThreadSafeContext(std::move(ctx));
Expand All @@ -1313,7 +1314,9 @@ JuliaOJIT::JuliaOJIT()
std::make_unique<PipelineT>(LockLayer, *TM, 2, PrintLLVMTimers),
std::make_unique<PipelineT>(LockLayer, *TM, 3, PrintLLVMTimers),
},
OptSelLayer(Pipelines)
OptSelLayer(Pipelines),
ExternalCompileLayer(ES, LockLayer,
std::make_unique<CompilerT>(orc::irManglingOptionsFromTargetOptions(TM->Options), *TM, 2))
{
#ifdef JL_USE_JITLINK
# if defined(LLVM_SHLIB)
Expand Down Expand Up @@ -1385,6 +1388,9 @@ JuliaOJIT::JuliaOJIT()
}

JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
JD.addToLinkOrder(ExternalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
ExternalJD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
ExternalJD.addToLinkOrder(JD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should not be any way for ExternalJD to access JD or vice versa. Neither of those should be accidentally breaking the content of the other.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this one is kind of wrong, but we do want to keep the ability to llvmcall a function from an external JD to avoid doing the trampoline/function pointer dance we do right now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems likely to go very badly for users. The llvmcall itself right now should not be able to access JD state either, but that is a different bug.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would it go bad for the users? (llvmcall currently aborts if it doesn't find a symbol but that's separate)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JD is an implementation detail and not one that is stable, or one that we think is good right now. We could fix the llvmcall crash by checking that it does not contain declarations (except of intrinsic things) and otherwise reject it when we parsed it. That would do the right thing.


#if JULIA_FLOAT16_ABI == 1
orc::SymbolAliasMap jl_crt = {
Expand Down Expand Up @@ -1484,10 +1490,34 @@ void JuliaOJIT::addModule(orc::ThreadSafeModule TSM)
}
}

Error JuliaOJIT::addExternalModule(orc::JITDylib &JD, orc::ThreadSafeModule TSM, bool ShouldOptimize)
{
if (auto Err = TSM.withModuleDo([&](Module &M) JL_NOTSAFEPOINT -> Error
{
if (M.getDataLayout().isDefault())
M.setDataLayout(DL);
if (M.getDataLayout() != DL)
return make_error<StringError>(
"Added modules have incompatible data layouts: " +
M.getDataLayout().getStringRepresentation() + " (module) vs " +
DL.getStringRepresentation() + " (jit)",
inconvertibleErrorCode());

return Error::success();
}))
return Err;
return ExternalCompileLayer.add(JD.getDefaultResourceTracker(), std::move(TSM));
}

Error JuliaOJIT::addObjectFile(orc::JITDylib &JD, std::unique_ptr<MemoryBuffer> Obj) {
assert(Obj && "Can not add null object");
return LockLayer.add(JD.getDefaultResourceTracker(), std::move(Obj));
}

JL_JITSymbol JuliaOJIT::findSymbol(StringRef Name, bool ExportedSymbolsOnly)
{
orc::JITDylib* SearchOrders[2] = {&JD, &GlobalJD};
ArrayRef<orc::JITDylib*> SearchOrder = makeArrayRef(&SearchOrders[0], ExportedSymbolsOnly ? 2 : 1);
orc::JITDylib* SearchOrders[3] = {&JD, &GlobalJD, &ExternalJD};
ArrayRef<orc::JITDylib*> SearchOrder = makeArrayRef(&SearchOrders[0], ExportedSymbolsOnly ? 3 : 1);
Comment on lines +1519 to +1520
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should be reverted?

auto Sym = ES.lookup(SearchOrder, Name);
if (Sym)
return *Sym;
Expand All @@ -1499,6 +1529,14 @@ JL_JITSymbol JuliaOJIT::findUnmangledSymbol(StringRef Name)
return findSymbol(getMangledName(Name), true);
}

Expected<JITEvaluatedSymbol> JuliaOJIT::findExternalJDSymbol(StringRef Name, bool ExternalJDOnly)
{
orc::JITDylib* SearchOrders[3] = {&ExternalJD, &GlobalJD, &JD};
Comment on lines +1532 to +1534
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should take some sort of handle that was returned by from addExternalModule / addObjectFile to do the lookup inside of. No global variables (esp. GlobalJD and JD) should be accessed here since their content is unreliable.

ArrayRef<orc::JITDylib*> SearchOrder = makeArrayRef(&SearchOrders[0], ExternalJDOnly ? 1 : 3);
auto Sym = ES.lookup(SearchOrder, getMangledName(Name));
return Sym;
}

uint64_t JuliaOJIT::getGlobalValueAddress(StringRef Name)
{
auto addr = findSymbol(getMangledName(Name), false);
Expand Down
14 changes: 13 additions & 1 deletion src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,16 @@ class JuliaOJIT {
void addGlobalMapping(StringRef Name, uint64_t Addr) JL_NOTSAFEPOINT;
void addModule(orc::ThreadSafeModule M) JL_NOTSAFEPOINT;

//Methods for the C API
Error addExternalModule(orc::JITDylib &JD, orc::ThreadSafeModule TSM,
bool ShouldOptimize = false) JL_NOTSAFEPOINT;
Error addObjectFile(orc::JITDylib &JD,
std::unique_ptr<MemoryBuffer> Obj) JL_NOTSAFEPOINT;
Expected<JITEvaluatedSymbol> findExternalJDSymbol(StringRef Name, bool ExternalJDOnly) JL_NOTSAFEPOINT;
orc::IRCompileLayer &getIRCompileLayer() JL_NOTSAFEPOINT { return ExternalCompileLayer; };
orc::ExecutionSession &getExecutionSession() JL_NOTSAFEPOINT { return ES; }
orc::JITDylib &getExternalJITDylib() JL_NOTSAFEPOINT { return ExternalJD; }

JL_JITSymbol findSymbol(StringRef Name, bool ExportedSymbolsOnly) JL_NOTSAFEPOINT;
JL_JITSymbol findUnmangledSymbol(StringRef Name) JL_NOTSAFEPOINT;
uint64_t getGlobalValueAddress(StringRef Name) JL_NOTSAFEPOINT;
Expand Down Expand Up @@ -523,7 +533,7 @@ class JuliaOJIT {
orc::ExecutionSession ES;
orc::JITDylib &GlobalJD;
orc::JITDylib &JD;

orc::JITDylib &ExternalJD;
//Map and inc are guarded by RLST_mutex
std::mutex RLST_mutex{};
int RLST_inc = 0;
Expand All @@ -548,6 +558,8 @@ class JuliaOJIT {
LockLayerT LockLayer;
const std::array<std::unique_ptr<PipelineT>, 4> Pipelines;
OptSelLayerT OptSelLayer;
CompileLayerT ExternalCompileLayer;

};
extern JuliaOJIT *jl_ExecutionEngine;
std::unique_ptr<Module> jl_create_llvm_module(StringRef name, LLVMContext &ctx, bool imaging_mode, const DataLayout &DL = jl_ExecutionEngine->getDataLayout(), const Triple &triple = jl_ExecutionEngine->getTargetTriple()) JL_NOTSAFEPOINT;
Expand Down
12 changes: 12 additions & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -573,5 +573,17 @@
YY(LLVMExtraAddGCInvariantVerifierPass) \
YY(LLVMExtraAddDemoteFloat16Pass) \
YY(LLVMExtraAddCPUFeaturesPass) \
YY(JLJITGetLLVMOrcExecutionSession) \
YY(JLJITGetJuliaOJIT) \
YY(JLJITGetExternalJITDylib) \
YY(JLJITAddObjectFile) \
YY(JLJITAddLLVMIRModule) \
YY(JLJITLookup) \
YY(JLJITMangleAndIntern) \
YY(JLJITGetTripleString) \
YY(JLJITGetGlobalPrefix) \
YY(JLJITGetDataLayoutString) \
YY(JLJITGetIRCompileLayer) \


// end of file
1 change: 1 addition & 0 deletions src/julia.expmap
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_Z22jl_coverage_alloc_lineN4llvm9StringRefEi;
_Z22jl_malloc_data_pointerN4llvm9StringRefEi;
LLVMExtra*;
JLJIT*;
llvmGetPassPluginInfo;

/* Make visible so that linker will merge duplicate definitions across DSO boundaries */
Expand Down
133 changes: 133 additions & 0 deletions src/llvm_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// This file is a part of Julia. License is MIT: https://julialang.org/license

#undef DEBUG
#include "llvm-version.h"
#include "platform.h"

#ifndef __STDC_LIMIT_MACROS
#define __STDC_LIMIT_MACROS
#define __STDC_CONSTANT_MACROS
#endif

#include <jitlayers.h>

#include <llvm-c/Core.h>
#include <llvm-c/Error.h>
#include <llvm-c/Orc.h>
#include <llvm-c/OrcEE.h>
#include <llvm-c/TargetMachine.h>
#include <llvm-c/Types.h>
#include <llvm/Support/CBindingWrapping.h>
#include <llvm/Support/MemoryBuffer.h>

namespace llvm {
namespace orc {
class OrcV2CAPIHelper {
public:
using PoolEntry = orc::SymbolStringPtr::PoolEntry;
using PoolEntryPtr = orc::SymbolStringPtr::PoolEntryPtr;

// Move from SymbolStringPtr to PoolEntryPtr (no change in ref count).
static PoolEntryPtr moveFromSymbolStringPtr(SymbolStringPtr S)
{
PoolEntryPtr Result = nullptr;
std::swap(Result, S.S);
return Result;
}
};
} // namespace orc
} // namespace llvm


typedef struct JLOpaqueJuliaOJIT *JuliaOJITRef;
typedef struct LLVMOrcOpaqueIRCompileLayer *LLVMOrcIRCompileLayerRef;

DEFINE_SIMPLE_CONVERSION_FUNCTIONS(JuliaOJIT, JuliaOJITRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::JITDylib, LLVMOrcJITDylibRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::ExecutionSession, LLVMOrcExecutionSessionRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::OrcV2CAPIHelper::PoolEntry,
LLVMOrcSymbolStringPoolEntryRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::IRCompileLayer, LLVMOrcIRCompileLayerRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::MaterializationResponsibility,
LLVMOrcMaterializationResponsibilityRef)
extern "C" {

JL_DLLEXPORT_CODEGEN JuliaOJITRef JLJITGetJuliaOJIT_impl(void)
{
return wrap(jl_ExecutionEngine);
}

JL_DLLEXPORT_CODEGEN LLVMOrcExecutionSessionRef
JLJITGetLLVMOrcExecutionSession_impl(JuliaOJITRef JIT)
{
return wrap(&unwrap(JIT)->getExecutionSession());
}

JL_DLLEXPORT_CODEGEN LLVMOrcJITDylibRef
JLJITGetExternalJITDylib_impl(JuliaOJITRef JIT)
{
return wrap(&unwrap(JIT)->getExternalJITDylib());
}

JL_DLLEXPORT_CODEGEN LLVMErrorRef JLJITAddObjectFile_impl(
JuliaOJITRef JIT, LLVMOrcJITDylibRef JD, LLVMMemoryBufferRef ObjBuffer)
{
return wrap(unwrap(JIT)->addObjectFile(
*unwrap(JD), std::unique_ptr<MemoryBuffer>(unwrap(ObjBuffer))));
}

JL_DLLEXPORT_CODEGEN LLVMErrorRef JLJITAddLLVMIRModule_impl(
JuliaOJITRef JIT, LLVMOrcJITDylibRef JD, LLVMOrcThreadSafeModuleRef TSM)
{
std::unique_ptr<orc::ThreadSafeModule> TmpTSM(unwrap(TSM));
return wrap(unwrap(JIT)->addExternalModule(*unwrap(JD), std::move(*TmpTSM)));
}

JL_DLLEXPORT_CODEGEN LLVMErrorRef
JLJITLookup_impl(JuliaOJITRef JIT, LLVMOrcExecutorAddress *Result,
const char *Name, int ExternalJDOnly)
{
auto Sym = unwrap(JIT)->findExternalJDSymbol(Name, ExternalJDOnly);
if (Sym) {
auto addr = Sym->getAddress();
*Result = orc::ExecutorAddr(addr).getValue();
return LLVMErrorSuccess;
}
else {
*Result = 0;
return wrap(Sym.takeError());
}
}

JL_DLLEXPORT_CODEGEN LLVMOrcSymbolStringPoolEntryRef
JLJITMangleAndIntern_impl(JuliaOJITRef JIT,
const char *Name)
{
return wrap(orc::OrcV2CAPIHelper::moveFromSymbolStringPtr(unwrap(JIT)->mangle(Name)));
}

JL_DLLEXPORT_CODEGEN const char *
JLJITGetTripleString_impl(JuliaOJITRef JIT)
{
return unwrap(JIT)->getTargetTriple().str().c_str();
}

JL_DLLEXPORT_CODEGEN const char
JLJITGetGlobalPrefix_impl(JuliaOJITRef JIT)
{
return unwrap(JIT)->getDataLayout().getGlobalPrefix();
}

JL_DLLEXPORT_CODEGEN const char *
JLJITGetDataLayoutString_impl(JuliaOJITRef JIT)
{
return unwrap(JIT)->getDataLayout().getStringRepresentation().c_str();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The C API supports DataLayout objects, so just return it directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't know that, because the other functions (GetDataLayout,SetDatalayout etc) that move it around all use C strings

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see that the existing APIs use strings too (LLVMOrcLLJITGetDataLayoutStr), so I'm OK keeping it as strings for consistency.

}

JL_DLLEXPORT_CODEGEN LLVMOrcIRCompileLayerRef
JLJITGetIRCompileLayer_impl(JuliaOJITRef JIT)
{
return wrap(&unwrap(JIT)->getIRCompileLayer());
}

} // extern "C"
13 changes: 13 additions & 0 deletions test/llvmcall2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,16 @@ let err = ErrorException("llvmcall only supports intrinsic calls")
@test_throws err (@eval ccall("llvm.floor.f64", llvmcall, Float64, (Float64, Float64...,), 0.0)) === 0.0
@test_throws err (@eval ccall("llvm.floor", llvmcall, Float64, (Float64, Float64...,), 0.0)) === 0.0
end

@testset "JLJIT API" begin
function JLJITGetJuliaOJIT()
ccall(:JLJITGetJuliaOJIT, Ptr{Cvoid}, ())
end
function JLJITGetTripleString(JIT)
ccall(:JLJITGetTripleString, Cstring, (Ptr{Cvoid},), JIT)
end
jit = JLJITGetJuliaOJIT()
str = JLJITGetTripleString(jit)
jl_str = unsafe_string(str)
@test length(jl_str) > 4
end