Skip to content

Commit

Permalink
[Runtime][Object] Add String container
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Feb 4, 2020
1 parent cf173fd commit 2a5bbb9
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 0 deletions.
176 changes: 176 additions & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>

#include <cstring>
#include <initializer_list>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -274,7 +276,181 @@ class ADT : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj);
};

/*! \brief An object representing string. It's POD type. */
class StringObj : public Object {
public:
/*! \brief The length of the string object. */
uint32_t size;

/*! \brief The pointer to string data. */
const char* data;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "String";
TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);

private:
/*! \brief String object which is moved from std::string container. */
class FromStd;

friend class String;
};

/*! \brief reference to string objects. */
class String : public ObjectRef {
public:
/*!
* \brief Construct a new String object
*
* \param other The moved/copied std::string object
*
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
inline explicit String(std::string other);

/*!
* \brief Compare is equal to other std::string
*
* \param other The other string
*/
bool operator==(const std::string& other) const {
return size() == other.size() &&
other.compare(0, other.size(), get()->data, size()) == 0;
}

/*!
* \brief Compare is not equal to other std::string
*
* \param other The other string
*/
bool operator!=(const std::string& other) const { return !operator==(other); }

/*!
* \brief Compare is equal to other char string
*
* \param other The other char string
*/
inline bool operator==(const char* other) const;

/*!
* \brief Compare is not equal to other char string
*
* \param other The other char string
*/
inline bool operator!=(const char* other) const;

/*!
* \brief Compares this to other for at most len chars
*
* \return zero if both char sequences compare equal. negative if this appear
* before other, positive otherwise.
*/
int compare(const String& other, size_t len) const {
return compare(other.data(), len);
}

/*!
* \brief Compares this to other for at most len chars
*
* \return zero if both char sequences compare equal. negative if this appear
* before other, positive otherwise.
*/
int compare(const std::string& other, size_t len) const {
return compare(other.data(), len);
}

/*!
* \brief Compares this to other for at most len chars
*
* \return zero if both char sequences compare equal. negative if this appear
* before other, positive otherwise.
*/
int compare(const char* other, size_t len) const {
return std::strncmp(get()->data, other, len);
}

/*!
* \brief Returns a pointer to the char array in the string.
*
* \return const char*
*/
inline const char* c_str() const { return get()->data; }

/*!
* \brief Return the length of the string
*
* \return size_t string length
*/
inline size_t size() const {
const auto* ptr = get();
if (ptr == nullptr) {
return 0;
}
return ptr->size;
}

/*!
* \brief Return the length of the string
*
* \return size_t string length
*/
inline size_t length() const { return size(); }

/*!
* \brief Return the data pointer
*
* \return const char* data pointer
*/
inline const char* data() const { return get()->data; }

/*! \return the internal StringObj pointer */
const StringObj* get() const { return operator->(); }

/*!
* \brief Convert String to an std::sting object
*
* \return std::string
*/
operator std::string() const { return std::string{get()->data, size()}; }

TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
};

/*! \brief An object representing string moved from std::string. */
class StringObj::FromStd : public StringObj {
public:
/*! \brief Container that holds the memory. */
std::string data_container;
};

inline String::String(std::string other) {
auto ptr = make_object<StringObj::FromStd>();
ptr->data_container.swap(other);
ptr->size = ptr->data_container.size();
ptr->data = ptr->data_container.data();
data_ = std::move(ptr);
}

inline bool String::operator==(const char* other) const {
return !operator!=(other);
}

inline bool String::operator!=(const char* other) const {
return size() != std::strlen(other) || compare(other, size()) != 0;
}

} // namespace runtime
} // namespace tvm

namespace std {

template <>
struct hash<::tvm::runtime::String> {
std::size_t operator()(const ::tvm::runtime::String& str) const {
return std::hash<std::string>()((std::string)str);
}
};
} // namespace std

#endif // TVM_RUNTIME_CONTAINER_H_
57 changes: 57 additions & 0 deletions tests/cpp/container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,63 @@ TEST(Map, Iterator) {
CHECK(map2[a].as<IntImmNode>()->value == 2);
}

TEST(String, MoveFromStd) {
using namespace std;
std::string source = "this is a string";
std::string expect = source;
String s(std::move(source));
std::string copy = (std::string)s;
CHECK_EQ(copy, expect);
CHECK_EQ(source.size(), 0);
}

TEST(String, CopyFromStd) {
using namespace std;
std::string source = "this is a string";
std::string expect = source;
String s{source};
std::string copy = (std::string)s;
CHECK_EQ(copy, expect);
CHECK_EQ(source.size(), expect.size());
}

TEST(String, Comparisons) {
using namespace std;
std::string source = "a string";
std::string mismatch = "a string but longer";
String s{source};

CHECK_EQ(s == source, true);
CHECK_EQ(s == mismatch, false);
CHECK_EQ(s == source.data(), true);
CHECK_EQ(s == mismatch.data(), false);
}

TEST(String, c_str) {
using namespace std;
std::string source = "this is a string";
std::string mismatch = "mismatch";
String s{source};

CHECK_EQ(std::strcmp(s.c_str(), source.data()), 0);
CHECK_NE(std::strcmp(s.c_str(), mismatch.data()), 0);
}

TEST(String, hash) {
using namespace std;
std::string source = "this is a string";
String s{source};
std::hash<String>()(s);
}

TEST(String, Cast) {
using namespace std;
std::string source = "this is a string";
String s{source};
ObjectRef r = s;
String s2 = Downcast<String>(r);
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
Expand Down

0 comments on commit 2a5bbb9

Please sign in to comment.