Skip to content

Commit

Permalink
[xla:cpu] Migrate CpuCompiler from SimpleOrcJit to JitCompiler
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701149812
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Nov 29, 2024
1 parent 2235f1f commit 6949b21
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 274 deletions.
4 changes: 2 additions & 2 deletions xla/backends/cpu/codegen/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("//xla:xla.bzl", "xla_cc_test")
load("//xla:xla.bzl", "xla_cc_test", "xla_internal")
load(
"//xla/tsl/platform:build_config_root.bzl",
"if_llvm_aarch64_available",
Expand Down Expand Up @@ -131,7 +131,7 @@ cc_library(
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
]) + if_llvm_x86_available([
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
]),
]) + xla_internal(["service/cpu:named_orc_jit_memory_mapper"]),
)

xla_cc_test(
Expand Down
27 changes: 16 additions & 11 deletions xla/backends/cpu/codegen/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,15 @@ absl::StatusOr<JitCompiler> JitCompiler::Create(
options.max_cpu_feature);
TF_ASSIGN_OR_RETURN(auto target_machine, target_machine_builder());

// Dispatch compilation tasks using the provided task runner.
auto task_dispatcher =
std::make_unique<TaskDispatcher>(std::move(task_runner));
TaskDispatcher* task_dispatcher_ptr = task_dispatcher.get();

// LLVM execution session that holds jit-compiled functions.
auto execution_session = std::make_unique<llvm::orc::ExecutionSession>(
std::make_unique<llvm::orc::UnsupportedExecutorProcessControl>(
/*SSP=*/nullptr,
std::make_unique<TaskDispatcher>(std::move(task_runner))));
/*SSP=*/nullptr, std::move(task_dispatcher)));

execution_session->setErrorReporter([](llvm::Error err) {
LOG(ERROR) << "LLVM compilation error: " << llvm::toString(std::move(err));
Expand All @@ -136,10 +140,10 @@ absl::StatusOr<JitCompiler> JitCompiler::Create(
target_machine_builder, std::move(options.ir_compiler_options),
std::move(options.ir_compiler_hooks));

return JitCompiler(std::move(target_machine_builder),
std::move(target_machine), std::move(execution_session),
std::move(ir_compiler), options.num_dylibs,
std::move(options.definition_generator));
return JitCompiler(
std::move(target_machine_builder), std::move(target_machine),
task_dispatcher_ptr, std::move(execution_session), std::move(ir_compiler),
options.num_dylibs, std::move(options.definition_generator));
}

static std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer>
Expand All @@ -162,11 +166,13 @@ static std::unique_ptr<llvm::orc::IRCompileLayer> CreateCompileLayer(
JitCompiler::JitCompiler(
IrCompiler::TargetMachineBuilder target_machine_builder,
std::shared_ptr<llvm::TargetMachine> target_machine,
TaskDispatcher* task_dispatcher,
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs,
DefinitionGenerator definition_generator)
: target_machine_builder_(std::move(target_machine_builder)),
target_machine_(std::move(target_machine)),
task_dispatcher_(task_dispatcher),
execution_session_(std::move(execution_session)),
object_layer_(CreateObjectLinkingLayer(*execution_session_)),
compile_layer_(CreateCompileLayer(*execution_session_, *object_layer_,
Expand Down Expand Up @@ -267,6 +273,10 @@ absl::StatusOr<std::unique_ptr<FunctionLibrary>> JitCompiler::Compile(
// Look up all requested symbols in the execution session.
auto symbol_map = execution_session_->lookup(std::move(search_order),
std::move(lookup_set));

// Wait for all compilation tasks to finish.
task_dispatcher_->shutdown();

if (auto err = symbol_map.takeError()) {
return Internal("%s", llvm::toString(std::move(err)));
}
Expand Down Expand Up @@ -342,11 +352,6 @@ JitCompiler::CompiledFunctionLibrary::~CompiledFunctionLibrary() {
if (auto err = execution_session_->endSession()) {
execution_session_->reportError(std::move(err));
}
// Explicitly destroy the execution session to ensure that all tasks are
// finished, because otherwise object layer materialization running inside the
// task dispatched triggers use-after-free errors. This is super fishy, and we
// don't really understand why this is happening.
execution_session_.reset();
}

absl::StatusOr<void*> JitCompiler::CompiledFunctionLibrary::ResolveFunction(
Expand Down
15 changes: 9 additions & 6 deletions xla/backends/cpu/codegen/jit_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,6 @@ class JitCompiler {
llvm::TargetMachine* target_machine() { return target_machine_.get(); }

private:
JitCompiler(IrCompiler::TargetMachineBuilder target_machine_builder,
std::shared_ptr<llvm::TargetMachine> target_machine,
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs,
DefinitionGenerator definition_generator);

// LLVM ORC task dispatcher that uses `TaskRunner` to run compilation tasks.
class TaskDispatcher : public llvm::orc::TaskDispatcher {
public:
Expand Down Expand Up @@ -192,11 +186,20 @@ class JitCompiler {
absl::flat_hash_map<std::string, ResolvedSymbol> symbols_map_;
};

JitCompiler(IrCompiler::TargetMachineBuilder target_machine_builder,
std::shared_ptr<llvm::TargetMachine> target_machine,
TaskDispatcher* task_dispatcher,
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs,
DefinitionGenerator definition_generator);

// Target machine builder that is used to construct target machines for this
// instance of `JitCompiler` (when compiling LLVM modules in parallel).
IrCompiler::TargetMachineBuilder target_machine_builder_;
std::shared_ptr<llvm::TargetMachine> target_machine_;

TaskDispatcher* task_dispatcher_; // owned by `execution_session_`

std::unique_ptr<llvm::orc::ExecutionSession> execution_session_;
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer_;
std::unique_ptr<llvm::orc::IRCompileLayer> compile_layer_;
Expand Down
5 changes: 5 additions & 0 deletions xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ cc_library(
":onednn_contraction_rewriter",
":onednn_ops_rewriter",
":parallel_task_assignment",
":runtime_symbol_generator",
":simple_orc_jit",
":thunk_emitter",
":xla_framework",
Expand All @@ -244,6 +245,7 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/backends/cpu/codegen:cpu_features",
"//xla/backends/cpu/codegen:function_library",
"//xla/backends/cpu/codegen:ir_compiler",
"//xla/backends/cpu/codegen:jit_compiler",
"//xla/backends/cpu/codegen:target_machine_features",
Expand Down Expand Up @@ -623,6 +625,7 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/codegen:function_library",
"//xla/backends/cpu/runtime:buffer_allocations",
"//xla/backends/cpu/runtime:thread_pool_task_runner",
"//xla/backends/cpu/runtime:thunk",
Expand All @@ -632,6 +635,7 @@ cc_library(
"//xla/service:custom_call_status",
"//xla/service:custom_call_status_internal",
"//xla/service:executable",
"//xla/service:hlo_execution_profile",
"//xla/service:hlo_profile_printer_data_cc",
"//xla/service:hlo_value",
"//xla/service:maybe_owning_device_memory",
Expand All @@ -653,6 +657,7 @@ cc_library(
"@llvm-project//llvm:OrcShared",
"@llvm-project//llvm:Support",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
Expand Down
Loading

0 comments on commit 6949b21

Please sign in to comment.