diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 819aafa0281cd..4e2e18375e224 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -97,7 +97,7 @@ struct AttrError : public dmlc::Error { * \brief constructor * \param msg error message */ - explicit AttrError(const std::string& msg) : dmlc::Error(msg) {} + explicit AttrError(std::string msg) : dmlc::Error("AttributeError:" + msg) {} }; /*! diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index a825b95294e97..f0399ef38b7ec 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -71,8 +71,8 @@ namespace transform { // Forward declare for TraceFunc. class PassInfo; -/*! \brief A callback for tracing passes, useful for debugging and logging. - * +/*! + * \brief A callback for tracing passes, useful for debugging and logging. */ using TraceFunc = runtime::TypedPackedFunc; @@ -96,19 +96,55 @@ class PassContextNode : public Object { int fallback_device{static_cast(kDLCPU)}; /*! \brief The list of required passes. */ - Array required_pass; + Array required_pass; /*! \brief The list of disabled passes. */ - Array disabled_pass; - + Array disabled_pass; + /*! \brief Trace function to be invoked before and after each pass. */ TraceFunc trace_func; + /*! \brief Pass specific configurations. */ + Map config; + PassContextNode() = default; + /*! + * \brief Get a config value from the pass context. + * + * \param key The config key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef. + * + * \endcode + */ + template + Optional GetConfig(const std::string& key, Optional default_value = + Optional(nullptr)) const { + static_assert(std::is_base_of::value, + "Can only call GetAttr with ObjectRef types."); + if (!config.defined()) return default_value; + auto it = config.find(key); + if (it != config.end()) { + return Downcast>((*it).second); + } else { + return default_value; + } + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetConfig(const std::string& key, TObjectRef default_value) const { + return GetConfig(key, Optional(default_value)); + } + void VisitAttrs(AttrVisitor* v) { v->Visit("opt_level", &opt_level); v->Visit("fallback_device", &fallback_device); v->Visit("required_pass", &required_pass); v->Visit("disabled_pass", &disabled_pass); + v->Visit("config", &config); } static constexpr const char* _type_key = "transform.PassContext"; @@ -150,6 +186,7 @@ class PassContext : public ObjectRef { CHECK(get() != nullptr); return static_cast(get_mutable()); } + /*! * \brief Construct a PassContext containing the default configurations. * \return The new PassContext. @@ -169,6 +206,20 @@ class PassContext : public ObjectRef { */ TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + /*! + * \brief Register a valid configuration option and its ValueType for validation. + * + * \param key The configuration key. + * \tparam ValueNodeType The value type to be registered + */ + template + static uint32_t RegisterConfigOption(const char* key) { + // NOTE: we could further update the function later. + uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); + RegisterConfigOption(key, tindex); + return tindex; + } + // accessor. using ContainerType = PassContextNode; class Internal; @@ -178,12 +229,26 @@ class PassContext : public ObjectRef { TVM_DLL void EnterWithScope(); // The exit of a pass context scope. TVM_DLL void ExitWithScope(); + // Register configuration key value type. + TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index); // Classes to get the Python `with` like syntax. friend class Internal; friend class With; }; +#define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid + +/*! + * \brief Helper macro to register the object type to runtime. + * Makes sure that the runtime type table is correctly populated. + * + * Use this macro in the cc file for each terminal class. + */ +#define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \ + TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \ + ::tvm::transform::PassContext::RegisterConfigOption(Key) + /*! * \brief Meta data that will be used to help optimization and analysis. * \sa PassInfo diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 59e31897f967a..e97c5fd48bba7 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -146,6 +146,23 @@ class ReflectionVTable { */ TVM_DLL ObjectPtr CreateInitObject(const std::string& type_key, const std::string& repr_bytes = "") const; + /*! + * \brief Create an object by giving kwargs about its fields. + * + * \param type_key The type key. + * \param kwargs the arguments in format key1, value1, ..., key_n, value_n. + * \return The created object. + */ + TVM_DLL ObjectRef CreateObject(const std::string& type_key, const runtime::TVMArgs& kwargs); + /*! + * \brief Create an object by giving kwargs about its fields. + * + * \param type_key The type key. + * \param kwargs The field arguments. + * \return The created object. + */ + TVM_DLL ObjectRef CreateObject(const std::string& type_key, + const Map& kwargs); /*! * \brief Get an field object by the attr name. * \param self The pointer to the object. diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index af0be45b624e2..eb57e3404c3b9 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -70,13 +70,17 @@ class PassContext(tvm.runtime.Object): disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are disabled. + + config : Optional[Dict[str, Object]] + Additional configurations for specific passes. """ def __init__(self, opt_level=2, fallback_device=_nd.cpu(), required_pass=None, disabled_pass=None, - trace=None): + trace=None, + config=None): if isinstance(fallback_device, str): fallback_device = _nd.context(fallback_device).device_type elif isinstance(fallback_device, tvm.runtime.TVMContext): @@ -97,7 +101,7 @@ def __init__(self, self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level, fallback_device, required, - disabled, trace) + disabled, trace, config) def __enter__(self): _ffi_transform_api.EnterPassContext(self) diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 59e0c1c852764..1dbad1ac7e165 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -28,12 +28,11 @@ #include #include -// TODO(tqchen): Update to use String container after it is merged. -#include - #include #include +#include "../runtime/object_internal.h" + namespace tvm { namespace transform { @@ -75,6 +74,70 @@ PassContext PassContext::Current() { } } +class PassConfigManager { + public: + void Register(std::string key, uint32_t value_type_index) { + CHECK_EQ(key2vtype_.count(key), 0U); + ValueTypeInfo info; + info.type_index = value_type_index; + info.type_key = runtime::Object::TypeIndex2Key(value_type_index); + key2vtype_[key] = info; + } + + // Trying to validate and legalize a config. + void Legalize(Map* config) { + std::vector> update; + auto* reflection = ReflectionVTable::Global(); + + for (auto kv : *config) { + auto it = key2vtype_.find(kv.first); + if (it == key2vtype_.end()) { + std::ostringstream os; + os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:"; + int counter = 0; + for (const auto& kv : key2vtype_) { + os << ' '; + if (counter++ != 0) os << ','; + os << kv.first; + } + LOG(FATAL) << os.str(); + } + const auto& info = it->second; + CHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None"; + if (kv.second->IsInstance::ContainerType>()) { + ObjectRef converted = reflection->CreateObject( + info.type_key, Downcast>(kv.second)); + update.emplace_back(kv.first, converted); + } else { + if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) { + LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type " + << info.type_key << " but get " << kv.second->GetTypeKey(); + } + } + } + for (auto&& kv : update) { + config->Set(kv.first, kv.second); + } + } + + static PassConfigManager* Global() { + static auto* inst = new PassConfigManager(); + return inst; + } + + private: + struct ValueTypeInfo { + std::string type_key; + uint32_t type_index; + }; + + std::unordered_map key2vtype_; +}; + +void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) { + PassConfigManager::Global()->Register(key, value_type_index); +} + PassContext PassContext::Create() { return PassContext(make_object()); } void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { @@ -390,20 +453,23 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_REGISTER_GLOBAL("transform.PassContext").set_body([](TVMArgs args, TVMRetValue* ret) { - auto pctx = PassContext::Create(); - int opt_level = args[0]; - int fallback_device = args[1]; - tvm::Array required = args[2]; - tvm::Array disabled = args[3]; - TraceFunc trace_func = args[4]; - pctx->opt_level = opt_level; - pctx->fallback_device = fallback_device; - pctx->required_pass = std::move(required); - pctx->disabled_pass = std::move(disabled); - pctx->trace_func = std::move(trace_func); - *ret = pctx; -}); +TVM_REGISTER_GLOBAL("transform.PassContext") + .set_body_typed([](int opt_level, int fallback_device, Array required, + Array disabled, TraceFunc trace_func, + Optional> config) { + auto pctx = PassContext::Create(); + pctx->opt_level = opt_level; + pctx->fallback_device = fallback_device; + + pctx->required_pass = std::move(required); + pctx->disabled_pass = std::move(disabled); + pctx->trace_func = std::move(trace_func); + if (config.defined()) { + pctx->config = config.value(); + } + PassConfigManager::Global()->Legalize(&(pctx->config)); + return pctx; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -413,17 +479,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "\topt_level: " << node->opt_level << "\n"; p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n"; - p->stream << "\trequired passes: [" << node->opt_level; + p->stream << "\trequired passes: ["; for (const auto& it : node->required_pass) { p->stream << it << " "; } p->stream << "]\n"; - p->stream << "\tdisabled passes: [" << node->opt_level; + p->stream << "\tdisabled passes: ["; for (const auto& it : node->disabled_pass) { p->stream << it << " "; } - p->stream << "]"; + p->stream << "]\n"; + p->stream << "\tconfig: " << node->config; }); class PassContext::Internal { diff --git a/src/node/reflection.cc b/src/node/reflection.cc index c3397e7500c17..afe795f75c0f0 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -195,14 +195,13 @@ class NodeAttrSetter : public AttrVisitor { } }; -void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { +void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const TVMArgs& args) { NodeAttrSetter setter; setter.type_key = n->GetTypeKey(); CHECK_EQ(args.size() % 2, 0); for (int i = 0; i < args.size(); i += 2) { setter.attrs.emplace(args[i].operator std::string(), args[i + 1]); } - auto* reflection = ReflectionVTable::Global(); reflection->VisitAttrs(n, &setter); if (setter.attrs.size() != 0) { @@ -215,6 +214,35 @@ void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { } } +ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const TVMArgs& kwargs) { + ObjectPtr n = this->CreateInitObject(type_key); + if (n->IsInstance()) { + static_cast(n.get())->InitByPackedArgs(kwargs); + } else { + InitNodeByPackedArgs(this, n.get(), kwargs); + } + return ObjectRef(n); +} + +ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, + const Map& kwargs) { + // Redirect to the TVMArgs version + // It is not the most efficient way, but CreateObject is not meant to be used + // in a fast code-path and is mainly reserved as a flexible API for frontends. + std::vector values(kwargs.size() * 2); + std::vector tcodes(kwargs.size() * 2); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + int index = 0; + + for (auto& kv : static_cast(kwargs.get())->data) { + setter(index, kv.first); + setter(index + 1, kv.second); + index += 2; + } + + return CreateObject(type_key, runtime::TVMArgs(values.data(), tcodes.data(), kwargs.size() * 2)); +} + // Expose to FFI APIs. void NodeGetAttr(TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kTVMObjectHandle); @@ -246,14 +274,7 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) { std::string type_key = args[0]; std::string empty_str; TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1); - auto* reflection = ReflectionVTable::Global(); - ObjectPtr n = reflection->CreateInitObject(type_key); - if (n->IsInstance()) { - static_cast(n.get())->InitByPackedArgs(kwargs); - } else { - InitNodeByPackedArgs(n.get(), kwargs); - } - *rv = ObjectRef(n); + *rv = ReflectionVTable::Global()->CreateObject(type_key, kwargs); } TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr); diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index d56046cfde3ce..35642fbb731b4 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -46,6 +46,15 @@ class ObjectInternal { static_cast(obj)->DecRef(); } } + /*! + * \brief Check of obj derives from the type indicated by type index. + * \param obj The original object. + * \param type_index The type index of interest. + * \return The derivation checking result. + */ + static bool DerivedFrom(const Object* obj, uint32_t type_index) { + return obj->DerivedFrom(type_index); + } /*! * \brief Expose TypeKey2Index * \param type_key The original type key. diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index a69ccc59adf31..8378a88d80fe0 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -39,6 +39,31 @@ namespace tvm { namespace tir { +struct LoopUnrollConfig : public tvm::AttrsNode { + int auto_max_step; + int auto_max_depth; + int auto_max_extent; + int explicit_unroll; + + TVM_DECLARE_ATTRS(LoopUnrollConfig, "tir.transform.LoopUnrollConfig") { + TVM_ATTR_FIELD(auto_max_step) + .describe("Threshold of number of steps in the loop to be automatically unrolled") + .set_default(0); + TVM_ATTR_FIELD(auto_max_depth) + .describe("The maximum nested level of loops that can be automatically unrolled.") + .set_default(8); + TVM_ATTR_FIELD(auto_max_extent) + .describe("The maximum extent of loop that will be unrolled.") + .set_default(0); + TVM_ATTR_FIELD(explicit_unroll) + .describe("Whether to explicitly unroll the loop instead of setting a pragma") + .set_default(true); + } +}; + +TVM_REGISTER_NODE_TYPE(LoopUnrollConfig); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", LoopUnrollConfig); + class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index 975192293d879..b10951691715d 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import pytest from tvm import te def test_const_saveload_json(): @@ -101,6 +102,34 @@ def test_string(): tvm.ir.assert_structural_equal(s1, s2) +def test_pass_config(): + cfg = tvm.transform.PassContext(opt_level=1, config={ + "tir.UnrollLoop": { + "auto_max_step": 10, + } + }) + cfg.opt_level == 1 + + assert cfg.config["tir.UnrollLoop"].auto_max_step == 10 + # default option + assert cfg.config["tir.UnrollLoop"].explicit_unroll == True + + # schema checking for specific config key + with pytest.raises(AttributeError): + cfg = tvm.transform.PassContext(config={ + "tir.UnrollLoop": { "invalid": 1 } + }) + + # schema check for un-registered config + with pytest.raises(AttributeError): + cfg = tvm.transform.PassContext(config={ "inavlid-opt": True }) + + # schema check for wrong type + with pytest.raises(AttributeError): + cfg = tvm.transform.PassContext(config={ + "tir.UnrollLoop": 1 + }) + if __name__ == "__main__": test_string() test_env_func() @@ -108,3 +137,4 @@ def test_string(): test_make_smap() test_const_saveload_json() test_make_sum() + test_pass_config()