From 562e93a6998fd09cf87fb5da7e0240fb9bea764c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 21 Jan 2021 18:21:57 -0800 Subject: [PATCH] Refactor the compile engine into a cleaner interface. Duplicate the CompileEngine interface. Refactor the graph_runtime_codegen to invoke the new LowerTE pass More changes Things appear to be working Some tracing to get Relay code to flow through too. Disable some assertions as exp. Tweak printing for now Fix a few bugs: (#13) 1. Don't add relay main function to list of lowered TIR functions 2. Don't skip visiting call to relay function in graph runtime codegen Remove debug prints. Start refactoring Split out shared data structures Fix implicit duplicate decl of IsDynamic Clean up handling of name + global prim fn Clean up the code and debug issue introduced by previous hack Clean up the debugging Do C++ lint clean up Update src/relay/backend/graph_executor_codegen.cc Co-authored-by: Chris Sullivan Clean up handling of external functions Add more error messages More clean up Update src/runtime/graph_executor/graph_executor.cc Co-authored-by: Chris Sullivan Update src/runtime/graph_executor/graph_executor.cc Co-authored-by: Chris Sullivan Update src/relay/backend/te_compiler.h Co-authored-by: Haichen Shen Update src/relay/backend/te_compiler.h Co-authored-by: Haichen Shen Fix CR More CR Format Fix lowering path for C++ Fix tests Remove uncessary change Clean up a few more things CI fix Fix the default context Fix Fix broken test cases Update Fix WIP Clean up storage data structures WIP WIP Fix build errors Remove TVMLower Fix lint Lint again fix black Move UpdateMainWorkspaceSize into te_compiler.cc Fix link errors Formatting Change UpdateMainWorkspaceSize to return Map Workaround for GCC 5 error caused by enums in maps (GCC 5 is on i386 CI) Testing how functions should be named Lint Change how function metadata is updated Attempt to update aot_executor_codegen to use new StaticMemoryPlan instead of storage_device_map Pass memory plan through LowerTE into UpdateMainWorkspaceSize so that we don't need to run GraphPlanMemory an extra time Fix return in UpdateMainWorkspaceSize Lint Try to fix UpdateMainWorkspaceSize Fix construction of static memory plan Clean up code while debugging Adding UpdateWorkspaceSize back Add closure + call to UpdateFunctionMetadata (WIP) UpdateFunctionMetadata builds; weird error with device ctx map though. Not sure if it came from this change or something else Add some debugging of UpdateMainWorkspaceSize Starting to move UpdateFunctionMetadata call to use process_fn infra UWhat target should be passed to UpdateFunctionMetadata? UpdateFunctionMetadata is not workinggg Added some comments about UpdateFunctionMetadata for Jared Fix the creation of function metadata Try another stab at cleaning up the information Fix Port StorageInfo and StaticMemoryPlan data structure (#8297) Restoring reshape opt Fix tests Caught a nasty typo from Lily, Map::Set does not mutate Format Disable stupid Google style warning Rebase cleanup Formatting Add docstring for storage info Black Post rebase fix Remove prints Disable assert that doesn't make sense for now Fix lint Add copying attrs from relay node to graph node; still need to figure out how to do this in the case of global vars Work with Lily to fix graph attrs Try to figure out where extra arguments are coming from; fix merge passes the profiling test Clean up Fix profile test Remove debugging Add attributes for BYOC uTVM case Format Dumb typo Another fix for byoc Format Fix last 3 failing tests Format Fix final two test cases Format Fix lint Fix again Fix Fix auto scheduler code Fix issue Address CR comment Format --- include/tvm/relay/attrs/annotation.h | 12 + .../tvm/auto_scheduler/relay_integration.py | 10 + python/tvm/auto_scheduler/task_scheduler.py | 2 +- python/tvm/relay/backend/compile_engine.py | 4 +- python/tvm/relay/expr.py | 24 +- src/driver/driver_api.cc | 10 +- src/relay/backend/aot_executor_codegen.cc | 18 +- src/relay/backend/compile_engine.cc | 663 ++-------------- src/relay/backend/compile_engine.h | 211 +---- src/relay/backend/graph_executor_codegen.cc | 392 +++------ src/relay/backend/graph_plan_memory.cc | 60 +- src/relay/backend/interpreter.cc | 3 +- src/relay/backend/te_compiler.cc | 743 ++++++++++++++++++ src/relay/backend/te_compiler.h | 196 +++++ src/relay/backend/te_compiler_cache.cc | 694 ++++++++++++++++ src/relay/backend/te_compiler_cache.h | 249 ++++++ src/relay/backend/utils.cc | 47 ++ src/relay/backend/utils.h | 4 + src/relay/backend/vm/compiler.cc | 7 +- src/relay/ir/function.cc | 14 +- .../auto_scheduler_layout_rewrite.cc | 2 +- src/relay/transforms/device_annotation.cc | 44 +- src/relay/transforms/memory_alloc.cc | 13 +- src/relay/transforms/type_infer.cc | 9 +- src/target/llvm/llvm_module.cc | 11 +- src/tir/transforms/legalize_packed_calls.cc | 2 +- .../test_auto_scheduler_task_extraction.py | 1 + .../relay/test_backend_graph_executor.py | 27 +- tests/python/relay/test_pass_annotation.py | 26 +- 29 files changed, 2340 insertions(+), 1158 deletions(-) create mode 100644 src/relay/backend/te_compiler.cc create mode 100644 src/relay/backend/te_compiler.h create mode 100644 src/relay/backend/te_compiler_cache.cc create mode 100644 src/relay/backend/te_compiler_cache.h diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 4a2eb63c7e6a..8379e6471561 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -67,6 +67,18 @@ struct CompilerAttrs : public tvm::AttrsNode { } }; +/*! + * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. + */ +struct TIRCallAttrs : public tvm::AttrsNode { + /*! \brief The metadata attached to the call node. */ + Map metadata; + + TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") { + TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ANNOTATION_H_ diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 0d18bc08e5ed..4b402a916267 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -318,6 +318,7 @@ def auto_schedule_topi(func_name, outs): A tuned schedule or none (if not tuned) in the final build mode; None in the tracing mode so that the fallback topi schedule will be used. """ + # pylint: disable=import-outside-toplevel from tvm.auto_scheduler.measure import ( prepare_input_map, @@ -376,6 +377,15 @@ def auto_schedule_topi(func_name, outs): return schedule +@tvm._ffi.register_func("auto_scheduler.relay_integration.te_compiler_update_weights") +def te_compiler_update_weights(function_weights): + """A callback for updating the weights of extracted tasks.""" + env = TracingEnvironment.current + if env is not None: + for key in env.wkl_key_to_weight: + env.wkl_key_to_weight[key] = function_weights[key[0]] + + def tensor_no_check_call(self, *indices): """An indexing function without any check. This is the same as `tvm.te.Tensor::__call__` except that the safety diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index dd5073331083..023fdc770a30 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -598,7 +598,7 @@ def pre_tune(self, task_scheduler, task_id): # overall info if all(cost < 1e9 for cost in task_scheduler.best_costs): - total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3) + total_latency_str = "%.3f" % (task_scheduler.cur_score.value * 1e3) else: total_latency_str = "-" print( diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 2db8c5a669f0..e9129db7b200 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -429,7 +429,7 @@ def dump(self): res += "------------------------------------\n" res += "target={}\n".format(k.target) res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.func_name) + res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) res += "----relay function----\n" res += k.source_func.astext() + "\n" res += "----tir function----- \n" @@ -444,7 +444,7 @@ def dump(self): res += "------------------------------------\n" res += "target={}\n".format(k.target) res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.func_name) + res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) res += "----relay function----\n" res += k.source_func.astext() + "\n" res += "----tir function----- \n" diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 8d73a090ed6f..8461885b38ce 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -23,7 +23,7 @@ import tvm._ffi from tvm._ffi import base as _base from tvm.runtime import NDArray, ndarray as _nd -from tvm.ir import RelayExpr, GlobalVar +from tvm.ir import RelayExpr, GlobalVar, Node from .base import RelayNode from . import _ffi_api @@ -538,3 +538,25 @@ def bind(expr, binds): The expression or function after binding. """ return _ffi_api.Bind(expr, binds) + + +@tvm._ffi.register_object("relay.StorageInfo") +class StorageInfo(Node): + """StorageInfo + + The static storage information produced by memory planning. + Contains the storage ids where expressions are stored, the + type of the "virtual devices" the expressions are stored on, + and the sizes of each storage element.""" + + @property + def storage_ids(self): + return _ffi_api.StorageInfoStorageIds(self) + + @property + def device_types(self): + return _ffi_api.StorageInfoDeviceTypes(self) + + @property + def storage_sizes(self): + return _ffi_api.StorageInfoStorageSizes(self) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cd8173717d5f..50f00140df9b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -437,14 +437,18 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target } if (target->kind->device_type == kDLCPU && target_host == target) { - ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " - << "and host_target are both llvm target." - << "\n"; + // TODO(@jroesch): This check is no longer true we need to figure out if we care about this. + // We need to relax this check for just TIR functions. + // ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " + // << "and host_target are both llvm target." + // << "\n"; } return {mhost, mdevice}; } +// Can we make this take one annotated IRModule? +// // Build for heterogeneous execution. runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 9b495adbdea8..84c17b53c83e 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -439,7 +439,7 @@ class AOTExecutorCodegen : public ExprVisitor { fi_node->tir_primfuncs.Set(primfunc_target, primfunc); fi_node->relay_primfuncs.Set(primfunc_target, relay_func); } - function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node)); + function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node)); } void VisitExpr_(const CallNode* op) override { @@ -465,20 +465,18 @@ class AOTExecutorCodegen : public ExprVisitor { << "(i.e functions composed of fusable operator invocations)"; } - auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); - auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); Target target; // Handle external function if (func->GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); - CCacheKey key = (*pf0)(func, target); - CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_); + CCacheKey key = CCacheKey(func, target); + CachedFunc ext_func = compile_engine_->Lower(key, mod_name_); ICHECK(ext_func.defined()) << "External function is not defined."; UpdateConstants(func, ¶ms_); // Generate the TIR function call - CreateFuncCall(GetRef(op), ext_func->func_name); + CreateFuncCall(GetRef(op), ext_func->prim_fn_var->name_hint); return; } @@ -503,8 +501,10 @@ class AOTExecutorCodegen : public ExprVisitor { } target = targets_[call_dev_type]; } - CCacheKey key = (*pf0)(func, target); - CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_); + + CCacheKey key = CCacheKey(func, target); + CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_); + if (!lowered_funcs_.count(target->str())) { lowered_funcs_[target->str()] = IRModule(Map({})); } @@ -513,7 +513,7 @@ class AOTExecutorCodegen : public ExprVisitor { UpdateFunctionMetadata(lowered_func, func, target); // Generate the TIR function call - CreateFuncCall(GetRef(op), lowered_func->func_name); + CreateFuncCall(GetRef(op), lowered_func->prim_fn_var->name_hint); } void VisitExpr_(const VarNode* op) override { diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index f0b43b14c650..6142e8323dea 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -46,569 +46,14 @@ #include "../../runtime/meta_data.h" #include "../transforms/pass_utils.h" +#include "te_compiler_cache.h" #include "utils.h" namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(LoweredOutputNode); -TVM_REGISTER_NODE_TYPE(CachedFuncNode); -TVM_REGISTER_NODE_TYPE(CCacheKeyNode); -TVM_REGISTER_NODE_TYPE(CCacheValueNode); TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); -LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { - auto n = make_object(); - n->outputs = std::move(outputs); - n->implementation = std::move(impl); - data_ = std::move(n); -} - -CCacheKey::CCacheKey(Function source_func, Target target) { - auto n = make_object(); - n->source_func = std::move(source_func); - n->target = std::move(target); - data_ = std::move(n); -} - -Array GetShape(const Array& shape) { - // for now, we always use int32 shape when possible - // even if the result of shape inference becomes int64. - Array res; - for (IndexExpr val : shape) { - const int64_t* pval = tir::as_const_int(val); - if (pval != nullptr) { -#ifndef TVM_INDEX_DEFAULT_I64 - ICHECK_LE(pval[0], std::numeric_limits::max()); - ICHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(IntImm(DataType::Int(32), *pval)); -#else - res.push_back(val); -#endif // TVM_INDEX_DEFAULT_I64 - } else if (val->IsInstance()) { - res.push_back(val.as()->ToVar()); - } else { - res.push_back(val); - } - } - return res; -} - -// The getter to get schedule from compile engine. -// Get schedule from functor. -class ScheduleGetter : public backend::MemoizedExprTranslator> { - public: - explicit ScheduleGetter(Target target) - : target_(target), device_copy_op_(Op::Get("device_copy")) { - // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - } - - CachedFunc Create(const Function& prim_func) { - auto cache_node = make_object(); - cache_node->target = target_; - for (Var param : prim_func->params) { - Array inputs; - if (const auto* ttype = param->checked_type().as()) { - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - // TODO(@icemelon): Allow recursive tuple - ICHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } - } - memo_[param] = inputs; - } - readable_name_stream_ << "fused"; - cache_node->outputs = this->VisitExpr(prim_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - cache_node->func_name = candidate_name; - ICHECK(anchor_op_.defined()); - // Fusion over tupled results may leave identity relationships - // between inputs and outputs, and those should not be scheduled. - // Hence schedule only non PlaceholderOp outputs. - tvm::Array tensor_outs; - for (const auto& tensor : cache_node->outputs) { - if (!tensor->op.as()) { - tensor_outs.push_back(tensor); - } - } - - te::Schedule schedule; - // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr) { - if (use_auto_scheduler_) { - const auto* fauto_schedule = - runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); - ICHECK(fauto_schedule != nullptr) - << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; - ObjectRef obj = (*fauto_schedule)(String(cache_node->func_name), tensor_outs); - if (obj.defined()) { - schedule = Downcast(obj); - } - } - - // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined()) { - ICHECK(anchor_implementation_.defined()); - schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); - } - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); - } - } - } - cache_node->schedule = std::move(schedule); - return CachedFunc(cache_node); - } - - Array VisitExpr_(const VarNode* op) final { - LOG(FATAL) << "Free variable " << op->name_hint(); - return {}; - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(op->is_scalar()); - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "compile_engine_const", topi::kBroadcast); - scalars_.push_back(value->op); - return {value}; - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = Op::GetAttrMap("TOpPattern"); - static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); - ICHECK(flower_call) << "relay.backend.lower_call is not registered."; - - Array inputs; - int count_tuple = 0; - for (Expr arg : call_node->args) { - if (arg->checked_type().as()) { - ++count_tuple; - } - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - - Array outputs; - OpImplementation impl; - // Skip fcompute for device copy operators as it is not registered. - if (op == device_copy_op_) { - const auto* copy_input = inputs[0].operator->(); - outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); - } else { - LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); - outputs = lowered_out->outputs; - impl = lowered_out->implementation; - } - - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern > anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; - } - if (outputs.size() != 1) { - const auto* tuple_type = call_node->checked_type().as(); - ICHECK(tuple_type) << "Expect output to be a tuple type"; - ICHECK_EQ(tuple_type->fields.size(), outputs.size()); - } - // Set the name to `__copy`. It will be detected in graph executor to perform - // data copy across devices. - if (op == device_copy_op_) { - readable_name_stream_.str(std::string()); - readable_name_stream_ << "__copy"; - } else { - readable_name_stream_ << '_' << op->name; - } - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - const auto* tuple_type = op->tuple->type_as(); - Array tuple = VisitExpr(op->tuple); - ICHECK_EQ(tuple_type->fields.size(), tuple.size()); - ICHECK_GE(op->index, 0); - ICHECK_LT(static_cast(op->index), tuple.size()); - return {tuple[op->index]}; - } - - private: - tvm::Target target_; - Op anchor_op_; - Attrs anchor_attrs_; - int anchor_op_pattern_{-1}; - OpImplementation anchor_implementation_; - std::ostringstream readable_name_stream_; - Array scalars_; - bool use_auto_scheduler_; - // Cache device copy op for equivalence checking to reduce registry lookup - // overhead for each invocation of call node when retrieving schedules. - const Op& device_copy_op_; -}; - -/*! - * \brief Create schedule for target. - * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \return Pair of schedule and cache. - * The funcs field in cache is not yet populated. - */ -CachedFunc CreateSchedule(const Function& source_func, const Target& target) { - return ScheduleGetter(target).Create(source_func); -} - -// Creates shape function from functor. -class MakeShapeFunc : public backend::MemoizedExprTranslator> { - public: - MakeShapeFunc() {} - - std::pair Create(const Function& prim_func) { - for (auto param : prim_func->params) { - param_states_[param] = kNoNeed; - Array data_inputs; - Array shape_inputs; - - auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { - // Add data placeholder - Shape shape = GetShape(ttype->shape); - tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); - data_inputs.push_back(data_tensor); - // Add shape placeholder - int64_t ndim = shape.size(); - Shape sshape; - if (ndim > 0) { - sshape.push_back(tvm::Integer(ndim)); - } - tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); - shape_inputs.push_back(shape_tensor); - }; - - if (const auto* ttype = param->checked_type().as()) { - add_placeholder(ttype); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - // TODO(@icemelon): Support recursive tuple - ICHECK(tuple_type); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - ICHECK(ttype); - add_placeholder(ttype); - } - } - param_data_[param] = data_inputs; - param_shapes_[param] = shape_inputs; - } - readable_name_stream_ << "shape_func"; - auto cache_node = make_object(); - cache_node->outputs = VisitExpr(prim_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - cache_node->func_name = candidate_name; - - // set inputs - for (auto param : prim_func->params) { - int state = param_states_[param]; - cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); - if (state & kNeedInputData) { - for (auto t : param_data_[param]) { - cache_node->inputs.push_back(t); - } - } - if (state & kNeedInputShape) { - for (auto t : param_shapes_[param]) { - cache_node->inputs.push_back(t); - } - } - } - - CachedFunc cfunc(cache_node); - // generate schedule for shape func - Array out_ops; - for (auto t : cache_node->outputs) { - out_ops.push_back(t->op); - } - auto schedule = te::create_schedule(out_ops); - tvm::te::AutoInlineInjective(schedule); - for (const auto& scalar : scalars_) { - auto scalar_op = scalar->op; - if (schedule->Contain(scalar_op)) { - schedule[scalar_op].compute_inline(); - } - } - return std::make_pair(schedule, cfunc); - } - - Array VisitExpr(const Expr& expr) final { - if (expr.as()) { - // Do not memoize vars because shape functions could use either the data - // or the shape of a var each time. - return ExprFunctor::VisitExpr(expr); - } - // For other case, do memoized visit - return backend::MemoizedExprTranslator>::VisitExpr(expr); - } - - Array VisitExpr_(const VarNode* var_node) final { - auto var = GetRef(var_node); - auto it = param_states_.find(var); - if (it == param_states_.end()) { - LOG(FATAL) << "Free variable " << var->name_hint(); - return {}; - } else { - ICHECK(data_dependents_per_input_.size()); - auto data_dependent = data_dependents_per_input_.back(); - if (data_dependent) { - param_states_[var] |= kNeedInputData; - return param_data_[var]; - } else { - param_states_[var] |= kNeedInputShape; - return param_shapes_[var]; - } - } - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(data_dependents_per_input_.size()); - bool data_dependent = data_dependents_per_input_.back(); - if (!op->is_scalar()) { - // This is a constant weight, extract the shape of the weight tensor. - // This can not be data dependent. - CHECK(!data_dependent); - auto ttype = op->checked_type().as(); - int ndim = static_cast(ttype->shape.size()); - Array out_shape{ndim}; - te::Tensor value = tvm::te::compute( - out_shape, - [&](const Array& indices) { - auto idx = indices[0]; - PrimExpr ret = make_const(DataType::Int(64), 0); - for (int i = 0; i < ndim; i++) { - ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); - } - return ret; - }, - "shape_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } - if (data_dependent) { - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = tvm::te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "data_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } else { - auto value = tvm::te::compute( - {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, - "shape_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto fshape_func = Op::GetAttrMap("FShapeFunc"); - static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) - << "Error in op fusion: output of the shape func is fed to a " - << "data-dependent shape func"; - ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; - ICHECK_GT(tshape_data_dependent.count(op), 0) - << "Internal error, cannot find TShapeDataDependent for " << op->name; - - Array dep_spec = tshape_data_dependent[op]; - if (dep_spec.size() == 1) { - // This is for cases when data dependence is specified per op - // Replicate 0 or 1 flag to all arguments - for (size_t i = 1; i < call_node->args.size(); ++i) { - dep_spec.push_back(dep_spec[0]); - } - } - - // Visit all inputs - Array inputs; - int count_tuple = 0; - for (size_t i = 0; i < call_node->args.size(); ++i) { - Expr arg = call_node->args[i]; - if (arg->checked_type().as()) { - ++count_tuple; - } - data_dependents_per_input_.push_back(dep_spec[i]->value != 0); - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - data_dependents_per_input_.pop_back(); - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - // Get output ndims - auto ret_type = call_node->checked_type(); - Array out_ndims; - if (const auto* ttype = ret_type.as()) { - out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); - } else { - auto rtype = ret_type.as(); - // TODO(@icemelon): Allow recursive tuple - ICHECK(rtype); - for (size_t i = 0; i < rtype->fields.size(); ++i) { - auto ttype = rtype->fields[i].as(); - ICHECK(ttype); - out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); - } - } - // Call shape function - auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); - readable_name_stream_ << "_" << op->name; - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - Array input_shapes = VisitExpr(op->tuple); - Array out; - out.push_back(input_shapes[op->index]); - return out; - } - - private: - /*! \brief String stream for function name */ - std::ostringstream readable_name_stream_; - /*! \brief Map from parameter to its shape function usage state */ - std::unordered_map param_states_; - /*! \brief Map from parameter to list of data placeholder */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; - /*! \brief Map from parameter to list of shape placeholder */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; - /*! \brief Stack of data dependencies for shape function, specified per each op input */ - std::vector data_dependents_per_input_; - /*! \brief Scalars used in the shape function */ - Array scalars_; -}; - class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. @@ -616,19 +61,19 @@ class CompileEngineImpl : public CompileEngineNode { return LowerInternal(key, mangle_fn)->cached_func; } + CachedFunc Lower(const CCacheKey& key, const String mod_name) { + auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; + + return Lower(key, mangle_fn); + } + // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { auto mangle_fn = [](String name) { return name; }; CCacheValue value = LowerInternal(key, mangle_fn); if (value->packed_func != nullptr) return value->packed_func; - // build the function. - tvm::runtime::Module m; - if (const auto* f = runtime::Registry::Get("relay.backend.build")) { - m = (*f)(value->cached_func->funcs, key->target); - } else { - m = build(value->cached_func->funcs, key->target, Target(nullptr)); - } - value->packed_func = m.GetFunction(value->cached_func->func_name); + auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); + value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); return value->packed_func; } @@ -643,6 +88,7 @@ class CompileEngineImpl : public CompileEngineNode { for (const auto& it : cache_) { auto src_func = it.first->source_func; ICHECK(src_func.defined()); + if (src_func->GetAttr(attr::kCompiler).defined()) { auto code_gen = src_func->GetAttr(attr::kCompiler); ICHECK(code_gen.defined()) << "No external codegen is set"; @@ -651,7 +97,9 @@ class CompileEngineImpl : public CompileEngineNode { auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" - << AsText(src_func, false); + << AsText(src_func, false) << "\n" + << "Functions with external codegen must have the " + << tvm::attr::kGlobalSymbol << " attr set."; std::string sn = symbol_name.value(); if (!cached_symbol.count(sn)) { @@ -669,7 +117,12 @@ class CompileEngineImpl : public CompileEngineNode { src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); runtime::Module ext_mod = (*pf)(src_func); - ICHECK(ext_mod.defined()) << "No external runtime is generated."; + // todo(@zhiics, @jroesch): Should this be a user visible error? + ICHECK(ext_mod.defined()) << "No external library was generated for " << ext_name + << "even though it was requested" + "by the annotated function " + << PrettyPrint(src_func); + ret.push_back(ext_mod); } } @@ -734,44 +187,49 @@ class CompileEngineImpl : public CompileEngineNode { // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. if (key->source_func->GetAttr(attr::kCompiler).defined()) { - auto cache_node = make_object(); + auto ir_module = IRModule(); const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; - cache_node->func_name = std::string(name_node.value()); - cache_node->target = Target("ext_dev"); - cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func); - value->cached_func = CachedFunc(cache_node); + auto func_name = std::string(name_node.value()); + auto target = Target("ext_dev"); + auto global_var = GlobalVar(func_name); + global_var->checked_type_ = key->source_func->checked_type(); + ir_module->Add(global_var, key->source_func); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); return value; } + // Enforce use the target. With target_scope(key->target); ICHECK(!value->cached_func.defined()); - auto cfunc = CreateSchedule(key->source_func, key->target); - auto cache_node = make_object(*(cfunc.operator->())); + auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(mangle_fn(name), &name_map_); + }); // Skip lowering for device copy node. const Expr body = (key->source_func)->body; if (const CallNode* call_node = body.as()) { if (call_node->attrs.as()) { - value->cached_func = CachedFunc(cache_node); + value->cached_func = cfunc; return value; } } - cache_node->func_name = GetUniqueName(mangle_fn(cache_node->func_name)); // NOTE: array will copy on write. - Array all_args = cache_node->inputs; - for (te::Tensor arg : cache_node->outputs) { + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { all_args.push_back(arg); } // lower the function std::unordered_map binds; - cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); + value->cached_func = cfunc; - value->cached_func = CachedFunc(cache_node); return value; } + // implement lowered shape func CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { std::lock_guard lock(mutex_); @@ -790,47 +248,17 @@ class CompileEngineImpl : public CompileEngineNode { With target_scope(key->target); ICHECK(!value->cached_func.defined()); - auto spair = MakeShapeFunc().Create(key->source_func); - auto cache_node = make_object(*(spair.second.operator->())); - cache_node->func_name = GetUniqueName(cache_node->func_name); - cache_node->target = key->target; - - Array all_args = cache_node->inputs; - for (te::Tensor arg : cache_node->outputs) { - all_args.push_back(arg); - } - using tvm::transform::PassContext; With fresh_pass_ctx_scope(PassContext::Create()); - std::unordered_map binds; - cache_node->funcs = tvm::LowerSchedule(spair.first, all_args, cache_node->func_name, binds); - value->cached_func = CachedFunc(cache_node); + auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(name, &name_map_); + }); + + value->cached_func = cached_func; return value; } - /*! - * \brief Get unique name from name. - * \param name The orginal name. - * \return Updated name which is unique. - */ - std::string GetUniqueName(std::string name) { - for (size_t i = 0; i < name.length(); ++i) { - if (name[i] == '.') name[i] = '_'; - } - while (true) { - auto it = name_map_.find(name); - if (it == name_map_.end()) { - name_map_[name] = 1; - return name; - } else { - std::ostringstream os; - os << name << "_" << it->second; - ++(it->second); - name = os.str(); - } - } - return name; - } + /*! \brief compiler cache lock*/ std::mutex mutex_; /*! \brief internal name map to get an unique name */ @@ -874,10 +302,7 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](Compi TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") .set_body_typed([](CompileEngine self, CCacheKey key, const String mod_name) { - auto mangle_fn = [mod_name](String name) { - return runtime::get_name_mangled(mod_name, name); - }; - return self->Lower(key, mangle_fn); + return self->Lower(key, mod_name); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index f766fcf97ea7..4afdc6d30485 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -19,8 +19,12 @@ /*! * \file relay/backend/compile_engine.h - * \brief Internal compialtion engine handle function cache. - * and interface to low level code generation. + * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. + * + * This layer represents the older design of the Relay compilation flow and is being deprecated + * in favor of te_compiler.h which is a migration step towards a standard pass based lowering of + * Relay functions. + * */ #ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ @@ -36,157 +40,12 @@ #include #include +#include "te_compiler_cache.h" + namespace tvm { namespace relay { -/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */ -enum ShapeFuncParamState { - kNoNeed = 0, - kNeedInputData = 1, - kNeedInputShape = 2, - kNeedBoth = 3, -}; - -struct LoweredOutputNode : public Object { - /*! \brief The outputs to the function */ - tvm::Array outputs; - /*! \brief The implementation used to compute the output */ - OpImplementation implementation; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("outputs", &outputs); - v->Visit("implementation", &implementation); - } - - static constexpr const char* _type_key = "relay.LoweredOutput"; - TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); -}; - -class LoweredOutput : public ObjectRef { - public: - TVM_DLL LoweredOutput(tvm::Array outputs, OpImplementation impl); - - TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode); -}; - -/*! \brief Node container to represent a cached function. */ -struct CachedFuncNode : public Object { - /* \brief compiled target */ - tvm::Target target; - /*! \brief Function name */ - std::string func_name; - /* \brief The inputs to the function */ - tvm::Array inputs; - /* \brief The outputs to the function */ - tvm::Array outputs; - /*! \brief The schedule to the function */ - te::Schedule schedule; - /*! \brief The lowered functions to support the function. */ - IRModule funcs = IRModule(Map({})); - - /*! \brief Parameter usage states in the shape function. */ - tvm::Array shape_func_param_states; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("target", &target); - v->Visit("func_name", &func_name); - v->Visit("inputs", &inputs); - v->Visit("outputs", &outputs); - v->Visit("schedule", &schedule); - v->Visit("funcs", &funcs); - v->Visit("shape_func_param_states", &shape_func_param_states); - } - - static constexpr const char* _type_key = "relay.CachedFunc"; - TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); -}; - -class CachedFunc : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); -}; - -class CCacheKey; -/*! \brief Compile cache key */ -class CCacheKeyNode : public Object { - public: - /*! \brief The source function to be lowered. */ - Function source_func; - /*! \brief The hardware target.*/ - Target target; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("source_func", &source_func); - v->Visit("target", &target); - } - /*! \return The hash value of CCacheKey. */ - inline size_t Hash() const; - /*! - * \brief check content equality - * \param other The other value. - * \return The result of equality check. - */ - inline bool Equal(const CCacheKeyNode* other) const; - - static constexpr const char* _type_key = "relay.CCacheKey"; - TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); - - private: - /*! - * \brief internal cached hash value. - */ - mutable size_t hash_{0}; -}; - -/*! \brief cache entry used in compile engine */ -class CCacheKey : public ObjectRef { - public: - CCacheKey() {} - explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} - - /*! - * \brief The constructor - * \param source_func The source function. - * \param target The target device. - */ - TVM_DLL CCacheKey(Function source_func, Target target); - - const CCacheKeyNode* operator->() const { return static_cast(get()); } - // comparator - inline bool operator==(const CCacheKey& other) const { - ICHECK(defined() && other.defined()); - return (*this)->Equal(other.operator->()); - } - using ContainerType = CCacheKeyNode; -}; - -/*! \brief Node container for compile cache. */ -class CCacheValueNode : public Object { - public: - /*! \brief The corresponding function */ - CachedFunc cached_func; - /*! \brief Result of Packed function generated by JIT */ - PackedFunc packed_func; - /*! \brief usage statistics */ - int use_count{0}; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("cached_func", &cached_func); - v->Visit("use_count", &use_count); - } - static constexpr const char* _type_key = "relay.CCacheValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); -}; - -/*! \brief cache entry used in compile engine */ -class CCacheValue : public ObjectRef { - public: - CCacheValue() {} - explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} - CCacheValueNode* operator->() { return static_cast(get_mutable()); } - const CCacheValueNode* operator->() const { return static_cast(get()); } - using ContainerType = CCacheValueNode; -}; +using namespace tvm::relay::tec; /*! * \brief Backend compilation engine for @@ -199,10 +58,18 @@ class CompileEngineNode : public Object { /*! * \brief Get lowered result. * \param key The key to the cached function. - * \param mod_name The module name to mangle the functions + * \param mod_name The mangling function for mangling names. * \return The result. */ virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; + + /*! + * \brief Get lowered result. + * \param key The key to the cached function. + * \param mod_name The module name to mangle the functions. + * \return The result. + */ + virtual CachedFunc Lower(const CCacheKey& key, const String mangle_fn) = 0; /*! * \brief Just in time compile to get a PackedFunc. * \param key The key to the cached function. @@ -242,49 +109,7 @@ class CompileEngine : public ObjectRef { TVM_DLL static CompileEngine& Global(); }; -/*! - * \brief Create schedule for target. - * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \return Pair of schedule and cache. - * The funcs field in cache is not yet populated. - */ -CachedFunc CreateSchedule(const Function& source_func, const Target& target); - -/*! - * \brief Check if the type is dynamic. - * \param ty The type to be checked. - * \return The result. - */ -bool IsDynamic(const Type& ty); - -// implementations -inline size_t CCacheKeyNode::Hash() const { - if (hash_ != 0) return hash_; - // do structral hash, avoid 0. - hash_ = tvm::StructuralHash()(this->source_func); - hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); - if (hash_ == 0) hash_ = 1; - return hash_; -} - -inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { - if (Hash() != other->Hash()) return false; - return this->target->str() == other->target->str() && - tvm::StructuralEqual()(this->source_func, other->source_func); -} - } // namespace relay } // namespace tvm -namespace std { -// overload hash -template <> -struct hash<::tvm::relay::CCacheKey> { - size_t operator()(const ::tvm::relay::CCacheKey& key) const { - ICHECK(key.defined()); - return key->Hash(); - } -}; -} // namespace std #endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index bca8e8244093..9d59e8e5f3a8 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -36,10 +37,13 @@ #include #include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { namespace relay { +// TODO(@jroesch, @csullivan): declare directly elsewhere +backend::StaticMemoryPlan GraphPlanMemory(const Function& func); namespace backend { class GraphNode; @@ -52,7 +56,6 @@ using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; using GraphInputObjectPtr = std::shared_ptr; using GraphOpObjectPtr = std::shared_ptr; -using TargetsMap = std::unordered_map; /*! \brief Node types */ enum GraphNodeType { @@ -176,112 +179,86 @@ class GraphOpNode : public GraphNode { const std::string op_type_name_{"tvm_op"}; }; -/*! \brief Code generator for graph executor */ +/*! \brief Code generator for the graph executor, produces a module containing the graph JSON, + * module, and parameters. + */ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> { public: - GraphExecutorCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) { - compile_engine_ = CompileEngine::Global(); + GraphExecutorCodegen(runtime::Module* mod, const TargetMap& targets) : mod_(mod) { targets_ = targets; } - /*! - * \brief Update the "main" control function's metadata - * - * \param func The main function that contains calls to relay primitive functions - */ - void UpdateMainWorkspaceSize(const Function& func) { - // This is a Map> - std::unordered_map> sid_workspace; - // This is a Map - std::unordered_map device_io; - // This is a Map - std::unordered_map device_consts; - - // Initialize the maps to zero - for (const auto& kv : storage_device_map_) { - auto sids = kv.second[0]; - auto devices = kv.second[1]; - CHECK_EQ(sids.size(), devices.size()); - for (uint32_t i = 0; i < sids.size(); i++) { - sid_workspace[devices[i]][sids[i]] = 0; - device_io[devices[i]] = 0; - device_consts[devices[i]] = 0; - } - } - - // Collect sizes of tensors - for (const auto& kv : storage_device_map_) { - auto size_bytes = CalculateRelayExprSizeBytes(kv.first->checked_type()); - auto sids = kv.second[0]; - auto devices = kv.second[1]; - if (kv.first->IsInstance()) { - for (const auto& dev : devices) { - device_consts[dev] += size_bytes; - } - continue; - } else if (kv.first->IsInstance() || kv.first == func->body) { - for (const auto& dev : devices) { - device_io[dev] += size_bytes; - } - continue; - } - for (uint32_t i = 0; i < sids.size(); i++) { - // Here we record the largest size of the tensor - // that share the same storage id, because storage_id will - // be shared between multiple tensors that are not live simultaneously. - if (size_bytes > sid_workspace[devices[i]][sids[i]]) { - sid_workspace[devices[i]][sids[i]] = size_bytes; - } - } - } - - // This is a Map - std::unordered_map device_workspace; - // Once we know the sizes of sids, we need to accumulate per device - for (const auto& dev_sid_size : sid_workspace) { - auto dev = dev_sid_size.first; - device_workspace[dev] = 0; - for (const auto& sid_size : dev_sid_size.second) { - device_workspace[dev] += sid_size.second; - } - } - - // Populate FunctionInfo - auto fi_node = make_object(); - // Initialize all target workspaces to zero - for (const auto& kv : targets_) { - auto tgt = kv.second; - fi_node->workspace_sizes.Set(tgt, 0); - } - for (const auto& dev_and_size : device_workspace) { - auto tgt = GetTargetFromInteger(dev_and_size.first); - fi_node->workspace_sizes.Set(tgt, dev_and_size.second); - fi_node->relay_primfuncs.Set(tgt, func); - } - for (const auto& dev_and_size : device_io) { - auto tgt = GetTargetFromInteger(dev_and_size.first); - fi_node->io_sizes.Set(tgt, dev_and_size.second); - } - for (const auto& dev_and_size : device_consts) { - auto tgt = GetTargetFromInteger(dev_and_size.first); - fi_node->constant_sizes.Set(tgt, dev_and_size.second); - } - - function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node)); + StorageInfo GetStorageInfo(const Expr& e) { + size_t count = memory_plan_->expr_to_storage_info.count(e); + ICHECK_GT(count, 0) << "Expr is not existing in storage plan"; + auto storage_info = memory_plan_->expr_to_storage_info[e]; + return storage_info; } LoweredOutput Codegen(relay::Function func, String mod_name) { - auto pf = GetPackedFunc("relay.backend.GraphPlanMemory"); - storage_device_map_ = (*pf)(func); mod_name_ = mod_name; - UpdateMainWorkspaceSize(func); + + // TODO(@jroesch): we need to split device planning and memory planning + // first we run device assignment, then we perform lowering, and then + // storage planning in ideal world. + + memory_plan_ = GraphPlanMemory(func); + + // This first phase moves from implicit use of compile engine, + // to instead explicitly lowering the incoming IRModule, and then + // performing the preexisting graph executor code generation phase. + IRModule mod = IRModule::FromExpr(func); + + // Build a map from each operation to device. + tec::DeviceMap device_context_map; + for (const auto& it : memory_plan_->expr_to_storage_info) { + auto expr = it.first; + auto storage_info = it.second; + auto device_types = storage_info->device_types; + // CHECK_EQ(device_types.size(), 1); + tvm::Device dev; + dev.device_id = 0; + dev.device_type = device_types[0]; + device_context_map.insert({expr, dev}); + } + + auto lowered_module = tec::LowerTE(mod, targets_, device_context_map, memory_plan_, mod_name_, + [this](Function func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } + + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + UpdateFunctionMetadata(func, this->function_metadata_); + }); + + function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); + auto main_module = lowered_module.main_module; + main_module = relay::transform::InferType()(main_module); + relay::Function main_func = Downcast(main_module->Lookup("main")); + + // Now that we have lowered all operators to TIR code, we can proceed with compilation. + // + // We need to unfortunately re-plan as the previous results have been invalidated by lowering + // we will fix this in future refactors. + memory_plan_ = GraphPlanMemory(main_func); + + // The graph planner also can not handle planning calls to global variables to we must remap + // First we convert all the parameters into input nodes. - for (auto param : func->params) { + for (auto param : main_func->params) { auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs()); var_map_[param.get()] = AddNode(node_ptr, param); } - heads_ = VisitExpr(func->body); + + heads_ = VisitExpr(main_func->body); std::ostringstream os; + dmlc::JSONWriter writer(&os); GetJSON(&writer); LoweredOutput ret; @@ -292,17 +269,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(param_storage_ids_[param.first]), param.second))); } - - for (auto& kv : lowered_funcs_) { - if (ret.lowered_funcs.count(kv.first) == 0) { - ret.lowered_funcs.Set(kv.first, IRModule(Map({}))); - } - auto& mod = ret.lowered_funcs[kv.first]; - mod->Update(kv.second); - ret.lowered_funcs.Set(kv.first, mod); - } - ret.external_mods = compile_engine_->LowerExternalFunctions(); ret.function_metadata = std::move(function_metadata_); + ret.lowered_funcs = lowered_module.per_target_module; + ret.external_mods = lowered_module.external_mods; return ret; } @@ -331,20 +300,18 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator AddNode(GraphObjectPtr node, Expr expr) { auto checked_type = expr->checked_type(); - size_t count = storage_device_map_.count(expr); - ICHECK_GT(count, 0) << "Expr is not existing in storage plan"; - auto storage_device_info = storage_device_map_[expr]; - ICHECK_EQ(storage_device_info.size(), 3); + + auto storage_info = GetStorageInfo(expr); // storage - std::vector storage_info; - for (auto& v : storage_device_info[0]) { - storage_info.push_back(v->value); + std::vector storage_ids; + for (auto v : storage_info->storage_ids) { + storage_ids.push_back(v); } - node->attrs_["storage_id"] = std::move(storage_info); + node->attrs_["storage_id"] = std::move(storage_ids); // type std::vector device_types; - for (auto& v : storage_device_info[1]) { - device_types.push_back(v->value); + for (auto v : storage_info->device_types) { + device_types.push_back(static_cast(v)); } size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0); if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) { @@ -404,7 +371,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorvalue; + param_storage_ids_[name] = GetStorageInfo(expr)->storage_ids[0]; params_[name] = op->data; return to_return; } @@ -420,8 +387,16 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator GraphAddCallNode(const CallNode* op, const std::string& op_name, - const std::string& func_name, GraphAttrs attrs) { + bool ShareSameStorage(const Expr& lhs, const Expr& rhs) { + StorageInfo lit = GetStorageInfo(lhs); + StorageInfo rit = GetStorageInfo(rhs); + int64_t lhs_storage_id = lit->storage_ids[0]; + int64_t rhs_storage_id = rit->storage_ids[0]; + return lhs_storage_id == rhs_storage_id; + } + + std::vector GraphAddCallNode(const CallNode* op, const std::string& func_name, + GraphAttrs attrs) { std::vector inputs; for (auto arg : op->args) { auto res = VisitExpr(arg); @@ -429,161 +404,52 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(op)); - } - bool ShareSameStorage(const Expr& lhs, const Expr& rhs) { - auto lit = storage_device_map_.find(lhs); - auto rit = storage_device_map_.find(rhs); - ICHECK(lit != storage_device_map_.end()); - ICHECK(rit != storage_device_map_.end()); - int64_t lhs_storage_id = ((*lit).second)[0][0]->value; - int64_t rhs_storage_id = ((*rit).second)[0][0]->value; - return lhs_storage_id == rhs_storage_id; - } + /// An adapted version of the storage optimization for the time being. + bool reshape_only = false; + if (op->attrs.defined()) { + if (auto tir_call_attrs = op->attrs.as()) { + Map metadata = tir_call_attrs->metadata; + if (metadata.count(attr::kReshapeOnly) && + Downcast(metadata[attr::kReshapeOnly])->value == 1) { + reshape_only = true; + } - /*! - * \brief Obtain the Target from the device type. - * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select associated using the targets_ Map. - * - * \param dev_type - * \return Target - */ - Target GetTargetFromInteger(int64_t dev_type) { - if (targets_.size() == 1) { - // homogeneous execution. - const auto& it = targets_.begin(); - return (*it).second; - } else { - // heterogeneous execution. - std::string call_dev_name; - if (dev_type == 0) { - call_dev_name = "llvm"; - } else { - call_dev_name = runtime::DeviceName(dev_type); - } - if (targets_.count(dev_type) == 0) { - LOG(FATAL) << "No target is provided for device " << call_dev_name; - } - return targets_[dev_type]; - } - } + auto relay_attrs = Downcast(tir_call_attrs->metadata["relay_attrs"]); - /*! - * \brief Update the function metadata for a given cached function and its relay - * primitive function. - * - * \param cfunc The cached function as provided the by the compile engine - * \param relay_func The source relay primitive function - * \param relay_target The target associated with relay primitive function - */ - void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func, - const Target& relay_target) { - auto fi_node = make_object(); - for (const auto& kv : cfunc->funcs->functions) { - auto primfunc = Downcast(kv.second); - auto workspace_byte_alignment = relay_target->GetAttr("workspace-byte-alignment") - .value_or(tvm::runtime::kDefaultWorkspaceAlignment); - Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); - Target primfunc_target = relay_target; - if (primfunc->attrs->dict.count("target")) { - primfunc_target = Downcast(primfunc->attrs->dict["target"]); - } - fi_node->workspace_sizes.Set(primfunc_target, workspace_size); - // Calculating size for I/O - for (auto const& param : primfunc->params) { - auto p_shape = primfunc->buffer_map[param]->shape; - int num_of_elements = 1; - for (const auto& dim_index_expr : p_shape) { - if (dim_index_expr->IsInstance()) { - num_of_elements *= dim_index_expr.as()->value; - } else { - // If shape is dynamic, we cannot calculate workspace in compile time. - num_of_elements = 0; + for (auto p : relay_attrs->dict) { + if (p.second.as()) { + attrs[p.first] = std::string(Downcast(p.second)); } } - int element_size = primfunc->buffer_map[param]->dtype.bytes(); - fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements); - } - fi_node->constant_sizes.Set(primfunc_target, 0); - fi_node->tir_primfuncs.Set(primfunc_target, primfunc); - fi_node->relay_primfuncs.Set(primfunc_target, relay_func); - } - function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node)); - } - - std::vector VisitExpr_(const CallNode* op) override { - Expr expr = GetRef(op); - Function func; - if (op->op.as()) { - LOG(FATAL) << "Operators should be transformed away; try applying" - << "the fuse_ops transformation to the expression."; - } else if (op->op.as()) { - LOG(FATAL) << "Not implemented"; - } else if (op->op.as()) { - func = GetRef(op->op.as()); - } else { - LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); - } - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - LOG(FATAL) << "TVM only support calls to primitive functions " - << "(i.e functions composed of fusable operator invocations)"; - } - - // Copy attrs from function into the graph node - // For now we only handle strings - GraphAttrs attrs; - for (auto p : func->attrs->dict) { - if (p.second.as()) { - attrs[p.first] = std::string(Downcast(p.second)); } } - auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); - auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); - Target target; - // Handle external function - if (func->GetAttr(attr::kCompiler).defined()) { - target = Target("ext_dev"); - CCacheKey key = (*pf0)(func, target); - CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_); - ICHECK(ext_func.defined()) << "External function is not defined."; - UpdateConstants(func, ¶ms_); - return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name, attrs); + if (reshape_only && ShareSameStorage(GetRef(op), op->args[0])) { + auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs); + return AddNode(node, GetRef(op)); } - // In the current flat memory allocation scenario - // the flat memory allocator can always allocate input - // and output of the reshape to the same memory, we can turn reshape only - // function to a nop. - // - // NOTE that for non-flat memory this is not necessarily true. - // - // TODO(tvm-team) Update checks of flat memory enablement when we support - // opaque-nd memory planning to skip this path. - if (func->HasNonzeroAttr(attr::kReshapeOnly) && ShareSameStorage(expr, op->args[0])) { - return GraphAddCallNode(op, "reshape_nop", "__nop", attrs); - } + // Compute the operator name, because we used the get unique name when generating the kernel. + auto op_name = _GetUniqueName(func_name); + auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs); + return AddNode(node, GetRef(op)); + } - ICHECK_GE(storage_device_map_.count(expr), 0); - auto& device_type = storage_device_map_[expr][1]; - auto call_dev_type = device_type[0]->value; - target = GetTargetFromInteger(call_dev_type); - // Normal Relay Function + std::vector VisitExpr_(const CallNode* call_node) override { + relay::Call call = GetRef(call_node); + if (auto global_node = call->op.as()) { + auto prim_fn_name = global_node->name_hint; - CCacheKey key = (*pf0)(func, target); - CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_); - if (!lowered_funcs_.count(target->str())) { - lowered_funcs_[target->str()] = IRModule(Map({})); + return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs()); + } else { + ICHECK(false) << "Non-primitive-call nodes should have been transformed away.\n" + << "The graph executor code generator expects all calls to have their callee " + "normalized to a GlobalVar but found a " + << call->GetTypeKey() << "." + << "AST: " << PrettyPrint(call) << PrettyPrint(call) << std::endl; + return {}; } - lowered_funcs_[target->str()]->Update(lowered_func->funcs); - - // Update function metadata via looking at all primfuncs - UpdateFunctionMetadata(lowered_func, func, target); - return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name, - attrs); } std::vector VisitExpr_(const LetNode* op) override { @@ -714,7 +580,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> var_map_; /*! \brief target device */ - TargetsMap targets_; + TargetMap targets_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). * These are take as inputs to the GraphExecutor. @@ -724,7 +590,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator params_; std::unordered_map param_storage_ids_; /*! \brief plan memory of device result */ - Map> storage_device_map_; + StaticMemoryPlan memory_plan_; /*! \brief the module name we use to mangle the function names */ String mod_name_; /*! \brief lowered funcs */ @@ -733,8 +599,6 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator function_metadata_; /*! \brief name map */ std::unordered_map name_map_; - /*! \brief compile engine */ - CompileEngine compile_engine_; }; class GraphExecutorCodegenModule : public runtime::ModuleNode { @@ -747,11 +611,11 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { << "runtime::Module mod and Map targets"; void* mod = args[0]; Map tmp = args[1]; - TargetsMap targets; + TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); ICHECK(dev_type); - targets[dev_type->value] = it.second; + targets[static_cast(dev_type->value)] = it.second; } codegen_ = std::make_shared(reinterpret_cast(mod), targets); diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 351469d6e1ca..93c823d8a007 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -23,15 +23,20 @@ * the program in the graph executor. */ #include +#include #include #include +#include #include #include "../../support/arena.h" +#include "./utils.h" namespace tvm { namespace relay { +using backend::StaticMemoryPlan; +using backend::StorageInfo; using IntegerArray = Array; struct StorageToken { @@ -48,6 +53,18 @@ struct StorageToken { int64_t storage_id{-1}; }; +std::ostream& operator<<(std::ostream& os, StorageToken tok) { + return os << "StorageToken: " << std::endl + << "ref_counter: " << tok.ref_counter << std::endl + << "max_bytes: " << tok.max_bytes << std::endl + << "tttype: " << tok.ttype + << std::endl + // ok idk how to print this properly + << "tttype shape: " << tok.ttype->shape << std::endl + << "device_type: " << tok.device_type << std::endl + << "storage_id: " << tok.storage_id << std::endl; +} + class StorageAllocaBaseVisitor : public ExprVisitor { public: // run the visitor on a function. @@ -114,7 +131,8 @@ class StorageAllocaBaseVisitor : public ExprVisitor { const std::vector& GetToken(const Expr& expr) { this->VisitExpr(expr); auto it = token_map_.find(expr.operator->()); - ICHECK(it != token_map_.end()); + ICHECK(it != token_map_.end()) + << "Expression: `" << PrettyPrint(expr) << "` not found in storage map."; return it->second; } /*! @@ -168,6 +186,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { void VisitExpr_(const CallNode* op) final { // create token for the call node. CreateToken(op, true); + // for each input, visit argument token. for (Expr arg : op->args) { for (StorageToken* tok : GetToken(arg)) { @@ -196,31 +215,32 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } // Run storage allocation for a function. - Map > Plan(const Function& func) { + StaticMemoryPlan Plan(const Function& func) { prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func); this->Run(func); // The value of smap contains two integer arrays where the first array // contains the planned storage ids and the second holds the device types. - Map > smap; + Map smap; int num_annotated_nodes = 0; int num_nodes = 0; for (const auto& kv : token_map_) { - std::vector storage_ids; - std::vector device_types; - std::vector sid_sizes_byte; + std::vector storage_ids; + std::vector device_types; + std::vector sid_sizes_byte; + for (StorageToken* tok : kv.second) { if (tok->device_type) { num_annotated_nodes++; } num_nodes++; storage_ids.push_back(tok->storage_id); - device_types.push_back(tok->device_type); + device_types.push_back(static_cast(tok->device_type)); sid_sizes_byte.push_back(GetMemorySize(tok)); } - smap.Set(GetRef(kv.first), - Array({storage_ids, device_types, sid_sizes_byte})); + auto storage_info = backend::StorageInfo(storage_ids, device_types, sid_sizes_byte); + smap.Set(GetRef(kv.first), storage_info); } // Either all or none of the nodes should be annotated. if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { @@ -228,7 +248,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { << "expressions are assigned with virtual device types. Either all " "or none of the expressions are expected to be annotated."; } - return smap; + + return backend::StaticMemoryPlan(smap); } protected: @@ -279,6 +300,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { args.push_back(tok); } } + // Under the flat-memory setting. // we can force aliasing the input and output of reshape // to make it an nop. Note that this is not true @@ -288,12 +310,17 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // TODO(tvm-team) Update checks of flat memory enablement when we support // opaque-nd memory planning to skip this path. if (IsReshape(op)) { + // TODO(@electriclilies, jroesch): This check is failing because the size of args is 3 + // I can't figure out where the extra args are coming from, I assume it must be related + // to the relay_attrs field we added to the TIRCallArgs, but I don't know where / how + // that's happening... ICHECK_EQ(args.size(), 1U); ReuseInputToken(op, args[0]); } else { // create token for the call node. CreateToken(op, true); } + // check if there is orphaned output that can be released immediately. for (StorageToken* tok : token_map_.at(op)) { CheckForRelease(tok); @@ -320,6 +347,15 @@ class StorageAllocator : public StorageAllocaBaseVisitor { if (const auto* fn = call->op.as()) { return fn->HasNonzeroAttr(attr::kReshapeOnly); } + + if (call->attrs.defined()) { + if (auto tir_call_attrs = call->attrs.as()) { + Map metadata = tir_call_attrs->metadata; + return metadata.count(attr::kReshapeOnly) && + (Downcast(metadata[attr::kReshapeOnly])->value == 1); + } + } + return false; } /*! @@ -419,9 +455,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { std::unordered_map > prototype_; }; -Map > GraphPlanMemory(const Function& func) { - return StorageAllocator().Plan(func); -} +StaticMemoryPlan GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); } TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index eeba010dc164..53985c78a33c 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -32,6 +32,7 @@ #include #include +#include "../transforms/pass_utils.h" #include "compile_engine.h" namespace tvm { @@ -381,7 +382,7 @@ class Interpreter : public ExprFunctor, } else { m = build(cfunc->funcs, cfunc->target, Target(nullptr)); } - shape_func = m.GetFunction(cfunc->func_name); + shape_func = m.GetFunction(cfunc->prim_fn_var->name_hint); shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); // Get output shapes diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc new file mode 100644 index 000000000000..93b9c6fc1827 --- /dev/null +++ b/src/relay/backend/te_compiler.cc @@ -0,0 +1,743 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "te_compiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../transforms/pass_utils.h" +#include "te_compiler.h" +#include "te_compiler_cache.h" +#include "utils.h" + +namespace tvm { +namespace relay { +// TODO(@jroesch, @csullivan): declare directly elsewhere +backend::StaticMemoryPlan GraphPlanMemory(const Function& func); + +namespace tec { + +using namespace tvm::relay::transform; + +TVM_REGISTER_OBJECT_TYPE(TECompilerNode); + +class TECompilerImpl : public TECompilerNode { + public: + // Lower the function. + CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) { + return LowerInternal(key, mangle_fn)->cached_func; + } + + CachedFunc Lower(const CCacheKey& key, const String mod_name) { + auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; + + return Lower(key, mangle_fn); + } + + // For now, build one module per function. + PackedFunc JIT(const CCacheKey& key) final { + auto mangle_fn = [](String name) { return name; }; + CCacheValue value = LowerInternal(key, mangle_fn); + if (value->packed_func != nullptr) { + return value->packed_func; + } + auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); + value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); + return value->packed_func; + } + + CachedFunc LowerShapeFunc(const CCacheKey& key) final { + return LowerShapeFuncInternal(key)->cached_func; + } + + Map GetLoweredFunctions() { + Map lowered_functions; + for (const auto& it : cache_) { + auto source_func = it.first; + auto lowered_func = it.second; + auto target = source_func->target; + + if (!lowered_functions.count(target->str())) { + lowered_functions.Set(target->str(), IRModule(Map({}))); + } + + lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + } + return lowered_functions; + } + + Array LowerExternalFunctions() { + Array ret; + std::unordered_map cached_symbol; + std::vector cached_ext_funcs; + for (const auto& it : cache_) { + auto src_func = it.first->source_func; + ICHECK(src_func.defined()); + if (src_func->GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->GetAttr(attr::kCompiler); + std::string code_gen_name = code_gen.value(); + cached_ext_funcs.push_back(it.first); + + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" + << AsText(src_func, false); + + std::string sn = symbol_name.value(); + if (cached_symbol.count(sn)) { + cached_symbol[sn] = code_gen_name; + } else { + ICHECK_NE(sn, code_gen_name) + << "Found duplicated symbol: " << sn << " for: " << code_gen_name; + } + + std::string ext_name = "relay.ext." + code_gen_name; + auto pf = tvm::runtime::Registry::Get(ext_name); + ICHECK(pf) << "Failed to find the codegen tool for " << ext_name; + // No need to keep compiler attribute at this point, functions have been + // extracted for specific codegen. + src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); + runtime::Module ext_mod = (*pf)(src_func); + + ICHECK(ext_mod.defined()) << "No external runtime is generated."; + ret.push_back(ext_mod); + } + } + + // No need to cache external functions as we collected them all to create + // external runtime modules. + for (const auto& it : cached_ext_funcs) { + cache_.erase(it); + } + return ret; + } + + void Clear() final { cache_.clear(); } + + // List all items in the cache. + Array ListItems() { + std::lock_guard lock(mutex_); + Array items; + for (auto& kv : cache_) { + items.push_back(kv.first); + items.push_back(kv.second); + } + return items; + } + + /*! + * \brief Get the cache key of the function that is being lowered currently + * \return the cache key + */ + CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } + + private: + // implement lowered func + CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { + std::lock_guard lock(mutex_); + CCacheValue value; + auto it = cache_.find(key); + if (it != cache_.end()) { + it->second->use_count += 1; + if (it->second->cached_func.defined()) return it->second; + value = it->second; + } else { + value = CCacheValue(make_object()); + value->use_count = 1; + cache_[key] = value; + } + cur_ccache_key_ = key; + + // No need to lower external functions for now. We will invoke the external + // codegen tool once and lower all functions together. + if (key->source_func->GetAttr(attr::kCompiler).defined()) { + auto ir_module = IRModule(); + const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(name_node.defined()) << "External function has not been attached a name yet."; + auto func_name = GetUniqueName(name_node.value(), &name_map_); + auto target = Target("ext_dev"); + auto global_var = GlobalVar(func_name); + global_var->checked_type_ = key->source_func->checked_type(); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); + return value; + } + + // Enforce use the target. + With target_scope(key->target); + + ICHECK(!value->cached_func.defined()); + auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { + auto mangled = mangle_fn(name); + return GetUniqueName(mangled, &name_map_); + }); + + // Skip lowering for device copy node. + const Expr body = (key->source_func)->body; + if (const CallNode* call_node = body.as()) { + if (call_node->attrs.as()) { + value->cached_func = cfunc; + return value; + } + } + + // NOTE: array will copy on write. + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { + all_args.push_back(arg); + } + + std::unordered_map binds; + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); + value->cached_func = cfunc; + return value; + } + + // implement lowered shape func + CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { + std::lock_guard lock(mutex_); + CCacheValue value; + auto it = shape_func_cache_.find(key); + if (it != shape_func_cache_.end()) { + it->second->use_count += 1; + if (it->second->cached_func.defined()) return it->second; + value = it->second; + } else { + value = CCacheValue(make_object()); + value->use_count = 0; + shape_func_cache_[key] = value; + } + // Enforce use the target. + With target_scope(key->target); + + ICHECK(!value->cached_func.defined()); + + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(name, &name_map_); + }); + + value->cached_func = cached_func; + return value; + } + + std::unordered_map GetOpWeights() { + std::unordered_map weights; + for (auto pair : cache_) { + auto value = pair.second; + auto name = value->cached_func->prim_fn_var->name_hint; + weights[name] = value->use_count; + } + return weights; + } + + /*! \brief compiler cache lock*/ + std::mutex mutex_; + /*! \brief internal name map to get an unique name */ + std::unordered_map name_map_; + /*! \brief internal compiler cache */ + std::unordered_map cache_; + /*! \brief internal compiler cache for shape funcs */ + std::unordered_map shape_func_cache_; + /*! \brief the cache key of the function that is being lowered currently*/ + CCacheKey cur_ccache_key_; +}; + +TECompiler::TECompiler() { + auto object = make_object(); + data_ = object; +} + +using AnalysisRemapping = std::unordered_map; + +std::tuple IsDeviceCopy(const Function& func) { + if (auto call_node = func->body.as()) { + if (auto op_node = call_node->op.as()) { + if (op_node->name == "device_copy") { + auto attrs = call_node->attrs.as(); + auto dst = attrs->dst_dev_type; + auto src = attrs->src_dev_type; + return std::tuple(true, src, dst); + } + } + } + + return std::tuple(false, -1, -1); +} + +class LowerTensorExpr : public ExprMutator { + public: + LowerTensorExpr(const IRModule& module, const TargetMap& targets, const DeviceMap& device_ctx_map, + ProcessFn process_fn, const String& module_name, TECompiler compiler) + : module_(module), + targets_(targets), + device_context_map_(device_ctx_map), + process_fn(process_fn), + module_name_(module_name), + compiler_(compiler) {} + + Expr VisitExpr_(const CallNode* call) override { + Call expr = GetRef(call); + Function func; + + if (call->op.as()) { + func = GetRef(call->op.as()); + } else { + return ExprMutator::VisitExpr_(call); + } + + if (!func->HasNonzeroAttr(attr::kPrimitive)) { + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn(func); + return ExprMutator::VisitExpr_(call); + } + + // Process inputs. + Array args; + for (size_t i = 0; i < expr->args.size(); i++) { + args.push_back(VisitExpr(expr->args[i])); + } + + Target target; + + if (func->GetAttr(attr::kCompiler).defined()) { + target = Target("ext_dev"); + CCacheKey key = CCacheKey(func, target); + CachedFunc ext_func = compiler_->Lower(key, module_name_); + ICHECK(ext_func.defined()) << "Lowering returned undefined function for " + << ext_func->prim_fn_var->name_hint; + + Map prim_fns; + + for (auto prim_fn : ext_func->funcs->functions) { + CHECK(prim_fn.second.as()) << "must be a prim fn"; + prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); + } + + relay::Function func_with_metadata = func; + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, "target", ext_func->target); + + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn(func_with_metadata); + + auto ret_call = Call(ext_func->prim_fn_var, args, {}); + return std::move(ret_call); + } + + ICHECK_GE(device_context_map_.count(expr), 0) + << "Could not find an entry in the device context map for " << PrettyPrint(expr) + << "The memory planning was either not performed for this precise node, or there is bug " + "in the memory planner."; + + auto& device_context = this->device_context_map_[expr]; + auto call_dev_type = device_context.device_type; + + // Non-External Relay Function + if (targets_.size() == 1) { + // The homogeneous execution case, we should only have one target + // so we just grab it. + const auto& it = targets_.begin(); + target = (*it).second; + } else { + // The heterogeneous execution case we have multiple targets + // in this case. + // + // We need to identify the target and translate. + std::string call_dev_name; + if (call_dev_type == 0) { + call_dev_name = "llvm"; + call_dev_type = kDLCPU; + } else { + call_dev_name = ::tvm::runtime::DeviceName(call_dev_type); + } + + if (targets_.count(call_dev_type) == 0) { + std::stringstream msg; + msg << "No target is specified for provided device name: `" << call_dev_name << "`\n\n"; + msg << call_dev_name << " mapped to device type (" << call_dev_type + << ") which was not found in the target map.\n"; + msg << "Availible targets: \n"; + for (auto target : targets_) { + msg << " " << target.first << "-> " << target.second << "\n"; + } + LOG(FATAL) << msg.str(); + } + + target = targets_[call_dev_type]; + } + + CCacheKey key = CCacheKey(func, target); + CachedFunc lowered_func = compiler_->Lower(key, module_name_); + + Map prim_fns; + + for (auto prim_fn : lowered_func->funcs->functions) { + CHECK(prim_fn.second.as()) << "must be a prim fn"; + prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); + } + + // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT + relay::Function func_with_metadata = func; + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, "target", lowered_func->target); + + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn(func_with_metadata); + + auto tir_call_attrs = make_object(); + if (func->HasNonzeroAttr(attr::kReshapeOnly)) { + tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); + } + + auto device_copy = IsDeviceCopy(func); + if (std::get<0>(device_copy)) { + auto source_device = std::get<1>(device_copy); + auto dst_device = std::get<2>(device_copy); + tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device)); + tir_call_attrs->metadata.Set("dst_device", tvm::Integer(dst_device)); + } + + tir_call_attrs->metadata.Set("relay_attrs", func->attrs); + + Expr ret_call = Call(lowered_func->prim_fn_var, args, Attrs(tir_call_attrs)); + return ret_call; + } + + IRModule module_; + TargetMap targets_; + DeviceMap device_context_map_; + ProcessFn process_fn; + String module_name_; + TECompiler compiler_; +}; + +/*! + * \brief Obtain the Target from the device type. + * If homogenous compilation, this will return the only target. + * If heteregenous compilation, this will select associated using the targets_ Map. + * + * \param dev_type + * \return Target + */ +Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { + if (targets.size() == 1) { + // homogeneous execution. + const auto& it = targets.begin(); + return (*it).second; + } else { + // heterogeneous execution. + std::string call_dev_name; + if (dev_type == 0) { + call_dev_name = "llvm"; + } else { + call_dev_name = runtime::DeviceName(dev_type); + } + if (targets.count(dev_type) == 0) { + LOG(FATAL) << "No target is provided for device " << call_dev_name; + } + return targets[dev_type]; + } +} + +/*! + * \brief Update the "main" control function's metadata + * + * \param mod The module + * \param targets Map of targets + * \return function_infos Function info for each function in the module + */ + +backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap targets, + Map storage_info_map) { + CHECK_EQ(mod->functions.size(), 1) + << "There should only be one function in the module passed to UpdateMainWorkspaceSize"; + Function func = Downcast(mod->Lookup("main")); + + // This is a Map> + std::unordered_map, EnumClassHash> sid_workspace; + // This is a Map + std::unordered_map device_io; + // This is a Map + std::unordered_map device_consts; + + // Initialize the mapping from all storage identifiers to workspace sizes, + // the amount of device io, and the device constants. + for (const auto& kv : storage_info_map) { + backend::StorageInfo storage_info = kv.second; + std::vector storage_ids = storage_info->storage_ids; + std::vector devices = storage_info->device_types; + + CHECK_EQ(storage_ids.size(), devices.size()); + for (uint32_t i = 0; i < devices.size(); i++) { + sid_workspace[devices[i]][storage_ids[i]] = 0; + device_io[devices[i]] = 0; + device_consts[devices[i]] = 0; + } + } + + // Iterate the storage map to compute all the tensor sizes in the program. + // There are 3 cases in this code: + // + // First we need to compute the sizes of all + // inline constants. + // + // Second we compute the size of any bound variable as these are input and output + // sizes of the program. + // + // Finally for all other expressions we check which storage identifier they have + // been assigned and we compute the maximal size of the storage, as tensors can + // share storage with other tensors which are the same size or larger. + // + // In this final case there is only one allocation for all tensors which share storage + // which will be the maximal size of all tensors which were assigned to it. + for (const auto& kv : storage_info_map) { + Expr expr = kv.first; + int64_t size_bytes = backend::CalculateRelayExprSizeBytes(expr->checked_type()); + backend::StorageInfo storage_info = kv.second; + std::vector storage_ids = storage_info->storage_ids; + std::vector devices = storage_info->device_types; + + if (expr->IsInstance()) { + for (const auto& dev : devices) { + device_consts[dev] += size_bytes; + } + continue; + } else if (expr->IsInstance() || expr.same_as(func->body)) { + CHECK_GE(devices.size(), 1) << "must be at least one device"; + for (const auto& dev : devices) { + device_io[dev] += size_bytes; + } + continue; + } + + // TODO(@electriclilies): This code is never being called which means sid_workspace is not + // updated.. This means that storage info is probably not being created correctly. Or is not + // equivalent to what was here previously + for (uint32_t i = 0; i < storage_ids.size(); i++) { + // Here we record the largest size of the tensor + // that share the same storage id, because storage_id will + // be shared between multiple tensors that are not live simultaneously. + if (size_bytes > sid_workspace[devices[i]][storage_ids[i]]) { + sid_workspace[devices[i]][storage_ids[i]] = size_bytes; + } + } + } + + // This is a Map + std::unordered_map device_workspace; + // Once we know the sizes of sids, we need to accumulate per device + for (const auto& dev_sid_size : sid_workspace) { + auto dev = dev_sid_size.first; + device_workspace[dev] = 0; + for (const auto& sid_size : dev_sid_size.second) { + device_workspace[dev] += sid_size.second; + } + } + + Map workspace_sizes; + Map io_sizes; + Map constant_sizes; + Map tir_primfuncs; + Map relay_primfuncs; + + // Initialize all target workspaces to zero + for (const auto& kv : targets) { + auto tgt = kv.second; + workspace_sizes.Set(tgt, 0); + } + + for (const auto& dev_and_size : device_workspace) { + auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + workspace_sizes.Set(tgt, dev_and_size.second); + relay_primfuncs.Set(tgt, func); + } + for (const auto& dev_and_size : device_io) { + auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + io_sizes.Set(tgt, dev_and_size.second); + } + + for (const auto& dev_and_size : device_consts) { + auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + constant_sizes.Set(tgt, dev_and_size.second); + } + + return backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, tir_primfuncs, + relay_primfuncs); +} + +// TODO(@electriclilies): Is the function passed in here relay_func?? +// Also should this be inlined? +/*! + * \brief A function to create the function metadata for an input function (ie calculate buffer + * input/output sizes) + * \param relay_func The function to calculate function metadata for + * \param function_metadata The map that stores all the function metadatas + */ +void UpdateFunctionMetadata(Function relay_func, + Map& function_metadata) { // NOLINT(*) + // Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored + // there Now the goal is to take only one func because process_fn should be controlling the + // iteration However, to do the workspace calculations we need the primfuncs. So process_fn needs + // to either access the cached funcs or be directly passed primfuncs This is bad and ideally we + // don't want process_fn to look at primfuncs There's also the question now of what the function + // metadatas are and how they are used if we can do something else to replicate the behavior of + // the function metadatas that might be good (ie annotating functions or something). + Map workspace_sizes; + Map io_sizes; + Map constant_sizes; + Map tir_primfuncs; + Map relay_primfuncs; + + Optional> prim_fns = + relay_func->GetAttr>("prim_funcs"); + CHECK(prim_fns) << "primitive functions not set on Relay function by TECompiler."; + + Optional prim_fn_var = relay_func->GetAttr("prim_fn_var"); + CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler."; + + Optional relay_target = relay_func->GetAttr("target"); + CHECK(relay_target) << "target must be set on Relay functions by the TECompiler."; + + for (const auto& kv : prim_fns.value()) { + auto prim_fn = Downcast(kv.second); + CHECK(prim_fn.defined()) << "the primitive function must be defined"; + + auto workspace_byte_alignment = + relay_target.value()->GetAttr("workspace_byte_alignment").value_or(16); + + Integer workspace_size = CalculateWorkspaceBytes(prim_fn, workspace_byte_alignment); + + // Workspace sizes + Target prim_fn_target; + if (prim_fn->attrs->dict.count("target")) { + prim_fn_target = Downcast(prim_fn->attrs->dict["target"]); + } else { + prim_fn_target = relay_target.value(); + } + + workspace_sizes.Set(prim_fn_target, workspace_size); + + // Calculating size for I/O + for (auto const& param : prim_fn->params) { + auto p_shape = prim_fn->buffer_map[param]->shape; + int num_of_elements = 1; + for (const auto& dim_index_expr : p_shape) { + if (dim_index_expr->IsInstance()) { + num_of_elements *= dim_index_expr.as()->value; + } else { + // If shape is dynamic, we cannot calculate workspace in compile time. + num_of_elements = 0; + } + } + int element_size = prim_fn->buffer_map[param]->dtype.bytes(); + io_sizes.Set(prim_fn_target, element_size * num_of_elements); + } + + constant_sizes.Set(prim_fn_target, 0); + tir_primfuncs.Set(prim_fn_target, prim_fn); + relay_primfuncs.Set(prim_fn_target, relay_func); + } + + backend::FunctionInfo fi = backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, + tir_primfuncs, relay_primfuncs); + + // The primitive function name here corresponds to the string we will use to generate + // this Relay function at the low level. + function_metadata.Set(prim_fn_var.value()->name_hint, fi); +} + +LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn) { + TECompiler compiler; + + CHECK_EQ(module->functions.size(), 1) + << "There should only be one function in the module passed to LowerTE"; + + auto pass = CreateFunctionPass( + [=](Function func, IRModule module, PassContext ctx) { + LowerTensorExpr lower_te(module, targets, device_context_map, process_fn, module_name, + compiler); + return Downcast(lower_te.VisitExpr(func)); + }, + 0, "LowerTensorExpr", {}); + + // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize + backend::FunctionInfo func_info = + UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info); + + auto updated_module = pass(module); + + // A temporary solution until we can rewrite the auto-scheduler task extraction code to work + // in a more reasonable way. + if (backend::IsAutoSchedulerEnabled()) { + const auto* te_compiler_update_weights = + runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights"); + + ICHECK(te_compiler_update_weights != nullptr) + << "auto_scheduler.relay_integration.te_compiler_update_weights"; + + Map weight_map; + + for (auto pair : compiler->GetOpWeights()) { + weight_map.Set(pair.first, pair.second); + } + + (*te_compiler_update_weights)(weight_map); + } + + LoweredModule lowered_module; + lowered_module.main_module = updated_module; + lowered_module.per_target_module = compiler->GetLoweredFunctions(); + lowered_module.external_mods = compiler->LowerExternalFunctions(); + lowered_module.main_func_info = func_info; + return lowered_module; +} + +} // namespace tec +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h new file mode 100644 index 000000000000..a32eefb5127f --- /dev/null +++ b/src/relay/backend/te_compiler.h @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/tir_compiler.h + * * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. + * + * + * This represents the new design of the Relay compilation flow and will replace the interface + * contained in compile_engine.h as we migrate towards a standard pass based lowering of + * Relay functions. + * + * This files provides an internal API which lowers Relay programs to components which + * can be combined with TVM produced kernels to compile an entire program. + * + * The result of lowering contains a combination of `runtime::Module`s produced by external + * compilers and a set of lowered PrimFns which can be code generated for targets. + */ +#ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_ +#define TVM_RELAY_BACKEND_TE_COMPILER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../transforms/infer_layout_utils.h" +#include "../transforms/pass_utils.h" +#include "./te_compiler_cache.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +// This class is needed to avoid a GCC 5 bug that prevents maps containing enums +// from being compiled. If i386 GCC version is increased, we can remove it. +struct EnumClassHash { + template + std::size_t operator()(T t) const { + return static_cast(t); + } +}; + +// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake +// we should a version of context which works in Map +using TargetMap = std::unordered_map; +using DeviceMap = + std::unordered_map; +using ProcessFn = std::function; + +/*! + * \brief A compiler which lowers primitive Relay functions to tensor expressions + * and schdules them into TIR functions. + */ +class TECompilerNode : public Object { + public: + /*! \brief destructor */ + virtual ~TECompilerNode() {} + /*! + * \brief Get lowered result. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; + + /*! + * \brief Get lowered result. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; + + /* Return all functions which have been lowered by the compiler, keyed by target. */ + virtual Map GetLoweredFunctions() = 0; + + /*! + * \brief Just in time compile to get a PackedFunc. + * \param key The key to the cached function. + * \return The result. + */ + virtual PackedFunc JIT(const CCacheKey& key) = 0; + /*! + * \brief Lower the shape function. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; + /*! + * \brief Lower the external function using external codegen tools. + * \return The runtime moduels for each needed external codegen tool. + */ + virtual tvm::Array LowerExternalFunctions() = 0; + + virtual std::unordered_map GetOpWeights() = 0; + + /*! \brief clear the cache. */ + virtual void Clear() = 0; + + void VisitAttrs(AttrVisitor*) {} + + static constexpr const char* _type_key = "relay.TECompiler"; + TVM_DECLARE_FINAL_OBJECT_INFO(TECompilerNode, Object); +}; + +/*! \brief cache entry used in compile engine */ +class TECompiler : public ObjectRef { + public: + TECompiler(); + explicit TECompiler(ObjectPtr n) : ObjectRef(n) {} + TECompilerNode* operator->() { return static_cast(get_mutable()); } + using ContainerType = TECompilerNode; +}; + +/*! \brief The result of lowering a module, for now we need to pass an aggregate data structure + * which contains more then a single module in order to interact with the today API. + */ +struct LoweredModule { + /*! \brief The module which contains the Relay code. */ + IRModule main_module; + /*! \brief The module which contains per target code. */ + Map per_target_module; + /*! \brief The external runtime modules which must be combined with the lowered code. */ + Array external_mods; + // TODO(@electriclilies): THis might need to become a map + /*! \brief The info for this function (not sure what a better description is??) + * + */ + backend::FunctionInfo main_func_info; +}; + +/*! + * \brief A function to create the function metadata for an input function (ie calculate buffer + * input/output sizes) + * \param relay_func The function to calculate function metadata for + * \param function_metadata The map that stores all the function metadatas + */ +void UpdateFunctionMetadata(Function relay_func, + Map& function_metadata); // NOLINT(*) + +/*! + * \brief Obtain the Target from the device type. + * If homogenous compilation, this will return the only target. + * If heteregenous compilation, this will select associated using the targets_ Map. + * + * \param dev_type + * \return Target + */ +Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); + +/*! \brief Lower an IRModule's primitive functions to TIR. + * + * This is the "back half" of the Relay compiler which lowers "primitive functions" + * to TE expressions, schedules them, and then to TIR. + * + * /param module The IRModule. + * /param targets The mapping for devices to targets. + * /param device_map An analysis result mapping each sub-expression to a device. + * /return The lowered module, see above. + */ +// TODO(@electriclilies): Not sure if this default initialization is correct... +LoweredModule LowerTE( + const IRModule& module, TargetMap targets, DeviceMap device_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + ProcessFn process_fn = [](Function f) {}); + +} // namespace tec +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_TE_COMPILER_H_ diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc new file mode 100644 index 000000000000..bbe38f0426b4 --- /dev/null +++ b/src/relay/backend/te_compiler_cache.cc @@ -0,0 +1,694 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./te_compiler_cache.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../transforms/pass_utils.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +TVM_REGISTER_NODE_TYPE(LoweredOutputNode); +TVM_REGISTER_NODE_TYPE(CachedFuncNode); +TVM_REGISTER_NODE_TYPE(CCacheKeyNode); +TVM_REGISTER_NODE_TYPE(CCacheValueNode); + +LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { + auto n = make_object(); + n->outputs = std::move(outputs); + n->implementation = std::move(impl); + data_ = std::move(n); +} + +CCacheKey::CCacheKey(Function source_func, Target target) { + auto n = make_object(); + n->source_func = std::move(source_func); + n->target = std::move(target); + data_ = std::move(n); +} + +CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array inputs, + tvm::Array outputs, te::Schedule schedule, + tvm::Array shape_func_param_states, IRModule funcs) { + auto n = make_object(); + n->target = target; + n->prim_fn_var = prim_fn_var; + n->inputs = inputs; + n->outputs = outputs; + n->schedule = schedule; + n->shape_func_param_states = shape_func_param_states; + n->funcs = funcs; + data_ = std::move(n); +} + +Array GetShape(const Array& shape) { + // for now, we always use int32 shape when possible + // even if the result of shape inference becomes int64. + Array res; + for (IndexExpr val : shape) { + const int64_t* pval = tir::as_const_int(val); + if (pval != nullptr) { +#ifndef TVM_INDEX_DEFAULT_I64 + ICHECK_LE(pval[0], std::numeric_limits::max()) + << "dimension must be less then int32_t's max value"; + ICHECK_GE(pval[0], std::numeric_limits::min()) + << "dimension must be less then int32_t's max value"; + res.push_back(IntImm(DataType::Int(32), *pval)); +#else + res.push_back(val); +#endif // TVM_INDEX_DEFAULT_I64 + } else if (val->IsInstance()) { + res.push_back(val.as()->ToVar()); + } else { + res.push_back(val); + } + } + return res; +} + +// Construct a schedule for a given Relay primitive function and target. +class ScheduleBuilder : public backend::MemoizedExprTranslator> { + public: + explicit ScheduleBuilder(Target target) + : target_(target), device_copy_op_(Op::Get("device_copy")) { + // Whether to use auto_scheduler schedule. + use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + } + + CachedFunc Create(const Function& prim_func, std::function renamer) { + Array fn_inputs; + for (Var param : prim_func->params) { + Array inputs; + if (const auto* ttype = param->checked_type().as()) { + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + fn_inputs.push_back(tensor); + inputs.push_back(tensor); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as(); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + // TODO(@icemelon): Allow recursive tuple + ICHECK(ttype != nullptr); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + fn_inputs.push_back(tensor); + inputs.push_back(tensor); + } + } + memo_[param] = inputs; + } + readable_name_stream_ << "fused"; + auto outputs = this->VisitExpr(prim_func->body); + auto candidate_name = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + // NB(@jroesch): unfortunately the graph runtime deals with copy in + // a totally hacky way, we really need to rectify this but this will + // have to work for now. + std::string prim_fn_name = candidate_name; + if (prim_fn_name != "__copy") { + prim_fn_name = renamer(prim_fn_name); + } + auto prim_fn_var = GlobalVar(prim_fn_name); + prim_fn_var->checked_type_ = prim_func->checked_type(); + + ICHECK(anchor_op_.defined()); + // Fusion over tupled results may leave identity relationships + // between inputs and outputs, and those should not be scheduled. + // Hence schedule only non PlaceholderOp outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + + te::Schedule schedule; + // No need to register schedule for device copy op. + if (anchor_attrs_.as() == nullptr) { + if (use_auto_scheduler_) { + const auto* fauto_schedule = + runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); + ICHECK(fauto_schedule != nullptr) + << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; + ObjectRef obj = (*fauto_schedule)(prim_fn_name, tensor_outs); + if (obj.defined()) { + schedule = Downcast(obj); + } + } + + // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. + if (!schedule.defined()) { + ICHECK(anchor_implementation_.defined()); + schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + } + for (const auto& scalar : scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } + } + } + + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {}); + } + + Array VisitExpr_(const VarNode* op) final { + LOG(FATAL) << "Unexpected free variable " << op->name_hint(); + return {}; + } + + Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; + ICHECK(op->is_scalar()); + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); + scalars_.push_back(value->op); + return {value}; + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto fpattern = Op::GetAttrMap("TOpPattern"); + static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); + ICHECK(flower_call) << "relay.backend.lower_call is not registered."; + + Array inputs; + int count_tuple = 0; + for (Expr arg : call_node->args) { + if (arg->checked_type().as()) { + ++count_tuple; + } + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + } + + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) + << "Only functions with a single tuple input are allowed, but " << count_tuple + << " were provided."; + } + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + Array outputs; + OpImplementation impl; + // Skip fcompute for device copy operators as it is not registered. + if (op == device_copy_op_) { + const auto* copy_input = inputs[0].operator->(); + outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); + } else { + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); + outputs = lowered_out->outputs; + impl = lowered_out->implementation; + } + + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + anchor_implementation_ = impl; + } + if (outputs.size() != 1) { + const auto* tuple_type = call_node->checked_type().as(); + ICHECK(tuple_type) << "Expected output to be a tuple type " + << PrettyPrint(call_node->checked_type()); + + ICHECK_EQ(tuple_type->fields.size(), outputs.size()); + } + // Set the name to `__copy`. It will be detected in graph runtime to perform + // data copy across devices. + if (op == device_copy_op_) { + readable_name_stream_.str(std::string()); + readable_name_stream_ << "__copy"; + } else { + readable_name_stream_ << '_' << op->name; + } + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Primitive Functions can not contain nested functions."; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + const auto* tuple_type = op->tuple->type_as(); + Array tuple = VisitExpr(op->tuple); + ICHECK_EQ(tuple_type->fields.size(), tuple.size()); + ICHECK_GE(op->index, 0); + ICHECK_LT(static_cast(op->index), tuple.size()); + return {tuple[op->index]}; + } + + private: + tvm::Target target_; + Op anchor_op_; + Attrs anchor_attrs_; + int anchor_op_pattern_{0}; + OpImplementation anchor_implementation_; + std::ostringstream readable_name_stream_; + Array scalars_; + bool use_auto_scheduler_; + // Cache device copy op for equivalence checking to reduce registry lookup + // overhead for each invocation of call node when retrieving schedules. + const Op& device_copy_op_; +}; + +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc PrimFuncFor(const Function& source_func, const Target& target, + std::function renamer) { + return ScheduleBuilder(target).Create(source_func, renamer); +} + +// Creates shape function from functor. +class MakeShapeFunc : public backend::MemoizedExprTranslator> { + public: + MakeShapeFunc() {} + + CachedFunc Create(const Function& prim_func, const Target& target, + std::function renamer) { + Array inputs; + TShapeDataDependent shape_func_param_states; + + for (auto param : prim_func->params) { + param_states_[param] = kNoNeed; + Array data_inputs; + Array shape_inputs; + + auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { + // Add data placeholder + Shape shape = GetShape(ttype->shape); + tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); + data_inputs.push_back(data_tensor); + // Add shape placeholder + int64_t ndim = shape.size(); + Shape sshape; + if (ndim > 0) { + sshape.push_back(tvm::Integer(ndim)); + } + tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); + shape_inputs.push_back(shape_tensor); + }; + + if (const auto* ttype = param->checked_type().as()) { + add_placeholder(ttype); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as(); + // TODO(@icemelon): Support recursive tuple + ICHECK(tuple_type); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + ICHECK(ttype); + add_placeholder(ttype); + } + } + param_data_[param] = data_inputs; + param_shapes_[param] = shape_inputs; + } + + // Setup the name; + readable_name_stream_ << "shape_func"; + + // Create the `te::Tensor`s which represent the output. + auto outputs = VisitExpr(prim_func->body); + + // Generate a name. + auto candidate_name = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + // Set all the inputs correctly. + for (auto param : prim_func->params) { + int state = param_states_[param]; + shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); + if (state & kNeedInputData) { + for (auto t : param_data_[param]) { + inputs.push_back(t); + } + } + if (state & kNeedInputShape) { + for (auto t : param_shapes_[param]) { + inputs.push_back(t); + } + } + } + + auto func_name = renamer(candidate_name); + auto prim_fn_gvar = GlobalVar(func_name); + prim_fn_gvar->checked_type_ = prim_func->checked_type(); + + // generate schedule for shape func + Array out_ops; + for (auto t : outputs) { + out_ops.push_back(t->op); + } + auto schedule = te::create_schedule(out_ops); + tvm::te::AutoInlineInjective(schedule); + for (const auto& scalar : scalars_) { + auto scalar_op = scalar->op; + if (schedule->Contain(scalar_op)) { + schedule[scalar_op].compute_inline(); + } + } + + Array all_args = Array(inputs); + for (te::Tensor arg : outputs) { + all_args.push_back(arg); + } + + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + + std::unordered_map binds; + IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, binds); + + return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, shape_func_param_states, + ir_module); + } + + Array VisitExpr(const Expr& expr) final { + if (expr.as()) { + // Do not memoize vars because shape functions could use either the data + // or the shape of a var each time. + return ExprFunctor::VisitExpr(expr); + } + // For other case, do memoized visit + return backend::MemoizedExprTranslator>::VisitExpr(expr); + } + + Array VisitExpr_(const VarNode* var_node) final { + auto var = GetRef(var_node); + auto it = param_states_.find(var); + if (it == param_states_.end()) { + LOG(FATAL) << "Unexpected free variable " << var->name_hint(); + return {}; + } else { + ICHECK(data_dependents_per_input_.size()); + auto data_dependent = data_dependents_per_input_.back(); + if (data_dependent) { + param_states_[var] |= kNeedInputData; + return param_data_[var]; + } else { + param_states_[var] |= kNeedInputShape; + return param_shapes_[var]; + } + } + } + + Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; + ICHECK(data_dependents_per_input_.size()); + bool data_dependent = data_dependents_per_input_.back(); + if (!op->is_scalar()) { + // This is a constant weight, extract the shape of the weight tensor. + // This can not be data dependent. + CHECK(!data_dependent); + auto ttype = op->checked_type().as(); + int ndim = static_cast(ttype->shape.size()); + Array out_shape{ndim}; + te::Tensor value = tvm::te::compute( + out_shape, + [&](const Array& indices) { + auto idx = indices[0]; + PrimExpr ret = make_const(DataType::Int(64), 0); + for (int i = 0; i < ndim; i++) { + ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); + } + return ret; + }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } + if (data_dependent) { + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = tvm::te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "data_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } else { + auto value = tvm::te::compute( + {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto fshape_func = Op::GetAttrMap("FShapeFunc"); + static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) + << "Error in op fusion: output of the shape func is fed to a " + << "data-dependent shape func"; + ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; + ICHECK_GT(tshape_data_dependent.count(op), 0) + << "Internal error, cannot find TShapeDataDependent for " << op->name; + + Array dep_spec = tshape_data_dependent[op]; + if (dep_spec.size() == 1) { + // This is for cases when data dependence is specified per op + // Replicate 0 or 1 flag to all arguments + for (size_t i = 1; i < call_node->args.size(); ++i) { + dep_spec.push_back(dep_spec[0]); + } + } + + // Visit all inputs + Array inputs; + int count_tuple = 0; + for (size_t i = 0; i < call_node->args.size(); ++i) { + Expr arg = call_node->args[i]; + if (arg->checked_type().as()) { + ++count_tuple; + } + data_dependents_per_input_.push_back(dep_spec[i]->value != 0); + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + data_dependents_per_input_.pop_back(); + } + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; + } + // Get output ndims + auto ret_type = call_node->checked_type(); + Array out_ndims; + if (const auto* ttype = ret_type.as()) { + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); + } else { + auto rtype = ret_type.as(); + // TODO(@icemelon): Allow recursive tuple + ICHECK(rtype); + for (size_t i = 0; i < rtype->fields.size(); ++i) { + auto ttype = rtype->fields[i].as(); + ICHECK(ttype); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); + } + } + // Call shape function + auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); + readable_name_stream_ << "_" << op->name; + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Do not support sub function"; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) + << "Expected a Tuple of Tensor, but got " << PrettyPrint(field->checked_type()); + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + Array input_shapes = VisitExpr(op->tuple); + Array out; + out.push_back(input_shapes[op->index]); + return out; + } + + private: + /*! \brief String stream for function name */ + std::ostringstream readable_name_stream_; + /*! \brief Map from parameter to its shape function usage state */ + std::unordered_map param_states_; + /*! \brief Map from parameter to list of data placeholder */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; + /*! \brief Map from parameter to list of shape placeholder */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; + /*! \brief Stack of data dependencies for shape function, specified per each op input */ + std::vector data_dependents_per_input_; + /*! \brief Scalars used in the shape function */ + Array scalars_; +}; + +CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, + std::function renamer) { + return MakeShapeFunc().Create(prim_func, target, renamer); +} + +/*! + * \brief Get unique name from name. + * \param name The orginal name. + * \return Updated name which is unique. + */ +std::string GetUniqueName(std::string name, std::unordered_map* name_map_) { + for (size_t i = 0; i < name.length(); ++i) { + if (name[i] == '.') name[i] = '_'; + } + while (true) { + auto it = name_map_->find(name); + if (it == name_map_->end()) { + (*name_map_)[name] = 1; + return name; + } else { + std::ostringstream os; + os << name << "_" << it->second; + ++(it->second); + name = os.str(); + } + } + return name; +} + +} // namespace tec +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h new file mode 100644 index 000000000000..1c7511ffd7d2 --- /dev/null +++ b/src/relay/backend/te_compiler_cache.h @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/tec_compiler_cache.h + * \brief Utilities for compiling tensor expressions inside of the Relay compiler. + */ +#ifndef TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ +#define TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../transforms/infer_layout_utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */ +enum ShapeFuncParamState { + kNoNeed = 0, + kNeedInputData = 1, + kNeedInputShape = 2, + kNeedBoth = 3, +}; + +struct LoweredOutputNode : public Object { + /*! \brief The outputs to the function */ + tvm::Array outputs; + /*! \brief The implementation used to compute the output */ + OpImplementation implementation; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("outputs", &outputs); + v->Visit("implementation", &implementation); + } + + static constexpr const char* _type_key = "relay.LoweredOutput"; + TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); +}; + +class LoweredOutput : public ObjectRef { + public: + TVM_DLL LoweredOutput(tvm::Array outputs, OpImplementation impl); + + TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode); +}; + +class CCacheKey; +/*! \brief Compile cache key */ +class CCacheKeyNode : public Object { + public: + /*! \brief The source function to be lowered. */ + Function source_func; + /*! \brief The hardware target.*/ + Target target; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("source_func", &source_func); + v->Visit("target", &target); + } + /*! \return The hash value of CCacheKey. */ + inline size_t Hash() const; + /*! + * \brief check content equality + * \param other The other value. + * \return The result of equality check. + */ + inline bool Equal(const CCacheKeyNode* other) const; + + static constexpr const char* _type_key = "relay.CCacheKey"; + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); + + private: + /*! + * \brief internal cached hash value. + */ + mutable size_t hash_{0}; +}; + +/*! \brief cache entry used in compile engine */ +class CCacheKey : public ObjectRef { + public: + CCacheKey() {} + explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief The constructor + * \param source_func The source function. + * \param target The target device. + */ + TVM_DLL CCacheKey(Function source_func, Target target); + + const CCacheKeyNode* operator->() const { return static_cast(get()); } + // comparator + inline bool operator==(const CCacheKey& other) const { + ICHECK(defined() && other.defined()); + return (*this)->Equal(other.operator->()); + } + using ContainerType = CCacheKeyNode; +}; + +/*! \brief Node container to represent a cached function. */ +struct CachedFuncNode : public Object { + /* \brief compiled target */ + tvm::Target target; + /*! \brief Primitive Function Name */ + GlobalVar prim_fn_var; + /* \brief The inputs to the function */ + tvm::Array inputs; + /* \brief The outputs to the function */ + tvm::Array outputs; + /*! \brief The schedule to the function */ + te::Schedule schedule; + /*! \brief Parameter usage states in the shape function. */ + tvm::Array shape_func_param_states; + /*! \brief The lowered functions to support the function. */ + IRModule funcs = IRModule(Map({})); + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("target", &target); + v->Visit("prim_fn_var", &prim_fn_var); + v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); + v->Visit("schedule", &schedule); + v->Visit("funcs", &funcs); + v->Visit("shape_func_param_states", &shape_func_param_states); + } + + static constexpr const char* _type_key = "relay.CachedFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); +}; + +class CachedFunc : public ObjectRef { + public: + CachedFunc(tvm::Target target, GlobalVar prim_fn_name, tvm::Array inputs, + tvm::Array outputs, te::Schedule schedule, + tvm::Array shape_func_param_states, + IRModule funcs = IRModule(Map({}))); + + public: + TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); +}; + +/*! \brief Node container for compile cache. */ +class CCacheValueNode : public Object { + public: + /*! \brief The corresponding function */ + CachedFunc cached_func; + /*! \brief Result of Packed function generated by JIT */ + PackedFunc packed_func; + /*! \brief usage statistics */ + int use_count{0}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cached_func", &cached_func); + v->Visit("use_count", &use_count); + } + static constexpr const char* _type_key = "relay.CCacheValue"; + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); +}; + +/*! \brief cache entry used in compile engine */ +class CCacheValue : public ObjectRef { + public: + CCacheValue() {} + explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} + CCacheValueNode* operator->() { return static_cast(get_mutable()); } + const CCacheValueNode* operator->() const { return static_cast(get()); } + using ContainerType = CCacheValueNode; +}; + +Array GetShape(const Array& shape); + +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc PrimFuncFor(const Function& source_func, const Target& target, + std::function renamer); + +CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, + std::function renamer); + +std::string GetUniqueName(std::string name, std::unordered_map* name_map); + +// implementations +inline size_t CCacheKeyNode::Hash() const { + if (hash_ != 0) return hash_; + // do structral hash, avoid 0. + hash_ = tvm::StructuralHash()(this->source_func); + hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); + if (hash_ == 0) hash_ = 1; + return hash_; +} + +inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { + if (Hash() != other->Hash()) return false; + return this->target->str() == other->target->str() && + tvm::StructuralEqual()(this->source_func, other->source_func); +} + +} // namespace tec +} // namespace relay +} // namespace tvm + +namespace std { +// overload hash +template <> +struct hash<::tvm::relay::tec::CCacheKey> { + size_t operator()(const ::tvm::relay::tec::CCacheKey& key) const { + ICHECK(key.defined()); + return key->Hash(); + } +}; +} // namespace std + +#endif // TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 3ea15438fe8f..f0c543f1244b 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -39,6 +39,30 @@ StorageInfo::StorageInfo(std::vector storage_ids, std::vector ids; + for (auto id : si->storage_ids) { + ids.push_back(id); + } + return ids; +}); + +TVM_REGISTER_GLOBAL("relay.ir.StorageInfoDeviceTypes").set_body_typed([](StorageInfo si) { + Array device_types; + for (auto id : si->device_types) { + device_types.push_back(id); + } + return device_types; +}); + +TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageSizes").set_body_typed([](StorageInfo si) { + Array storage_sizes_in_bytes; + for (auto id : si->storage_sizes_in_bytes) { + storage_sizes_in_bytes.push_back(id); + } + return storage_sizes_in_bytes; +}); + TVM_REGISTER_NODE_TYPE(StaticMemoryPlanNode); StaticMemoryPlan::StaticMemoryPlan(Map expr_to_storage_info) { @@ -73,6 +97,29 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { TVM_REGISTER_NODE_TYPE(FunctionInfoNode); +FunctionInfo::FunctionInfo(Map workspace_sizes, Map io_sizes, + Map constant_sizes, + Map tir_primfuncs, + Map relay_primfuncs) { + ObjectPtr n = make_object(); + n->workspace_sizes = std::move(workspace_sizes); + n->io_sizes = std::move(io_sizes); + n->constant_sizes = std::move(constant_sizes); + n->tir_primfuncs = std::move(tir_primfuncs); + n->relay_primfuncs = std::move(relay_primfuncs); + data_ = std::move(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FunctionInfoNode(\n" + << "workspace_sizes=" << node->workspace_sizes << ",\n io_sizes=" << node->io_sizes + << ",\n constant_sizes=" << node->constant_sizes + << ",\n tir_primfuncs=" << node->tir_primfuncs + << ",\n relay_primfuncs=" << node->relay_primfuncs << ")"; + }); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 7d7f026c298e..d2a173a43f46 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -114,6 +114,10 @@ struct FunctionInfoNode : public Object { class FunctionInfo : public ObjectRef { public: + FunctionInfo(Map workspace_sizes, Map io_sizes, + Map constant_sizes, Map tir_primfuncs, + Map relay_primfuncs); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FunctionInfo, ObjectRef, FunctionInfoNode); }; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index c50f2f65f949..96aa77f286a9 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -978,7 +978,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe // update primitive function map size_t primitive_index = 0; for (const auto& cfunc : context_.cached_funcs) { - exec_->primitive_map.insert({cfunc->func_name, primitive_index++}); + exec_->primitive_map.insert({cfunc->prim_fn_var->name_hint, primitive_index++}); } } @@ -1173,8 +1173,9 @@ void VMCompiler::Codegen() { if (target->kind->device_type == kDLExtDev) { // Collect metadata in functions that are handled by external codegen. - ICHECK(mod->ContainGlobalVar(cfunc->func_name)); - Function func = Downcast(mod->Lookup(cfunc->func_name)); + auto name = cfunc->prim_fn_var->name_hint; + ICHECK(mod->ContainGlobalVar(name)); + Function func = Downcast(mod->Lookup(name)); backend::UpdateConstants(func, ¶ms_); } else if (funcs.count(target) == 0) { funcs.Set(target, mod); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index c9920a621b56..83ac55fce085 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -62,9 +62,17 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body - << ", " << node->type_params << ", " << node->attrs << ")"; + // TODO(@jroesch): previously this had a debug printer, the debug printer + // can cause exponential behavior and is currently dangerous, for these + // cases we need some kind of de-duping. + // + // See old implementation: + // + // auto* node = static_cast(ref.get()); + // p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << + // node->body + // << ", " << node->type_params << ", " << node->attrs << ")"; + p->stream << PrettyPrint(ref); }); } // namespace relay diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index da0bd35a332a..7a86af8aeffa 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -126,7 +126,7 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - CreateSchedule(GetRef(func), Target::Current()); + PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index e744fb51e0a6..02f9d474411a 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -53,7 +53,18 @@ bool IsOnDeviceNode(const ExprNode* node) { bool IsDeviceCopyNode(const ExprNode* node) { if (!node->IsInstance()) return false; const auto* call_node = static_cast(node); - return call_node->attrs.as(); + + if (call_node->attrs.as()) { + return true; + } + + auto tir_call_attrs = call_node->attrs.as(); + if (tir_call_attrs) { + auto metadata = tir_call_attrs->metadata; + return metadata.count("source_device") == 1 && metadata.count("dst_device") == 1; + } + + return false; } } // namespace @@ -395,16 +406,31 @@ class DeviceInfo { const auto* call_node = static_cast(node); auto attrs = call_node->attrs.as(); - num_device_copy_ops_++; - dev_type_ = attrs->src_dev_type; - for (auto& arg : call->args) { - Visit(arg); - // restore the type for remaining arguments + if (attrs) { + num_device_copy_ops_++; dev_type_ = attrs->src_dev_type; + for (auto& arg : call->args) { + Visit(arg); + // restore the type for remaining arguments + dev_type_ = attrs->src_dev_type; + } + device_tag_[call] = attrs->dst_dev_type; + // update the out_dev_type_, which should be the dst_dev_type of last copy + out_dev_type_ = attrs->dst_dev_type; + } else { + auto attrs = call_node->attrs.as(); + CHECK(attrs) << "must be non-null"; + num_device_copy_ops_++; + dev_type_ = Downcast(attrs->metadata["source_device"]); + for (auto& arg : call->args) { + Visit(arg); + // restore the type for remaining arguments + dev_type_ = Downcast(attrs->metadata["source_device"]); + } + device_tag_[call] = Downcast(attrs->metadata["dst_device"]); + // update the out_dev_type_, which should be the dst_dev_type of last copy + out_dev_type_ = Downcast(attrs->metadata["dst_device"]); } - device_tag_[call] = attrs->dst_dev_type; - // update the out_dev_type_, which should be the dst_dev_type of last copy - out_dev_type_ = attrs->dst_dev_type; } else { for (auto& arg : call->args) { int cur_dev_type = dev_type_; diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 03473b7d7455..b61567d0bae0 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,7 @@ #include "../backend/compile_engine.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" +#include "./pass_utils.h" #include "let_list.h" #include "pattern_utils.h" @@ -66,9 +68,18 @@ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dt // Check if the primitive function contains only reshape ops. bool IsReshapeOnly(const Expr& expr) { - if (auto* func = expr.as()) { + if (const FunctionNode* func = expr.as()) { return func->HasNonzeroAttr(attr::kReshapeOnly); } + if (const CallNode* call = expr.as()) { + if (call->attrs.defined()) { + if (auto tir_call_attrs = call->attrs.as()) { + Map metadata = tir_call_attrs->metadata; + return metadata.count(attr::kReshapeOnly) && + (Downcast(metadata[attr::kReshapeOnly])->value == 1); + } + } + } return false; } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 4c6013792426..f29087dcc049 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -205,8 +205,13 @@ class TypeInferencer : private ExprFunctor, this->EmitFatal(Diagnostic::Error(op->span) << "Cannot do type inference on global variables " << "without a module"); } - relay::Function e = Downcast(mod_->Lookup(var)); - return e->checked_type(); + + if (mod_->ContainGlobalVar(var->name_hint)) { + relay::Function e = Downcast(mod_->Lookup(var)); + return e->checked_type(); + } else { + return op->checked_type_; + } } Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); } diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 24fb3dc95819..15a1493b8585 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -223,8 +223,12 @@ class LLVMModuleNode final : public runtime::ModuleNode { found_linked_params = true; continue; } - ICHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey(); + if (!kv.second->IsInstance()) { + // (@jroesch): we relax constraints here, Relay functions will just be ignored. + DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " + << kv.second->GetTypeKey(); + continue; + } auto f = Downcast(kv.second); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()); @@ -234,7 +238,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { } funcs.push_back(f); } - ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params)); + // TODO(@jroesch): follow up on this condition. + // ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params)); // TODO(tqchen): remove the entry function behavior as it does not // makes sense when we start to use multiple modules. cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index 424da1e817b6..cb2b50260326 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -109,7 +109,7 @@ Pass LegalizePackedCalls() { inputs[i] = true; } n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body)); - return std::move(f); + return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {}); } diff --git a/tests/python/relay/test_auto_scheduler_task_extraction.py b/tests/python/relay/test_auto_scheduler_task_extraction.py index cfbca40cf379..0acd8b0c87d6 100644 --- a/tests/python/relay/test_auto_scheduler_task_extraction.py +++ b/tests/python/relay/test_auto_scheduler_task_extraction.py @@ -101,6 +101,7 @@ def test_task_extraction_cuda(): mod, params = get_network("mlp") tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) + assert len(tasks) == 1 assert sum(task_weights) == 2 diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index 4ec1c21467fc..e7040f55f631 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -130,22 +130,22 @@ def test_plan_memory(): mod = relay.transform.FuseOps(0)(mod) func = mod["main"] mod = relay.transform.InferType()(mod) - smap = relay.backend._backend.GraphPlanMemory(func) + memory_plan = relay.backend._backend.GraphPlanMemory(func) storage_ids = set() device_types = set() storage_sizes = {} - for k, v in smap.items(): - assert len(v) == 3 - for x in v[0]: - storage_ids.add(x.value) - storage_sizes[x.value] = v[2] - for x in v[1]: - device_types.add(x.value) + + for k, v in memory_plan.expr_to_storage_info.items(): + for x in v.storage_ids: + storage_ids.add(x) + storage_sizes[x] = v.storage_sizes + for x in v.device_types: + device_types.add(x) # Current rule requires vars have unique storage id # because we don't do inplace, we will need another # two alternating temporary space. - assert len(storage_ids) == 4 + assert len(storage_ids) == 4, f"found storage_ids: {storage_ids}" assert len(device_types) == 1 assert len(storage_sizes) == 4 @@ -288,11 +288,4 @@ def test_graph_executor_nested_tuples(): if __name__ == "__main__": - test_reshape_nop() - test_plan_memory() - test_with_params() - test_add_op_scalar() - test_add_op_tensor() - test_add_op_broadcast() - test_gru_like() - test_compile_nested_tuples() + sys.exit(pytest.main([file] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index f0949ab19f9c..c33bd5792242 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -42,14 +42,14 @@ def check_graph_executor( target, ref_res, device, func, params, config, opt_level, expected_index=None ): with tvm.transform.PassContext(opt_level=opt_level, config=config): - graph, lib, new_params = relay.build(func, target, params=params) + graph_executor_factory = relay.build(func, target, params=params) + contexts = [tvm.cpu(0), tvm.device(device)] - graph_json = json.loads(graph) + graph_json = json.loads(graph_executor_factory.graph_json) if "device_index" in graph_json["attrs"]: device_index = graph_json["attrs"]["device_index"][1] assert device_index == expected_index - mod = graph_executor.create(graph, lib, contexts) - mod.set_input(**new_params) + mod = graph_executor.GraphModule(graph_executor_factory["default"](*contexts)) mod.run() res = mod.get_output(0).numpy() tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) @@ -272,12 +272,14 @@ def check_storage_and_device_types(): smap = relay.backend._backend.GraphPlanMemory(func) storage_ids = [] device_types = [] - for _, storage_dev_type in smap.items(): - assert len(storage_dev_type) == 3 - for sid in storage_dev_type[0]: + for _, storage_info in smap.expr_to_storage_info.items(): + + for sid in storage_info.storage_ids: storage_ids.append(sid.value) - for did in storage_dev_type[1]: + + for did in storage_info.device_types: device_types.append(did.value) + assert len(storage_ids) == 10 assert len(set(storage_ids)) == 8 assert len(set(device_types)) == 2 @@ -350,16 +352,16 @@ def expected(): assert tvm.ir.structural_equal(annotated_expr, expected_expr) smap = relay.backend._backend.GraphPlanMemory(annotated_expr) - for expr, storage_dev_type in smap.items(): + for expr, storage_info in smap.expr_to_storage_info.items(): # x is dev1 as output is dev1 if isinstance(expr, tvm.relay.expr.Var): - assert storage_dev_type[1][0] == dev1.device_type + assert storage_info.device_types[0] == dev1.device_type else: # device_copy op should be its dst_dev_type if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs): - assert storage_dev_type[1][0] == expr.attrs.dst_dev_type + assert storage_info.device_types[0] == expr.attrs.dst_dev_type else: - assert storage_dev_type[1][0] == expected_dev_type[expr.op.name].device_type + assert storage_info.device_types[0] == expected_dev_type[expr.op.name].device_type def run_fusible_network(dev, tgt):