Skip to content

Commit

Permalink
Expose the Julia JIT with a C API (#49858)
Browse files Browse the repository at this point in the history
  • Loading branch information
gbaraldi authored May 29, 2023
1 parent 1cc10a6 commit 957972e
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop llvm
llvm-final-gc-lowering llvm-pass-helpers llvm-late-gc-lowering llvm-ptls \
llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces \
llvm-multiversioning llvm-alloc-opt llvm-alloc-helpers cgmemmgr llvm-remove-addrspaces \
llvm-remove-ni llvm-julia-licm llvm-demote-float16 llvm-cpufeatures pipeline
llvm-remove-ni llvm-julia-licm llvm-demote-float16 llvm-cpufeatures pipeline llvm_api
FLAGS += -I$(shell $(LLVM_CONFIG_HOST) --includedir)
CG_LLVM_LIBS := all
ifeq ($(USE_POLLY),1)
Expand Down
23 changes: 23 additions & 0 deletions src/codegen-stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,26 @@ JL_DLLEXPORT void LLVMExtraAddGCInvariantVerifierPass_fallback(void *PM, bool_t
JL_DLLEXPORT void LLVMExtraAddDemoteFloat16Pass_fallback(void *PM) UNAVAILABLE

JL_DLLEXPORT void LLVMExtraAddCPUFeaturesPass_fallback(void *PM) UNAVAILABLE

//LLVM C api to the julia JIT
JL_DLLEXPORT void* JLJITGetLLVMOrcExecutionSession_fallback(void* JIT) UNAVAILABLE

JL_DLLEXPORT void* JLJITGetJuliaOJIT_fallback(void) UNAVAILABLE

JL_DLLEXPORT void* JLJITGetExternalJITDylib_fallback(void* JIT) UNAVAILABLE

JL_DLLEXPORT void* JLJITAddObjectFile_fallback(void* JIT, void* JD, void* ObjBuffer) UNAVAILABLE

JL_DLLEXPORT void* JLJITAddLLVMIRModule_fallback(void* JIT, void* JD, void* TSM) UNAVAILABLE

JL_DLLEXPORT void* JLJITLookup_fallback(void* JIT, void* Result, const char *Name) UNAVAILABLE

JL_DLLEXPORT void* JLJITMangleAndIntern_fallback(void* JIT, const char *Name) UNAVAILABLE

JL_DLLEXPORT const char *JLJITGetTripleString_fallback(void* JIT) UNAVAILABLE

JL_DLLEXPORT const char JLJITGetGlobalPrefix_fallback(void* JIT) UNAVAILABLE

JL_DLLEXPORT const char *JLJITGetDataLayoutString_fallback(void* JIT) UNAVAILABLE

JL_DLLEXPORT void* JLJITGetIRCompileLayer_fallback(void* JIT) UNAVAILABLE
44 changes: 41 additions & 3 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,7 @@ JuliaOJIT::JuliaOJIT()
ES(cantFail(orc::SelfExecutorProcessControl::Create())),
GlobalJD(ES.createBareJITDylib("JuliaGlobals")),
JD(ES.createBareJITDylib("JuliaOJIT")),
ExternalJD(ES.createBareJITDylib("JuliaExternal")),
ContextPool([](){
auto ctx = std::make_unique<LLVMContext>();
return orc::ThreadSafeContext(std::move(ctx));
Expand All @@ -1323,7 +1324,9 @@ JuliaOJIT::JuliaOJIT()
std::make_unique<PipelineT>(LockLayer, *TM, 2, PrintLLVMTimers),
std::make_unique<PipelineT>(LockLayer, *TM, 3, PrintLLVMTimers),
},
OptSelLayer(Pipelines)
OptSelLayer(Pipelines),
ExternalCompileLayer(ES, LockLayer,
std::make_unique<CompilerT>(orc::irManglingOptionsFromTargetOptions(TM->Options), *TM, 2))
{
#ifdef JL_USE_JITLINK
# if defined(LLVM_SHLIB)
Expand Down Expand Up @@ -1395,6 +1398,9 @@ JuliaOJIT::JuliaOJIT()
}

JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
JD.addToLinkOrder(ExternalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
ExternalJD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
ExternalJD.addToLinkOrder(JD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);

#if JULIA_FLOAT16_ABI == 1
orc::SymbolAliasMap jl_crt = {
Expand Down Expand Up @@ -1494,10 +1500,34 @@ void JuliaOJIT::addModule(orc::ThreadSafeModule TSM)
}
}

Error JuliaOJIT::addExternalModule(orc::JITDylib &JD, orc::ThreadSafeModule TSM, bool ShouldOptimize)
{
if (auto Err = TSM.withModuleDo([&](Module &M) JL_NOTSAFEPOINT -> Error
{
if (M.getDataLayout().isDefault())
M.setDataLayout(DL);
if (M.getDataLayout() != DL)
return make_error<StringError>(
"Added modules have incompatible data layouts: " +
M.getDataLayout().getStringRepresentation() + " (module) vs " +
DL.getStringRepresentation() + " (jit)",
inconvertibleErrorCode());

return Error::success();
}))
return Err;
return ExternalCompileLayer.add(JD.getDefaultResourceTracker(), std::move(TSM));
}

Error JuliaOJIT::addObjectFile(orc::JITDylib &JD, std::unique_ptr<MemoryBuffer> Obj) {
assert(Obj && "Can not add null object");
return LockLayer.add(JD.getDefaultResourceTracker(), std::move(Obj));
}

JL_JITSymbol JuliaOJIT::findSymbol(StringRef Name, bool ExportedSymbolsOnly)
{
orc::JITDylib* SearchOrders[2] = {&JD, &GlobalJD};
ArrayRef<orc::JITDylib*> SearchOrder = makeArrayRef(&SearchOrders[0], ExportedSymbolsOnly ? 2 : 1);
orc::JITDylib* SearchOrders[3] = {&JD, &GlobalJD, &ExternalJD};
ArrayRef<orc::JITDylib*> SearchOrder = makeArrayRef(&SearchOrders[0], ExportedSymbolsOnly ? 3 : 1);
auto Sym = ES.lookup(SearchOrder, Name);
if (Sym)
return *Sym;
Expand All @@ -1509,6 +1539,14 @@ JL_JITSymbol JuliaOJIT::findUnmangledSymbol(StringRef Name)
return findSymbol(getMangledName(Name), true);
}

Expected<JITEvaluatedSymbol> JuliaOJIT::findExternalJDSymbol(StringRef Name, bool ExternalJDOnly)
{
orc::JITDylib* SearchOrders[3] = {&ExternalJD, &GlobalJD, &JD};
ArrayRef<orc::JITDylib*> SearchOrder = makeArrayRef(&SearchOrders[0], ExternalJDOnly ? 1 : 3);
auto Sym = ES.lookup(SearchOrder, getMangledName(Name));
return Sym;
}

uint64_t JuliaOJIT::getGlobalValueAddress(StringRef Name)
{
auto addr = findSymbol(getMangledName(Name), false);
Expand Down
14 changes: 13 additions & 1 deletion src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,16 @@ class JuliaOJIT {
void addGlobalMapping(StringRef Name, uint64_t Addr) JL_NOTSAFEPOINT;
void addModule(orc::ThreadSafeModule M) JL_NOTSAFEPOINT;

//Methods for the C API
Error addExternalModule(orc::JITDylib &JD, orc::ThreadSafeModule TSM,
bool ShouldOptimize = false) JL_NOTSAFEPOINT;
Error addObjectFile(orc::JITDylib &JD,
std::unique_ptr<MemoryBuffer> Obj) JL_NOTSAFEPOINT;
Expected<JITEvaluatedSymbol> findExternalJDSymbol(StringRef Name, bool ExternalJDOnly) JL_NOTSAFEPOINT;
orc::IRCompileLayer &getIRCompileLayer() JL_NOTSAFEPOINT { return ExternalCompileLayer; };
orc::ExecutionSession &getExecutionSession() JL_NOTSAFEPOINT { return ES; }
orc::JITDylib &getExternalJITDylib() JL_NOTSAFEPOINT { return ExternalJD; }

JL_JITSymbol findSymbol(StringRef Name, bool ExportedSymbolsOnly) JL_NOTSAFEPOINT;
JL_JITSymbol findUnmangledSymbol(StringRef Name) JL_NOTSAFEPOINT;
uint64_t getGlobalValueAddress(StringRef Name) JL_NOTSAFEPOINT;
Expand Down Expand Up @@ -523,7 +533,7 @@ class JuliaOJIT {
orc::ExecutionSession ES;
orc::JITDylib &GlobalJD;
orc::JITDylib &JD;

orc::JITDylib &ExternalJD;
//Map and inc are guarded by RLST_mutex
std::mutex RLST_mutex{};
int RLST_inc = 0;
Expand All @@ -548,6 +558,8 @@ class JuliaOJIT {
LockLayerT LockLayer;
const std::array<std::unique_ptr<PipelineT>, 4> Pipelines;
OptSelLayerT OptSelLayer;
CompileLayerT ExternalCompileLayer;

};
extern JuliaOJIT *jl_ExecutionEngine;
std::unique_ptr<Module> jl_create_llvm_module(StringRef name, LLVMContext &ctx, bool imaging_mode, const DataLayout &DL = jl_ExecutionEngine->getDataLayout(), const Triple &triple = jl_ExecutionEngine->getTargetTriple()) JL_NOTSAFEPOINT;
Expand Down
12 changes: 12 additions & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -573,5 +573,17 @@
YY(LLVMExtraAddGCInvariantVerifierPass) \
YY(LLVMExtraAddDemoteFloat16Pass) \
YY(LLVMExtraAddCPUFeaturesPass) \
YY(JLJITGetLLVMOrcExecutionSession) \
YY(JLJITGetJuliaOJIT) \
YY(JLJITGetExternalJITDylib) \
YY(JLJITAddObjectFile) \
YY(JLJITAddLLVMIRModule) \
YY(JLJITLookup) \
YY(JLJITMangleAndIntern) \
YY(JLJITGetTripleString) \
YY(JLJITGetGlobalPrefix) \
YY(JLJITGetDataLayoutString) \
YY(JLJITGetIRCompileLayer) \


// end of file
1 change: 1 addition & 0 deletions src/julia.expmap
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_Z22jl_coverage_alloc_lineN4llvm9StringRefEi;
_Z22jl_malloc_data_pointerN4llvm9StringRefEi;
LLVMExtra*;
JLJIT*;
llvmGetPassPluginInfo;

/* Make visible so that linker will merge duplicate definitions across DSO boundaries */
Expand Down
133 changes: 133 additions & 0 deletions src/llvm_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// This file is a part of Julia. License is MIT: https://julialang.org/license

#undef DEBUG
#include "llvm-version.h"
#include "platform.h"

#ifndef __STDC_LIMIT_MACROS
#define __STDC_LIMIT_MACROS
#define __STDC_CONSTANT_MACROS
#endif

#include <jitlayers.h>

#include <llvm-c/Core.h>
#include <llvm-c/Error.h>
#include <llvm-c/Orc.h>
#include <llvm-c/OrcEE.h>
#include <llvm-c/TargetMachine.h>
#include <llvm-c/Types.h>
#include <llvm/Support/CBindingWrapping.h>
#include <llvm/Support/MemoryBuffer.h>

namespace llvm {
namespace orc {
class OrcV2CAPIHelper {
public:
using PoolEntry = orc::SymbolStringPtr::PoolEntry;
using PoolEntryPtr = orc::SymbolStringPtr::PoolEntryPtr;

// Move from SymbolStringPtr to PoolEntryPtr (no change in ref count).
static PoolEntryPtr moveFromSymbolStringPtr(SymbolStringPtr S)
{
PoolEntryPtr Result = nullptr;
std::swap(Result, S.S);
return Result;
}
};
} // namespace orc
} // namespace llvm


typedef struct JLOpaqueJuliaOJIT *JuliaOJITRef;
typedef struct LLVMOrcOpaqueIRCompileLayer *LLVMOrcIRCompileLayerRef;

DEFINE_SIMPLE_CONVERSION_FUNCTIONS(JuliaOJIT, JuliaOJITRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::JITDylib, LLVMOrcJITDylibRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::ExecutionSession, LLVMOrcExecutionSessionRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::OrcV2CAPIHelper::PoolEntry,
LLVMOrcSymbolStringPoolEntryRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::IRCompileLayer, LLVMOrcIRCompileLayerRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(orc::MaterializationResponsibility,
LLVMOrcMaterializationResponsibilityRef)
extern "C" {

JL_DLLEXPORT_CODEGEN JuliaOJITRef JLJITGetJuliaOJIT_impl(void)
{
return wrap(jl_ExecutionEngine);
}

JL_DLLEXPORT_CODEGEN LLVMOrcExecutionSessionRef
JLJITGetLLVMOrcExecutionSession_impl(JuliaOJITRef JIT)
{
return wrap(&unwrap(JIT)->getExecutionSession());
}

JL_DLLEXPORT_CODEGEN LLVMOrcJITDylibRef
JLJITGetExternalJITDylib_impl(JuliaOJITRef JIT)
{
return wrap(&unwrap(JIT)->getExternalJITDylib());
}

JL_DLLEXPORT_CODEGEN LLVMErrorRef JLJITAddObjectFile_impl(
JuliaOJITRef JIT, LLVMOrcJITDylibRef JD, LLVMMemoryBufferRef ObjBuffer)
{
return wrap(unwrap(JIT)->addObjectFile(
*unwrap(JD), std::unique_ptr<MemoryBuffer>(unwrap(ObjBuffer))));
}

JL_DLLEXPORT_CODEGEN LLVMErrorRef JLJITAddLLVMIRModule_impl(
JuliaOJITRef JIT, LLVMOrcJITDylibRef JD, LLVMOrcThreadSafeModuleRef TSM)
{
std::unique_ptr<orc::ThreadSafeModule> TmpTSM(unwrap(TSM));
return wrap(unwrap(JIT)->addExternalModule(*unwrap(JD), std::move(*TmpTSM)));
}

JL_DLLEXPORT_CODEGEN LLVMErrorRef
JLJITLookup_impl(JuliaOJITRef JIT, LLVMOrcExecutorAddress *Result,
const char *Name, int ExternalJDOnly)
{
auto Sym = unwrap(JIT)->findExternalJDSymbol(Name, ExternalJDOnly);
if (Sym) {
auto addr = Sym->getAddress();
*Result = orc::ExecutorAddr(addr).getValue();
return LLVMErrorSuccess;
}
else {
*Result = 0;
return wrap(Sym.takeError());
}
}

JL_DLLEXPORT_CODEGEN LLVMOrcSymbolStringPoolEntryRef
JLJITMangleAndIntern_impl(JuliaOJITRef JIT,
const char *Name)
{
return wrap(orc::OrcV2CAPIHelper::moveFromSymbolStringPtr(unwrap(JIT)->mangle(Name)));
}

JL_DLLEXPORT_CODEGEN const char *
JLJITGetTripleString_impl(JuliaOJITRef JIT)
{
return unwrap(JIT)->getTargetTriple().str().c_str();
}

JL_DLLEXPORT_CODEGEN const char
JLJITGetGlobalPrefix_impl(JuliaOJITRef JIT)
{
return unwrap(JIT)->getDataLayout().getGlobalPrefix();
}

JL_DLLEXPORT_CODEGEN const char *
JLJITGetDataLayoutString_impl(JuliaOJITRef JIT)
{
return unwrap(JIT)->getDataLayout().getStringRepresentation().c_str();
}

JL_DLLEXPORT_CODEGEN LLVMOrcIRCompileLayerRef
JLJITGetIRCompileLayer_impl(JuliaOJITRef JIT)
{
return wrap(&unwrap(JIT)->getIRCompileLayer());
}

} // extern "C"
13 changes: 13 additions & 0 deletions test/llvmcall2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,16 @@ let err = ErrorException("llvmcall only supports intrinsic calls")
@test_throws err (@eval ccall("llvm.floor.f64", llvmcall, Float64, (Float64, Float64...,), 0.0)) === 0.0
@test_throws err (@eval ccall("llvm.floor", llvmcall, Float64, (Float64, Float64...,), 0.0)) === 0.0
end

@testset "JLJIT API" begin
function JLJITGetJuliaOJIT()
ccall(:JLJITGetJuliaOJIT, Ptr{Cvoid}, ())
end
function JLJITGetTripleString(JIT)
ccall(:JLJITGetTripleString, Cstring, (Ptr{Cvoid},), JIT)
end
jit = JLJITGetJuliaOJIT()
str = JLJITGetTripleString(jit)
jl_str = unsafe_string(str)
@test length(jl_str) > 4
end

0 comments on commit 957972e

Please sign in to comment.