diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 92d3e7149463f..2c264cc557962 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -28,7 +28,9 @@ #include #include +#include #include +#include #include #include #include @@ -274,7 +276,181 @@ class ADT : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); }; +/*! \brief An object representing string. It's POD type. */ +class StringObj : public Object { + public: + /*! \brief The length of the string object. */ + uint32_t size; + + /*! \brief The pointer to string data. */ + const char* data; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "String"; + TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object); + + private: + /*! \brief String object which is moved from std::string container. */ + class FromStd; + + friend class String; +}; + +/*! \brief reference to string objects. */ +class String : public ObjectRef { + public: + /*! + * \brief Construct a new String object + * + * \param other The moved/copied std::string object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + inline explicit String(std::string other); + + /*! + * \brief Compare is equal to other std::string + * + * \param other The other string + */ + bool operator==(const std::string& other) const { + return size() == other.size() && + other.compare(0, other.size(), get()->data, size()) == 0; + } + + /*! + * \brief Compare is not equal to other std::string + * + * \param other The other string + */ + bool operator!=(const std::string& other) const { return !operator==(other); } + + /*! + * \brief Compare is equal to other char string + * + * \param other The other char string + */ + inline bool operator==(const char* other) const; + + /*! + * \brief Compare is not equal to other char string + * + * \param other The other char string + */ + inline bool operator!=(const char* other) const; + + /*! + * \brief Compares this to other for at most len chars + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const String& other, size_t len) const { + return compare(other.data(), len); + } + + /*! + * \brief Compares this to other for at most len chars + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const std::string& other, size_t len) const { + return compare(other.data(), len); + } + + /*! + * \brief Compares this to other for at most len chars + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const char* other, size_t len) const { + return std::strncmp(get()->data, other, len); + } + + /*! + * \brief Returns a pointer to the char array in the string. + * + * \return const char* + */ + inline const char* c_str() const { return get()->data; } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + inline size_t size() const { + const auto* ptr = get(); + if (ptr == nullptr) { + return 0; + } + return ptr->size; + } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + inline size_t length() const { return size(); } + + /*! + * \brief Return the data pointer + * + * \return const char* data pointer + */ + inline const char* data() const { return get()->data; } + + /*! \return the internal StringObj pointer */ + const StringObj* get() const { return operator->(); } + + /*! + * \brief Convert String to an std::sting object + * + * \return std::string + */ + operator std::string() const { return std::string{get()->data, size()}; } + + TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); +}; + +/*! \brief An object representing string moved from std::string. */ +class StringObj::FromStd : public StringObj { + public: + /*! \brief Container that holds the memory. */ + std::string data_container; +}; + +inline String::String(std::string other) { + auto ptr = make_object(); + ptr->data_container.swap(other); + ptr->size = ptr->data_container.size(); + ptr->data = ptr->data_container.data(); + data_ = std::move(ptr); +} + +inline bool String::operator==(const char* other) const { + return !operator!=(other); +} + +inline bool String::operator!=(const char* other) const { + return size() != std::strlen(other) || compare(other, size()) != 0; +} + } // namespace runtime } // namespace tvm +namespace std { + +template <> +struct hash<::tvm::runtime::String> { + std::size_t operator()(const ::tvm::runtime::String& str) const { + return std::hash()((std::string)str); + } +}; +} // namespace std + #endif // TVM_RUNTIME_CONTAINER_H_ diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 3e6ef21386250..b6f66a8838235 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -226,6 +226,63 @@ TEST(Map, Iterator) { CHECK(map2[a].as()->value == 2); } +TEST(String, MoveFromStd) { + using namespace std; + std::string source = "this is a string"; + std::string expect = source; + String s(std::move(source)); + std::string copy = (std::string)s; + CHECK_EQ(copy, expect); + CHECK_EQ(source.size(), 0); +} + +TEST(String, CopyFromStd) { + using namespace std; + std::string source = "this is a string"; + std::string expect = source; + String s{source}; + std::string copy = (std::string)s; + CHECK_EQ(copy, expect); + CHECK_EQ(source.size(), expect.size()); +} + +TEST(String, Comparisons) { + using namespace std; + std::string source = "a string"; + std::string mismatch = "a string but longer"; + String s{source}; + + CHECK_EQ(s == source, true); + CHECK_EQ(s == mismatch, false); + CHECK_EQ(s == source.data(), true); + CHECK_EQ(s == mismatch.data(), false); +} + +TEST(String, c_str) { + using namespace std; + std::string source = "this is a string"; + std::string mismatch = "mismatch"; + String s{source}; + + CHECK_EQ(std::strcmp(s.c_str(), source.data()), 0); + CHECK_NE(std::strcmp(s.c_str(), mismatch.data()), 0); +} + +TEST(String, hash) { + using namespace std; + std::string source = "this is a string"; + String s{source}; + std::hash()(s); +} + +TEST(String, Cast) { + using namespace std; + std::string source = "this is a string"; + String s{source}; + ObjectRef r = s; + String s2 = Downcast(r); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";