Skip to content

Commit

Permalink
Standardise on lowered_ir_mods and correct device_hook variable name
Browse files Browse the repository at this point in the history
  • Loading branch information
Mousius committed Nov 15, 2021
1 parent 4dbe678 commit a69a7ac
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/backend/executor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
----------
ir_mod : :py:class:`~tvm.IRModule`
The IR module to build.
built_ir_mods : dict[Target, IRModule]
The IR modules built per Target.
lowered_ir_mods : dict[Target, IRModule]
The IR modules lowered per Target.
target : tvm.Target
The Target used to build this module.
libmod : tvm.Module
Expand All @@ -92,10 +92,10 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
"""

def __init__(
self, ir_mod, built_ir_mods, target, libmod, libmod_name, params, function_metadata, devices
self, ir_mod, lowered_ir_mods, target, libmod, libmod_name, params, function_metadata, devices
):
self.ir_mod = ir_mod
self.built_ir_mods = built_ir_mods
self.lowered_ir_mods = lowered_ir_mods
self.target = target
self.lib = libmod
self.libmod_name = libmod_name
Expand Down
13 changes: 10 additions & 3 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def get_params(self):
return ret

def get_irmodule(self):
"""Returns the Target IRModule's from code generation"""
"""Returns the Target IRModule's post-lowering"""
return self._get_irmodule()


Expand Down Expand Up @@ -381,11 +381,18 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
)
func_metadata = bld_mod.get_function_metadata()
devices = bld_mod.get_devices()
final_ir_mods = bld_mod.get_irmodule()
lowered_ir_mods = bld_mod.get_irmodule()

if executor == "aot":
executor_factory = _executor_factory.AOTExecutorFactoryModule(
ir_mod, final_ir_mods, target, runtime_mod, mod_name, params, func_metadata, devices
ir_mod,
lowered_ir_mods,
target,
runtime_mod,
mod_name,
params,
func_metadata,
devices,
)
elif executor == "graph":
executor_factory = _executor_factory.GraphExecutorFactoryModule(
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,18 +440,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
* \return Statement with function calls for each device
*/
tir::Stmt GenerateAllDeviceHook(const String& hook) {
std::vector<tir::Stmt> device_activations;
std::vector<tir::Stmt> device_hooks;
for (const auto& it : devices_) {
const String& device_name = it.first;
const tir::Var& context = it.second;
Array<String> sections = {"Device", device_name, hook};
String device_activation = ToCFunctionStyle(PrefixName(sections));
String device_hook_name = ToCFunctionStyle(PrefixName(sections));

tir::Evaluate device_hook(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
{tvm::tir::StringImm(device_activation), context}));
device_activations.push_back(device_hook);
{tvm::tir::StringImm(device_hook_name), context}));
device_hooks.push_back(device_hook);
}
return tir::SeqStmt(device_activations);
return tir::SeqStmt(device_hooks);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def representative_dataset():
input_data,
output_data,
)
main_ir_module = list(compiled_models[0].executor_factory.built_ir_mods.values())[0]
main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0]
main_func = main_ir_module["run_model"]

# Activate Device
Expand Down

0 comments on commit a69a7ac

Please sign in to comment.