diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 4a2eb63c7e6a..8379e6471561 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -67,6 +67,18 @@ struct CompilerAttrs : public tvm::AttrsNode { } }; +/*! + * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. + */ +struct TIRCallAttrs : public tvm::AttrsNode { + /*! \brief The metadata attached to the call node. */ + Map metadata; + + TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") { + TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ANNOTATION_H_ diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 0d18bc08e5ed..4b402a916267 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -318,6 +318,7 @@ def auto_schedule_topi(func_name, outs): A tuned schedule or none (if not tuned) in the final build mode; None in the tracing mode so that the fallback topi schedule will be used. """ + # pylint: disable=import-outside-toplevel from tvm.auto_scheduler.measure import ( prepare_input_map, @@ -376,6 +377,15 @@ def auto_schedule_topi(func_name, outs): return schedule +@tvm._ffi.register_func("auto_scheduler.relay_integration.te_compiler_update_weights") +def te_compiler_update_weights(function_weights): + """A callback for updating the weights of extracted tasks.""" + env = TracingEnvironment.current + if env is not None: + for key in env.wkl_key_to_weight: + env.wkl_key_to_weight[key] = function_weights[key[0]] + + def tensor_no_check_call(self, *indices): """An indexing function without any check. This is the same as `tvm.te.Tensor::__call__` except that the safety diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index dd5073331083..023fdc770a30 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -598,7 +598,7 @@ def pre_tune(self, task_scheduler, task_id): # overall info if all(cost < 1e9 for cost in task_scheduler.best_costs): - total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3) + total_latency_str = "%.3f" % (task_scheduler.cur_score.value * 1e3) else: total_latency_str = "-" print( diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 2db8c5a669f0..e9129db7b200 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -429,7 +429,7 @@ def dump(self): res += "------------------------------------\n" res += "target={}\n".format(k.target) res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.func_name) + res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) res += "----relay function----\n" res += k.source_func.astext() + "\n" res += "----tir function----- \n" @@ -444,7 +444,7 @@ def dump(self): res += "------------------------------------\n" res += "target={}\n".format(k.target) res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.func_name) + res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) res += "----relay function----\n" res += k.source_func.astext() + "\n" res += "----tir function----- \n" diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 8d73a090ed6f..8461885b38ce 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -23,7 +23,7 @@ import tvm._ffi from tvm._ffi import base as _base from tvm.runtime import NDArray, ndarray as _nd -from tvm.ir import RelayExpr, GlobalVar +from tvm.ir import RelayExpr, GlobalVar, Node from .base import RelayNode from . import _ffi_api @@ -538,3 +538,25 @@ def bind(expr, binds): The expression or function after binding. """ return _ffi_api.Bind(expr, binds) + + +@tvm._ffi.register_object("relay.StorageInfo") +class StorageInfo(Node): + """StorageInfo + + The static storage information produced by memory planning. + Contains the storage ids where expressions are stored, the + type of the "virtual devices" the expressions are stored on, + and the sizes of each storage element.""" + + @property + def storage_ids(self): + return _ffi_api.StorageInfoStorageIds(self) + + @property + def device_types(self): + return _ffi_api.StorageInfoDeviceTypes(self) + + @property + def storage_sizes(self): + return _ffi_api.StorageInfoStorageSizes(self) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cd8173717d5f..50f00140df9b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -437,14 +437,18 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target } if (target->kind->device_type == kDLCPU && target_host == target) { - ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " - << "and host_target are both llvm target." - << "\n"; + // TODO(@jroesch): This check is no longer true we need to figure out if we care about this. + // We need to relax this check for just TIR functions. + // ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " + // << "and host_target are both llvm target." + // << "\n"; } return {mhost, mdevice}; } +// Can we make this take one annotated IRModule? +// // Build for heterogeneous execution. runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 9b495adbdea8..84c17b53c83e 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -439,7 +439,7 @@ class AOTExecutorCodegen : public ExprVisitor { fi_node->tir_primfuncs.Set(primfunc_target, primfunc); fi_node->relay_primfuncs.Set(primfunc_target, relay_func); } - function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node)); + function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node)); } void VisitExpr_(const CallNode* op) override { @@ -465,20 +465,18 @@ class AOTExecutorCodegen : public ExprVisitor { << "(i.e functions composed of fusable operator invocations)"; } - auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); - auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); Target target; // Handle external function if (func->GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); - CCacheKey key = (*pf0)(func, target); - CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_); + CCacheKey key = CCacheKey(func, target); + CachedFunc ext_func = compile_engine_->Lower(key, mod_name_); ICHECK(ext_func.defined()) << "External function is not defined."; UpdateConstants(func, ¶ms_); // Generate the TIR function call - CreateFuncCall(GetRef(op), ext_func->func_name); + CreateFuncCall(GetRef(op), ext_func->prim_fn_var->name_hint); return; } @@ -503,8 +501,10 @@ class AOTExecutorCodegen : public ExprVisitor { } target = targets_[call_dev_type]; } - CCacheKey key = (*pf0)(func, target); - CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_); + + CCacheKey key = CCacheKey(func, target); + CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_); + if (!lowered_funcs_.count(target->str())) { lowered_funcs_[target->str()] = IRModule(Map({})); } @@ -513,7 +513,7 @@ class AOTExecutorCodegen : public ExprVisitor { UpdateFunctionMetadata(lowered_func, func, target); // Generate the TIR function call - CreateFuncCall(GetRef(op), lowered_func->func_name); + CreateFuncCall(GetRef(op), lowered_func->prim_fn_var->name_hint); } void VisitExpr_(const VarNode* op) override { diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index f0b43b14c650..6142e8323dea 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -46,569 +46,14 @@ #include "../../runtime/meta_data.h" #include "../transforms/pass_utils.h" +#include "te_compiler_cache.h" #include "utils.h" namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(LoweredOutputNode); -TVM_REGISTER_NODE_TYPE(CachedFuncNode); -TVM_REGISTER_NODE_TYPE(CCacheKeyNode); -TVM_REGISTER_NODE_TYPE(CCacheValueNode); TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); -LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { - auto n = make_object(); - n->outputs = std::move(outputs); - n->implementation = std::move(impl); - data_ = std::move(n); -} - -CCacheKey::CCacheKey(Function source_func, Target target) { - auto n = make_object(); - n->source_func = std::move(source_func); - n->target = std::move(target); - data_ = std::move(n); -} - -Array GetShape(const Array& shape) { - // for now, we always use int32 shape when possible - // even if the result of shape inference becomes int64. - Array res; - for (IndexExpr val : shape) { - const int64_t* pval = tir::as_const_int(val); - if (pval != nullptr) { -#ifndef TVM_INDEX_DEFAULT_I64 - ICHECK_LE(pval[0], std::numeric_limits::max()); - ICHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(IntImm(DataType::Int(32), *pval)); -#else - res.push_back(val); -#endif // TVM_INDEX_DEFAULT_I64 - } else if (val->IsInstance()) { - res.push_back(val.as()->ToVar()); - } else { - res.push_back(val); - } - } - return res; -} - -// The getter to get schedule from compile engine. -// Get schedule from functor. -class ScheduleGetter : public backend::MemoizedExprTranslator> { - public: - explicit ScheduleGetter(Target target) - : target_(target), device_copy_op_(Op::Get("device_copy")) { - // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - } - - CachedFunc Create(const Function& prim_func) { - auto cache_node = make_object(); - cache_node->target = target_; - for (Var param : prim_func->params) { - Array inputs; - if (const auto* ttype = param->checked_type().as()) { - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - // TODO(@icemelon): Allow recursive tuple - ICHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } - } - memo_[param] = inputs; - } - readable_name_stream_ << "fused"; - cache_node->outputs = this->VisitExpr(prim_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - cache_node->func_name = candidate_name; - ICHECK(anchor_op_.defined()); - // Fusion over tupled results may leave identity relationships - // between inputs and outputs, and those should not be scheduled. - // Hence schedule only non PlaceholderOp outputs. - tvm::Array tensor_outs; - for (const auto& tensor : cache_node->outputs) { - if (!tensor->op.as()) { - tensor_outs.push_back(tensor); - } - } - - te::Schedule schedule; - // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr) { - if (use_auto_scheduler_) { - const auto* fauto_schedule = - runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); - ICHECK(fauto_schedule != nullptr) - << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; - ObjectRef obj = (*fauto_schedule)(String(cache_node->func_name), tensor_outs); - if (obj.defined()) { - schedule = Downcast(obj); - } - } - - // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined()) { - ICHECK(anchor_implementation_.defined()); - schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); - } - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); - } - } - } - cache_node->schedule = std::move(schedule); - return CachedFunc(cache_node); - } - - Array VisitExpr_(const VarNode* op) final { - LOG(FATAL) << "Free variable " << op->name_hint(); - return {}; - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(op->is_scalar()); - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "compile_engine_const", topi::kBroadcast); - scalars_.push_back(value->op); - return {value}; - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = Op::GetAttrMap("TOpPattern"); - static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); - ICHECK(flower_call) << "relay.backend.lower_call is not registered."; - - Array inputs; - int count_tuple = 0; - for (Expr arg : call_node->args) { - if (arg->checked_type().as()) { - ++count_tuple; - } - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - - Array outputs; - OpImplementation impl; - // Skip fcompute for device copy operators as it is not registered. - if (op == device_copy_op_) { - const auto* copy_input = inputs[0].operator->(); - outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); - } else { - LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); - outputs = lowered_out->outputs; - impl = lowered_out->implementation; - } - - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern > anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; - } - if (outputs.size() != 1) { - const auto* tuple_type = call_node->checked_type().as(); - ICHECK(tuple_type) << "Expect output to be a tuple type"; - ICHECK_EQ(tuple_type->fields.size(), outputs.size()); - } - // Set the name to `__copy`. It will be detected in graph executor to perform - // data copy across devices. - if (op == device_copy_op_) { - readable_name_stream_.str(std::string()); - readable_name_stream_ << "__copy"; - } else { - readable_name_stream_ << '_' << op->name; - } - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - const auto* tuple_type = op->tuple->type_as(); - Array tuple = VisitExpr(op->tuple); - ICHECK_EQ(tuple_type->fields.size(), tuple.size()); - ICHECK_GE(op->index, 0); - ICHECK_LT(static_cast(op->index), tuple.size()); - return {tuple[op->index]}; - } - - private: - tvm::Target target_; - Op anchor_op_; - Attrs anchor_attrs_; - int anchor_op_pattern_{-1}; - OpImplementation anchor_implementation_; - std::ostringstream readable_name_stream_; - Array scalars_; - bool use_auto_scheduler_; - // Cache device copy op for equivalence checking to reduce registry lookup - // overhead for each invocation of call node when retrieving schedules. - const Op& device_copy_op_; -}; - -/*! - * \brief Create schedule for target. - * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \return Pair of schedule and cache. - * The funcs field in cache is not yet populated. - */ -CachedFunc CreateSchedule(const Function& source_func, const Target& target) { - return ScheduleGetter(target).Create(source_func); -} - -// Creates shape function from functor. -class MakeShapeFunc : public backend::MemoizedExprTranslator> { - public: - MakeShapeFunc() {} - - std::pair Create(const Function& prim_func) { - for (auto param : prim_func->params) { - param_states_[param] = kNoNeed; - Array data_inputs; - Array shape_inputs; - - auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { - // Add data placeholder - Shape shape = GetShape(ttype->shape); - tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); - data_inputs.push_back(data_tensor); - // Add shape placeholder - int64_t ndim = shape.size(); - Shape sshape; - if (ndim > 0) { - sshape.push_back(tvm::Integer(ndim)); - } - tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); - shape_inputs.push_back(shape_tensor); - }; - - if (const auto* ttype = param->checked_type().as()) { - add_placeholder(ttype); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - // TODO(@icemelon): Support recursive tuple - ICHECK(tuple_type); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - ICHECK(ttype); - add_placeholder(ttype); - } - } - param_data_[param] = data_inputs; - param_shapes_[param] = shape_inputs; - } - readable_name_stream_ << "shape_func"; - auto cache_node = make_object(); - cache_node->outputs = VisitExpr(prim_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - cache_node->func_name = candidate_name; - - // set inputs - for (auto param : prim_func->params) { - int state = param_states_[param]; - cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); - if (state & kNeedInputData) { - for (auto t : param_data_[param]) { - cache_node->inputs.push_back(t); - } - } - if (state & kNeedInputShape) { - for (auto t : param_shapes_[param]) { - cache_node->inputs.push_back(t); - } - } - } - - CachedFunc cfunc(cache_node); - // generate schedule for shape func - Array out_ops; - for (auto t : cache_node->outputs) { - out_ops.push_back(t->op); - } - auto schedule = te::create_schedule(out_ops); - tvm::te::AutoInlineInjective(schedule); - for (const auto& scalar : scalars_) { - auto scalar_op = scalar->op; - if (schedule->Contain(scalar_op)) { - schedule[scalar_op].compute_inline(); - } - } - return std::make_pair(schedule, cfunc); - } - - Array VisitExpr(const Expr& expr) final { - if (expr.as()) { - // Do not memoize vars because shape functions could use either the data - // or the shape of a var each time. - return ExprFunctor::VisitExpr(expr); - } - // For other case, do memoized visit - return backend::MemoizedExprTranslator>::VisitExpr(expr); - } - - Array VisitExpr_(const VarNode* var_node) final { - auto var = GetRef(var_node); - auto it = param_states_.find(var); - if (it == param_states_.end()) { - LOG(FATAL) << "Free variable " << var->name_hint(); - return {}; - } else { - ICHECK(data_dependents_per_input_.size()); - auto data_dependent = data_dependents_per_input_.back(); - if (data_dependent) { - param_states_[var] |= kNeedInputData; - return param_data_[var]; - } else { - param_states_[var] |= kNeedInputShape; - return param_shapes_[var]; - } - } - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(data_dependents_per_input_.size()); - bool data_dependent = data_dependents_per_input_.back(); - if (!op->is_scalar()) { - // This is a constant weight, extract the shape of the weight tensor. - // This can not be data dependent. - CHECK(!data_dependent); - auto ttype = op->checked_type().as(); - int ndim = static_cast(ttype->shape.size()); - Array out_shape{ndim}; - te::Tensor value = tvm::te::compute( - out_shape, - [&](const Array& indices) { - auto idx = indices[0]; - PrimExpr ret = make_const(DataType::Int(64), 0); - for (int i = 0; i < ndim; i++) { - ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); - } - return ret; - }, - "shape_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } - if (data_dependent) { - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = tvm::te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "data_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } else { - auto value = tvm::te::compute( - {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, - "shape_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto fshape_func = Op::GetAttrMap("FShapeFunc"); - static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) - << "Error in op fusion: output of the shape func is fed to a " - << "data-dependent shape func"; - ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; - ICHECK_GT(tshape_data_dependent.count(op), 0) - << "Internal error, cannot find TShapeDataDependent for " << op->name; - - Array dep_spec = tshape_data_dependent[op]; - if (dep_spec.size() == 1) { - // This is for cases when data dependence is specified per op - // Replicate 0 or 1 flag to all arguments - for (size_t i = 1; i < call_node->args.size(); ++i) { - dep_spec.push_back(dep_spec[0]); - } - } - - // Visit all inputs - Array inputs; - int count_tuple = 0; - for (size_t i = 0; i < call_node->args.size(); ++i) { - Expr arg = call_node->args[i]; - if (arg->checked_type().as()) { - ++count_tuple; - } - data_dependents_per_input_.push_back(dep_spec[i]->value != 0); - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - data_dependents_per_input_.pop_back(); - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - // Get output ndims - auto ret_type = call_node->checked_type(); - Array out_ndims; - if (const auto* ttype = ret_type.as()) { - out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); - } else { - auto rtype = ret_type.as(); - // TODO(@icemelon): Allow recursive tuple - ICHECK(rtype); - for (size_t i = 0; i < rtype->fields.size(); ++i) { - auto ttype = rtype->fields[i].as(); - ICHECK(ttype); - out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); - } - } - // Call shape function - auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); - readable_name_stream_ << "_" << op->name; - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - Array input_shapes = VisitExpr(op->tuple); - Array out; - out.push_back(input_shapes[op->index]); - return out; - } - - private: - /*! \brief String stream for function name */ - std::ostringstream readable_name_stream_; - /*! \brief Map from parameter to its shape function usage state */ - std::unordered_map param_states_; - /*! \brief Map from parameter to list of data placeholder */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; - /*! \brief Map from parameter to list of shape placeholder */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; - /*! \brief Stack of data dependencies for shape function, specified per each op input */ - std::vector data_dependents_per_input_; - /*! \brief Scalars used in the shape function */ - Array scalars_; -}; - class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. @@ -616,19 +61,19 @@ class CompileEngineImpl : public CompileEngineNode { return LowerInternal(key, mangle_fn)->cached_func; } + CachedFunc Lower(const CCacheKey& key, const String mod_name) { + auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; + + return Lower(key, mangle_fn); + } + // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { auto mangle_fn = [](String name) { return name; }; CCacheValue value = LowerInternal(key, mangle_fn); if (value->packed_func != nullptr) return value->packed_func; - // build the function. - tvm::runtime::Module m; - if (const auto* f = runtime::Registry::Get("relay.backend.build")) { - m = (*f)(value->cached_func->funcs, key->target); - } else { - m = build(value->cached_func->funcs, key->target, Target(nullptr)); - } - value->packed_func = m.GetFunction(value->cached_func->func_name); + auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); + value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); return value->packed_func; } @@ -643,6 +88,7 @@ class CompileEngineImpl : public CompileEngineNode { for (const auto& it : cache_) { auto src_func = it.first->source_func; ICHECK(src_func.defined()); + if (src_func->GetAttr(attr::kCompiler).defined()) { auto code_gen = src_func->GetAttr(attr::kCompiler); ICHECK(code_gen.defined()) << "No external codegen is set"; @@ -651,7 +97,9 @@ class CompileEngineImpl : public CompileEngineNode { auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" - << AsText(src_func, false); + << AsText(src_func, false) << "\n" + << "Functions with external codegen must have the " + << tvm::attr::kGlobalSymbol << " attr set."; std::string sn = symbol_name.value(); if (!cached_symbol.count(sn)) { @@ -669,7 +117,12 @@ class CompileEngineImpl : public CompileEngineNode { src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); runtime::Module ext_mod = (*pf)(src_func); - ICHECK(ext_mod.defined()) << "No external runtime is generated."; + // todo(@zhiics, @jroesch): Should this be a user visible error? + ICHECK(ext_mod.defined()) << "No external library was generated for " << ext_name + << "even though it was requested" + "by the annotated function " + << PrettyPrint(src_func); + ret.push_back(ext_mod); } } @@ -734,44 +187,49 @@ class CompileEngineImpl : public CompileEngineNode { // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. if (key->source_func->GetAttr(attr::kCompiler).defined()) { - auto cache_node = make_object(); + auto ir_module = IRModule(); const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; - cache_node->func_name = std::string(name_node.value()); - cache_node->target = Target("ext_dev"); - cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func); - value->cached_func = CachedFunc(cache_node); + auto func_name = std::string(name_node.value()); + auto target = Target("ext_dev"); + auto global_var = GlobalVar(func_name); + global_var->checked_type_ = key->source_func->checked_type(); + ir_module->Add(global_var, key->source_func); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); return value; } + // Enforce use the target. With target_scope(key->target); ICHECK(!value->cached_func.defined()); - auto cfunc = CreateSchedule(key->source_func, key->target); - auto cache_node = make_object(*(cfunc.operator->())); + auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(mangle_fn(name), &name_map_); + }); // Skip lowering for device copy node. const Expr body = (key->source_func)->body; if (const CallNode* call_node = body.as()) { if (call_node->attrs.as()) { - value->cached_func = CachedFunc(cache_node); + value->cached_func = cfunc; return value; } } - cache_node->func_name = GetUniqueName(mangle_fn(cache_node->func_name)); // NOTE: array will copy on write. - Array all_args = cache_node->inputs; - for (te::Tensor arg : cache_node->outputs) { + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { all_args.push_back(arg); } // lower the function std::unordered_map binds; - cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); + value->cached_func = cfunc; - value->cached_func = CachedFunc(cache_node); return value; } + // implement lowered shape func CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { std::lock_guard lock(mutex_); @@ -790,47 +248,17 @@ class CompileEngineImpl : public CompileEngineNode { With target_scope(key->target); ICHECK(!value->cached_func.defined()); - auto spair = MakeShapeFunc().Create(key->source_func); - auto cache_node = make_object(*(spair.second.operator->())); - cache_node->func_name = GetUniqueName(cache_node->func_name); - cache_node->target = key->target; - - Array all_args = cache_node->inputs; - for (te::Tensor arg : cache_node->outputs) { - all_args.push_back(arg); - } - using tvm::transform::PassContext; With fresh_pass_ctx_scope(PassContext::Create()); - std::unordered_map binds; - cache_node->funcs = tvm::LowerSchedule(spair.first, all_args, cache_node->func_name, binds); - value->cached_func = CachedFunc(cache_node); + auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(name, &name_map_); + }); + + value->cached_func = cached_func; return value; } - /*! - * \brief Get unique name from name. - * \param name The orginal name. - * \return Updated name which is unique. - */ - std::string GetUniqueName(std::string name) { - for (size_t i = 0; i < name.length(); ++i) { - if (name[i] == '.') name[i] = '_'; - } - while (true) { - auto it = name_map_.find(name); - if (it == name_map_.end()) { - name_map_[name] = 1; - return name; - } else { - std::ostringstream os; - os << name << "_" << it->second; - ++(it->second); - name = os.str(); - } - } - return name; - } + /*! \brief compiler cache lock*/ std::mutex mutex_; /*! \brief internal name map to get an unique name */ @@ -874,10 +302,7 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](Compi TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") .set_body_typed([](CompileEngine self, CCacheKey key, const String mod_name) { - auto mangle_fn = [mod_name](String name) { - return runtime::get_name_mangled(mod_name, name); - }; - return self->Lower(key, mangle_fn); + return self->Lower(key, mod_name); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index f766fcf97ea7..4afdc6d30485 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -19,8 +19,12 @@ /*! * \file relay/backend/compile_engine.h - * \brief Internal compialtion engine handle function cache. - * and interface to low level code generation. + * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. + * + * This layer represents the older design of the Relay compilation flow and is being deprecated + * in favor of te_compiler.h which is a migration step towards a standard pass based lowering of + * Relay functions. + * */ #ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ @@ -36,157 +40,12 @@ #include #include +#include "te_compiler_cache.h" + namespace tvm { namespace relay { -/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */ -enum ShapeFuncParamState { - kNoNeed = 0, - kNeedInputData = 1, - kNeedInputShape = 2, - kNeedBoth = 3, -}; - -struct LoweredOutputNode : public Object { - /*! \brief The outputs to the function */ - tvm::Array outputs; - /*! \brief The implementation used to compute the output */ - OpImplementation implementation; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("outputs", &outputs); - v->Visit("implementation", &implementation); - } - - static constexpr const char* _type_key = "relay.LoweredOutput"; - TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); -}; - -class LoweredOutput : public ObjectRef { - public: - TVM_DLL LoweredOutput(tvm::Array outputs, OpImplementation impl); - - TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode); -}; - -/*! \brief Node container to represent a cached function. */ -struct CachedFuncNode : public Object { - /* \brief compiled target */ - tvm::Target target; - /*! \brief Function name */ - std::string func_name; - /* \brief The inputs to the function */ - tvm::Array inputs; - /* \brief The outputs to the function */ - tvm::Array outputs; - /*! \brief The schedule to the function */ - te::Schedule schedule; - /*! \brief The lowered functions to support the function. */ - IRModule funcs = IRModule(Map({})); - - /*! \brief Parameter usage states in the shape function. */ - tvm::Array shape_func_param_states; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("target", &target); - v->Visit("func_name", &func_name); - v->Visit("inputs", &inputs); - v->Visit("outputs", &outputs); - v->Visit("schedule", &schedule); - v->Visit("funcs", &funcs); - v->Visit("shape_func_param_states", &shape_func_param_states); - } - - static constexpr const char* _type_key = "relay.CachedFunc"; - TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); -}; - -class CachedFunc : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); -}; - -class CCacheKey; -/*! \brief Compile cache key */ -class CCacheKeyNode : public Object { - public: - /*! \brief The source function to be lowered. */ - Function source_func; - /*! \brief The hardware target.*/ - Target target; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("source_func", &source_func); - v->Visit("target", &target); - } - /*! \return The hash value of CCacheKey. */ - inline size_t Hash() const; - /*! - * \brief check content equality - * \param other The other value. - * \return The result of equality check. - */ - inline bool Equal(const CCacheKeyNode* other) const; - - static constexpr const char* _type_key = "relay.CCacheKey"; - TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); - - private: - /*! - * \brief internal cached hash value. - */ - mutable size_t hash_{0}; -}; - -/*! \brief cache entry used in compile engine */ -class CCacheKey : public ObjectRef { - public: - CCacheKey() {} - explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} - - /*! - * \brief The constructor - * \param source_func The source function. - * \param target The target device. - */ - TVM_DLL CCacheKey(Function source_func, Target target); - - const CCacheKeyNode* operator->() const { return static_cast(get()); } - // comparator - inline bool operator==(const CCacheKey& other) const { - ICHECK(defined() && other.defined()); - return (*this)->Equal(other.operator->()); - } - using ContainerType = CCacheKeyNode; -}; - -/*! \brief Node container for compile cache. */ -class CCacheValueNode : public Object { - public: - /*! \brief The corresponding function */ - CachedFunc cached_func; - /*! \brief Result of Packed function generated by JIT */ - PackedFunc packed_func; - /*! \brief usage statistics */ - int use_count{0}; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("cached_func", &cached_func); - v->Visit("use_count", &use_count); - } - static constexpr const char* _type_key = "relay.CCacheValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); -}; - -/*! \brief cache entry used in compile engine */ -class CCacheValue : public ObjectRef { - public: - CCacheValue() {} - explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} - CCacheValueNode* operator->() { return static_cast(get_mutable()); } - const CCacheValueNode* operator->() const { return static_cast(get()); } - using ContainerType = CCacheValueNode; -}; +using namespace tvm::relay::tec; /*! * \brief Backend compilation engine for @@ -199,10 +58,18 @@ class CompileEngineNode : public Object { /*! * \brief Get lowered result. * \param key The key to the cached function. - * \param mod_name The module name to mangle the functions + * \param mod_name The mangling function for mangling names. * \return The result. */ virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; + + /*! + * \brief Get lowered result. + * \param key The key to the cached function. + * \param mod_name The module name to mangle the functions. + * \return The result. + */ + virtual CachedFunc Lower(const CCacheKey& key, const String mangle_fn) = 0; /*! * \brief Just in time compile to get a PackedFunc. * \param key The key to the cached function. @@ -242,49 +109,7 @@ class CompileEngine : public ObjectRef { TVM_DLL static CompileEngine& Global(); }; -/*! - * \brief Create schedule for target. - * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \return Pair of schedule and cache. - * The funcs field in cache is not yet populated. - */ -CachedFunc CreateSchedule(const Function& source_func, const Target& target); - -/*! - * \brief Check if the type is dynamic. - * \param ty The type to be checked. - * \return The result. - */ -bool IsDynamic(const Type& ty); - -// implementations -inline size_t CCacheKeyNode::Hash() const { - if (hash_ != 0) return hash_; - // do structral hash, avoid 0. - hash_ = tvm::StructuralHash()(this->source_func); - hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); - if (hash_ == 0) hash_ = 1; - return hash_; -} - -inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { - if (Hash() != other->Hash()) return false; - return this->target->str() == other->target->str() && - tvm::StructuralEqual()(this->source_func, other->source_func); -} - } // namespace relay } // namespace tvm -namespace std { -// overload hash -template <> -struct hash<::tvm::relay::CCacheKey> { - size_t operator()(const ::tvm::relay::CCacheKey& key) const { - ICHECK(key.defined()); - return key->Hash(); - } -}; -} // namespace std #endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index bca8e8244093..9d59e8e5f3a8 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -36,10 +37,13 @@ #include #include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { namespace relay { +// TODO(@jroesch, @csullivan): declare directly elsewhere +backend::StaticMemoryPlan GraphPlanMemory(const Function& func); namespace backend { class GraphNode; @@ -52,7 +56,6 @@ using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; using GraphInputObjectPtr = std::shared_ptr; using GraphOpObjectPtr = std::shared_ptr; -using TargetsMap = std::unordered_map; /*! \brief Node types */ enum GraphNodeType { @@ -176,112 +179,86 @@ class GraphOpNode : public GraphNode { const std::string op_type_name_{"tvm_op"}; }; -/*! \brief Code generator for graph executor */ +/*! \brief Code generator for the graph executor, produces a module containing the graph JSON, + * module, and parameters. + */ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> { public: - GraphExecutorCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) { - compile_engine_ = CompileEngine::Global(); + GraphExecutorCodegen(runtime::Module* mod, const TargetMap& targets) : mod_(mod) { targets_ = targets; } - /*! - * \brief Update the "main" control function's metadata - * - * \param func The main function that contains calls to relay primitive functions - */ - void UpdateMainWorkspaceSize(const Function& func) { - // This is a Map> - std::unordered_map> sid_workspace; - // This is a Map - std::unordered_map device_io; - // This is a Map - std::unordered_map device_consts; - - // Initialize the maps to zero - for (const auto& kv : storage_device_map_) { - auto sids = kv.second[0]; - auto devices = kv.second[1]; - CHECK_EQ(sids.size(), devices.size()); - for (uint32_t i = 0; i < sids.size(); i++) { - sid_workspace[devices[i]][sids[i]] = 0; - device_io[devices[i]] = 0; - device_consts[devices[i]] = 0; - } - } - - // Collect sizes of tensors - for (const auto& kv : storage_device_map_) { - auto size_bytes = CalculateRelayExprSizeBytes(kv.first->checked_type()); - auto sids = kv.second[0]; - auto devices = kv.second[1]; - if (kv.first->IsInstance()) { - for (const auto& dev : devices) { - device_consts[dev] += size_bytes; - } - continue; - } else if (kv.first->IsInstance() || kv.first == func->body) { - for (const auto& dev : devices) { - device_io[dev] += size_bytes; - } - continue; - } - for (uint32_t i = 0; i < sids.size(); i++) { - // Here we record the largest size of the tensor - // that share the same storage id, because storage_id will - // be shared between multiple tensors that are not live simultaneously. - if (size_bytes > sid_workspace[devices[i]][sids[i]]) { - sid_workspace[devices[i]][sids[i]] = size_bytes; - } - } - } - - // This is a Map - std::unordered_map device_workspace; - // Once we know the sizes of sids, we need to accumulate per device - for (const auto& dev_sid_size : sid_workspace) { - auto dev = dev_sid_size.first; - device_workspace[dev] = 0; - for (const auto& sid_size : dev_sid_size.second) { - device_workspace[dev] += sid_size.second; - } - } - - // Populate FunctionInfo - auto fi_node = make_object(); - // Initialize all target workspaces to zero - for (const auto& kv : targets_) { - auto tgt = kv.second; - fi_node->workspace_sizes.Set(tgt, 0); - } - for (const auto& dev_and_size : device_workspace) { - auto tgt = GetTargetFromInteger(dev_and_size.first); - fi_node->workspace_sizes.Set(tgt, dev_and_size.second); - fi_node->relay_primfuncs.Set(tgt, func); - } - for (const auto& dev_and_size : device_io) { - auto tgt = GetTargetFromInteger(dev_and_size.first); - fi_node->io_sizes.Set(tgt, dev_and_size.second); - } - for (const auto& dev_and_size : device_consts) { - auto tgt = GetTargetFromInteger(dev_and_size.first); - fi_node->constant_sizes.Set(tgt, dev_and_size.second); - } - - function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node)); + StorageInfo GetStorageInfo(const Expr& e) { + size_t count = memory_plan_->expr_to_storage_info.count(e); + ICHECK_GT(count, 0) << "Expr is not existing in storage plan"; + auto storage_info = memory_plan_->expr_to_storage_info[e]; + return storage_info; } LoweredOutput Codegen(relay::Function func, String mod_name) { - auto pf = GetPackedFunc("relay.backend.GraphPlanMemory"); - storage_device_map_ = (*pf)(func); mod_name_ = mod_name; - UpdateMainWorkspaceSize(func); + + // TODO(@jroesch): we need to split device planning and memory planning + // first we run device assignment, then we perform lowering, and then + // storage planning in ideal world. + + memory_plan_ = GraphPlanMemory(func); + + // This first phase moves from implicit use of compile engine, + // to instead explicitly lowering the incoming IRModule, and then + // performing the preexisting graph executor code generation phase. + IRModule mod = IRModule::FromExpr(func); + + // Build a map from each operation to device. + tec::DeviceMap device_context_map; + for (const auto& it : memory_plan_->expr_to_storage_info) { + auto expr = it.first; + auto storage_info = it.second; + auto device_types = storage_info->device_types; + // CHECK_EQ(device_types.size(), 1); + tvm::Device dev; + dev.device_id = 0; + dev.device_type = device_types[0]; + device_context_map.insert({expr, dev}); + } + + auto lowered_module = tec::LowerTE(mod, targets_, device_context_map, memory_plan_, mod_name_, + [this](Function func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } + + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + UpdateFunctionMetadata(func, this->function_metadata_); + }); + + function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); + auto main_module = lowered_module.main_module; + main_module = relay::transform::InferType()(main_module); + relay::Function main_func = Downcast(main_module->Lookup("main")); + + // Now that we have lowered all operators to TIR code, we can proceed with compilation. + // + // We need to unfortunately re-plan as the previous results have been invalidated by lowering + // we will fix this in future refactors. + memory_plan_ = GraphPlanMemory(main_func); + + // The graph planner also can not handle planning calls to global variables to we must remap + // First we convert all the parameters into input nodes. - for (auto param : func->params) { + for (auto param : main_func->params) { auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs()); var_map_[param.get()] = AddNode(node_ptr, param); } - heads_ = VisitExpr(func->body); + + heads_ = VisitExpr(main_func->body); std::ostringstream os; + dmlc::JSONWriter writer(&os); GetJSON(&writer); LoweredOutput ret; @@ -292,17 +269,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(param_storage_ids_[param.first]), param.second))); } - - for (auto& kv : lowered_funcs_) { - if (ret.lowered_funcs.count(kv.first) == 0) { - ret.lowered_funcs.Set(kv.first, IRModule(Map({}))); - } - auto& mod = ret.lowered_funcs[kv.first]; - mod->Update(kv.second); - ret.lowered_funcs.Set(kv.first, mod); - } - ret.external_mods = compile_engine_->LowerExternalFunctions(); ret.function_metadata = std::move(function_metadata_); + ret.lowered_funcs = lowered_module.per_target_module; + ret.external_mods = lowered_module.external_mods; return ret; } @@ -331,20 +300,18 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator AddNode(GraphObjectPtr node, Expr expr) { auto checked_type = expr->checked_type(); - size_t count = storage_device_map_.count(expr); - ICHECK_GT(count, 0) << "Expr is not existing in storage plan"; - auto storage_device_info = storage_device_map_[expr]; - ICHECK_EQ(storage_device_info.size(), 3); + + auto storage_info = GetStorageInfo(expr); // storage - std::vector storage_info; - for (auto& v : storage_device_info[0]) { - storage_info.push_back(v->value); + std::vector storage_ids; + for (auto v : storage_info->storage_ids) { + storage_ids.push_back(v); } - node->attrs_["storage_id"] = std::move(storage_info); + node->attrs_["storage_id"] = std::move(storage_ids); // type std::vector device_types; - for (auto& v : storage_device_info[1]) { - device_types.push_back(v->value); + for (auto v : storage_info->device_types) { + device_types.push_back(static_cast(v)); } size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0); if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) { @@ -404,7 +371,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorvalue; + param_storage_ids_[name] = GetStorageInfo(expr)->storage_ids[0]; params_[name] = op->data; return to_return; } @@ -420,8 +387,16 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator GraphAddCallNode(const CallNode* op, const std::string& op_name, - const std::string& func_name, GraphAttrs attrs) { + bool ShareSameStorage(const Expr& lhs, const Expr& rhs) { + StorageInfo lit = GetStorageInfo(lhs); + StorageInfo rit = GetStorageInfo(rhs); + int64_t lhs_storage_id = lit->storage_ids[0]; + int64_t rhs_storage_id = rit->storage_ids[0]; + return lhs_storage_id == rhs_storage_id; + } + + std::vector GraphAddCallNode(const CallNode* op, const std::string& func_name, + GraphAttrs attrs) { std::vector inputs; for (auto arg : op->args) { auto res = VisitExpr(arg); @@ -429,161 +404,52 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(op)); - } - bool ShareSameStorage(const Expr& lhs, const Expr& rhs) { - auto lit = storage_device_map_.find(lhs); - auto rit = storage_device_map_.find(rhs); - ICHECK(lit != storage_device_map_.end()); - ICHECK(rit != storage_device_map_.end()); - int64_t lhs_storage_id = ((*lit).second)[0][0]->value; - int64_t rhs_storage_id = ((*rit).second)[0][0]->value; - return lhs_storage_id == rhs_storage_id; - } + /// An adapted version of the storage optimization for the time being. + bool reshape_only = false; + if (op->attrs.defined()) { + if (auto tir_call_attrs = op->attrs.as()) { + Map metadata = tir_call_attrs->metadata; + if (metadata.count(attr::kReshapeOnly) && + Downcast(metadata[attr::kReshapeOnly])->value == 1) { + reshape_only = true; + } - /*! - * \brief Obtain the Target from the device type. - * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select associated using the targets_ Map. - * - * \param dev_type - * \return Target - */ - Target GetTargetFromInteger(int64_t dev_type) { - if (targets_.size() == 1) { - // homogeneous execution. - const auto& it = targets_.begin(); - return (*it).second; - } else { - // heterogeneous execution. - std::string call_dev_name; - if (dev_type == 0) { - call_dev_name = "llvm"; - } else { - call_dev_name = runtime::DeviceName(dev_type); - } - if (targets_.count(dev_type) == 0) { - LOG(FATAL) << "No target is provided for device " << call_dev_name; - } - return targets_[dev_type]; - } - } + auto relay_attrs = Downcast(tir_call_attrs->metadata["relay_attrs"]); - /*! - * \brief Update the function metadata for a given cached function and its relay - * primitive function. - * - * \param cfunc The cached function as provided the by the compile engine - * \param relay_func The source relay primitive function - * \param relay_target The target associated with relay primitive function - */ - void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func, - const Target& relay_target) { - auto fi_node = make_object(); - for (const auto& kv : cfunc->funcs->functions) { - auto primfunc = Downcast(kv.second); - auto workspace_byte_alignment = relay_target->GetAttr("workspace-byte-alignment") - .value_or(tvm::runtime::kDefaultWorkspaceAlignment); - Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); - Target primfunc_target = relay_target; - if (primfunc->attrs->dict.count("target")) { - primfunc_target = Downcast(primfunc->attrs->dict["target"]); - } - fi_node->workspace_sizes.Set(primfunc_target, workspace_size); - // Calculating size for I/O - for (auto const& param : primfunc->params) { - auto p_shape = primfunc->buffer_map[param]->shape; - int num_of_elements = 1; - for (const auto& dim_index_expr : p_shape) { - if (dim_index_expr->IsInstance()) { - num_of_elements *= dim_index_expr.as()->value; - } else { - // If shape is dynamic, we cannot calculate workspace in compile time. - num_of_elements = 0; + for (auto p : relay_attrs->dict) { + if (p.second.as()) { + attrs[p.first] = std::string(Downcast(p.second)); } } - int element_size = primfunc->buffer_map[param]->dtype.bytes(); - fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements); - } - fi_node->constant_sizes.Set(primfunc_target, 0); - fi_node->tir_primfuncs.Set(primfunc_target, primfunc); - fi_node->relay_primfuncs.Set(primfunc_target, relay_func); - } - function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node)); - } - - std::vector VisitExpr_(const CallNode* op) override { - Expr expr = GetRef(op); - Function func; - if (op->op.as()) { - LOG(FATAL) << "Operators should be transformed away; try applying" - << "the fuse_ops transformation to the expression."; - } else if (op->op.as()) { - LOG(FATAL) << "Not implemented"; - } else if (op->op.as()) { - func = GetRef(op->op.as()); - } else { - LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); - } - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - LOG(FATAL) << "TVM only support calls to primitive functions " - << "(i.e functions composed of fusable operator invocations)"; - } - - // Copy attrs from function into the graph node - // For now we only handle strings - GraphAttrs attrs; - for (auto p : func->attrs->dict) { - if (p.second.as()) { - attrs[p.first] = std::string(Downcast(p.second)); } } - auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); - auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); - Target target; - // Handle external function - if (func->GetAttr(attr::kCompiler).defined()) { - target = Target("ext_dev"); - CCacheKey key = (*pf0)(func, target); - CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_); - ICHECK(ext_func.defined()) << "External function is not defined."; - UpdateConstants(func, ¶ms_); - return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name, attrs); + if (reshape_only && ShareSameStorage(GetRef(op), op->args[0])) { + auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs); + return AddNode(node, GetRef(op)); } - // In the current flat memory allocation scenario - // the flat memory allocator can always allocate input - // and output of the reshape to the same memory, we can turn reshape only - // function to a nop. - // - // NOTE that for non-flat memory this is not necessarily true. - // - // TODO(tvm-team) Update checks of flat memory enablement when we support - // opaque-nd memory planning to skip this path. - if (func->HasNonzeroAttr(attr::kReshapeOnly) && ShareSameStorage(expr, op->args[0])) { - return GraphAddCallNode(op, "reshape_nop", "__nop", attrs); - } + // Compute the operator name, because we used the get unique name when generating the kernel. + auto op_name = _GetUniqueName(func_name); + auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs); + return AddNode(node, GetRef(op)); + } - ICHECK_GE(storage_device_map_.count(expr), 0); - auto& device_type = storage_device_map_[expr][1]; - auto call_dev_type = device_type[0]->value; - target = GetTargetFromInteger(call_dev_type); - // Normal Relay Function + std::vector VisitExpr_(const CallNode* call_node) override { + relay::Call call = GetRef(call_node); + if (auto global_node = call->op.as()) { + auto prim_fn_name = global_node->name_hint; - CCacheKey key = (*pf0)(func, target); - CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_); - if (!lowered_funcs_.count(target->str())) { - lowered_funcs_[target->str()] = IRModule(Map({})); + return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs()); + } else { + ICHECK(false) << "Non-primitive-call nodes should have been transformed away.\n" + << "The graph executor code generator expects all calls to have their callee " + "normalized to a GlobalVar but found a " + << call->GetTypeKey() << "." + << "AST: " << PrettyPrint(call) << PrettyPrint(call) << std::endl; + return {}; } - lowered_funcs_[target->str()]->Update(lowered_func->funcs); - - // Update function metadata via looking at all primfuncs - UpdateFunctionMetadata(lowered_func, func, target); - return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name, - attrs); } std::vector VisitExpr_(const LetNode* op) override { @@ -714,7 +580,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> var_map_; /*! \brief target device */ - TargetsMap targets_; + TargetMap targets_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). * These are take as inputs to the GraphExecutor. @@ -724,7 +590,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator params_; std::unordered_map param_storage_ids_; /*! \brief plan memory of device result */ - Map> storage_device_map_; + StaticMemoryPlan memory_plan_; /*! \brief the module name we use to mangle the function names */ String mod_name_; /*! \brief lowered funcs */ @@ -733,8 +599,6 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator function_metadata_; /*! \brief name map */ std::unordered_map name_map_; - /*! \brief compile engine */ - CompileEngine compile_engine_; }; class GraphExecutorCodegenModule : public runtime::ModuleNode { @@ -747,11 +611,11 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { << "runtime::Module mod and Map targets"; void* mod = args[0]; Map tmp = args[1]; - TargetsMap targets; + TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); ICHECK(dev_type); - targets[dev_type->value] = it.second; + targets[static_cast(dev_type->value)] = it.second; } codegen_ = std::make_shared(reinterpret_cast(mod), targets); diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 351469d6e1ca..93c823d8a007 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -23,15 +23,20 @@ * the program in the graph executor. */ #include +#include #include #include +#include #include #include "../../support/arena.h" +#include "./utils.h" namespace tvm { namespace relay { +using backend::StaticMemoryPlan; +using backend::StorageInfo; using IntegerArray = Array; struct StorageToken { @@ -48,6 +53,18 @@ struct StorageToken { int64_t storage_id{-1}; }; +std::ostream& operator<<(std::ostream& os, StorageToken tok) { + return os << "StorageToken: " << std::endl + << "ref_counter: " << tok.ref_counter << std::endl + << "max_bytes: " << tok.max_bytes << std::endl + << "tttype: " << tok.ttype + << std::endl + // ok idk how to print this properly + << "tttype shape: " << tok.ttype->shape << std::endl + << "device_type: " << tok.device_type << std::endl + << "storage_id: " << tok.storage_id << std::endl; +} + class StorageAllocaBaseVisitor : public ExprVisitor { public: // run the visitor on a function. @@ -114,7 +131,8 @@ class StorageAllocaBaseVisitor : public ExprVisitor { const std::vector& GetToken(const Expr& expr) { this->VisitExpr(expr); auto it = token_map_.find(expr.operator->()); - ICHECK(it != token_map_.end()); + ICHECK(it != token_map_.end()) + << "Expression: `" << PrettyPrint(expr) << "` not found in storage map."; return it->second; } /*! @@ -168,6 +186,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { void VisitExpr_(const CallNode* op) final { // create token for the call node. CreateToken(op, true); + // for each input, visit argument token. for (Expr arg : op->args) { for (StorageToken* tok : GetToken(arg)) { @@ -196,31 +215,32 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } // Run storage allocation for a function. - Map > Plan(const Function& func) { + StaticMemoryPlan Plan(const Function& func) { prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func); this->Run(func); // The value of smap contains two integer arrays where the first array // contains the planned storage ids and the second holds the device types. - Map > smap; + Map smap; int num_annotated_nodes = 0; int num_nodes = 0; for (const auto& kv : token_map_) { - std::vector storage_ids; - std::vector device_types; - std::vector sid_sizes_byte; + std::vector storage_ids; + std::vector device_types; + std::vector sid_sizes_byte; + for (StorageToken* tok : kv.second) { if (tok->device_type) { num_annotated_nodes++; } num_nodes++; storage_ids.push_back(tok->storage_id); - device_types.push_back(tok->device_type); + device_types.push_back(static_cast(tok->device_type)); sid_sizes_byte.push_back(GetMemorySize(tok)); } - smap.Set(GetRef(kv.first), - Array({storage_ids, device_types, sid_sizes_byte})); + auto storage_info = backend::StorageInfo(storage_ids, device_types, sid_sizes_byte); + smap.Set(GetRef(kv.first), storage_info); } // Either all or none of the nodes should be annotated. if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { @@ -228,7 +248,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { << "expressions are assigned with virtual device types. Either all " "or none of the expressions are expected to be annotated."; } - return smap; + + return backend::StaticMemoryPlan(smap); } protected: @@ -279,6 +300,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { args.push_back(tok); } } + // Under the flat-memory setting. // we can force aliasing the input and output of reshape // to make it an nop. Note that this is not true @@ -288,12 +310,17 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // TODO(tvm-team) Update checks of flat memory enablement when we support // opaque-nd memory planning to skip this path. if (IsReshape(op)) { + // TODO(@electriclilies, jroesch): This check is failing because the size of args is 3 + // I can't figure out where the extra args are coming from, I assume it must be related + // to the relay_attrs field we added to the TIRCallArgs, but I don't know where / how + // that's happening... ICHECK_EQ(args.size(), 1U); ReuseInputToken(op, args[0]); } else { // create token for the call node. CreateToken(op, true); } + // check if there is orphaned output that can be released immediately. for (StorageToken* tok : token_map_.at(op)) { CheckForRelease(tok); @@ -320,6 +347,15 @@ class StorageAllocator : public StorageAllocaBaseVisitor { if (const auto* fn = call->op.as()) { return fn->HasNonzeroAttr(attr::kReshapeOnly); } + + if (call->attrs.defined()) { + if (auto tir_call_attrs = call->attrs.as()) { + Map metadata = tir_call_attrs->metadata; + return metadata.count(attr::kReshapeOnly) && + (Downcast(metadata[attr::kReshapeOnly])->value == 1); + } + } + return false; } /*! @@ -419,9 +455,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { std::unordered_map > prototype_; }; -Map > GraphPlanMemory(const Function& func) { - return StorageAllocator().Plan(func); -} +StaticMemoryPlan GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); } TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index eeba010dc164..53985c78a33c 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -32,6 +32,7 @@ #include #include +#include "../transforms/pass_utils.h" #include "compile_engine.h" namespace tvm { @@ -381,7 +382,7 @@ class Interpreter : public ExprFunctor, } else { m = build(cfunc->funcs, cfunc->target, Target(nullptr)); } - shape_func = m.GetFunction(cfunc->func_name); + shape_func = m.GetFunction(cfunc->prim_fn_var->name_hint); shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); // Get output shapes diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc new file mode 100644 index 000000000000..93b9c6fc1827 --- /dev/null +++ b/src/relay/backend/te_compiler.cc @@ -0,0 +1,743 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "te_compiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../transforms/pass_utils.h" +#include "te_compiler.h" +#include "te_compiler_cache.h" +#include "utils.h" + +namespace tvm { +namespace relay { +// TODO(@jroesch, @csullivan): declare directly elsewhere +backend::StaticMemoryPlan GraphPlanMemory(const Function& func); + +namespace tec { + +using namespace tvm::relay::transform; + +TVM_REGISTER_OBJECT_TYPE(TECompilerNode); + +class TECompilerImpl : public TECompilerNode { + public: + // Lower the function. + CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) { + return LowerInternal(key, mangle_fn)->cached_func; + } + + CachedFunc Lower(const CCacheKey& key, const String mod_name) { + auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; + + return Lower(key, mangle_fn); + } + + // For now, build one module per function. + PackedFunc JIT(const CCacheKey& key) final { + auto mangle_fn = [](String name) { return name; }; + CCacheValue value = LowerInternal(key, mangle_fn); + if (value->packed_func != nullptr) { + return value->packed_func; + } + auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); + value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); + return value->packed_func; + } + + CachedFunc LowerShapeFunc(const CCacheKey& key) final { + return LowerShapeFuncInternal(key)->cached_func; + } + + Map GetLoweredFunctions() { + Map lowered_functions; + for (const auto& it : cache_) { + auto source_func = it.first; + auto lowered_func = it.second; + auto target = source_func->target; + + if (!lowered_functions.count(target->str())) { + lowered_functions.Set(target->str(), IRModule(Map({}))); + } + + lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + } + return lowered_functions; + } + + Array LowerExternalFunctions() { + Array ret; + std::unordered_map cached_symbol; + std::vector cached_ext_funcs; + for (const auto& it : cache_) { + auto src_func = it.first->source_func; + ICHECK(src_func.defined()); + if (src_func->GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->GetAttr(attr::kCompiler); + std::string code_gen_name = code_gen.value(); + cached_ext_funcs.push_back(it.first); + + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" + << AsText(src_func, false); + + std::string sn = symbol_name.value(); + if (cached_symbol.count(sn)) { + cached_symbol[sn] = code_gen_name; + } else { + ICHECK_NE(sn, code_gen_name) + << "Found duplicated symbol: " << sn << " for: " << code_gen_name; + } + + std::string ext_name = "relay.ext." + code_gen_name; + auto pf = tvm::runtime::Registry::Get(ext_name); + ICHECK(pf) << "Failed to find the codegen tool for " << ext_name; + // No need to keep compiler attribute at this point, functions have been + // extracted for specific codegen. + src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); + runtime::Module ext_mod = (*pf)(src_func); + + ICHECK(ext_mod.defined()) << "No external runtime is generated."; + ret.push_back(ext_mod); + } + } + + // No need to cache external functions as we collected them all to create + // external runtime modules. + for (const auto& it : cached_ext_funcs) { + cache_.erase(it); + } + return ret; + } + + void Clear() final { cache_.clear(); } + + // List all items in the cache. + Array ListItems() { + std::lock_guard lock(mutex_); + Array items; + for (auto& kv : cache_) { + items.push_back(kv.first); + items.push_back(kv.second); + } + return items; + } + + /*! + * \brief Get the cache key of the function that is being lowered currently + * \return the cache key + */ + CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } + + private: + // implement lowered func + CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { + std::lock_guard lock(mutex_); + CCacheValue value; + auto it = cache_.find(key); + if (it != cache_.end()) { + it->second->use_count += 1; + if (it->second->cached_func.defined()) return it->second; + value = it->second; + } else { + value = CCacheValue(make_object()); + value->use_count = 1; + cache_[key] = value; + } + cur_ccache_key_ = key; + + // No need to lower external functions for now. We will invoke the external + // codegen tool once and lower all functions together. + if (key->source_func->GetAttr(attr::kCompiler).defined()) { + auto ir_module = IRModule(); + const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(name_node.defined()) << "External function has not been attached a name yet."; + auto func_name = GetUniqueName(name_node.value(), &name_map_); + auto target = Target("ext_dev"); + auto global_var = GlobalVar(func_name); + global_var->checked_type_ = key->source_func->checked_type(); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); + return value; + } + + // Enforce use the target. + With target_scope(key->target); + + ICHECK(!value->cached_func.defined()); + auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { + auto mangled = mangle_fn(name); + return GetUniqueName(mangled, &name_map_); + }); + + // Skip lowering for device copy node. + const Expr body = (key->source_func)->body; + if (const CallNode* call_node = body.as()) { + if (call_node->attrs.as()) { + value->cached_func = cfunc; + return value; + } + } + + // NOTE: array will copy on write. + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { + all_args.push_back(arg); + } + + std::unordered_map binds; + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); + value->cached_func = cfunc; + return value; + } + + // implement lowered shape func + CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { + std::lock_guard lock(mutex_); + CCacheValue value; + auto it = shape_func_cache_.find(key); + if (it != shape_func_cache_.end()) { + it->second->use_count += 1; + if (it->second->cached_func.defined()) return it->second; + value = it->second; + } else { + value = CCacheValue(make_object()); + value->use_count = 0; + shape_func_cache_[key] = value; + } + // Enforce use the target. + With target_scope(key->target); + + ICHECK(!value->cached_func.defined()); + + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(name, &name_map_); + }); + + value->cached_func = cached_func; + return value; + } + + std::unordered_map GetOpWeights() { + std::unordered_map weights; + for (auto pair : cache_) { + auto value = pair.second; + auto name = value->cached_func->prim_fn_var->name_hint; + weights[name] = value->use_count; + } + return weights; + } + + /*! \brief compiler cache lock*/ + std::mutex mutex_; + /*! \brief internal name map to get an unique name */ + std::unordered_map name_map_; + /*! \brief internal compiler cache */ + std::unordered_map cache_; + /*! \brief internal compiler cache for shape funcs */ + std::unordered_map shape_func_cache_; + /*! \brief the cache key of the function that is being lowered currently*/ + CCacheKey cur_ccache_key_; +}; + +TECompiler::TECompiler() { + auto object = make_object(); + data_ = object; +} + +using AnalysisRemapping = std::unordered_map; + +std::tuple IsDeviceCopy(const Function& func) { + if (auto call_node = func->body.as()) { + if (auto op_node = call_node->op.as()) { + if (op_node->name == "device_copy") { + auto attrs = call_node->attrs.as(); + auto dst = attrs->dst_dev_type; + auto src = attrs->src_dev_type; + return std::tuple(true, src, dst); + } + } + } + + return std::tuple(false, -1, -1); +} + +class LowerTensorExpr : public ExprMutator { + public: + LowerTensorExpr(const IRModule& module, const TargetMap& targets, const DeviceMap& device_ctx_map, + ProcessFn process_fn, const String& module_name, TECompiler compiler) + : module_(module), + targets_(targets), + device_context_map_(device_ctx_map), + process_fn(process_fn), + module_name_(module_name), + compiler_(compiler) {} + + Expr VisitExpr_(const CallNode* call) override { + Call expr = GetRef(call); + Function func; + + if (call->op.as()) { + func = GetRef(call->op.as()); + } else { + return ExprMutator::VisitExpr_(call); + } + + if (!func->HasNonzeroAttr(attr::kPrimitive)) { + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn(func); + return ExprMutator::VisitExpr_(call); + } + + // Process inputs. + Array args; + for (size_t i = 0; i < expr->args.size(); i++) { + args.push_back(VisitExpr(expr->args[i])); + } + + Target target; + + if (func->GetAttr(attr::kCompiler).defined()) { + target = Target("ext_dev"); + CCacheKey key = CCacheKey(func, target); + CachedFunc ext_func = compiler_->Lower(key, module_name_); + ICHECK(ext_func.defined()) << "Lowering returned undefined function for " + << ext_func->prim_fn_var->name_hint; + + Map prim_fns; + + for (auto prim_fn : ext_func->funcs->functions) { + CHECK(prim_fn.second.as()) << "must be a prim fn"; + prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); + } + + relay::Function func_with_metadata = func; + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, "target", ext_func->target); + + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn(func_with_metadata); + + auto ret_call = Call(ext_func->prim_fn_var, args, {}); + return std::move(ret_call); + } + + ICHECK_GE(device_context_map_.count(expr), 0) + << "Could not find an entry in the device context map for " << PrettyPrint(expr) + << "The memory planning was either not performed for this precise node, or there is bug " + "in the memory planner."; + + auto& device_context = this->device_context_map_[expr]; + auto call_dev_type = device_context.device_type; + + // Non-External Relay Function + if (targets_.size() == 1) { + // The homogeneous execution case, we should only have one target + // so we just grab it. + const auto& it = targets_.begin(); + target = (*it).second; + } else { + // The heterogeneous execution case we have multiple targets + // in this case. + // + // We need to identify the target and translate. + std::string call_dev_name; + if (call_dev_type == 0) { + call_dev_name = "llvm"; + call_dev_type = kDLCPU; + } else { + call_dev_name = ::tvm::runtime::DeviceName(call_dev_type); + } + + if (targets_.count(call_dev_type) == 0) { + std::stringstream msg; + msg << "No target is specified for provided device name: `" << call_dev_name << "`\n\n"; + msg << call_dev_name << " mapped to device type (" << call_dev_type + << ") which was not found in the target map.\n"; + msg << "Availible targets: \n"; + for (auto target : targets_) { + msg << " " << target.first << "-> " << target.second << "\n"; + } + LOG(FATAL) << msg.str(); + } + + target = targets_[call_dev_type]; + } + + CCacheKey key = CCacheKey(func, target); + CachedFunc lowered_func = compiler_->Lower(key, module_name_); + + Map prim_fns; + + for (auto prim_fn : lowered_func->funcs->functions) { + CHECK(prim_fn.second.as()) << "must be a prim fn"; + prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); + } + + // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT + relay::Function func_with_metadata = func; + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, "target", lowered_func->target); + + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn(func_with_metadata); + + auto tir_call_attrs = make_object(); + if (func->HasNonzeroAttr(attr::kReshapeOnly)) { + tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); + } + + auto device_copy = IsDeviceCopy(func); + if (std::get<0>(device_copy)) { + auto source_device = std::get<1>(device_copy); + auto dst_device = std::get<2>(device_copy); + tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device)); + tir_call_attrs->metadata.Set("dst_device", tvm::Integer(dst_device)); + } + + tir_call_attrs->metadata.Set("relay_attrs", func->attrs); + + Expr ret_call = Call(lowered_func->prim_fn_var, args, Attrs(tir_call_attrs)); + return ret_call; + } + + IRModule module_; + TargetMap targets_; + DeviceMap device_context_map_; + ProcessFn process_fn; + String module_name_; + TECompiler compiler_; +}; + +/*! + * \brief Obtain the Target from the device type. + * If homogenous compilation, this will return the only target. + * If heteregenous compilation, this will select associated using the targets_ Map. + * + * \param dev_type + * \return Target + */ +Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { + if (targets.size() == 1) { + // homogeneous execution. + const auto& it = targets.begin(); + return (*it).second; + } else { + // heterogeneous execution. + std::string call_dev_name; + if (dev_type == 0) { + call_dev_name = "llvm"; + } else { + call_dev_name = runtime::DeviceName(dev_type); + } + if (targets.count(dev_type) == 0) { + LOG(FATAL) << "No target is provided for device " << call_dev_name; + } + return targets[dev_type]; + } +} + +/*! + * \brief Update the "main" control function's metadata + * + * \param mod The module + * \param targets Map of targets + * \return function_infos Function info for each function in the module + */ + +backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap targets, + Map storage_info_map) { + CHECK_EQ(mod->functions.size(), 1) + << "There should only be one function in the module passed to UpdateMainWorkspaceSize"; + Function func = Downcast(mod->Lookup("main")); + + // This is a Map> + std::unordered_map, EnumClassHash> sid_workspace; + // This is a Map + std::unordered_map device_io; + // This is a Map + std::unordered_map device_consts; + + // Initialize the mapping from all storage identifiers to workspace sizes, + // the amount of device io, and the device constants. + for (const auto& kv : storage_info_map) { + backend::StorageInfo storage_info = kv.second; + std::vector storage_ids = storage_info->storage_ids; + std::vector devices = storage_info->device_types; + + CHECK_EQ(storage_ids.size(), devices.size()); + for (uint32_t i = 0; i < devices.size(); i++) { + sid_workspace[devices[i]][storage_ids[i]] = 0; + device_io[devices[i]] = 0; + device_consts[devices[i]] = 0; + } + } + + // Iterate the storage map to compute all the tensor sizes in the program. + // There are 3 cases in this code: + // + // First we need to compute the sizes of all + // inline constants. + // + // Second we compute the size of any bound variable as these are input and output + // sizes of the program. + // + // Finally for all other expressions we check which storage identifier they have + // been assigned and we compute the maximal size of the storage, as tensors can + // share storage with other tensors which are the same size or larger. + // + // In this final case there is only one allocation for all tensors which share storage + // which will be the maximal size of all tensors which were assigned to it. + for (const auto& kv : storage_info_map) { + Expr expr = kv.first; + int64_t size_bytes = backend::CalculateRelayExprSizeBytes(expr->checked_type()); + backend::StorageInfo storage_info = kv.second; + std::vector storage_ids = storage_info->storage_ids; + std::vector devices = storage_info->device_types; + + if (expr->IsInstance()) { + for (const auto& dev : devices) { + device_consts[dev] += size_bytes; + } + continue; + } else if (expr->IsInstance() || expr.same_as(func->body)) { + CHECK_GE(devices.size(), 1) << "must be at least one device"; + for (const auto& dev : devices) { + device_io[dev] += size_bytes; + } + continue; + } + + // TODO(@electriclilies): This code is never being called which means sid_workspace is not + // updated.. This means that storage info is probably not being created correctly. Or is not + // equivalent to what was here previously + for (uint32_t i = 0; i < storage_ids.size(); i++) { + // Here we record the largest size of the tensor + // that share the same storage id, because storage_id will + // be shared between multiple tensors that are not live simultaneously. + if (size_bytes > sid_workspace[devices[i]][storage_ids[i]]) { + sid_workspace[devices[i]][storage_ids[i]] = size_bytes; + } + } + } + + // This is a Map + std::unordered_map device_workspace; + // Once we know the sizes of sids, we need to accumulate per device + for (const auto& dev_sid_size : sid_workspace) { + auto dev = dev_sid_size.first; + device_workspace[dev] = 0; + for (const auto& sid_size : dev_sid_size.second) { + device_workspace[dev] += sid_size.second; + } + } + + Map workspace_sizes; + Map io_sizes; + Map constant_sizes; + Map tir_primfuncs; + Map relay_primfuncs; + + // Initialize all target workspaces to zero + for (const auto& kv : targets) { + auto tgt = kv.second; + workspace_sizes.Set(tgt, 0); + } + + for (const auto& dev_and_size : device_workspace) { + auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + workspace_sizes.Set(tgt, dev_and_size.second); + relay_primfuncs.Set(tgt, func); + } + for (const auto& dev_and_size : device_io) { + auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + io_sizes.Set(tgt, dev_and_size.second); + } + + for (const auto& dev_and_size : device_consts) { + auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + constant_sizes.Set(tgt, dev_and_size.second); + } + + return backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, tir_primfuncs, + relay_primfuncs); +} + +// TODO(@electriclilies): Is the function passed in here relay_func?? +// Also should this be inlined? +/*! + * \brief A function to create the function metadata for an input function (ie calculate buffer + * input/output sizes) + * \param relay_func The function to calculate function metadata for + * \param function_metadata The map that stores all the function metadatas + */ +void UpdateFunctionMetadata(Function relay_func, + Map& function_metadata) { // NOLINT(*) + // Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored + // there Now the goal is to take only one func because process_fn should be controlling the + // iteration However, to do the workspace calculations we need the primfuncs. So process_fn needs + // to either access the cached funcs or be directly passed primfuncs This is bad and ideally we + // don't want process_fn to look at primfuncs There's also the question now of what the function + // metadatas are and how they are used if we can do something else to replicate the behavior of + // the function metadatas that might be good (ie annotating functions or something). + Map workspace_sizes; + Map io_sizes; + Map constant_sizes; + Map tir_primfuncs; + Map relay_primfuncs; + + Optional> prim_fns = + relay_func->GetAttr>("prim_funcs"); + CHECK(prim_fns) << "primitive functions not set on Relay function by TECompiler."; + + Optional prim_fn_var = relay_func->GetAttr("prim_fn_var"); + CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler."; + + Optional relay_target = relay_func->GetAttr("target"); + CHECK(relay_target) << "target must be set on Relay functions by the TECompiler."; + + for (const auto& kv : prim_fns.value()) { + auto prim_fn = Downcast(kv.second); + CHECK(prim_fn.defined()) << "the primitive function must be defined"; + + auto workspace_byte_alignment = + relay_target.value()->GetAttr("workspace_byte_alignment").value_or(16); + + Integer workspace_size = CalculateWorkspaceBytes(prim_fn, workspace_byte_alignment); + + // Workspace sizes + Target prim_fn_target; + if (prim_fn->attrs->dict.count("target")) { + prim_fn_target = Downcast(prim_fn->attrs->dict["target"]); + } else { + prim_fn_target = relay_target.value(); + } + + workspace_sizes.Set(prim_fn_target, workspace_size); + + // Calculating size for I/O + for (auto const& param : prim_fn->params) { + auto p_shape = prim_fn->buffer_map[param]->shape; + int num_of_elements = 1; + for (const auto& dim_index_expr : p_shape) { + if (dim_index_expr->IsInstance()) { + num_of_elements *= dim_index_expr.as()->value; + } else { + // If shape is dynamic, we cannot calculate workspace in compile time. + num_of_elements = 0; + } + } + int element_size = prim_fn->buffer_map[param]->dtype.bytes(); + io_sizes.Set(prim_fn_target, element_size * num_of_elements); + } + + constant_sizes.Set(prim_fn_target, 0); + tir_primfuncs.Set(prim_fn_target, prim_fn); + relay_primfuncs.Set(prim_fn_target, relay_func); + } + + backend::FunctionInfo fi = backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, + tir_primfuncs, relay_primfuncs); + + // The primitive function name here corresponds to the string we will use to generate + // this Relay function at the low level. + function_metadata.Set(prim_fn_var.value()->name_hint, fi); +} + +LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn) { + TECompiler compiler; + + CHECK_EQ(module->functions.size(), 1) + << "There should only be one function in the module passed to LowerTE"; + + auto pass = CreateFunctionPass( + [=](Function func, IRModule module, PassContext ctx) { + LowerTensorExpr lower_te(module, targets, device_context_map, process_fn, module_name, + compiler); + return Downcast(lower_te.VisitExpr(func)); + }, + 0, "LowerTensorExpr", {}); + + // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize + backend::FunctionInfo func_info = + UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info); + + auto updated_module = pass(module); + + // A temporary solution until we can rewrite the auto-scheduler task extraction code to work + // in a more reasonable way. + if (backend::IsAutoSchedulerEnabled()) { + const auto* te_compiler_update_weights = + runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights"); + + ICHECK(te_compiler_update_weights != nullptr) + << "auto_scheduler.relay_integration.te_compiler_update_weights"; + + Map weight_map; + + for (auto pair : compiler->GetOpWeights()) { + weight_map.Set(pair.first, pair.second); + } + + (*te_compiler_update_weights)(weight_map); + } + + LoweredModule lowered_module; + lowered_module.main_module = updated_module; + lowered_module.per_target_module = compiler->GetLoweredFunctions(); + lowered_module.external_mods = compiler->LowerExternalFunctions(); + lowered_module.main_func_info = func_info; + return lowered_module; +} + +} // namespace tec +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h new file mode 100644 index 000000000000..a32eefb5127f --- /dev/null +++ b/src/relay/backend/te_compiler.h @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/tir_compiler.h + * * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. + * + * + * This represents the new design of the Relay compilation flow and will replace the interface + * contained in compile_engine.h as we migrate towards a standard pass based lowering of + * Relay functions. + * + * This files provides an internal API which lowers Relay programs to components which + * can be combined with TVM produced kernels to compile an entire program. + * + * The result of lowering contains a combination of `runtime::Module`s produced by external + * compilers and a set of lowered PrimFns which can be code generated for targets. + */ +#ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_ +#define TVM_RELAY_BACKEND_TE_COMPILER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../transforms/infer_layout_utils.h" +#include "../transforms/pass_utils.h" +#include "./te_compiler_cache.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +// This class is needed to avoid a GCC 5 bug that prevents maps containing enums +// from being compiled. If i386 GCC version is increased, we can remove it. +struct EnumClassHash { + template + std::size_t operator()(T t) const { + return static_cast(t); + } +}; + +// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake +// we should a version of context which works in Map +using TargetMap = std::unordered_map; +using DeviceMap = + std::unordered_map; +using ProcessFn = std::function; + +/*! + * \brief A compiler which lowers primitive Relay functions to tensor expressions + * and schdules them into TIR functions. + */ +class TECompilerNode : public Object { + public: + /*! \brief destructor */ + virtual ~TECompilerNode() {} + /*! + * \brief Get lowered result. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; + + /*! + * \brief Get lowered result. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; + + /* Return all functions which have been lowered by the compiler, keyed by target. */ + virtual Map GetLoweredFunctions() = 0; + + /*! + * \brief Just in time compile to get a PackedFunc. + * \param key The key to the cached function. + * \return The result. + */ + virtual PackedFunc JIT(const CCacheKey& key) = 0; + /*! + * \brief Lower the shape function. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; + /*! + * \brief Lower the external function using external codegen tools. + * \return The runtime moduels for each needed external codegen tool. + */ + virtual tvm::Array LowerExternalFunctions() = 0; + + virtual std::unordered_map GetOpWeights() = 0; + + /*! \brief clear the cache. */ + virtual void Clear() = 0; + + void VisitAttrs(AttrVisitor*) {} + + static constexpr const char* _type_key = "relay.TECompiler"; + TVM_DECLARE_FINAL_OBJECT_INFO(TECompilerNode, Object); +}; + +/*! \brief cache entry used in compile engine */ +class TECompiler : public ObjectRef { + public: + TECompiler(); + explicit TECompiler(ObjectPtr n) : ObjectRef(n) {} + TECompilerNode* operator->() { return static_cast(get_mutable()); } + using ContainerType = TECompilerNode; +}; + +/*! \brief The result of lowering a module, for now we need to pass an aggregate data structure + * which contains more then a single module in order to interact with the today API. + */ +struct LoweredModule { + /*! \brief The module which contains the Relay code. */ + IRModule main_module; + /*! \brief The module which contains per target code. */ + Map per_target_module; + /*! \brief The external runtime modules which must be combined with the lowered code. */ + Array external_mods; + // TODO(@electriclilies): THis might need to become a map + /*! \brief The info for this function (not sure what a better description is??) + * + */ + backend::FunctionInfo main_func_info; +}; + +/*! + * \brief A function to create the function metadata for an input function (ie calculate buffer + * input/output sizes) + * \param relay_func The function to calculate function metadata for + * \param function_metadata The map that stores all the function metadatas + */ +void UpdateFunctionMetadata(Function relay_func, + Map& function_metadata); // NOLINT(*) + +/*! + * \brief Obtain the Target from the device type. + * If homogenous compilation, this will return the only target. + * If heteregenous compilation, this will select associated using the targets_ Map. + * + * \param dev_type + * \return Target + */ +Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); + +/*! \brief Lower an IRModule's primitive functions to TIR. + * + * This is the "back half" of the Relay compiler which lowers "primitive functions" + * to TE expressions, schedules them, and then to TIR. + * + * /param module The IRModule. + * /param targets The mapping for devices to targets. + * /param device_map An analysis result mapping each sub-expression to a device. + * /return The lowered module, see above. + */ +// TODO(@electriclilies): Not sure if this default initialization is correct... +LoweredModule LowerTE( + const IRModule& module, TargetMap targets, DeviceMap device_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + ProcessFn process_fn = [](Function f) {}); + +} // namespace tec +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_TE_COMPILER_H_ diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc new file mode 100644 index 000000000000..bbe38f0426b4 --- /dev/null +++ b/src/relay/backend/te_compiler_cache.cc @@ -0,0 +1,694 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./te_compiler_cache.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../transforms/pass_utils.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +TVM_REGISTER_NODE_TYPE(LoweredOutputNode); +TVM_REGISTER_NODE_TYPE(CachedFuncNode); +TVM_REGISTER_NODE_TYPE(CCacheKeyNode); +TVM_REGISTER_NODE_TYPE(CCacheValueNode); + +LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { + auto n = make_object(); + n->outputs = std::move(outputs); + n->implementation = std::move(impl); + data_ = std::move(n); +} + +CCacheKey::CCacheKey(Function source_func, Target target) { + auto n = make_object(); + n->source_func = std::move(source_func); + n->target = std::move(target); + data_ = std::move(n); +} + +CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array inputs, + tvm::Array outputs, te::Schedule schedule, + tvm::Array shape_func_param_states, IRModule funcs) { + auto n = make_object(); + n->target = target; + n->prim_fn_var = prim_fn_var; + n->inputs = inputs; + n->outputs = outputs; + n->schedule = schedule; + n->shape_func_param_states = shape_func_param_states; + n->funcs = funcs; + data_ = std::move(n); +} + +Array GetShape(const Array& shape) { + // for now, we always use int32 shape when possible + // even if the result of shape inference becomes int64. + Array res; + for (IndexExpr val : shape) { + const int64_t* pval = tir::as_const_int(val); + if (pval != nullptr) { +#ifndef TVM_INDEX_DEFAULT_I64 + ICHECK_LE(pval[0], std::numeric_limits::max()) + << "dimension must be less then int32_t's max value"; + ICHECK_GE(pval[0], std::numeric_limits::min()) + << "dimension must be less then int32_t's max value"; + res.push_back(IntImm(DataType::Int(32), *pval)); +#else + res.push_back(val); +#endif // TVM_INDEX_DEFAULT_I64 + } else if (val->IsInstance()) { + res.push_back(val.as()->ToVar()); + } else { + res.push_back(val); + } + } + return res; +} + +// Construct a schedule for a given Relay primitive function and target. +class ScheduleBuilder : public backend::MemoizedExprTranslator> { + public: + explicit ScheduleBuilder(Target target) + : target_(target), device_copy_op_(Op::Get("device_copy")) { + // Whether to use auto_scheduler schedule. + use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + } + + CachedFunc Create(const Function& prim_func, std::function renamer) { + Array fn_inputs; + for (Var param : prim_func->params) { + Array inputs; + if (const auto* ttype = param->checked_type().as()) { + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + fn_inputs.push_back(tensor); + inputs.push_back(tensor); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as(); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + // TODO(@icemelon): Allow recursive tuple + ICHECK(ttype != nullptr); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + fn_inputs.push_back(tensor); + inputs.push_back(tensor); + } + } + memo_[param] = inputs; + } + readable_name_stream_ << "fused"; + auto outputs = this->VisitExpr(prim_func->body); + auto candidate_name = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + // NB(@jroesch): unfortunately the graph runtime deals with copy in + // a totally hacky way, we really need to rectify this but this will + // have to work for now. + std::string prim_fn_name = candidate_name; + if (prim_fn_name != "__copy") { + prim_fn_name = renamer(prim_fn_name); + } + auto prim_fn_var = GlobalVar(prim_fn_name); + prim_fn_var->checked_type_ = prim_func->checked_type(); + + ICHECK(anchor_op_.defined()); + // Fusion over tupled results may leave identity relationships + // between inputs and outputs, and those should not be scheduled. + // Hence schedule only non PlaceholderOp outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + + te::Schedule schedule; + // No need to register schedule for device copy op. + if (anchor_attrs_.as() == nullptr) { + if (use_auto_scheduler_) { + const auto* fauto_schedule = + runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); + ICHECK(fauto_schedule != nullptr) + << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; + ObjectRef obj = (*fauto_schedule)(prim_fn_name, tensor_outs); + if (obj.defined()) { + schedule = Downcast(obj); + } + } + + // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. + if (!schedule.defined()) { + ICHECK(anchor_implementation_.defined()); + schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + } + for (const auto& scalar : scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } + } + } + + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {}); + } + + Array VisitExpr_(const VarNode* op) final { + LOG(FATAL) << "Unexpected free variable " << op->name_hint(); + return {}; + } + + Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; + ICHECK(op->is_scalar()); + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); + scalars_.push_back(value->op); + return {value}; + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto fpattern = Op::GetAttrMap("TOpPattern"); + static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); + ICHECK(flower_call) << "relay.backend.lower_call is not registered."; + + Array inputs; + int count_tuple = 0; + for (Expr arg : call_node->args) { + if (arg->checked_type().as()) { + ++count_tuple; + } + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + } + + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) + << "Only functions with a single tuple input are allowed, but " << count_tuple + << " were provided."; + } + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + Array outputs; + OpImplementation impl; + // Skip fcompute for device copy operators as it is not registered. + if (op == device_copy_op_) { + const auto* copy_input = inputs[0].operator->(); + outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); + } else { + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); + outputs = lowered_out->outputs; + impl = lowered_out->implementation; + } + + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + anchor_implementation_ = impl; + } + if (outputs.size() != 1) { + const auto* tuple_type = call_node->checked_type().as(); + ICHECK(tuple_type) << "Expected output to be a tuple type " + << PrettyPrint(call_node->checked_type()); + + ICHECK_EQ(tuple_type->fields.size(), outputs.size()); + } + // Set the name to `__copy`. It will be detected in graph runtime to perform + // data copy across devices. + if (op == device_copy_op_) { + readable_name_stream_.str(std::string()); + readable_name_stream_ << "__copy"; + } else { + readable_name_stream_ << '_' << op->name; + } + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Primitive Functions can not contain nested functions."; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + const auto* tuple_type = op->tuple->type_as(); + Array tuple = VisitExpr(op->tuple); + ICHECK_EQ(tuple_type->fields.size(), tuple.size()); + ICHECK_GE(op->index, 0); + ICHECK_LT(static_cast(op->index), tuple.size()); + return {tuple[op->index]}; + } + + private: + tvm::Target target_; + Op anchor_op_; + Attrs anchor_attrs_; + int anchor_op_pattern_{0}; + OpImplementation anchor_implementation_; + std::ostringstream readable_name_stream_; + Array scalars_; + bool use_auto_scheduler_; + // Cache device copy op for equivalence checking to reduce registry lookup + // overhead for each invocation of call node when retrieving schedules. + const Op& device_copy_op_; +}; + +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc PrimFuncFor(const Function& source_func, const Target& target, + std::function renamer) { + return ScheduleBuilder(target).Create(source_func, renamer); +} + +// Creates shape function from functor. +class MakeShapeFunc : public backend::MemoizedExprTranslator> { + public: + MakeShapeFunc() {} + + CachedFunc Create(const Function& prim_func, const Target& target, + std::function renamer) { + Array inputs; + TShapeDataDependent shape_func_param_states; + + for (auto param : prim_func->params) { + param_states_[param] = kNoNeed; + Array data_inputs; + Array shape_inputs; + + auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { + // Add data placeholder + Shape shape = GetShape(ttype->shape); + tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); + data_inputs.push_back(data_tensor); + // Add shape placeholder + int64_t ndim = shape.size(); + Shape sshape; + if (ndim > 0) { + sshape.push_back(tvm::Integer(ndim)); + } + tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); + shape_inputs.push_back(shape_tensor); + }; + + if (const auto* ttype = param->checked_type().as()) { + add_placeholder(ttype); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as(); + // TODO(@icemelon): Support recursive tuple + ICHECK(tuple_type); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + ICHECK(ttype); + add_placeholder(ttype); + } + } + param_data_[param] = data_inputs; + param_shapes_[param] = shape_inputs; + } + + // Setup the name; + readable_name_stream_ << "shape_func"; + + // Create the `te::Tensor`s which represent the output. + auto outputs = VisitExpr(prim_func->body); + + // Generate a name. + auto candidate_name = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + // Set all the inputs correctly. + for (auto param : prim_func->params) { + int state = param_states_[param]; + shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); + if (state & kNeedInputData) { + for (auto t : param_data_[param]) { + inputs.push_back(t); + } + } + if (state & kNeedInputShape) { + for (auto t : param_shapes_[param]) { + inputs.push_back(t); + } + } + } + + auto func_name = renamer(candidate_name); + auto prim_fn_gvar = GlobalVar(func_name); + prim_fn_gvar->checked_type_ = prim_func->checked_type(); + + // generate schedule for shape func + Array out_ops; + for (auto t : outputs) { + out_ops.push_back(t->op); + } + auto schedule = te::create_schedule(out_ops); + tvm::te::AutoInlineInjective(schedule); + for (const auto& scalar : scalars_) { + auto scalar_op = scalar->op; + if (schedule->Contain(scalar_op)) { + schedule[scalar_op].compute_inline(); + } + } + + Array all_args = Array(inputs); + for (te::Tensor arg : outputs) { + all_args.push_back(arg); + } + + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + + std::unordered_map binds; + IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, binds); + + return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, shape_func_param_states, + ir_module); + } + + Array VisitExpr(const Expr& expr) final { + if (expr.as()) { + // Do not memoize vars because shape functions could use either the data + // or the shape of a var each time. + return ExprFunctor::VisitExpr(expr); + } + // For other case, do memoized visit + return backend::MemoizedExprTranslator>::VisitExpr(expr); + } + + Array VisitExpr_(const VarNode* var_node) final { + auto var = GetRef(var_node); + auto it = param_states_.find(var); + if (it == param_states_.end()) { + LOG(FATAL) << "Unexpected free variable " << var->name_hint(); + return {}; + } else { + ICHECK(data_dependents_per_input_.size()); + auto data_dependent = data_dependents_per_input_.back(); + if (data_dependent) { + param_states_[var] |= kNeedInputData; + return param_data_[var]; + } else { + param_states_[var] |= kNeedInputShape; + return param_shapes_[var]; + } + } + } + + Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; + ICHECK(data_dependents_per_input_.size()); + bool data_dependent = data_dependents_per_input_.back(); + if (!op->is_scalar()) { + // This is a constant weight, extract the shape of the weight tensor. + // This can not be data dependent. + CHECK(!data_dependent); + auto ttype = op->checked_type().as(); + int ndim = static_cast(ttype->shape.size()); + Array out_shape{ndim}; + te::Tensor value = tvm::te::compute( + out_shape, + [&](const Array& indices) { + auto idx = indices[0]; + PrimExpr ret = make_const(DataType::Int(64), 0); + for (int i = 0; i < ndim; i++) { + ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); + } + return ret; + }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } + if (data_dependent) { + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = tvm::te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "data_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } else { + auto value = tvm::te::compute( + {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto fshape_func = Op::GetAttrMap("FShapeFunc"); + static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) + << "Error in op fusion: output of the shape func is fed to a " + << "data-dependent shape func"; + ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; + ICHECK_GT(tshape_data_dependent.count(op), 0) + << "Internal error, cannot find TShapeDataDependent for " << op->name; + + Array dep_spec = tshape_data_dependent[op]; + if (dep_spec.size() == 1) { + // This is for cases when data dependence is specified per op + // Replicate 0 or 1 flag to all arguments + for (size_t i = 1; i < call_node->args.size(); ++i) { + dep_spec.push_back(dep_spec[0]); + } + } + + // Visit all inputs + Array inputs; + int count_tuple = 0; + for (size_t i = 0; i < call_node->args.size(); ++i) { + Expr arg = call_node->args[i]; + if (arg->checked_type().as()) { + ++count_tuple; + } + data_dependents_per_input_.push_back(dep_spec[i]->value != 0); + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + data_dependents_per_input_.pop_back(); + } + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; + } + // Get output ndims + auto ret_type = call_node->checked_type(); + Array out_ndims; + if (const auto* ttype = ret_type.as()) { + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); + } else { + auto rtype = ret_type.as(); + // TODO(@icemelon): Allow recursive tuple + ICHECK(rtype); + for (size_t i = 0; i < rtype->fields.size(); ++i) { + auto ttype = rtype->fields[i].as(); + ICHECK(ttype); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); + } + } + // Call shape function + auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); + readable_name_stream_ << "_" << op->name; + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Do not support sub function"; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) + << "Expected a Tuple of Tensor, but got " << PrettyPrint(field->checked_type()); + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + Array input_shapes = VisitExpr(op->tuple); + Array out; + out.push_back(input_shapes[op->index]); + return out; + } + + private: + /*! \brief String stream for function name */ + std::ostringstream readable_name_stream_; + /*! \brief Map from parameter to its shape function usage state */ + std::unordered_map param_states_; + /*! \brief Map from parameter to list of data placeholder */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; + /*! \brief Map from parameter to list of shape placeholder */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; + /*! \brief Stack of data dependencies for shape function, specified per each op input */ + std::vector data_dependents_per_input_; + /*! \brief Scalars used in the shape function */ + Array scalars_; +}; + +CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, + std::function renamer) { + return MakeShapeFunc().Create(prim_func, target, renamer); +} + +/*! + * \brief Get unique name from name. + * \param name The orginal name. + * \return Updated name which is unique. + */ +std::string GetUniqueName(std::string name, std::unordered_map* name_map_) { + for (size_t i = 0; i < name.length(); ++i) { + if (name[i] == '.') name[i] = '_'; + } + while (true) { + auto it = name_map_->find(name); + if (it == name_map_->end()) { + (*name_map_)[name] = 1; + return name; + } else { + std::ostringstream os; + os << name << "_" << it->second; + ++(it->second); + name = os.str(); + } + } + return name; +} + +} // namespace tec +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h new file mode 100644 index 000000000000..1c7511ffd7d2 --- /dev/null +++ b/src/relay/backend/te_compiler_cache.h @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/tec_compiler_cache.h + * \brief Utilities for compiling tensor expressions inside of the Relay compiler. + */ +#ifndef TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ +#define TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../transforms/infer_layout_utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */ +enum ShapeFuncParamState { + kNoNeed = 0, + kNeedInputData = 1, + kNeedInputShape = 2, + kNeedBoth = 3, +}; + +struct LoweredOutputNode : public Object { + /*! \brief The outputs to the function */ + tvm::Array outputs; + /*! \brief The implementation used to compute the output */ + OpImplementation implementation; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("outputs", &outputs); + v->Visit("implementation", &implementation); + } + + static constexpr const char* _type_key = "relay.LoweredOutput"; + TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); +}; + +class LoweredOutput : public ObjectRef { + public: + TVM_DLL LoweredOutput(tvm::Array outputs, OpImplementation impl); + + TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode); +}; + +class CCacheKey; +/*! \brief Compile cache key */ +class CCacheKeyNode : public Object { + public: + /*! \brief The source function to be lowered. */ + Function source_func; + /*! \brief The hardware target.*/ + Target target; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("source_func", &source_func); + v->Visit("target", &target); + } + /*! \return The hash value of CCacheKey. */ + inline size_t Hash() const; + /*! + * \brief check content equality + * \param other The other value. + * \return The result of equality check. + */ + inline bool Equal(const CCacheKeyNode* other) const; + + static constexpr const char* _type_key = "relay.CCacheKey"; + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); + + private: + /*! + * \brief internal cached hash value. + */ + mutable size_t hash_{0}; +}; + +/*! \brief cache entry used in compile engine */ +class CCacheKey : public ObjectRef { + public: + CCacheKey() {} + explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief The constructor + * \param source_func The source function. + * \param target The target device. + */ + TVM_DLL CCacheKey(Function source_func, Target target); + + const CCacheKeyNode* operator->() const { return static_cast(get()); } + // comparator + inline bool operator==(const CCacheKey& other) const { + ICHECK(defined() && other.defined()); + return (*this)->Equal(other.operator->()); + } + using ContainerType = CCacheKeyNode; +}; + +/*! \brief Node container to represent a cached function. */ +struct CachedFuncNode : public Object { + /* \brief compiled target */ + tvm::Target target; + /*! \brief Primitive Function Name */ + GlobalVar prim_fn_var; + /* \brief The inputs to the function */ + tvm::Array inputs; + /* \brief The outputs to the function */ + tvm::Array outputs; + /*! \brief The schedule to the function */ + te::Schedule schedule; + /*! \brief Parameter usage states in the shape function. */ + tvm::Array shape_func_param_states; + /*! \brief The lowered functions to support the function. */ + IRModule funcs = IRModule(Map({})); + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("target", &target); + v->Visit("prim_fn_var", &prim_fn_var); + v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); + v->Visit("schedule", &schedule); + v->Visit("funcs", &funcs); + v->Visit("shape_func_param_states", &shape_func_param_states); + } + + static constexpr const char* _type_key = "relay.CachedFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); +}; + +class CachedFunc : public ObjectRef { + public: + CachedFunc(tvm::Target target, GlobalVar prim_fn_name, tvm::Array inputs, + tvm::Array outputs, te::Schedule schedule, + tvm::Array shape_func_param_states, + IRModule funcs = IRModule(Map({}))); + + public: + TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); +}; + +/*! \brief Node container for compile cache. */ +class CCacheValueNode : public Object { + public: + /*! \brief The corresponding function */ + CachedFunc cached_func; + /*! \brief Result of Packed function generated by JIT */ + PackedFunc packed_func; + /*! \brief usage statistics */ + int use_count{0}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cached_func", &cached_func); + v->Visit("use_count", &use_count); + } + static constexpr const char* _type_key = "relay.CCacheValue"; + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); +}; + +/*! \brief cache entry used in compile engine */ +class CCacheValue : public ObjectRef { + public: + CCacheValue() {} + explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} + CCacheValueNode* operator->() { return static_cast(get_mutable()); } + const CCacheValueNode* operator->() const { return static_cast(get()); } + using ContainerType = CCacheValueNode; +}; + +Array GetShape(const Array& shape); + +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc PrimFuncFor(const Function& source_func, const Target& target, + std::function renamer); + +CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, + std::function renamer); + +std::string GetUniqueName(std::string name, std::unordered_map* name_map); + +// implementations +inline size_t CCacheKeyNode::Hash() const { + if (hash_ != 0) return hash_; + // do structral hash, avoid 0. + hash_ = tvm::StructuralHash()(this->source_func); + hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); + if (hash_ == 0) hash_ = 1; + return hash_; +} + +inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { + if (Hash() != other->Hash()) return false; + return this->target->str() == other->target->str() && + tvm::StructuralEqual()(this->source_func, other->source_func); +} + +} // namespace tec +} // namespace relay +} // namespace tvm + +namespace std { +// overload hash +template <> +struct hash<::tvm::relay::tec::CCacheKey> { + size_t operator()(const ::tvm::relay::tec::CCacheKey& key) const { + ICHECK(key.defined()); + return key->Hash(); + } +}; +} // namespace std + +#endif // TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 3ea15438fe8f..f0c543f1244b 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -39,6 +39,30 @@ StorageInfo::StorageInfo(std::vector storage_ids, std::vector ids; + for (auto id : si->storage_ids) { + ids.push_back(id); + } + return ids; +}); + +TVM_REGISTER_GLOBAL("relay.ir.StorageInfoDeviceTypes").set_body_typed([](StorageInfo si) { + Array device_types; + for (auto id : si->device_types) { + device_types.push_back(id); + } + return device_types; +}); + +TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageSizes").set_body_typed([](StorageInfo si) { + Array storage_sizes_in_bytes; + for (auto id : si->storage_sizes_in_bytes) { + storage_sizes_in_bytes.push_back(id); + } + return storage_sizes_in_bytes; +}); + TVM_REGISTER_NODE_TYPE(StaticMemoryPlanNode); StaticMemoryPlan::StaticMemoryPlan(Map expr_to_storage_info) { @@ -73,6 +97,29 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { TVM_REGISTER_NODE_TYPE(FunctionInfoNode); +FunctionInfo::FunctionInfo(Map workspace_sizes, Map io_sizes, + Map constant_sizes, + Map tir_primfuncs, + Map relay_primfuncs) { + ObjectPtr n = make_object(); + n->workspace_sizes = std::move(workspace_sizes); + n->io_sizes = std::move(io_sizes); + n->constant_sizes = std::move(constant_sizes); + n->tir_primfuncs = std::move(tir_primfuncs); + n->relay_primfuncs = std::move(relay_primfuncs); + data_ = std::move(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FunctionInfoNode(\n" + << "workspace_sizes=" << node->workspace_sizes << ",\n io_sizes=" << node->io_sizes + << ",\n constant_sizes=" << node->constant_sizes + << ",\n tir_primfuncs=" << node->tir_primfuncs + << ",\n relay_primfuncs=" << node->relay_primfuncs << ")"; + }); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 7d7f026c298e..d2a173a43f46 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -114,6 +114,10 @@ struct FunctionInfoNode : public Object { class FunctionInfo : public ObjectRef { public: + FunctionInfo(Map workspace_sizes, Map io_sizes, + Map constant_sizes, Map tir_primfuncs, + Map relay_primfuncs); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FunctionInfo, ObjectRef, FunctionInfoNode); }; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index c50f2f65f949..96aa77f286a9 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -978,7 +978,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe // update primitive function map size_t primitive_index = 0; for (const auto& cfunc : context_.cached_funcs) { - exec_->primitive_map.insert({cfunc->func_name, primitive_index++}); + exec_->primitive_map.insert({cfunc->prim_fn_var->name_hint, primitive_index++}); } } @@ -1173,8 +1173,9 @@ void VMCompiler::Codegen() { if (target->kind->device_type == kDLExtDev) { // Collect metadata in functions that are handled by external codegen. - ICHECK(mod->ContainGlobalVar(cfunc->func_name)); - Function func = Downcast(mod->Lookup(cfunc->func_name)); + auto name = cfunc->prim_fn_var->name_hint; + ICHECK(mod->ContainGlobalVar(name)); + Function func = Downcast(mod->Lookup(name)); backend::UpdateConstants(func, ¶ms_); } else if (funcs.count(target) == 0) { funcs.Set(target, mod); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index c9920a621b56..83ac55fce085 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -62,9 +62,17 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body - << ", " << node->type_params << ", " << node->attrs << ")"; + // TODO(@jroesch): previously this had a debug printer, the debug printer + // can cause exponential behavior and is currently dangerous, for these + // cases we need some kind of de-duping. + // + // See old implementation: + // + // auto* node = static_cast(ref.get()); + // p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << + // node->body + // << ", " << node->type_params << ", " << node->attrs << ")"; + p->stream << PrettyPrint(ref); }); } // namespace relay diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index da0bd35a332a..7a86af8aeffa 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -126,7 +126,7 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - CreateSchedule(GetRef(func), Target::Current()); + PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index e744fb51e0a6..02f9d474411a 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -53,7 +53,18 @@ bool IsOnDeviceNode(const ExprNode* node) { bool IsDeviceCopyNode(const ExprNode* node) { if (!node->IsInstance()) return false; const auto* call_node = static_cast(node); - return call_node->attrs.as(); + + if (call_node->attrs.as()) { + return true; + } + + auto tir_call_attrs = call_node->attrs.as(); + if (tir_call_attrs) { + auto metadata = tir_call_attrs->metadata; + return metadata.count("source_device") == 1 && metadata.count("dst_device") == 1; + } + + return false; } } // namespace @@ -395,16 +406,31 @@ class DeviceInfo { const auto* call_node = static_cast(node); auto attrs = call_node->attrs.as(); - num_device_copy_ops_++; - dev_type_ = attrs->src_dev_type; - for (auto& arg : call->args) { - Visit(arg); - // restore the type for remaining arguments + if (attrs) { + num_device_copy_ops_++; dev_type_ = attrs->src_dev_type; + for (auto& arg : call->args) { + Visit(arg); + // restore the type for remaining arguments + dev_type_ = attrs->src_dev_type; + } + device_tag_[call] = attrs->dst_dev_type; + // update the out_dev_type_, which should be the dst_dev_type of last copy + out_dev_type_ = attrs->dst_dev_type; + } else { + auto attrs = call_node->attrs.as(); + CHECK(attrs) << "must be non-null"; + num_device_copy_ops_++; + dev_type_ = Downcast(attrs->metadata["source_device"]); + for (auto& arg : call->args) { + Visit(arg); + // restore the type for remaining arguments + dev_type_ = Downcast(attrs->metadata["source_device"]); + } + device_tag_[call] = Downcast(attrs->metadata["dst_device"]); + // update the out_dev_type_, which should be the dst_dev_type of last copy + out_dev_type_ = Downcast(attrs->metadata["dst_device"]); } - device_tag_[call] = attrs->dst_dev_type; - // update the out_dev_type_, which should be the dst_dev_type of last copy - out_dev_type_ = attrs->dst_dev_type; } else { for (auto& arg : call->args) { int cur_dev_type = dev_type_; diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 03473b7d7455..b61567d0bae0 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,7 @@ #include "../backend/compile_engine.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" +#include "./pass_utils.h" #include "let_list.h" #include "pattern_utils.h" @@ -66,9 +68,18 @@ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dt // Check if the primitive function contains only reshape ops. bool IsReshapeOnly(const Expr& expr) { - if (auto* func = expr.as()) { + if (const FunctionNode* func = expr.as()) { return func->HasNonzeroAttr(attr::kReshapeOnly); } + if (const CallNode* call = expr.as()) { + if (call->attrs.defined()) { + if (auto tir_call_attrs = call->attrs.as()) { + Map metadata = tir_call_attrs->metadata; + return metadata.count(attr::kReshapeOnly) && + (Downcast(metadata[attr::kReshapeOnly])->value == 1); + } + } + } return false; } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 4c6013792426..f29087dcc049 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -205,8 +205,13 @@ class TypeInferencer : private ExprFunctor, this->EmitFatal(Diagnostic::Error(op->span) << "Cannot do type inference on global variables " << "without a module"); } - relay::Function e = Downcast(mod_->Lookup(var)); - return e->checked_type(); + + if (mod_->ContainGlobalVar(var->name_hint)) { + relay::Function e = Downcast(mod_->Lookup(var)); + return e->checked_type(); + } else { + return op->checked_type_; + } } Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); } diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 24fb3dc95819..15a1493b8585 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -223,8 +223,12 @@ class LLVMModuleNode final : public runtime::ModuleNode { found_linked_params = true; continue; } - ICHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey(); + if (!kv.second->IsInstance()) { + // (@jroesch): we relax constraints here, Relay functions will just be ignored. + DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " + << kv.second->GetTypeKey(); + continue; + } auto f = Downcast(kv.second); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()); @@ -234,7 +238,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { } funcs.push_back(f); } - ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params)); + // TODO(@jroesch): follow up on this condition. + // ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params)); // TODO(tqchen): remove the entry function behavior as it does not // makes sense when we start to use multiple modules. cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index 424da1e817b6..cb2b50260326 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -109,7 +109,7 @@ Pass LegalizePackedCalls() { inputs[i] = true; } n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body)); - return std::move(f); + return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {}); } diff --git a/tests/python/relay/test_auto_scheduler_task_extraction.py b/tests/python/relay/test_auto_scheduler_task_extraction.py index cfbca40cf379..0acd8b0c87d6 100644 --- a/tests/python/relay/test_auto_scheduler_task_extraction.py +++ b/tests/python/relay/test_auto_scheduler_task_extraction.py @@ -101,6 +101,7 @@ def test_task_extraction_cuda(): mod, params = get_network("mlp") tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) + assert len(tasks) == 1 assert sum(task_weights) == 2 diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index 4ec1c21467fc..e7040f55f631 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -130,22 +130,22 @@ def test_plan_memory(): mod = relay.transform.FuseOps(0)(mod) func = mod["main"] mod = relay.transform.InferType()(mod) - smap = relay.backend._backend.GraphPlanMemory(func) + memory_plan = relay.backend._backend.GraphPlanMemory(func) storage_ids = set() device_types = set() storage_sizes = {} - for k, v in smap.items(): - assert len(v) == 3 - for x in v[0]: - storage_ids.add(x.value) - storage_sizes[x.value] = v[2] - for x in v[1]: - device_types.add(x.value) + + for k, v in memory_plan.expr_to_storage_info.items(): + for x in v.storage_ids: + storage_ids.add(x) + storage_sizes[x] = v.storage_sizes + for x in v.device_types: + device_types.add(x) # Current rule requires vars have unique storage id # because we don't do inplace, we will need another # two alternating temporary space. - assert len(storage_ids) == 4 + assert len(storage_ids) == 4, f"found storage_ids: {storage_ids}" assert len(device_types) == 1 assert len(storage_sizes) == 4 @@ -288,11 +288,4 @@ def test_graph_executor_nested_tuples(): if __name__ == "__main__": - test_reshape_nop() - test_plan_memory() - test_with_params() - test_add_op_scalar() - test_add_op_tensor() - test_add_op_broadcast() - test_gru_like() - test_compile_nested_tuples() + sys.exit(pytest.main([file] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index f0949ab19f9c..c33bd5792242 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -42,14 +42,14 @@ def check_graph_executor( target, ref_res, device, func, params, config, opt_level, expected_index=None ): with tvm.transform.PassContext(opt_level=opt_level, config=config): - graph, lib, new_params = relay.build(func, target, params=params) + graph_executor_factory = relay.build(func, target, params=params) + contexts = [tvm.cpu(0), tvm.device(device)] - graph_json = json.loads(graph) + graph_json = json.loads(graph_executor_factory.graph_json) if "device_index" in graph_json["attrs"]: device_index = graph_json["attrs"]["device_index"][1] assert device_index == expected_index - mod = graph_executor.create(graph, lib, contexts) - mod.set_input(**new_params) + mod = graph_executor.GraphModule(graph_executor_factory["default"](*contexts)) mod.run() res = mod.get_output(0).numpy() tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) @@ -272,12 +272,14 @@ def check_storage_and_device_types(): smap = relay.backend._backend.GraphPlanMemory(func) storage_ids = [] device_types = [] - for _, storage_dev_type in smap.items(): - assert len(storage_dev_type) == 3 - for sid in storage_dev_type[0]: + for _, storage_info in smap.expr_to_storage_info.items(): + + for sid in storage_info.storage_ids: storage_ids.append(sid.value) - for did in storage_dev_type[1]: + + for did in storage_info.device_types: device_types.append(did.value) + assert len(storage_ids) == 10 assert len(set(storage_ids)) == 8 assert len(set(device_types)) == 2 @@ -350,16 +352,16 @@ def expected(): assert tvm.ir.structural_equal(annotated_expr, expected_expr) smap = relay.backend._backend.GraphPlanMemory(annotated_expr) - for expr, storage_dev_type in smap.items(): + for expr, storage_info in smap.expr_to_storage_info.items(): # x is dev1 as output is dev1 if isinstance(expr, tvm.relay.expr.Var): - assert storage_dev_type[1][0] == dev1.device_type + assert storage_info.device_types[0] == dev1.device_type else: # device_copy op should be its dst_dev_type if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs): - assert storage_dev_type[1][0] == expr.attrs.dst_dev_type + assert storage_info.device_types[0] == expr.attrs.dst_dev_type else: - assert storage_dev_type[1][0] == expected_dev_type[expr.op.name].device_type + assert storage_info.device_types[0] == expected_dev_type[expr.op.name].device_type def run_fusible_network(dev, tgt):