diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 5bf86d468bfdf..82b3dd4695415 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -69,19 +69,20 @@ namespace runtime { TVM_DLL std::string GetCustomTypeName(uint8_t type_code); /*! - * \brief Runtime utility for getting custom type code from name - * \param type_name Custom type name - * \return Custom type code - */ -TVM_DLL uint8_t GetCustomTypeCode(const std::string& type_name); - - /*! - * \brief Runtime utility for checking whether custom type is registered - * \param type_code Custom type code - * \return Bool representing whether type is registered - */ + * \brief Runtime utility for checking whether custom type is registered + * \param type_code Custom type code + * \return Bool representing whether type is registered + */ TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code); +/*! + * \brief Runtime utility for parsing string of the form "custom[]" + * \param s String to parse + * \param scan pointer to parsing pointer, which is scanning across s + * \return type code of custom type parsed + */ +TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); + // forward declarations class TVMArgs; class TVMArgValue; @@ -1025,22 +1026,7 @@ inline TVMType String2TVMType(std::string s) { t.lanes = 1; return t; } else if (s.substr(0, 6) == "custom") { - // TODO(gus) this should be separated out into its own parsing function and cleaned up, or - // replaced by a regex. - scan = s.c_str() + 6; - if (*scan != '[') - LOG(FATAL) << "expected opening brace after 'custom' type in" << s; - ++scan; - size_t custom_name_len = 0; - while (scan + custom_name_len <= s.c_str() + s.length() && - *(scan + custom_name_len) != ']') - ++custom_name_len; - if (*(scan + custom_name_len) != ']') - LOG(FATAL) << "expected closing brace after 'custom' type in" << s; - scan += custom_name_len + 1; - - auto type_name = s.substr(7, custom_name_len); - t.code = GetCustomTypeCode(type_name); + t.code = ParseCustomDatatype(s, &scan); } else { scan = s.c_str(); LOG(FATAL) << "unknown type " << s; diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index b93e3881fcbc1..20793b4618b3a 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -63,6 +63,34 @@ bool GetCustomTypeRegistered(uint8_t type_code) { return (*f)(type_code).operator bool(); } +uint8_t ParseCustomDatatype(const std::string& s, const char** scan) { + CHECK(s.substr(0, 6) == "custom") << "Not a valid custom datatype string"; + + auto tmp = s.c_str(); + + CHECK(s.c_str() == tmp); + *scan = s.c_str() + 6; + CHECK(s.c_str() == tmp); + if (**scan != '[') LOG(FATAL) << "expected opening brace after 'custom' type in" << s; + CHECK(s.c_str() == tmp); + *scan += 1; + CHECK(s.c_str() == tmp); + size_t custom_name_len = 0; + CHECK(s.c_str() == tmp); + while (*scan + custom_name_len <= s.c_str() + s.length() && *(*scan + custom_name_len) != ']') + ++custom_name_len; + CHECK(s.c_str() == tmp); + if (*(*scan + custom_name_len) != ']') + LOG(FATAL) << "expected closing brace after 'custom' type in" << s; + CHECK(s.c_str() == tmp); + *scan += custom_name_len + 1; + CHECK(s.c_str() == tmp); + + auto type_name = s.substr(7, custom_name_len); + CHECK(s.c_str() == tmp); + return GetCustomTypeCode(type_name); +} + class DeviceAPIManager { public: static const int kMaxDeviceAPI = 32;