diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 6379a1327514..9e92db92c6bd 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -27,6 +27,7 @@ #include #include #include +#include namespace tvm { namespace runtime { @@ -263,6 +264,141 @@ inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) { inline bool TypeEqual(DLDataType lhs, DLDataType rhs) { return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; } + +/*! + * \brief Runtime utility for getting custom type name from code + * \param type_code Custom type code + * \return Custom type name + */ +TVM_DLL std::string GetCustomTypeName(uint8_t type_code); + +/*! + * \brief Runtime utility for checking whether custom type is registered + * \param type_code Custom type code + * \return Bool representing whether type is registered + */ +TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code); + +/*! + * \brief Runtime utility for parsing string of the form "custom[]" + * \param s String to parse + * \param scan pointer to parsing pointer, which is scanning across s + * \return type code of custom type parsed + */ +TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); + +/*! + * \brief Convert type code to its name + * \param type_code The type code . + * \return The name of type code. + */ +inline const char* TypeCode2Str(int type_code); + +/*! + * \brief convert a string to TVM type. + * \param s The string to be converted. + * \return The corresponding tvm type. + */ +inline DLDataType String2DLDataType(std::string s); + +/*! + * \brief convert a TVM type to string. + * \param t The type to be converted. + * \return The corresponding tvm type in string. + */ +inline std::string DLDataType2String(DLDataType t); + +// implementation details +inline const char* TypeCode2Str(int type_code) { + switch (type_code) { + case kDLInt: return "int"; + case kDLUInt: return "uint"; + case kDLFloat: return "float"; + case kTVMStr: return "str"; + case kTVMBytes: return "bytes"; + case kTVMOpaqueHandle: return "handle"; + case kTVMNullptr: return "NULL"; + case kTVMDLTensorHandle: return "ArrayHandle"; + case kTVMDataType: return "DLDataType"; + case kTVMContext: return "TVMContext"; + case kTVMPackedFuncHandle: return "FunctionHandle"; + case kTVMModuleHandle: return "ModuleHandle"; + case kTVMNDArrayHandle: return "NDArrayContainer"; + case kTVMObjectHandle: return "Object"; + default: LOG(FATAL) << "unknown type_code=" + << static_cast(type_code); return ""; + } +} + +inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) + if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { + os << "bool"; return os; + } + if (t.code < kTVMCustomBegin) { + os << TypeCode2Str(t.code); + } else { + os << "custom[" << GetCustomTypeName(t.code) << "]"; + } + if (t.code == kTVMOpaqueHandle) return os; + os << static_cast(t.bits); + if (t.lanes != 1) { + os << 'x' << static_cast(t.lanes); + } + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) + return os << dtype.operator DLDataType(); +} + +inline std::string DLDataType2String(DLDataType t) { + if (t.bits == 0) return ""; + std::ostringstream os; + os << t; + return os.str(); +} + +inline DLDataType String2DLDataType(std::string s) { + DLDataType t; + // handle None type + if (s.length() == 0) { + t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle; + return t; + } + t.bits = 32; t.lanes = 1; + const char* scan; + if (s.substr(0, 3) == "int") { + t.code = kDLInt; scan = s.c_str() + 3; + } else if (s.substr(0, 4) == "uint") { + t.code = kDLUInt; scan = s.c_str() + 4; + } else if (s.substr(0, 5) == "float") { + t.code = kDLFloat; scan = s.c_str() + 5; + } else if (s.substr(0, 6) == "handle") { + t.code = kTVMOpaqueHandle; + t.bits = 64; // handle uses 64 bit by default. + scan = s.c_str() + 6; + } else if (s == "bool") { + t.code = kDLUInt; + t.bits = 1; + t.lanes = 1; + return t; + } else if (s.substr(0, 6) == "custom") { + t.code = ParseCustomDatatype(s, &scan); + } else { + scan = s.c_str(); + LOG(FATAL) << "unknown type " << s; + } + char* xdelim; // emulate sscanf("%ux%u", bits, lanes) + uint8_t bits = static_cast(strtoul(scan, &xdelim, 10)); + if (bits != 0) t.bits = bits; + char* endpt = xdelim; + if (*xdelim == 'x') { + t.lanes = static_cast(strtoul(xdelim + 1, &endpt, 10)); + } + CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s; + return t; +} + } // namespace runtime using DataType = runtime::DataType; diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 46fe1a189285..d5c017502426 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -52,28 +53,6 @@ class PrimExpr; namespace runtime { -/*! - * \brief Runtime utility for getting custom type name from code - * \param type_code Custom type code - * \return Custom type name - */ -TVM_DLL std::string GetCustomTypeName(uint8_t type_code); - -/*! - * \brief Runtime utility for checking whether custom type is registered - * \param type_code Custom type code - * \return Bool representing whether type is registered - */ -TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code); - -/*! - * \brief Runtime utility for parsing string of the form "custom[]" - * \param s String to parse - * \param scan pointer to parsing pointer, which is scanning across s - * \return type code of custom type parsed - */ -TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); - // forward declarations class TVMArgs; class TVMArgValue; @@ -359,27 +338,6 @@ class TVMArgs { inline TVMArgValue operator[](int i) const; }; -/*! - * \brief Convert type code to its name - * \param type_code The type code . - * \return The name of type code. - */ -inline const char* TypeCode2Str(int type_code); - -/*! - * \brief convert a string to TVM type. - * \param s The string to be converted. - * \return The corresponding tvm type. - */ -inline DLDataType String2DLDataType(std::string s); - -/*! - * \brief convert a TVM type to string. - * \param t The type to be converted. - * \return The corresponding tvm type in string. - */ -inline std::string DLDataType2String(DLDataType t); - // macro to check type code. #define TVM_CHECK_TYPE_CODE(CODE, T) \ CHECK_EQ(CODE, T) << " expected " \ @@ -554,6 +512,10 @@ class TVMArgValue : public TVMPODValue_ { return std::string(value_.v_str); } } + operator tvm::runtime::String() const { + // directly use the std::string constructor for now. + return tvm::runtime::String(operator std::string()); + } operator DLDataType() const { if (type_code_ == kTVMStr) { return String2DLDataType(operator std::string()); @@ -642,6 +604,10 @@ class TVMRetValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMStr); return *ptr(); } + operator tvm::runtime::String() const { + // directly use the std::string constructor for now. + return tvm::runtime::String(operator std::string()); + } operator DLDataType() const { if (type_code_ == kTVMStr) { return String2DLDataType(operator std::string()); @@ -994,96 +960,6 @@ class TVMRetValue : public TVMPODValue_ { } \ } -// implementation details -inline const char* TypeCode2Str(int type_code) { - switch (type_code) { - case kDLInt: return "int"; - case kDLUInt: return "uint"; - case kDLFloat: return "float"; - case kTVMStr: return "str"; - case kTVMBytes: return "bytes"; - case kTVMOpaqueHandle: return "handle"; - case kTVMNullptr: return "NULL"; - case kTVMDLTensorHandle: return "ArrayHandle"; - case kTVMDataType: return "DLDataType"; - case kTVMContext: return "TVMContext"; - case kTVMPackedFuncHandle: return "FunctionHandle"; - case kTVMModuleHandle: return "ModuleHandle"; - case kTVMNDArrayHandle: return "NDArrayContainer"; - case kTVMObjectHandle: return "Object"; - default: LOG(FATAL) << "unknown type_code=" - << static_cast(type_code); return ""; - } -} - -inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) - if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { - os << "bool"; return os; - } - if (t.code < kTVMCustomBegin) { - os << TypeCode2Str(t.code); - } else { - os << "custom[" << GetCustomTypeName(t.code) << "]"; - } - if (t.code == kTVMOpaqueHandle) return os; - os << static_cast(t.bits); - if (t.lanes != 1) { - os << 'x' << static_cast(t.lanes); - } - return os; -} - -inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) - return os << dtype.operator DLDataType(); -} - -inline std::string DLDataType2String(DLDataType t) { - if (t.bits == 0) return ""; - std::ostringstream os; - os << t; - return os.str(); -} - -inline DLDataType String2DLDataType(std::string s) { - DLDataType t; - // handle None type - if (s.length() == 0) { - t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle; - return t; - } - t.bits = 32; t.lanes = 1; - const char* scan; - if (s.substr(0, 3) == "int") { - t.code = kDLInt; scan = s.c_str() + 3; - } else if (s.substr(0, 4) == "uint") { - t.code = kDLUInt; scan = s.c_str() + 4; - } else if (s.substr(0, 5) == "float") { - t.code = kDLFloat; scan = s.c_str() + 5; - } else if (s.substr(0, 6) == "handle") { - t.code = kTVMOpaqueHandle; - t.bits = 64; // handle uses 64 bit by default. - scan = s.c_str() + 6; - } else if (s == "bool") { - t.code = kDLUInt; - t.bits = 1; - t.lanes = 1; - return t; - } else if (s.substr(0, 6) == "custom") { - t.code = ParseCustomDatatype(s, &scan); - } else { - scan = s.c_str(); - LOG(FATAL) << "unknown type " << s; - } - char* xdelim; // emulate sscanf("%ux%u", bits, lanes) - uint8_t bits = static_cast(strtoul(scan, &xdelim, 10)); - if (bits != 0) t.bits = bits; - char* endpt = xdelim; - if (*xdelim == 'x') { - t.lanes = static_cast(strtoul(xdelim + 1, &endpt, 10)); - } - CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s; - return t; -} inline TVMArgValue TVMArgs::operator[](int i) const { CHECK_LT(i, num_args) diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 8357a70b720b..4a815ffd5d7d 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -91,6 +91,8 @@ TEST(PackedFunc, str) { CHECK(args.num_args == 1); std::string x = args[0]; CHECK(x == "hello"); + String y = args[0]; + CHECK(y == "hello"); *rv = x; })("hello"); }