Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][RUNTIME] Update NDArray use the Unified Object System #4581

Merged
merged 2 commits into from
Dec 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions apps/extension/python/tvm_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,23 @@ def __getitem__(self, idx):

nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info")
nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info")

@tvm.register_object("tvm_ext.NDSubClass")
class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.
By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_code = 1

@staticmethod
def create(addtional_info):
return nd_create(addtional_info)
def create(additional_info):
return nd_create(additional_info)

@property
def addtional_info(self):
return nd_get_addtional_info(self)
def additional_info(self):
return nd_get_additional_info(self)

def __add__(self, other):
return nd_add_two(self, other)

tvm.register_extension(NDSubClass, NDSubClass)
80 changes: 35 additions & 45 deletions apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,6 @@
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>

namespace tvm_ext {
class NDSubClass;
} // namespace tvm_ext

namespace tvm {
namespace runtime {
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
} // namespace runtime

using namespace tvm;
using namespace tvm::runtime;

Expand All @@ -52,54 +39,55 @@ namespace tvm_ext {
* To use this extension, an external library should
*
* 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_info` for this class.
*
* 2) Define a constructor in the inherited class that accepts
* a pointer to TVM's Container, which is nullable.
* 2) Follow the new object protocol to define new NDArray as a reference class.
*
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`,
* define the class attribute `_array_type_code` consistent to
* the C++ type trait, and register the subclass using `tvm.register_extension`.
* 3) On Python frontend, inherit `tvm.nd.NDArray`,
* register the type using tvm.register_object
*/
class NDSubClass : public tvm::runtime::NDArray {
public:
class SubContainer : public NDArray::Container {
public:
SubContainer(int addtional_info) :
addtional_info_(addtional_info) {
array_type_code_ = array_type_info<NDSubClass>::code;
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
return c->array_type_code_ == array_type_info<NDSubClass>::code;
SubContainer(int additional_info) :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment between line 47 and line 49 needs to be updated.

additional_info_(additional_info) {
type_index_ = SubContainer::RuntimeTypeIndex();
}
int addtional_info_{0};
int additional_info_{0};

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "tvm_ext.NDSubClass";
TVM_DECLARE_FINAL_OBJECT_INFO(SubContainer, NDArray::Container);
};
NDSubClass(NDArray::Container *container) {
if (container == nullptr) {
data_ = nullptr;
return;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;

static void SubContainerDeleter(Object* obj) {
auto* ptr = static_cast<SubContainer*>(obj);
delete ptr;
}
~NDSubClass() {
this->reset();

NDSubClass() {}
explicit NDSubClass(ObjectPtr<Object> n) : NDArray(n) {}
explicit NDSubClass(int additional_info) {
SubContainer* ptr = new SubContainer(additional_info);
ptr->SetDeleter(SubContainerDeleter);
data_ = GetObjectPtr<Object>(ptr);
}

NDSubClass AddWith(const NDSubClass &other) const {
SubContainer *a = static_cast<SubContainer*>(data_);
SubContainer *b = static_cast<SubContainer*>(other.data_);
SubContainer *a = static_cast<SubContainer*>(get_mutable());
SubContainer *b = static_cast<SubContainer*>(other.get_mutable());
CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_));
return NDSubClass(a->additional_info_ + b->additional_info_);
}
int get_additional_info() const {
SubContainer *self = static_cast<SubContainer*>(data_);
SubContainer *self = static_cast<SubContainer*>(get_mutable());
CHECK(self != nullptr);
return self->addtional_info_;
return self->additional_info_;
}
using ContainerType = SubContainer;
};

TVM_REGISTER_OBJECT_TYPE(NDSubClass::SubContainer);

/*!
* \brief Introduce additional extension data structures
Expand Down Expand Up @@ -166,8 +154,10 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")

TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int addtional_info = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(addtional_info));
int additional_info = args[0];
*rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kNDArrayContainer);

});

TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
Expand All @@ -177,7 +167,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
*rv = a.AddWith(b);
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info")
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
*rv = a.get_additional_info();
Expand Down
15 changes: 8 additions & 7 deletions apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,17 @@ def check_llvm():


def test_nd_subclass():
a = tvm_ext.NDSubClass.create(addtional_info=3)
b = tvm_ext.NDSubClass.create(addtional_info=5)
a = tvm_ext.NDSubClass.create(additional_info=3)
b = tvm_ext.NDSubClass.create(additional_info=5)
assert isinstance(a, tvm_ext.NDSubClass)
c = a + b
d = a + a
e = b + b
assert(a.addtional_info == 3)
assert(b.addtional_info == 5)
assert(c.addtional_info == 8)
assert(d.addtional_info == 6)
assert(e.addtional_info == 10)
assert(a.additional_info == 3)
assert(b.additional_info == 5)
assert(c.additional_info == 8)
assert(d.additional_info == 6)
assert(e.additional_info == 10)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
#ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_

#include <tvm/node/node.h>

#include <type_traits>
#include <vector>
#include <initializer_list>
#include <unordered_map>
#include <utility>
#include <string>
#include "node.h"
#include "memory.h"

namespace tvm {

Expand Down
98 changes: 15 additions & 83 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_

#include <sstream>
#include <string>
#include <memory>
#include <limits>
Expand All @@ -43,22 +42,7 @@ using runtime::TVMRetValue;
using runtime::PackedFunc;

namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct ObjectTypeChecker {
static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType;
if (ptr == nullptr) return true;
return ptr->IsInstance<ContainerType>();
}
static void PrintName(std::ostream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};


template<typename T>
struct ObjectTypeChecker<Array<T> > {
Expand All @@ -73,10 +57,8 @@ struct ObjectTypeChecker<Array<T> > {
}
return true;
}
static void PrintName(std::ostream& os) { // NOLINT(*)
os << "List[";
ObjectTypeChecker<T>::PrintName(os);
os << "]";
static std::string TypeName() {
return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
}
};

Expand All @@ -91,11 +73,9 @@ struct ObjectTypeChecker<Map<std::string, V> > {
}
return true;
}
static void PrintName(std::ostream& os) { // NOLINT(*)
os << "Map[str";
os << ',';
ObjectTypeChecker<V>::PrintName(os);
os << ']';
static std::string TypeName() {
return "Map[str, " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};

Expand All @@ -111,39 +91,16 @@ struct ObjectTypeChecker<Map<K, V> > {
}
return true;
}
static void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "Map[";
ObjectTypeChecker<K>::PrintName(os);
os << ',';
ObjectTypeChecker<V>::PrintName(os);
os << ']';
static std::string TypeName() {
return "Map[" +
ObjectTypeChecker<K>::TypeName() +
", " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};

template<typename T>
inline std::string ObjectTypeName() {
std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os);
return os.str();
}

// extensions for tvm arg value

template<typename TObjectRef>
inline TObjectRef TVMArgValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Node>(ptr));
}

inline TVMArgValue::operator tvm::Expr() const {
inline TVMPODValue_::operator tvm::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
Expand All @@ -164,12 +121,12 @@ inline TVMArgValue::operator tvm::Expr() const {
return Tensor(ObjectPtr<Node>(ptr))();
}
CHECK(ObjectTypeChecker<Expr>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>()
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Expr(ObjectPtr<Node>(ptr));
}

inline TVMArgValue::operator tvm::Integer() const {
inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
Expand All @@ -179,35 +136,10 @@ inline TVMArgValue::operator tvm::Integer() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>()
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Node>(ptr));
}

template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return ObjectTypeChecker<TObjectRef>::Check(ptr);
}

// extensions for TVMRetValue
template<typename TObjectRef>
inline TObjectRef TVMRetValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef();
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);

Object* ptr = static_cast<Object*>(value_.v_handle);

CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Object>(ptr));
}

} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_
1 change: 1 addition & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#ifndef TVM_RUNTIME_CONTAINER_H_
#define TVM_RUNTIME_CONTAINER_H_

#include <dmlc/logging.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
Expand Down
Loading