Skip to content

Commit

Permalink
Prepare for switching VM to LowerTEPass. (apache#9550)
Browse files Browse the repository at this point in the history
This is a grab bag of fallout changes from switching the VM to use LoweTEPass
which can be easily split out of the main apache#9483 PR.

- AnnotateSpans can be used from C++ (though, unfortunately, it didn't help
  me with debugging since spans are universally dropped in most passes).
- Can get a human readable dump of the VM's PackedFunc names and indexes for
  debugging.
- If TVM_LOG_DEBUG defined then include types and ids of GlobalVars. I had
  a lot of difficulty tracking down where duplicate GlobalVars for the same
  name_hint were getting created and propagated.
- GetCallLoweredProps follows same API as GetDeviceCopy and GetOnDevice
  where will return 'null' properties if call/expr is not of call_lowered
  form. Mildly more convenient, though switching all the above to ICHECK
  and push 'if (op == the relevant op)' into all use sites would also be just
  fine.
- Misc VLOG improvements made while tracking down issues in apache#9483.
  • Loading branch information
mbs-octoml authored and mehrdadh committed Dec 1, 2021
1 parent f831cc5 commit 6c2add9
Show file tree
Hide file tree
Showing 27 changed files with 311 additions and 112 deletions.
10 changes: 9 additions & 1 deletion include/tvm/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
#ifndef TVM_PARSER_PARSER_H_
#define TVM_PARSER_PARSER_H_
/*!
* \file parser.h
* \file include/tvm/parser/parser.h
* \brief A parser for TVM IR.
*/
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

Expand All @@ -39,6 +40,13 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte
const Optional<IRModule>& init_module = Optional<IRModule>(),
const MetaTable& init_meta_table = MetaTable());

/*!
* \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
* for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
* modules constructed programaticaly rather than textually.
*/
transform::Pass AnnotateSpans();

} // namespace parser
} // namespace tvm

Expand Down
11 changes: 9 additions & 2 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ class Executable : public ModuleNode {
*/
std::string GetVirtualDevices() const;

/*!
* \brief Returns a description of all the 'primitive' (ie PackedFuncs) in the executable.
* These correspond to eithed PrimFuncs we've compiled locally, or functions compiled by
* a BYOC external codegen.
*/
std::string GetPrimitives() const;

/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
Expand Down Expand Up @@ -201,9 +208,9 @@ class Executable : public ModuleNode {
int host_device_index = -1;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
/*! \brief A map from globals (as strings) to their index in the Relay function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
/*! \brief A mapping from the packed function's global name (as string) to the index that
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
*/
std::unordered_map<std::string, Index> primitive_map;
Expand Down
22 changes: 15 additions & 7 deletions include/tvm/target/compilation_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tvm {

/*!
* \brief Gathers the \p Targets and distinguished \p SEScopes in canonical form needed to
* compile a Relay module. All centralizes any setup and validation logic needed to transition
* compile a Relay module. Centralizes any setup and validation logic needed to transition
* from configuration options conveyed implicitly (eg in \p PassContexts) or explicitly
* (eg a a list of \p Targets) to the configuration.
*
Expand All @@ -49,9 +49,12 @@ namespace tvm {
class CompilationConfigNode : public Object {
public:
/*!
* \brief The legacy targets map, mapping device type to \p Targets. Does not include any
* entry for the host target. Intended to give a unique \p Target for every \p DLDeviceType,
* though we want to get rid of that limitation.
* \brief The legacy targets map, mapping device type to the corresponding \p Target to use
* when compiling primitive functions. Does not include an entry for the host target, however
* each \p Target in this map will have it's \p host field set to the \p host_target.
*
* Currently we require at most one \p Target per \p DLDeviceType, though we want to get rid of
* that limitation.
*
* CAUTION: Since keys are \p Integers they are compared by object equality not integer
* value.
Expand All @@ -63,13 +66,18 @@ class CompilationConfigNode : public Object {
/*!
* \brief The host target. Used for 'scalar' data and code (such as shapes and shape
* functions) and residual Relay expressions and data (such as conditionals and ADTs).
*
* Note that it is possible for a \p Target used for primitive operations to be structurally
* equal to the host \p Target (up to the \p host field.) However the \p Target objects will
* be distinct, and can be used as keys within a \p Map without collision.
*/
Target host_target;

/*!
* \brief Vector of all available targets for primitive operators. May contain a \p Target
* for the same device type as for the \p host_target, however the \p host_target should
* be preferred for all host computations and data.
* \brief Vector of all available \p Targets for compiling primitive operators. May contain
* a \p Target for the same device type as for the \p host_target, however the \p host_target
* should be used for all host computations and data. Each \p Target will have \p host_target
* as its host.
*/
Array<Target> primitive_targets;

Expand Down
6 changes: 6 additions & 0 deletions include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ class SEScope : public ObjectRef {
return SEScope(device.device_type, device.device_id, std::move(target));
}

/*! \brief Returns the \p SEScope for \p target. */
static SEScope ForTarget(Target target) {
return SEScope(static_cast<DLDeviceType>(target->kind->device_type), /*virtual_device_id=*/0,
std::move(target));
}

/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
MemoryScope memory_scope) {
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, mod):
self._get_bytecode = self.mod["get_bytecode"]
self._get_constants = self.mod["get_constants"]
self._get_virtual_devices = self.mod["get_virtual_devices"]
self._get_primitives = self.mod["get_primitives"]
self._get_stats = self.mod["get_stats"]
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]
Expand Down Expand Up @@ -257,6 +258,12 @@ def virtual_devices(self):
"""Returns a human-readable description of all the (virtual) devices in the executable."""
return self._get_virtual_devices()

@property
def primitive(self):
"""Returns a human-readable dencription of all the primitives (ie PackedFuncs) in the
executable"""
return self._get_primitives()

@property
def globals(self):
"""Get the globals used by the Relay VM executable.
Expand Down
13 changes: 8 additions & 5 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
ICHECK_EQ((*it).second, var);
} else {
ICHECK(global_var_map_.count(var->name_hint) == 0)
<< "Duplicate global function name " << var->name_hint;
<< "Duplicate global function name " << PrettyPrint(var);
}

global_var_map_.Set(var->name_hint, var);
Expand Down Expand Up @@ -243,7 +243,7 @@ void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData&
if (!update) {
// set global type var map
ICHECK(global_type_var_map_.count(var->name_hint) == 0)
<< "Duplicate global type definition name " << var->name_hint;
<< "Duplicate global type definition name " << PrettyPrint(var);
}
global_type_var_map_.Set(var->name_hint, var);
RegisterConstructors(var, type);
Expand All @@ -266,7 +266,7 @@ void IRModuleNode::Remove(const GlobalVar& var) {

BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var);
ICHECK(it != functions.end()) << "There is no definition of " << var->name_hint;
ICHECK(it != functions.end()) << "There is no definition of " << PrettyPrint(var);
return (*it).second;
}

Expand All @@ -277,7 +277,7 @@ BaseFunc IRModuleNode::Lookup(const String& name) const {

TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
ICHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint;
ICHECK(it != type_definitions.end()) << "There is no definition of " << PrettyPrint(var);
return (*it).second;
}

Expand Down Expand Up @@ -306,6 +306,10 @@ String IRModuleNode::GetUniqueName(const String& name) {
}
}

/*!
* \brief Renames global type/term variables to prefer the GlobalTypeVar/GlobalVar in the lhs
* ('one') side above the rhs ('two').
*/
struct Renamer : relay::ExprMutator, TypeMutator {
Map<String, GlobalVar> defs;
Map<String, GlobalTypeVar> types;
Expand Down Expand Up @@ -411,7 +415,6 @@ IRModule IRModule::FromExpr(const RelayExpr& expr, const Map<GlobalVar, BaseFunc
void IRModuleNode::Import(const String& path) {
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
std::fstream src_file(path, std::fstream::in);
std::string file_contents{std::istreambuf_iterator<char>(src_file),
std::istreambuf_iterator<char>()};
Expand Down
27 changes: 17 additions & 10 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1909,7 +1909,8 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,

IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(9) << "ParseModule";
VLOG_CONTEXT << "ParseModule";
VLOG(9) << "parsing and type-checking " << file_name;
auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
auto mod = parser.ParseModule();
ICHECK(mod.defined()) << "The parser must return a non-null module.";
Expand Down Expand Up @@ -1952,15 +1953,21 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr")
return ParseExpr(file_name, file_content);
});

TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() {
return CreateModulePass(
[](const IRModule& mod, const PassContext& ctx) {
String text = AsText(mod, /*show_meta_data=*/true);
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
return ParseModule("GeneratedSource", text);
},
0, "AnnotateSpans", {});
});
/*!
* \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
* for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
* modules constructed programaticaly rather than textually.
*/
Pass AnnotateSpans() {
auto pass_func = [](const IRModule& mod, const PassContext& ctx) {
String text = AsText(mod, /*show_meta_data=*/true);
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
return ParseModule("GeneratedSource", text);
};
return CreateModulePass(pass_func, 0, "AnnotateSpans", {});
}

TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans);

} // namespace parser
} // namespace tvm
12 changes: 11 additions & 1 deletion src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,17 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}

Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text("@" + op->name_hint); }
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
Doc doc;
doc << "@" << op->name_hint;
#if TVM_LOG_DEBUG
if (op->checked_type_.defined()) {
doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */";
}
doc << " /* id=" << reinterpret_cast<uint64_t>(op) << " */";
#endif
return doc;
}

Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }

Expand Down
20 changes: 18 additions & 2 deletions src/printer/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,29 @@ Doc TextPrinter::PrintMod(const IRModule& mod) {
if (kv.second.as<relay::FunctionNode>()) {
std::ostringstream os;
os << "def @" << kv.first->name_hint;
#if TVM_LOG_DEBUG
os << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
} else if (kv.second.as<tir::PrimFuncNode>()) {
doc << "@" << kv.first->name_hint << " = ";
doc << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
doc << "@" << kv.first->name_hint;
#if TVM_LOG_DEBUG
doc << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
}
doc << Doc::NewLine();
}
#if TVM_LOG_DEBUG
// attributes
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
doc << "attributes {" << Doc::NewLine();
for (const auto& kv : mod->attrs->dict) {
doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine();
}
doc << "}" << Doc::NewLine();
}
#endif
return doc;
}

Expand Down
33 changes: 27 additions & 6 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
Expr func;
Array<Expr> args;

if (call_node->op == CallLoweredOp()) {
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
if (call_lowered_props.lowered_func.defined()) {
func = call_lowered_props.lowered_func;
args = call_lowered_props.arguments;
} else { // Relay functions that have not been lowered and lowered extern functions
Expand Down Expand Up @@ -516,10 +516,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
call_lowered_props = CallLoweredProps{GetRef<GlobalVar>(gvn), call_node->args, {}};
} else {
ICHECK(call_node->op == CallLoweredOp()) << "Operators should be transformed away; Try "
"applying the fuse_ops transformation to the "
"expression.";
call_lowered_props = GetCallLoweredProps(call_node);
ICHECK(call_lowered_props.lowered_func.defined())
<< "Operators should be transformed away; Try "
"applying the fuse_ops transformation to the "
"expression.";
for (const auto& arg : call_lowered_props.arguments) {
VisitExpr(arg);
}
Expand Down Expand Up @@ -717,6 +718,14 @@ class AOTExecutorCodegen : public MixedModeVisitor {
: mod_(mod), targets_(targets), target_host_(target_host), use_unpacked_api_(Bool(false)) {}

LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) {
VLOG_CONTEXT << "AOT";
for (const auto& kv : targets_) {
VLOG(1) << "target: " << kv.second->ToDebugString();
}
if (target_host_.defined()) {
VLOG(1) << "target host: " << target_host_->ToDebugString();
}

Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
String interface_api = executor_config->GetAttr<String>("interface-api").value_or("packed");
Integer workspace_byte_alignment =
Expand Down Expand Up @@ -793,10 +802,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
}

// Build the TIR IRModule for the AOT function
// Build the TIR IRModule for the main AOT function
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
IRModule mod_run(symbol_map, {}, {}, {}, mod->attrs);
VLOG(1) << "main module:" << std::endl << PrettyPrint(mod_run);

// Apply storage rewrite pass to the runner function to do memory planning
auto storage_rewrite = tir::transform::StorageRewrite();
Expand Down Expand Up @@ -827,12 +837,23 @@ class AOTExecutorCodegen : public MixedModeVisitor {
ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point.";

// This is the point where we separate the functions in the module by target
VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod);
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
VLOG(1) << "per-target modules:";
for (const auto& kv : ret.lowered_funcs) {
VLOG(1) << "target:" << std::endl
<< kv.first->ToDebugString() << std::endl
<< "maps to:" << std::endl
<< PrettyPrint(kv.second);
}

ret.external_mods = external_modules.value();

if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
VLOG(1) << "merging main into existing module for host target";
ret.lowered_funcs[target_host_]->Update(mod_run);
} else {
VLOG(1) << "adding main into new module for host target";
ret.lowered_funcs.Set(target_host_, mod_run);
}

Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
std::vector<GraphNodeRef> inputs;
std::string func_name;

if (call->op == CallLoweredOp()) {
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
if (call_lowered_props.lowered_func.defined()) {
// Extract function and arguments from the call_lowered op
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

func_name = call_lowered_props.lowered_func->name_hint;

Expand Down
5 changes: 3 additions & 2 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,9 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
}

ObjectRef VisitExpr_(const CallNode* call_node) final {
if (call_node->op == CallLoweredOp()) { // Special case: Call a lowered TIR function.
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
if (call_lowered_props.lowered_func.defined()) {
// Special case: Call a lowered TIR function.

// Evaluate only function args
std::vector<ObjectRef> args;
Expand Down
Loading

0 comments on commit 6c2add9

Please sign in to comment.