From cbae698f2d5ace8da796c5ac75dd53b1b7d120e2 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 14 Aug 2020 18:33:37 -0700 Subject: [PATCH] [Target] Creating Target from JSON-like Configuration (#6218) * [Target] Creating Target from JSON-like Configuration * Address comments from Cody * fix unittest * More testcases as suggested by @comaniac --- include/tvm/target/target.h | 46 +++- include/tvm/target/target_kind.h | 20 +- src/target/target.cc | 413 +++++++++++++++++++++++++++---- src/target/target_kind.cc | 289 --------------------- tests/cpp/target_test.cc | 131 +++++++--- 5 files changed, 516 insertions(+), 383 deletions(-) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 4a83579b81be..258b2d83ee72 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -62,6 +63,13 @@ class TargetNode : public Object { v->Visit("attrs", &attrs); } + /*! + * \brief Get an entry from attrs of the target + * \tparam TObjectRef Type of the attribute + * \param attr_key The name of the attribute key + * \param default_value The value returned if the key is not present + * \return An optional, NullOpt if not found, otherwise the value found + */ template Optional GetAttr( const std::string& attr_key, @@ -75,15 +83,19 @@ class TargetNode : public Object { return default_value; } } - + /*! + * \brief Get an entry from attrs of the target + * \tparam TObjectRef Type of the attribute + * \param attr_key The name of the attribute key + * \param default_value The value returned if the key is not present + * \return An optional, NullOpt if not found, otherwise the value found + */ template Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { return GetAttr(attr_key, Optional(default_value)); } - /*! \brief Get the keys for this target as a vector of string */ TVM_DLL std::vector GetKeys() const; - /*! \brief Get the keys for this target as an unordered_set of string */ TVM_DLL std::unordered_set GetLibs() const; @@ -93,6 +105,26 @@ class TargetNode : public Object { private: /*! \brief Internal string repr. */ mutable std::string str_repr_; + /*! + * \brief Parsing TargetNode::attrs from a list of raw strings + * \param obj The attribute to be parsed + * \param info The runtime type information for parsing + * \return The attribute parsed + */ + ObjectRef ParseAttr(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info) const; + /*! + * \brief Parsing TargetNode::attrs from a list of raw strings + * \param options The raw string of fields to be parsed + * \return The attributes parsed + */ + Map ParseAttrsFromRaw(const std::vector& options) const; + /*! + * \brief Serialize the attributes of a target to raw string + * \param attrs The attributes to be converted to string + * \return The string converted, NullOpt if attrs is empty + */ + Optional StringifyAttrsToRaw(const Map& attrs) const; + friend class Target; }; @@ -103,10 +135,18 @@ class TargetNode : public Object { class Target : public ObjectRef { public: Target() {} + /*! \brief Constructor from ObjectPtr */ explicit Target(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief Create a Target using a JSON-like configuration + * \param config The JSON-like configuration + * \return The target created + */ + TVM_DLL static Target FromConfig(const Map& config); /*! * \brief Create a Target given a string * \param target_str the string to parse + * \return The target created */ TVM_DLL static Target Create(const String& target_str); /*! diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index a661efad58f0..e4e7c2fa8a4d 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -45,9 +45,6 @@ struct ValueTypeInfoMaker; class Target; -/*! \brief Perform schema validation */ -TVM_DLL void TargetValidateSchema(const Map& config); - template class TargetKindAttrMap; @@ -67,14 +64,14 @@ class TargetKindNode : public Object { v->Visit("default_keys", &default_keys); } - Map ParseAttrsFromRaw(const std::vector& options) const; - - Optional StringifyAttrsToRaw(const Map& attrs) const; - static constexpr const char* _type_key = "TargetKind"; TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object); private: + /*! \brief Return the index stored in attr registry */ + uint32_t AttrRegistryIndex() const { return index_; } + /*! \brief Return the name stored in attr registry */ + String AttrRegistryName() const { return name; } /*! \brief Stores the required type_key and type_index of a specific attr of a target */ struct ValueTypeInfo { String type_key; @@ -82,21 +79,14 @@ class TargetKindNode : public Object { std::unique_ptr key; std::unique_ptr val; }; - - uint32_t AttrRegistryIndex() const { return index_; } - String AttrRegistryName() const { return name; } - /*! \brief Perform schema validation */ - void ValidateSchema(const Map& config) const; - /*! \brief Verify if the obj is consistent with the type info */ - void VerifyTypeInfo(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info) const; /*! \brief A hash table that stores the type information of each attr of the target key */ std::unordered_map key2vtype_; /*! \brief A hash table that stores the default value of each attr of the target key */ std::unordered_map key2default_; /*! \brief Index used for internal lookup of attribute registry */ uint32_t index_; - friend void TargetValidateSchema(const Map&); friend class Target; + friend class TargetNode; friend class TargetKind; template friend class AttrRegistry; diff --git a/src/target/target.cc b/src/target/target.cc index 94b5b035de8c..6a245973315e 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -30,12 +30,201 @@ #include #include +#include "../runtime/object_internal.h" + namespace tvm { using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +TVM_REGISTER_NODE_TYPE(TargetNode); + +static std::vector DeduplicateKeys(const std::vector& keys) { + std::vector new_keys; + for (size_t i = 0; i < keys.size(); ++i) { + bool found = false; + for (size_t j = 0; j < i; ++j) { + if (keys[i] == keys[j]) { + found = true; + break; + } + } + if (!found) { + new_keys.push_back(keys[i]); + } + } + return new_keys; +} + +static inline std::string RemovePrefixDashes(const std::string& s) { + size_t n_dashes = 0; + for (; n_dashes < s.length() && s[n_dashes] == '-'; ++n_dashes) { + } + CHECK(0 < n_dashes && n_dashes < s.size()) << "ValueError: Not an attribute key \"" << s << "\""; + return s.substr(n_dashes); +} + +static inline int FindUniqueSubstr(const std::string& str, const std::string& substr) { + size_t pos = str.find_first_of(substr); + if (pos == std::string::npos) { + return -1; + } + size_t next_pos = pos + substr.size(); + CHECK(next_pos >= str.size() || str.find_first_of(substr, next_pos) == std::string::npos) + << "ValueError: At most one \"" << substr << "\" is allowed in " + << "the the given string \"" << str << "\""; + return pos; +} + +static inline ObjectRef ParseAtomicType(uint32_t type_index, const std::string& str) { + std::istringstream is(str); + if (type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + int v; + is >> v; + return is.fail() ? ObjectRef(nullptr) : Integer(v); + } else if (type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + std::string v; + is >> v; + return is.fail() ? ObjectRef(nullptr) : String(v); + } + return ObjectRef(nullptr); +} + +Map TargetNode::ParseAttrsFromRaw( + const std::vector& options) const { + std::unordered_map attrs; + for (size_t iter = 0, end = options.size(); iter < end;) { + // remove the prefix dashes + std::string s = RemovePrefixDashes(options[iter++]); + // parse name-obj pair + std::string name; + std::string obj; + int pos; + if ((pos = FindUniqueSubstr(s, "=")) != -1) { + // case 1. --key=value + name = s.substr(0, pos); + obj = s.substr(pos + 1); + CHECK(!name.empty()) << "ValueError: Empty attribute key in \"" << options[iter - 1] << "\""; + CHECK(!obj.empty()) << "ValueError: Empty attribute in \"" << options[iter - 1] << "\""; + } else if (iter < end && options[iter][0] != '-') { + // case 2. --key value + name = s; + obj = options[iter++]; + } else { + // case 3. --boolean-key + name = s; + obj = "1"; + } + // check if `name` is invalid + auto it = this->kind->key2vtype_.find(name); + if (it == this->kind->key2vtype_.end()) { + std::ostringstream os; + os << "AttributeError: Invalid config option, cannot recognize \'" << name + << "\'. Candidates are:"; + for (const auto& kv : this->kind->key2vtype_) { + os << "\n " << kv.first; + } + LOG(FATAL) << os.str(); + } + // check if `name` has been set once + CHECK(!attrs.count(name)) << "AttributeError: key \"" << name + << "\" appears more than once in the target string"; + // then `name` is valid, let's parse them + // only several types are supported when parsing raw string + const auto& info = it->second; + ObjectRef parsed_obj(nullptr); + if (info.type_index != ArrayNode::_type_index) { + parsed_obj = ParseAtomicType(info.type_index, obj); + } else { + Array array; + std::string item; + bool failed = false; + uint32_t type_index = info.key->type_index; + for (std::istringstream is(obj); std::getline(is, item, ',');) { + ObjectRef parsed_obj = ParseAtomicType(type_index, item); + if (parsed_obj.defined()) { + array.push_back(parsed_obj); + } else { + failed = true; + break; + } + } + if (!failed) { + parsed_obj = std::move(array); + } + } + if (!parsed_obj.defined()) { + LOG(FATAL) << "ValueError: Cannot parse type \"" << info.type_key << "\"" + << ", where attribute key is \"" << name << "\"" + << ", and attribute is \"" << obj << "\""; + } + attrs[name] = std::move(parsed_obj); + } + // set default attribute values if they do not exist + for (const auto& kv : this->kind->key2default_) { + if (!attrs.count(kv.first)) { + attrs[kv.first] = kv.second; + } + } + return attrs; +} + +static inline Optional StringifyAtomicType(const ObjectRef& obj) { + if (const auto* p = obj.as()) { + return String(std::to_string(p->value)); + } + if (const auto* p = obj.as()) { + return GetRef(p); + } + return NullOpt; +} + +static inline Optional JoinString(const std::vector& array, char separator) { + if (array.empty()) { + return NullOpt; + } + std::ostringstream os; + os << array[0]; + for (size_t i = 1; i < array.size(); ++i) { + os << separator << array[i]; + } + return String(os.str()); +} + +Optional TargetNode::StringifyAttrsToRaw(const Map& attrs) const { + std::ostringstream os; + std::vector keys; + for (const auto& kv : attrs) { + keys.push_back(kv.first); + } + std::sort(keys.begin(), keys.end()); + std::vector result; + for (const auto& key : keys) { + const ObjectRef& obj = attrs[key]; + Optional value = NullOpt; + if (const auto* array = obj.as()) { + std::vector items; + for (const ObjectRef& item : *array) { + Optional str = StringifyAtomicType(item); + if (str.defined()) { + items.push_back(str.value()); + } else { + items.clear(); + break; + } + } + value = JoinString(items, ','); + } else { + value = StringifyAtomicType(obj); + } + if (value.defined()) { + result.push_back("-" + key + "=" + value.value()); + } + } + return JoinString(result, ' '); +} + Target Target::CreateTarget(const std::string& name, const std::vector& options) { TargetKind kind = TargetKind::Get(name); ObjectPtr target = make_object(); @@ -43,7 +232,7 @@ Target Target::CreateTarget(const std::string& name, const std::vectortag = ""; // parse attrs - target->attrs = kind->ParseAttrsFromRaw(options); + target->attrs = target->ParseAttrsFromRaw(options); String device_name = target->GetAttr("device", "").value(); // set up keys { @@ -62,48 +251,11 @@ Target Target::CreateTarget(const std::string& name, const std::vectorkeys = std::move(keys); + target->keys = DeduplicateKeys(keys); } return Target(target); } -TVM_REGISTER_NODE_TYPE(TargetNode); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->str(); - }); - -TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args, TVMRetValue* ret) { - std::string name = args[0]; - std::vector options; - for (int i = 1; i < args.num_args; ++i) { - std::string arg = args[i]; - options.push_back(arg); - } - - *ret = Target::CreateTarget(name, options); -}); - -TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args, TVMRetValue* ret) { - std::string target_str = args[0]; - *ret = Target::Create(target_str); -}); - std::vector TargetNode::GetKeys() const { std::vector result; for (auto& expr : keys) { @@ -140,7 +292,7 @@ const std::string& TargetNode::str() const { os << s; } } - if (Optional attrs_str = kind->StringifyAttrsToRaw(attrs)) { + if (Optional attrs_str = this->StringifyAttrsToRaw(attrs)) { os << ' ' << attrs_str.value(); } str_repr_ = os.str(); @@ -162,6 +314,160 @@ Target Target::Create(const String& target_str) { return CreateTarget(splits[0], {splits.begin() + 1, splits.end()}); } +ObjectRef TargetNode::ParseAttr(const ObjectRef& obj, + const TargetKindNode::ValueTypeInfo& info) const { + if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + const auto* v = obj.as(); + CHECK(v != nullptr) << "Expect type 'int', but get: " << obj->GetTypeKey(); + return GetRef(v); + } + if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + const auto* v = obj.as(); + CHECK(v != nullptr) << "Expect type 'str', but get: " << obj->GetTypeKey(); + return GetRef(v); + } + if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + CHECK(obj->IsInstance()) + << "Expect type 'dict' to construct Target, but get: " << obj->GetTypeKey(); + return Target::FromConfig(Downcast>(obj)); + } + if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { + CHECK(obj->IsInstance()) << "Expect type 'list', but get: " << obj->GetTypeKey(); + Array array = Downcast>(obj); + std::vector result; + int i = 0; + for (const ObjectRef& e : array) { + ++i; + try { + result.push_back(TargetNode::ParseAttr(e, *info.key)); + } catch (const dmlc::Error& e) { + LOG(FATAL) << "Error occurred when parsing element " << i << " of the array: " << array + << ". Details:\n" + << e.what(); + } + } + return Array(result); + } + if (info.type_index == MapNode::_GetOrAllocRuntimeTypeIndex()) { + CHECK(obj->IsInstance()) << "Expect type 'dict', but get: " << obj->GetTypeKey(); + std::unordered_map result; + for (const auto& kv : Downcast>(obj)) { + ObjectRef key, val; + try { + key = TargetNode::ParseAttr(kv.first, *info.key); + } catch (const tvm::Error& e) { + LOG(FATAL) << "Error occurred when parsing a key of the dict: " << kv.first + << ". Details:\n" + << e.what(); + } + try { + val = TargetNode::ParseAttr(kv.second, *info.val); + } catch (const tvm::Error& e) { + LOG(FATAL) << "Error occurred when parsing a value of the dict: " << kv.second + << ". Details:\n" + << e.what(); + } + result[key] = val; + } + return Map(result); + } + LOG(FATAL) << "Unsupported type registered: \"" << info.type_key + << "\", and the type given is: " << obj->GetTypeKey(); + throw; +} + +Target Target::FromConfig(const Map& config_dict) { + const String kKind = "kind"; + const String kTag = "tag"; + const String kKeys = "keys"; + const String kDeviceName = "device"; + std::unordered_map config(config_dict.begin(), config_dict.end()); + ObjectPtr target = make_object(); + // parse 'kind' + if (config.count(kKind)) { + const auto* kind = config[kKind].as(); + CHECK(kind != nullptr) << "AttributeError: Expect type of field 'kind' is string, but get: " + << config[kKind]->GetTypeKey(); + target->kind = TargetKind::Get(GetRef(kind)); + config.erase(kKind); + } else { + LOG(FATAL) << "AttributeError: Field 'kind' is not found"; + } + // parse "tag" + if (config.count(kTag)) { + const auto* tag = config[kTag].as(); + CHECK(tag != nullptr) << "AttributeError: Expect type of field 'tag' is string, but get: " + << config[kTag]->GetTypeKey(); + target->tag = GetRef(tag); + config.erase(kTag); + } else { + target->tag = ""; + } + // parse "keys" + if (config.count(kKeys)) { + std::vector keys; + // user provided keys + const auto* cfg_keys = config[kKeys].as(); + CHECK(cfg_keys != nullptr) + << "AttributeError: Expect type of field 'keys' is an Array, but get: " + << config[kTag]->GetTypeKey(); + for (const ObjectRef& e : *cfg_keys) { + const auto* key = e.as(); + CHECK(key != nullptr) << "AttributeError: Expect 'keys' to be an array of strings, but it " + "contains an element of type: " + << e->GetTypeKey(); + keys.push_back(GetRef(key)); + } + // add device name + if (config_dict.count(kDeviceName)) { + if (const auto* device = config_dict.at(kDeviceName).as()) { + keys.push_back(GetRef(device)); + } + } + // add default keys + for (const auto& key : target->kind->default_keys) { + keys.push_back(key); + } + // de-duplicate keys + target->keys = DeduplicateKeys(keys); + config.erase(kKeys); + } else { + target->keys = {}; + } + // parse attrs + std::unordered_map attrs; + const auto& key2vtype = target->kind->key2vtype_; + for (const auto& cfg_kv : config) { + const String& name = cfg_kv.first; + const ObjectRef& obj = cfg_kv.second; + if (!key2vtype.count(name)) { + std::ostringstream os; + os << "AttributeError: Unrecognized config option: \"" << name << "\". Candidates are:"; + for (const auto& kv : key2vtype) { + os << " " << kv.first; + } + LOG(FATAL) << os.str(); + } + ObjectRef val; + try { + val = target->ParseAttr(obj, key2vtype.at(name)); + } catch (const dmlc::Error& e) { + LOG(FATAL) << "AttributeError: Error occurred in parsing the config key \"" << name + << "\". Details:\n" + << e.what(); + } + attrs[name] = val; + } + // set default attribute values if they do not exist + for (const auto& kv : target->kind->key2default_) { + if (!attrs.count(kv.first)) { + attrs[kv.first] = kv.second; + } + } + target->attrs = attrs; + return Target(target); +} + /*! \brief Entry to hold the Target context stack. */ struct TVMTargetThreadLocalEntry { /*! \brief The current target context */ @@ -169,7 +475,7 @@ struct TVMTargetThreadLocalEntry { }; /*! \brief Thread local store to hold the Target context stack. */ -typedef dmlc::ThreadLocalStore TVMTargetThreadLocalStore; +using TVMTargetThreadLocalStore = dmlc::ThreadLocalStore; void Target::EnterWithScope() { TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); @@ -194,20 +500,37 @@ tvm::Target Target::Current(bool allow_not_defined) { return Target(); } -TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body([](TVMArgs args, TVMRetValue* ret) { - bool allow_not_defined = args[0]; - *ret = Target::Current(allow_not_defined); -}); class Target::Internal { public: static void EnterScope(Target target) { target.EnterWithScope(); } static void ExitScope(Target target) { target.ExitWithScope(); } }; +TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args, TVMRetValue* ret) { + std::string name = args[0]; + std::vector options; + for (int i = 1; i < args.num_args; ++i) { + std::string arg = args[i]; + options.push_back(arg); + } + + *ret = Target::CreateTarget(name, options); +}); + TVM_REGISTER_GLOBAL("target.EnterTargetScope").set_body_typed(Target::Internal::EnterScope); TVM_REGISTER_GLOBAL("target.ExitTargetScope").set_body_typed(Target::Internal::ExitScope); +TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body_typed(Target::Current); + +TVM_REGISTER_GLOBAL("target.TargetFromString").set_body_typed(Target::Create); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->str(); + }); + namespace target { std::vector MergeOptions(std::vector opts, const std::vector& new_opts) { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index e6f7c5cdec13..3e35e5b3690d 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -26,7 +26,6 @@ #include #include "../node/attr_registry.h" -#include "../runtime/object_internal.h" namespace tvm { @@ -60,294 +59,6 @@ const TargetKind& TargetKind::Get(const String& target_kind_name) { return reg->kind_; } -void TargetKindNode::VerifyTypeInfo(const ObjectRef& obj, - const TargetKindNode::ValueTypeInfo& info) const { - CHECK(obj.defined()) << "Object is None"; - if (!runtime::ObjectInternal::DerivedFrom(obj.get(), info.type_index)) { - LOG(FATAL) << "AttributeError: expect type \"" << info.type_key << "\" but get " - << obj->GetTypeKey(); - throw; - } - if (info.type_index == ArrayNode::_type_index) { - int i = 0; - for (const auto& e : *obj.as()) { - try { - VerifyTypeInfo(e, *info.key); - } catch (const tvm::Error& e) { - LOG(FATAL) << "The i-th element of array failed type checking, where i = " << i - << ", and the error is:\n" - << e.what(); - throw; - } - ++i; - } - } else if (info.type_index == MapNode::_type_index) { - for (const auto& kv : *obj.as()) { - try { - VerifyTypeInfo(kv.first, *info.key); - } catch (const tvm::Error& e) { - LOG(FATAL) << "The key of map failed type checking, where key = \"" << kv.first - << "\", value = \"" << kv.second << "\", and the error is:\n" - << e.what(); - throw; - } - try { - VerifyTypeInfo(kv.second, *info.val); - } catch (const tvm::Error& e) { - LOG(FATAL) << "The value of map failed type checking, where key = \"" << kv.first - << "\", value = \"" << kv.second << "\", and the error is:\n" - << e.what(); - throw; - } - } - } -} - -void TargetKindNode::ValidateSchema(const Map& config) const { - const String kTargetKind = "kind"; - for (const auto& kv : config) { - const String& name = kv.first; - const ObjectRef& obj = kv.second; - if (name == kTargetKind) { - CHECK(obj->IsInstance()) - << "AttributeError: \"kind\" is not a string, but its type is \"" << obj->GetTypeKey() - << "\""; - CHECK(Downcast(obj) == this->name) - << "AttributeError: \"kind\" = \"" << obj << "\" is inconsistent with TargetKind \"" - << this->name << "\""; - continue; - } - auto it = key2vtype_.find(name); - if (it == key2vtype_.end()) { - std::ostringstream os; - os << "AttributeError: Invalid config option, cannot recognize \"" << name - << "\". Candidates are:"; - for (const auto& kv : key2vtype_) { - os << "\n " << kv.first; - } - LOG(FATAL) << os.str(); - throw; - } - const auto& info = it->second; - try { - VerifyTypeInfo(obj, info); - } catch (const tvm::Error& e) { - LOG(FATAL) << "AttributeError: Schema validation failed for TargetKind \"" << this->name - << "\", details:\n" - << e.what() << "\n" - << "The config is:\n" - << config; - throw; - } - } -} - -inline String GetKind(const Map& target, const char* name) { - const String kTargetKind = "kind"; - CHECK(target.count(kTargetKind)) - << "AttributeError: \"kind\" does not exist in \"" << name << "\"\n" - << name << " = " << target; - const ObjectRef& obj = target[kTargetKind]; - CHECK(obj->IsInstance()) << "AttributeError: \"kind\" is not a string in \"" << name - << "\", but its type is \"" << obj->GetTypeKey() << "\"\n" - << name << " = \"" << target << '"'; - return Downcast(obj); -} - -void TargetValidateSchema(const Map& config) { - try { - const String kTargetHost = "target_host"; - Map target = config; - Map target_host; - String target_kind = GetKind(target, "target"); - String target_host_kind; - if (config.count(kTargetHost)) { - target.erase(kTargetHost); - target_host = Downcast>(config[kTargetHost]); - target_host_kind = GetKind(target_host, "target_host"); - } - TargetKind::Get(target_kind)->ValidateSchema(target); - if (!target_host.empty()) { - TargetKind::Get(target_host_kind)->ValidateSchema(target_host); - } - } catch (const tvm::Error& e) { - LOG(FATAL) << "AttributeError: schedule validation fails:\n" - << e.what() << "\nThe configuration is:\n" - << config; - } -} - -static inline size_t CountNumPrefixDashes(const std::string& s) { - size_t i = 0; - for (; i < s.length() && s[i] == '-'; ++i) { - } - return i; -} - -static inline int FindUniqueSubstr(const std::string& str, const std::string& substr) { - size_t pos = str.find_first_of(substr); - if (pos == std::string::npos) { - return -1; - } - size_t next_pos = pos + substr.size(); - CHECK(next_pos >= str.size() || str.find_first_of(substr, next_pos) == std::string::npos) - << "ValueError: At most one \"" << substr << "\" is allowed in " - << "the the given string \"" << str << "\""; - return pos; -} - -static inline ObjectRef ParseScalar(uint32_t type_index, const std::string& str) { - std::istringstream is(str); - if (type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - int v; - is >> v; - return is.fail() ? ObjectRef(nullptr) : Integer(v); - } else if (type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - std::string v; - is >> v; - return is.fail() ? ObjectRef(nullptr) : String(v); - } - return ObjectRef(nullptr); -} - -static inline Optional StringifyScalar(const ObjectRef& obj) { - if (const auto* p = obj.as()) { - return String(std::to_string(p->value)); - } - if (const auto* p = obj.as()) { - return GetRef(p); - } - return NullOpt; -} - -static inline Optional Join(const std::vector& array, char separator) { - if (array.empty()) { - return NullOpt; - } - std::ostringstream os; - os << array[0]; - for (size_t i = 1; i < array.size(); ++i) { - os << separator << array[i]; - } - return String(os.str()); -} - -Map TargetKindNode::ParseAttrsFromRaw( - const std::vector& options) const { - std::unordered_map attrs; - for (size_t iter = 0, end = options.size(); iter < end;) { - std::string s = options[iter++]; - // remove the prefix dashes - size_t n_dashes = CountNumPrefixDashes(s); - CHECK(0 < n_dashes && n_dashes < s.size()) - << "ValueError: Not an attribute key \"" << s << "\""; - s = s.substr(n_dashes); - // parse name-obj pair - std::string name; - std::string obj; - int pos; - if ((pos = FindUniqueSubstr(s, "=")) != -1) { - // case 1. --key=value - name = s.substr(0, pos); - obj = s.substr(pos + 1); - CHECK(!name.empty()) << "ValueError: Empty attribute key in \"" << options[iter - 1] << "\""; - CHECK(!obj.empty()) << "ValueError: Empty attribute in \"" << options[iter - 1] << "\""; - } else if (iter < end && options[iter][0] != '-') { - // case 2. --key value - name = s; - obj = options[iter++]; - } else { - // case 3. --boolean-key - name = s; - obj = "1"; - } - // check if `name` is invalid - auto it = key2vtype_.find(name); - if (it == key2vtype_.end()) { - std::ostringstream os; - os << "AttributeError: Invalid config option, cannot recognize \'" << name - << "\'. Candidates are:"; - for (const auto& kv : key2vtype_) { - os << "\n " << kv.first; - } - LOG(FATAL) << os.str(); - } - // check if `name` has been set once - CHECK(!attrs.count(name)) << "AttributeError: key \"" << name - << "\" appears more than once in the target string"; - // then `name` is valid, let's parse them - // only several types are supported when parsing raw string - const auto& info = it->second; - ObjectRef parsed_obj(nullptr); - if (info.type_index != ArrayNode::_type_index) { - parsed_obj = ParseScalar(info.type_index, obj); - } else { - Array array; - std::string item; - bool failed = false; - uint32_t type_index = info.key->type_index; - for (std::istringstream is(obj); std::getline(is, item, ',');) { - ObjectRef parsed_obj = ParseScalar(type_index, item); - if (parsed_obj.defined()) { - array.push_back(parsed_obj); - } else { - failed = true; - break; - } - } - if (!failed) { - parsed_obj = std::move(array); - } - } - if (!parsed_obj.defined()) { - LOG(FATAL) << "ValueError: Cannot parse type \"" << info.type_key << "\"" - << ", where attribute key is \"" << name << "\"" - << ", and attribute is \"" << obj << "\""; - } - attrs[name] = std::move(parsed_obj); - } - // set default attribute values if they do not exist - for (const auto& kv : key2default_) { - if (!attrs.count(kv.first)) { - attrs[kv.first] = kv.second; - } - } - return attrs; -} - -Optional TargetKindNode::StringifyAttrsToRaw(const Map& attrs) const { - std::ostringstream os; - std::vector keys; - for (const auto& kv : attrs) { - keys.push_back(kv.first); - } - std::sort(keys.begin(), keys.end()); - std::vector result; - for (const auto& key : keys) { - const ObjectRef& obj = attrs[key]; - Optional value = NullOpt; - if (const auto* array = obj.as()) { - std::vector items; - for (const ObjectRef& item : *array) { - Optional str = StringifyScalar(item); - if (str.defined()) { - items.push_back(str.value()); - } else { - items.clear(); - break; - } - } - value = Join(items, ','); - } else { - value = StringifyScalar(obj); - } - if (value.defined()) { - result.push_back("-" + key + "=" + value.value()); - } - } - return Join(result, ' '); -} - // TODO(@junrushao1994): remove some redundant attributes TVM_REGISTER_TARGET_KIND("llvm") diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 8bee7078d5cd..e8748e63295a 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -19,7 +19,7 @@ #include #include -#include +#include #include #include @@ -39,48 +39,117 @@ TEST(TargetKind, GetAttrMap) { CHECK_EQ(result, "Value1"); } -TEST(TargetKind, SchemaValidation) { - tvm::Map target; - { - tvm::Array your_names{"junru", "jian"}; - tvm::Map her_maps{ - {"a", 1}, - {"b", 2}, - }; - target.Set("my_bool", Bool(true)); - target.Set("your_names", your_names); - target.Set("her_maps", her_maps); - target.Set("kind", String("TestTargetKind")); +TEST(TargetCreation, NestedConfig) { + Map config = { + {"my_bool", Bool(true)}, + {"your_names", Array{"junru", "jian"}}, + {"kind", String("TestTargetKind")}, + { + "her_maps", + Map{ + {"a", 1}, + {"b", 2}, + }, + }, + }; + Target target = Target::FromConfig(config); + CHECK_EQ(target->kind, TargetKind::Get("TestTargetKind")); + CHECK_EQ(target->tag, ""); + CHECK(target->keys.empty()); + Bool my_bool = target->GetAttr("my_bool").value(); + CHECK_EQ(my_bool.operator bool(), true); + Array your_names = target->GetAttr>("your_names").value(); + CHECK_EQ(your_names.size(), 2U); + CHECK_EQ(your_names[0], "junru"); + CHECK_EQ(your_names[1], "jian"); + Map her_maps = target->GetAttr>("her_maps").value(); + CHECK_EQ(her_maps.size(), 2U); + CHECK_EQ(her_maps["a"], 1); + CHECK_EQ(her_maps["b"], 2); +} + +TEST(TargetCreationFail, UnrecognizedConfigOption) { + Map config = { + {"my_bool", Bool(true)}, + {"your_names", Array{"junru", "jian"}}, + {"kind", String("TestTargetKind")}, + {"bad", ObjectRef(nullptr)}, + { + "her_maps", + Map{ + {"a", 1}, + {"b", 2}, + }, + }, + }; + bool failed = false; + try { + Target::FromConfig(config); + } catch (...) { + failed = true; } - TargetValidateSchema(target); - tvm::Map target_host(target.begin(), target.end()); - target.Set("target_host", target_host); - TargetValidateSchema(target); + ASSERT_EQ(failed, true); } -TEST(TargetKind, SchemaValidationFail) { - tvm::Map target; - { - tvm::Array your_names{"junru", "jian"}; - tvm::Map her_maps{ - {"a", 1}, - {"b", 2}, - }; - target.Set("my_bool", Bool(true)); - target.Set("your_names", your_names); - target.Set("her_maps", her_maps); - target.Set("ok", ObjectRef(nullptr)); - target.Set("kind", String("TestTargetKind")); +TEST(TargetCreationFail, TypeMismatch) { + Map config = { + {"my_bool", String("true")}, + {"your_names", Array{"junru", "jian"}}, + {"kind", String("TestTargetKind")}, + { + "her_maps", + Map{ + {"a", 1}, + {"b", 2}, + }, + }, + }; + bool failed = false; + try { + Target::FromConfig(config); + } catch (...) { + failed = true; } + ASSERT_EQ(failed, true); +} + +TEST(TargetCreationFail, TargetKindNotFound) { + Map config = { + {"my_bool", Bool("true")}, + {"your_names", Array{"junru", "jian"}}, + { + "her_maps", + Map{ + {"a", 1}, + {"b", 2}, + }, + }, + }; bool failed = false; try { - TargetValidateSchema(target); + Target::FromConfig(config); } catch (...) { failed = true; } ASSERT_EQ(failed, true); } +TEST(TargetCreation, DeduplicateKeys) { + Map config = { + {"kind", String("llvm")}, + {"keys", Array{"cpu", "arm_cpu"}}, + {"device", String("arm_cpu")}, + }; + Target target = Target::FromConfig(config); + CHECK_EQ(target->kind, TargetKind::Get("llvm")); + CHECK_EQ(target->tag, ""); + CHECK_EQ(target->keys.size(), 2U); + CHECK_EQ(target->keys[0], "cpu"); + CHECK_EQ(target->keys[1], "arm_cpu"); + CHECK_EQ(target->attrs.size(), 1U); + CHECK_EQ(target->GetAttr("device"), "arm_cpu"); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";