From 72fa41173c500199688b1c61f786d5de5e9b2324 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 19 May 2020 15:54:46 -0700 Subject: [PATCH] [NODE][PASS] Introduce config to PassContext. This PR introduces a new config field to the PassContext to allow it store arbitary config values. To make sure that the config is validated, we allow each pass to register the config key they would expect and the corresponding types. We also introduce a CreateObject from Map to allow config creation from a json-nest(like in vscode) in python. We added an example of UnrollLoopConfig. Followup PR should migrate the passes to use the new config field. --- include/tvm/ir/attrs.h | 2 +- include/tvm/ir/transform.h | 73 +++++++++++- include/tvm/node/reflection.h | 17 +++ python/tvm/ir/transform.py | 8 +- src/ir/transform.cc | 107 ++++++++++++++---- src/node/reflection.cc | 41 +++++-- src/runtime/object_internal.h | 9 ++ src/tir/transforms/unroll_loop.cc | 25 ++++ tests/python/unittest/test_node_reflection.py | 30 +++++ 9 files changed, 274 insertions(+), 38 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 819aafa0281cd..4e2e18375e224 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -97,7 +97,7 @@ struct AttrError : public dmlc::Error { * \brief constructor * \param msg error message */ - explicit AttrError(const std::string& msg) : dmlc::Error(msg) {} + explicit AttrError(std::string msg) : dmlc::Error("AttributeError:" + msg) {} }; /*! diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index a825b95294e97..dc29b821a5d7f 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -71,8 +71,8 @@ namespace transform { // Forward declare for TraceFunc. class PassInfo; -/*! \brief A callback for tracing passes, useful for debugging and logging. - * +/*! + * \brief A callback for tracing passes, useful for debugging and logging. */ using TraceFunc = runtime::TypedPackedFunc; @@ -96,19 +96,53 @@ class PassContextNode : public Object { int fallback_device{static_cast(kDLCPU)}; /*! \brief The list of required passes. */ - Array required_pass; + Array required_pass; /*! \brief The list of disabled passes. */ - Array disabled_pass; - + Array disabled_pass; + /*! \brief Trace function to be invoked before and after each pass. */ TraceFunc trace_func; + /*! \brief Pass specific configurations. */ + Map config; + PassContextNode() = default; + /*! + * \brief Get a config value from the pass context. + * + * \param key The config key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef. + */ + template + Optional GetConfig(const std::string& key, Optional default_value = + Optional(nullptr)) const { + static_assert(std::is_base_of::value, + "Can only call GetAttr with ObjectRef types."); + if (!config.defined()) return default_value; + auto it = config.find(key); + if (it != config.end()) { + return Downcast>((*it).second); + } else { + return default_value; + } + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetConfig(const std::string& key, TObjectRef default_value) const { + return GetConfig(key, Optional(default_value)); + } + void VisitAttrs(AttrVisitor* v) { v->Visit("opt_level", &opt_level); v->Visit("fallback_device", &fallback_device); v->Visit("required_pass", &required_pass); v->Visit("disabled_pass", &disabled_pass); + v->Visit("config", &config); } static constexpr const char* _type_key = "transform.PassContext"; @@ -150,6 +184,7 @@ class PassContext : public ObjectRef { CHECK(get() != nullptr); return static_cast(get_mutable()); } + /*! * \brief Construct a PassContext containing the default configurations. * \return The new PassContext. @@ -169,6 +204,20 @@ class PassContext : public ObjectRef { */ TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + /*! + * \brief Register a valid configuration option and its ValueType for validation. + * + * \param key The configuration key. + * \tparam ValueNodeType The value type to be registered + */ + template + static uint32_t RegisterConfigOption(const char* key) { + // NOTE: we could further update the function later. + uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); + RegisterConfigOption(key, tindex); + return tindex; + } + // accessor. using ContainerType = PassContextNode; class Internal; @@ -178,12 +227,26 @@ class PassContext : public ObjectRef { TVM_DLL void EnterWithScope(); // The exit of a pass context scope. TVM_DLL void ExitWithScope(); + // Register configuration key value type. + TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index); // Classes to get the Python `with` like syntax. friend class Internal; friend class With; }; +#define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid + +/*! + * \brief Helper macro to register the object type to runtime. + * Makes sure that the runtime type table is correctly populated. + * + * Use this macro in the cc file for each terminal class. + */ +#define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \ + TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \ + ::tvm::transform::PassContext::RegisterConfigOption(Key) + /*! * \brief Meta data that will be used to help optimization and analysis. * \sa PassInfo diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 59e31897f967a..e97c5fd48bba7 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -146,6 +146,23 @@ class ReflectionVTable { */ TVM_DLL ObjectPtr CreateInitObject(const std::string& type_key, const std::string& repr_bytes = "") const; + /*! + * \brief Create an object by giving kwargs about its fields. + * + * \param type_key The type key. + * \param kwargs the arguments in format key1, value1, ..., key_n, value_n. + * \return The created object. + */ + TVM_DLL ObjectRef CreateObject(const std::string& type_key, const runtime::TVMArgs& kwargs); + /*! + * \brief Create an object by giving kwargs about its fields. + * + * \param type_key The type key. + * \param kwargs The field arguments. + * \return The created object. + */ + TVM_DLL ObjectRef CreateObject(const std::string& type_key, + const Map& kwargs); /*! * \brief Get an field object by the attr name. * \param self The pointer to the object. diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index af0be45b624e2..eb57e3404c3b9 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -70,13 +70,17 @@ class PassContext(tvm.runtime.Object): disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are disabled. + + config : Optional[Dict[str, Object]] + Additional configurations for specific passes. """ def __init__(self, opt_level=2, fallback_device=_nd.cpu(), required_pass=None, disabled_pass=None, - trace=None): + trace=None, + config=None): if isinstance(fallback_device, str): fallback_device = _nd.context(fallback_device).device_type elif isinstance(fallback_device, tvm.runtime.TVMContext): @@ -97,7 +101,7 @@ def __init__(self, self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level, fallback_device, required, - disabled, trace) + disabled, trace, config) def __enter__(self): _ffi_transform_api.EnterPassContext(self) diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 59e0c1c852764..1dbad1ac7e165 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -28,12 +28,11 @@ #include #include -// TODO(tqchen): Update to use String container after it is merged. -#include - #include #include +#include "../runtime/object_internal.h" + namespace tvm { namespace transform { @@ -75,6 +74,70 @@ PassContext PassContext::Current() { } } +class PassConfigManager { + public: + void Register(std::string key, uint32_t value_type_index) { + CHECK_EQ(key2vtype_.count(key), 0U); + ValueTypeInfo info; + info.type_index = value_type_index; + info.type_key = runtime::Object::TypeIndex2Key(value_type_index); + key2vtype_[key] = info; + } + + // Trying to validate and legalize a config. + void Legalize(Map* config) { + std::vector> update; + auto* reflection = ReflectionVTable::Global(); + + for (auto kv : *config) { + auto it = key2vtype_.find(kv.first); + if (it == key2vtype_.end()) { + std::ostringstream os; + os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:"; + int counter = 0; + for (const auto& kv : key2vtype_) { + os << ' '; + if (counter++ != 0) os << ','; + os << kv.first; + } + LOG(FATAL) << os.str(); + } + const auto& info = it->second; + CHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None"; + if (kv.second->IsInstance::ContainerType>()) { + ObjectRef converted = reflection->CreateObject( + info.type_key, Downcast>(kv.second)); + update.emplace_back(kv.first, converted); + } else { + if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) { + LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type " + << info.type_key << " but get " << kv.second->GetTypeKey(); + } + } + } + for (auto&& kv : update) { + config->Set(kv.first, kv.second); + } + } + + static PassConfigManager* Global() { + static auto* inst = new PassConfigManager(); + return inst; + } + + private: + struct ValueTypeInfo { + std::string type_key; + uint32_t type_index; + }; + + std::unordered_map key2vtype_; +}; + +void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) { + PassConfigManager::Global()->Register(key, value_type_index); +} + PassContext PassContext::Create() { return PassContext(make_object()); } void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { @@ -390,20 +453,23 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_REGISTER_GLOBAL("transform.PassContext").set_body([](TVMArgs args, TVMRetValue* ret) { - auto pctx = PassContext::Create(); - int opt_level = args[0]; - int fallback_device = args[1]; - tvm::Array required = args[2]; - tvm::Array disabled = args[3]; - TraceFunc trace_func = args[4]; - pctx->opt_level = opt_level; - pctx->fallback_device = fallback_device; - pctx->required_pass = std::move(required); - pctx->disabled_pass = std::move(disabled); - pctx->trace_func = std::move(trace_func); - *ret = pctx; -}); +TVM_REGISTER_GLOBAL("transform.PassContext") + .set_body_typed([](int opt_level, int fallback_device, Array required, + Array disabled, TraceFunc trace_func, + Optional> config) { + auto pctx = PassContext::Create(); + pctx->opt_level = opt_level; + pctx->fallback_device = fallback_device; + + pctx->required_pass = std::move(required); + pctx->disabled_pass = std::move(disabled); + pctx->trace_func = std::move(trace_func); + if (config.defined()) { + pctx->config = config.value(); + } + PassConfigManager::Global()->Legalize(&(pctx->config)); + return pctx; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -413,17 +479,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "\topt_level: " << node->opt_level << "\n"; p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n"; - p->stream << "\trequired passes: [" << node->opt_level; + p->stream << "\trequired passes: ["; for (const auto& it : node->required_pass) { p->stream << it << " "; } p->stream << "]\n"; - p->stream << "\tdisabled passes: [" << node->opt_level; + p->stream << "\tdisabled passes: ["; for (const auto& it : node->disabled_pass) { p->stream << it << " "; } - p->stream << "]"; + p->stream << "]\n"; + p->stream << "\tconfig: " << node->config; }); class PassContext::Internal { diff --git a/src/node/reflection.cc b/src/node/reflection.cc index c3397e7500c17..afe795f75c0f0 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -195,14 +195,13 @@ class NodeAttrSetter : public AttrVisitor { } }; -void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { +void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const TVMArgs& args) { NodeAttrSetter setter; setter.type_key = n->GetTypeKey(); CHECK_EQ(args.size() % 2, 0); for (int i = 0; i < args.size(); i += 2) { setter.attrs.emplace(args[i].operator std::string(), args[i + 1]); } - auto* reflection = ReflectionVTable::Global(); reflection->VisitAttrs(n, &setter); if (setter.attrs.size() != 0) { @@ -215,6 +214,35 @@ void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { } } +ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const TVMArgs& kwargs) { + ObjectPtr n = this->CreateInitObject(type_key); + if (n->IsInstance()) { + static_cast(n.get())->InitByPackedArgs(kwargs); + } else { + InitNodeByPackedArgs(this, n.get(), kwargs); + } + return ObjectRef(n); +} + +ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, + const Map& kwargs) { + // Redirect to the TVMArgs version + // It is not the most efficient way, but CreateObject is not meant to be used + // in a fast code-path and is mainly reserved as a flexible API for frontends. + std::vector values(kwargs.size() * 2); + std::vector tcodes(kwargs.size() * 2); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + int index = 0; + + for (auto& kv : static_cast(kwargs.get())->data) { + setter(index, kv.first); + setter(index + 1, kv.second); + index += 2; + } + + return CreateObject(type_key, runtime::TVMArgs(values.data(), tcodes.data(), kwargs.size() * 2)); +} + // Expose to FFI APIs. void NodeGetAttr(TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kTVMObjectHandle); @@ -246,14 +274,7 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) { std::string type_key = args[0]; std::string empty_str; TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1); - auto* reflection = ReflectionVTable::Global(); - ObjectPtr n = reflection->CreateInitObject(type_key); - if (n->IsInstance()) { - static_cast(n.get())->InitByPackedArgs(kwargs); - } else { - InitNodeByPackedArgs(n.get(), kwargs); - } - *rv = ObjectRef(n); + *rv = ReflectionVTable::Global()->CreateObject(type_key, kwargs); } TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr); diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index d56046cfde3ce..35642fbb731b4 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -46,6 +46,15 @@ class ObjectInternal { static_cast(obj)->DecRef(); } } + /*! + * \brief Check of obj derives from the type indicated by type index. + * \param obj The original object. + * \param type_index The type index of interest. + * \return The derivation checking result. + */ + static bool DerivedFrom(const Object* obj, uint32_t type_index) { + return obj->DerivedFrom(type_index); + } /*! * \brief Expose TypeKey2Index * \param type_key The original type key. diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index a69ccc59adf31..8378a88d80fe0 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -39,6 +39,31 @@ namespace tvm { namespace tir { +struct LoopUnrollConfig : public tvm::AttrsNode { + int auto_max_step; + int auto_max_depth; + int auto_max_extent; + int explicit_unroll; + + TVM_DECLARE_ATTRS(LoopUnrollConfig, "tir.transform.LoopUnrollConfig") { + TVM_ATTR_FIELD(auto_max_step) + .describe("Threshold of number of steps in the loop to be automatically unrolled") + .set_default(0); + TVM_ATTR_FIELD(auto_max_depth) + .describe("The maximum nested level of loops that can be automatically unrolled.") + .set_default(8); + TVM_ATTR_FIELD(auto_max_extent) + .describe("The maximum extent of loop that will be unrolled.") + .set_default(0); + TVM_ATTR_FIELD(explicit_unroll) + .describe("Whether to explicitly unroll the loop instead of setting a pragma") + .set_default(true); + } +}; + +TVM_REGISTER_NODE_TYPE(LoopUnrollConfig); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", LoopUnrollConfig); + class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index 975192293d879..b10951691715d 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import pytest from tvm import te def test_const_saveload_json(): @@ -101,6 +102,34 @@ def test_string(): tvm.ir.assert_structural_equal(s1, s2) +def test_pass_config(): + cfg = tvm.transform.PassContext(opt_level=1, config={ + "tir.UnrollLoop": { + "auto_max_step": 10, + } + }) + cfg.opt_level == 1 + + assert cfg.config["tir.UnrollLoop"].auto_max_step == 10 + # default option + assert cfg.config["tir.UnrollLoop"].explicit_unroll == True + + # schema checking for specific config key + with pytest.raises(AttributeError): + cfg = tvm.transform.PassContext(config={ + "tir.UnrollLoop": { "invalid": 1 } + }) + + # schema check for un-registered config + with pytest.raises(AttributeError): + cfg = tvm.transform.PassContext(config={ "inavlid-opt": True }) + + # schema check for wrong type + with pytest.raises(AttributeError): + cfg = tvm.transform.PassContext(config={ + "tir.UnrollLoop": 1 + }) + if __name__ == "__main__": test_string() test_env_func() @@ -108,3 +137,4 @@ def test_string(): test_make_smap() test_const_saveload_json() test_make_sum() + test_pass_config()