Skip to content

Commit

Permalink
Move to a pool of threadsafecontexts (#44605)
Browse files Browse the repository at this point in the history
* Use pooled contexts

* Allow move construction of the resource pool
  • Loading branch information
pchintalapudi authored Apr 10, 2022
1 parent 3d87815 commit 992b261
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 22 deletions.
2 changes: 2 additions & 0 deletions doc/src/devdocs/locks.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ The following are definitely leaf locks (level 1), and must not try to acquire a
> * flisp
> * jl_in_stackwalk (Win32)
> * PM_mutex[i]
> * ContextPool::mutex
>
> > flisp itself is already threadsafe, this lock only protects the `jl_ast_context_list_t` pool
> > likewise, orc::ThreadSafeContexts carry their own lock, the ContextPool::mutex just protects the pool
The following is a leaf lock (level 2), and only acquires level 1 locks (safepoint) internally:

Expand Down
20 changes: 13 additions & 7 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,19 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
jl_native_code_desc_t *data = new jl_native_code_desc_t;
CompilationPolicy policy = (CompilationPolicy) _policy;
bool imaging = imaging_default() || policy == CompilationPolicy::ImagingMode;
orc::ThreadSafeModule backing;
if (!llvmmod) {
backing = jl_create_llvm_module("text", jl_ExecutionEngine->getContext(), imaging);
}
orc::ThreadSafeModule &clone = llvmmod ? *reinterpret_cast<orc::ThreadSafeModule*>(llvmmod) : backing;
auto ctxt = clone.getContext();
jl_workqueue_t emitted;
jl_method_instance_t *mi = NULL;
jl_code_info_t *src = NULL;
JL_GC_PUSH1(&src);
JL_LOCK(&jl_codegen_lock);
orc::ThreadSafeContext ctx;
orc::ThreadSafeModule backing;
if (!llvmmod) {
ctx = jl_ExecutionEngine->acquireContext();
backing = jl_create_llvm_module("text", ctx, imaging);
}
orc::ThreadSafeModule &clone = llvmmod ? *reinterpret_cast<orc::ThreadSafeModule*>(llvmmod) : backing;
auto ctxt = clone.getContext();
jl_codegen_params_t params(ctxt);
params.params = cgparams;
uint64_t compiler_start_time = 0;
Expand Down Expand Up @@ -402,6 +404,9 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
data->M = std::move(clone);
if (measure_compile_time_enabled)
jl_atomic_fetch_add_relaxed(&jl_cumulative_compile_time, (jl_hrtime() - compiler_start_time));
if (ctx.getContext()) {
jl_ExecutionEngine->releaseContext(std::move(ctx));
}
JL_UNLOCK(&jl_codegen_lock); // Might GC
return (void*)data;
}
Expand Down Expand Up @@ -1020,7 +1025,8 @@ void *jl_get_llvmf_defn_impl(jl_method_instance_t *mi, size_t world, char getwra
// emit this function into a new llvm module
if (src && jl_is_code_info(src)) {
JL_LOCK(&jl_codegen_lock);
jl_codegen_params_t output(jl_ExecutionEngine->getContext());
auto ctx = jl_ExecutionEngine->getContext();
jl_codegen_params_t output(*ctx);
output.world = world;
output.params = &params;
orc::ThreadSafeModule m = jl_create_llvm_module(name_from_method_instance(mi), output.tsctx, output.imaging);
Expand Down
2 changes: 1 addition & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8372,7 +8372,7 @@ extern "C" void jl_init_llvm(void)
if (clopt && clopt->getNumOccurrences() == 0)
cl::ProvidePositionalOption(clopt, "4", 1);

jl_ExecutionEngine = new JuliaOJIT(new LLVMContext());
jl_ExecutionEngine = new JuliaOJIT();

bool jl_using_gdb_jitevents = false;
// Register GDB event listener
Expand Down
27 changes: 17 additions & 10 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,21 @@ int jl_compile_extern_c_impl(LLVMOrcThreadSafeModuleRef llvmmod, void *p, void *
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
if (measure_compile_time_enabled)
compiler_start_time = jl_hrtime();
orc::ThreadSafeContext ctx;
auto into = reinterpret_cast<orc::ThreadSafeModule*>(llvmmod);
jl_codegen_params_t *pparams = (jl_codegen_params_t*)p;
orc::ThreadSafeModule backing;
if (into == NULL) {
backing = jl_create_llvm_module("cextern", pparams ? pparams->tsctx : jl_ExecutionEngine->getContext(), pparams ? pparams->imaging : imaging_default());
if (!pparams) {
ctx = jl_ExecutionEngine->acquireContext();
}
backing = jl_create_llvm_module("cextern", pparams ? pparams->tsctx : ctx, pparams ? pparams->imaging : imaging_default());
into = &backing;
}
jl_codegen_params_t params(into->getContext());
if (pparams == NULL)
pparams = &params;
assert(pparams->tsctx.getContext() == into->getContext().getContext());
const char *name = jl_generate_ccallable(reinterpret_cast<LLVMOrcThreadSafeModuleRef>(into), sysimg, declrt, sigt, *pparams);
bool success = true;
if (!sysimg) {
Expand All @@ -252,6 +257,9 @@ int jl_compile_extern_c_impl(LLVMOrcThreadSafeModuleRef llvmmod, void *p, void *
}
if (jl_codegen_lock.count == 1 && measure_compile_time_enabled)
jl_atomic_fetch_add_relaxed(&jl_cumulative_compile_time, (jl_hrtime() - compiler_start_time));
if (ctx.getContext()) {
jl_ExecutionEngine->releaseContext(std::move(ctx));
}
JL_UNLOCK(&jl_codegen_lock);
return success;
}
Expand Down Expand Up @@ -306,7 +314,8 @@ extern "C" JL_DLLEXPORT
jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES_ROOT, size_t world)
{
JL_LOCK(&jl_codegen_lock); // also disables finalizers, to prevent any unexpected recursion
auto &context = jl_ExecutionEngine->getContext();
auto ctx = jl_ExecutionEngine->getContext();
auto &context = *ctx;
uint64_t compiler_start_time = 0;
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
if (measure_compile_time_enabled)
Expand Down Expand Up @@ -363,7 +372,8 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
return;
}
JL_LOCK(&jl_codegen_lock);
auto &context = jl_ExecutionEngine->getContext();
auto ctx = jl_ExecutionEngine->getContext();
auto &context = *ctx;
uint64_t compiler_start_time = 0;
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
if (measure_compile_time_enabled)
Expand Down Expand Up @@ -417,7 +427,8 @@ jl_value_t *jl_dump_method_asm_impl(jl_method_instance_t *mi, size_t world,
// (using sentinel value `1` instead)
// so create an exception here so we can print pretty our lies
JL_LOCK(&jl_codegen_lock); // also disables finalizers, to prevent any unexpected recursion
auto &context = jl_ExecutionEngine->getContext();
auto ctx = jl_ExecutionEngine->getContext();
auto &context = *ctx;
uint64_t compiler_start_time = 0;
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
if (measure_compile_time_enabled)
Expand Down Expand Up @@ -909,7 +920,7 @@ llvm::DataLayout jl_create_datalayout(TargetMachine &TM) {
return jl_data_layout;
}

JuliaOJIT::JuliaOJIT(LLVMContext *LLVMCtx)
JuliaOJIT::JuliaOJIT()
: TM(createTargetMachine()),
DL(jl_create_datalayout(*TM)),
TMs{
Expand All @@ -918,14 +929,14 @@ JuliaOJIT::JuliaOJIT(LLVMContext *LLVMCtx)
cantFail(createJTMBFromTM(*TM, 2).createTargetMachine()),
cantFail(createJTMBFromTM(*TM, 3).createTargetMachine())
},
TSCtx(std::unique_ptr<LLVMContext>(LLVMCtx)),
#if JL_LLVM_VERSION >= 130000
ES(cantFail(orc::SelfExecutorProcessControl::Create())),
#else
ES(),
#endif
GlobalJD(ES.createBareJITDylib("JuliaGlobals")),
JD(ES.createBareJITDylib("JuliaOJIT")),
ContextPool([](){ return orc::ThreadSafeContext(std::make_unique<LLVMContext>()); }),
#ifdef JL_USE_JITLINK
// TODO: Port our memory management optimisations to JITLink instead of using the
// default InProcessMemoryManager.
Expand Down Expand Up @@ -1165,10 +1176,6 @@ void JuliaOJIT::RegisterJITEventListener(JITEventListener *L)
}
#endif

orc::ThreadSafeContext &JuliaOJIT::getContext() {
return TSCtx;
}

const DataLayout& JuliaOJIT::getDataLayout() const
{
return DL;
Expand Down
98 changes: 94 additions & 4 deletions src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <llvm/IR/Module.h>
#include <llvm/IR/Value.h>
#include <llvm/IR/PassManager.h>
#include "llvm/IR/LegacyPassManager.h"
#include <llvm/IR/LegacyPassManager.h>

#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
#include <llvm/ExecutionEngine/Orc/IRTransformLayer.h>
Expand Down Expand Up @@ -194,6 +194,87 @@ class JuliaOJIT {
typedef orc::IRTransformLayer OptimizeLayerT;
typedef object::OwningBinary<object::ObjectFile> OwningObj;
private:
template<typename ResourceT, size_t max = 0>
struct ResourcePool {
public:
ResourcePool(function_ref<ResourceT()> creator) : creator(std::move(creator)), mutex(std::make_unique<WNMutex>()) {}
class OwningResource {
public:
OwningResource(ResourcePool &pool, ResourceT resource) : pool(pool), resource(std::move(resource)) {}
OwningResource(const OwningResource &) = delete;
OwningResource &operator=(const OwningResource &) = delete;
OwningResource(OwningResource &&) = default;
OwningResource &operator=(OwningResource &&) = default;
~OwningResource() {
if (resource) pool.release_(std::move(*resource));
}
ResourceT release() {
ResourceT res(std::move(*resource));
resource.reset();
return res;
}
void reset(ResourceT res) {
*resource = std::move(res);
}
ResourceT &operator*() {
return *resource;
}
ResourceT *operator->() {
return get();
}
ResourceT *get() {
return resource.getPointer();
}
const ResourceT &operator*() const {
return *resource;
}
const ResourceT *operator->() const {
return get();
}
const ResourceT *get() const {
return resource.getPointer();
}
explicit operator bool() const {
return resource;
}
private:
ResourcePool &pool;
llvm::Optional<ResourceT> resource;
};

OwningResource acquire() {
return OwningResource(*this, acquire_());
}

ResourceT acquire_() {
std::unique_lock<std::mutex> lock(mutex->mutex);
if (!pool.empty()) {
return pool.pop_back_val();
}
if (!max || created < max) {
created++;
return creator();
}
mutex->empty.wait(lock, [&](){ return !pool.empty(); });
assert(!pool.empty() && "Expected resource pool to have a value!");
return pool.pop_back_val();
}
void release_(ResourceT &&resource) {
std::lock_guard<std::mutex> lock(mutex->mutex);
pool.push_back(std::move(resource));
mutex->empty.notify_one();
}
private:
llvm::function_ref<ResourceT()> creator;
size_t created = 0;
llvm::SmallVector<ResourceT, max == 0 ? 8 : max> pool;
struct WNMutex {
std::mutex mutex;
std::condition_variable empty;
};

std::unique_ptr<WNMutex> mutex;
};
struct OptimizerT {
OptimizerT(legacy::PassManager &PM, std::mutex &mutex, int optlevel) : optlevel(optlevel), PM(PM), mutex(mutex) {}

Expand Down Expand Up @@ -223,7 +304,7 @@ class JuliaOJIT {

public:

JuliaOJIT(LLVMContext *Ctx);
JuliaOJIT();

void enableJITDebuggingSupport();
#ifndef JL_USE_JITLINK
Expand All @@ -239,7 +320,15 @@ class JuliaOJIT {
uint64_t getGlobalValueAddress(StringRef Name);
uint64_t getFunctionAddress(StringRef Name);
StringRef getFunctionAtAddress(uint64_t Addr, jl_code_instance_t *codeinst);
orc::ThreadSafeContext &getContext();
auto getContext() {
return ContextPool.acquire();
}
orc::ThreadSafeContext acquireContext() {
return ContextPool.acquire_();
}
void releaseContext(orc::ThreadSafeContext &&ctx) {
ContextPool.release_(std::move(ctx));
}
const DataLayout& getDataLayout() const;
TargetMachine &getTargetMachine();
const Triple& getTargetTriple() const;
Expand All @@ -260,11 +349,12 @@ class JuliaOJIT {
std::mutex PM_mutexes[4];
std::unique_ptr<TargetMachine> TMs[4];

orc::ThreadSafeContext TSCtx;
orc::ExecutionSession ES;
orc::JITDylib &GlobalJD;
orc::JITDylib &JD;

ResourcePool<orc::ThreadSafeContext> ContextPool;

#ifndef JL_USE_JITLINK
std::shared_ptr<RTDyldMemoryManager> MemMgr;
#endif
Expand Down

0 comments on commit 992b261

Please sign in to comment.