Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME][IR] Allow non-nullable ObjectRef, introduce Optional<T>. #5314

Merged
merged 4 commits into from
Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ namespace tvm {
*/
template<typename TObjectRef>
inline TObjectRef NullValue() {
static_assert(TObjectRef::_type_is_nullable,
"Can only get NullValue for nullable types");
return TObjectRef(ObjectPtr<Object>(nullptr));
}

Expand Down
67 changes: 63 additions & 4 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,47 @@ class FloatImm : public PrimExpr {
TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
};

/*!
* \brief Boolean constant.
*
* This reference type is useful to add additional compile-time
* type checks and helper functions for Integer equal comparisons.
*/
class Bool : public IntImm {
public:
explicit Bool(bool value)
: IntImm(DataType::Bool(), value) {
}
Bool operator!() const {
return Bool((*this)->value == 0);
}
operator bool() const {
return (*this)->value != 0;
}

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode);
};

// Overload operators to make sure we have the most fine grained types.
inline Bool operator||(const Bool& a, bool b) {
return Bool(a.operator bool() || b);
}
inline Bool operator||(bool a, const Bool& b) {
return Bool(a || b.operator bool());
}
inline Bool operator||(const Bool& a, const Bool& b) {
return Bool(a.operator bool() || b.operator bool());
}
inline Bool operator&&(const Bool& a, bool b) {
return Bool(a.operator bool() && b);
}
inline Bool operator&&(bool a, const Bool& b) {
return Bool(a && b.operator bool());
}
inline Bool operator&&(const Bool& a, const Bool& b) {
return Bool(a.operator bool() && b.operator bool());
}

/*!
* \brief Container of constant int that adds more constructors.
*
Expand Down Expand Up @@ -346,10 +387,10 @@ class Integer : public IntImm {
* \tparam Enum The enum type.
* \param value The enum value.
*/
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
explicit Integer(ENum value) : Integer(static_cast<int>(value)) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
template<typename Enum,
typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
explicit Integer(Enum value) : Integer(static_cast<int>(value)) {
static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value,
"declare enum to be enum int to use visitor");
}
/*!
Expand All @@ -368,6 +409,24 @@ class Integer : public IntImm {
<< " Trying to reference a null Integer";
return (*this)->value;
}
// comparators
Bool operator==(int other) const {
if (data_ == nullptr) return Bool(false);
return Bool((*this)->value == other);
}
Bool operator!=(int other) const {
return !(*this == other);
}
template<typename Enum,
typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
Bool operator==(Enum other) const {
return *this == static_cast<int>(other);
}
template<typename Enum,
typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
Bool operator!=(Enum other) const {
return *this != static_cast<int>(other);
}
};

/*! \brief range over one dimension */
Expand Down
19 changes: 13 additions & 6 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/expr.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/container.h>
#include <type_traits>
#include <string>

Expand Down Expand Up @@ -90,25 +91,31 @@ class BaseFuncNode : public RelayExprNode {
* \code
*
* void GetAttrExample(const BaseFunc& f) {
* Integer value = f->GetAttr<Integer>("AttrKey", 0);
* auto value = f->GetAttr<Integer>("AttrKey", 0);
* }
*
* \endcode
*/
template<typename TObjectRef>
TObjectRef GetAttr(const std::string& attr_key,
TObjectRef default_value = NullValue<TObjectRef>()) const {
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!attrs.defined()) return default_value;
auto it = attrs->dict.find(attr_key);
if (it != attrs->dict.end()) {
return Downcast<TObjectRef>((*it).second);
return Downcast<Optional<TObjectRef>>((*it).second);
} else {
return default_value;
}
}

// variant that uses TObjectRef to enable implicit conversion to default value.
template<typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
/*!
* \brief Check whether the function has an non-zero integer attr.
*
Expand All @@ -129,7 +136,7 @@ class BaseFuncNode : public RelayExprNode {
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0)->value != 0;
return GetAttr<Integer>(attr_key, 0) != 0;
}

static constexpr const char* _type_key = "BaseFunc";
Expand Down
1 change: 0 additions & 1 deletion include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ using runtime::make_object;
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::String;

} // namespace tvm
#endif // TVM_NODE_NODE_H_
138 changes: 134 additions & 4 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ class StringObj : public Object {
*/
class String : public ObjectRef {
public:
/*!
* \brief Construct an empty string.
*/
String() : String(std::string()) {}
/*!
* \brief Construct a new String object
*
Expand Down Expand Up @@ -467,9 +471,6 @@ class String : public ObjectRef {
*/
size_t size() const {
const auto* ptr = get();
if (ptr == nullptr) {
return 0;
}
return ptr->size;
}

Expand Down Expand Up @@ -524,7 +525,7 @@ class String : public ObjectRef {
/*! \return the internal StringObj pointer */
const StringObj* get() const { return operator->(); }

TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);

private:
/*!
Expand Down Expand Up @@ -610,7 +611,136 @@ struct PackedFuncValueConverter<::tvm::runtime::String> {
}
};

/*!
* \brief Optional container that to represent to a Nullable variant of T.
* \tparam T The original ObjectRef.
*
* \code
*
* Optional<String> opt0 = nullptr;
* Optional<String> opt1 = String("xyz");
* CHECK(opt0 == nullptr);
* CHECK(opt1 == "xyz");
*
* \endcode
*/
template<typename T>
class Optional : public ObjectRef {
public:
using ContainerType = typename T::ContainerType;
static_assert(std::is_base_of<ObjectRef, T>::value,
"Optional is only defined for ObjectRef.");
// default constructors.
Optional() = default;
Optional(const Optional<T>&) = default;
Optional(Optional<T>&&) = default;
Optional<T>& operator=(const Optional<T>&) = default;
Optional<T>& operator=(Optional<T>&&) = default;
/*!
* \brief Construct from an ObjectPtr
* whose type already matches the ContainerType.
* \param ptr
*/
explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
// nullptr handling.
// disallow implicit conversion as 0 can be implicitly converted to nullptr_t
explicit Optional(std::nullptr_t) {}
Optional<T>& operator=(std::nullptr_t) {
data_ = nullptr;
return *this;
}
// normal value handling.
Optional(T other) // NOLINT(*)
: ObjectRef(std::move(other)) {
}
Optional<T>& operator=(T other) {
ObjectRef::operator=(std::move(other));
return *this;
}
// delete the int constructor
// since Optional<Integer>(0) is ambiguious
// 0 can be implicitly casted to nullptr_t
explicit Optional(int val) = delete;
Optional<T>& operator=(int val) = delete;
/*!
* \return A not-null container value in the optional.
* \note This function performs not-null checking.
*/
T value() const {
CHECK(data_ != nullptr);
return T(data_);
}
tqchen marked this conversation as resolved.
Show resolved Hide resolved
/*!
* \return The contained value if the Optional is not null
* otherwise return the default_value.
*/
T value_or(T default_value) const {
tqchen marked this conversation as resolved.
Show resolved Hide resolved
return data_ != nullptr ? T(data_) : default_value;
}
/*! \return Whether the container is not nullptr.*/
explicit operator bool() const {
tqchen marked this conversation as resolved.
Show resolved Hide resolved
return *this != nullptr;
}
// operator overloadings
bool operator==(std::nullptr_t) const {
return data_ == nullptr;
}
bool operator!=(std::nullptr_t) const {
return data_ != nullptr;
}
auto operator==(const Optional<T>& other) const {
// support case where sub-class returns a symbolic ref type.
using RetType = decltype(value() == other.value());
if (same_as(other)) return RetType(true);
if (*this != nullptr && other != nullptr) {
return value() == other.value();
} else {
// one of them is nullptr.
return RetType(false);
}
}
auto operator!=(const Optional<T>& other) const {
return !(*this == other);
}
auto operator==(const T& other) const {
using RetType = decltype(value() == other);
if (same_as(other)) return RetType(true);
if (*this != nullptr) return value() == other;
return RetType(false);
}
auto operator!=(const T& other) const {
return !(*this == other);
tqchen marked this conversation as resolved.
Show resolved Hide resolved
}
template<typename U>
auto operator==(const U& other) const {
using RetType = decltype(value() == other);
if (*this == nullptr) return RetType(false);
return value() == other;
}
template<typename U>
auto operator!=(const U& other) const {
return !(*this == other);
}
static constexpr bool _type_is_nullable = true;
};

template<typename T>
struct PackedFuncValueConverter<Optional<T>> {
static Optional<T> From(const TVMArgValue& val) {
if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
return PackedFuncValueConverter<T>::From(val);
}
static Optional<T> From(const TVMRetValue& val) {
if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
return PackedFuncValueConverter<T>::From(val);
}
};

} // namespace runtime

// expose the functions to the root namespace.
using runtime::String;
using runtime::Optional;
} // namespace tvm

namespace std {
Expand Down
Loading