diff --git a/apps/extension/Makefile b/apps/extension/Makefile index 14c71d92ca203..1680a003e06f2 100644 --- a/apps/extension/Makefile +++ b/apps/extension/Makefile @@ -20,8 +20,7 @@ TVM_ROOT=$(shell cd ../..; pwd) PKG_CFLAGS = -std=c++11 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\ - -I${TVM_ROOT}/3rdparty/dlpack/include\ - -I${TVM_ROOT}/3rdparty/HalideIR/src + -I${TVM_ROOT}/3rdparty/dlpack/include PKG_LDFLAGS =-L${TVM_ROOT}/build UNAME_S := $(shell uname -s) diff --git a/apps/extension/python/tvm_ext/__init__.py b/apps/extension/python/tvm_ext/__init__.py index 38d511eeb6170..7404a717f7788 100644 --- a/apps/extension/python/tvm_ext/__init__.py +++ b/apps/extension/python/tvm_ext/__init__.py @@ -38,18 +38,9 @@ def load_lib(): ivec_create = tvm.get_global_func("tvm_ext.ivec_create") ivec_get = tvm.get_global_func("tvm_ext.ivec_get") -class IntVec(object): +@tvm.register_object("tvm_ext.IntVector") +class IntVec(tvm.Object): """Example for using extension class in c++ """ - _tvm_tcode = 17 - - def __init__(self, handle): - self.handle = handle - - def __del__(self): - # You can also call your own customized - # deleter if you can free it via your own FFI. - tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode) - @property def _tvm_handle(self): return self.handle.value @@ -57,9 +48,6 @@ def _tvm_handle(self): def __getitem__(self, idx): return ivec_get(self, idx) -# Register IntVec extension on python side. -tvm.register_extension(IntVec, IntVec) - nd_create = tvm.get_global_func("tvm_ext.nd_create") nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two") diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index 8655fa7d0c305..788c28da18d34 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,17 +30,12 @@ #include namespace tvm_ext { -using IntVector = std::vector; class NDSubClass; } // namespace tvm_ext namespace tvm { namespace runtime { template<> -struct extension_type_info { - static const int code = 17; -}; -template<> struct array_type_info { static const int code = 1; }; @@ -104,24 +99,47 @@ class NDSubClass : public tvm::runtime::NDArray { return self->addtional_info_; } }; + + +/*! + * \brief Introduce additional extension data structures + * by sub-classing TVM's object system. + */ +class IntVectorObj : public Object { + public: + std::vector vec; + + static constexpr const char* _type_key = "tvm_ext.IntVector"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntVectorObj, Object); +}; + +/*! + * \brief Int vector reference class. + */ +class IntVector : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(IntVector, ObjectRef, IntVectorObj); +}; + +TVM_REGISTER_OBJECT_TYPE(IntVectorObj); + } // namespace tvm_ext namespace tvm_ext { -TVM_REGISTER_EXT_TYPE(IntVector); - TVM_REGISTER_GLOBAL("tvm_ext.ivec_create") .set_body([](TVMArgs args, TVMRetValue *rv) { - IntVector vec; + auto n = tvm::runtime::make_object(); for (int i = 0; i < args.size(); ++i) { - vec.push_back(args[i].operator int()); + n->vec.push_back(args[i].operator int()); } - *rv = vec; + *rv = IntVector(n); }); TVM_REGISTER_GLOBAL("tvm_ext.ivec_get") .set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = args[0].AsExtension()[args[1].operator int()]; + IntVector p = args[0]; + *rv = p->vec[args[1].operator int()]; }); diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 5053326058bc6..dda2a98dac22b 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -234,14 +234,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, int query_imports, TVMFunctionHandle *out); -/*! - * \brief Free front-end extension type resource. - * \param handle The extension handle. - * \param type_code The type of of the extension type. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMExtTypeFree(void* handle, int type_code); - /*! * \brief Free the Module * \param mod The module to be freed. diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 1d7db66ec5709..27dcb4130b4c2 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -387,7 +387,6 @@ inline std::string TVMType2String(TVMType t); #define TVM_CHECK_TYPE_CODE(CODE, T) \ CHECK_EQ(CODE, T) << " expected " \ << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ - /*! * \brief Type traits to mark if a class is tvm extension type. * @@ -404,34 +403,6 @@ struct extension_type_info { static const int code = 0; }; -/*! - * \brief Runtime function table about extension type. - */ -class ExtTypeVTable { - public: - /*! \brief function to be called to delete a handle */ - void (*destroy)(void* handle); - /*! \brief function to be called when clone a handle */ - void* (*clone)(void* handle); - /*! - * \brief Register type - * \tparam T The type to be register. - * \return The registered vtable. - */ - template - static inline ExtTypeVTable* Register_(); - /*! - * \brief Get a vtable based on type code. - * \param type_code The type code - * \return The registered vtable. - */ - TVM_DLL static ExtTypeVTable* Get(int type_code); - - private: - // Internal registration function. - TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt); -}; - /*! * \brief Internal base class to * handle conversion to POD values. @@ -518,11 +489,6 @@ class TVMPODValue_ { CHECK_EQ(container->array_type_code_, array_type_info::code); return TNDArray(container); } - template - const TExtension& AsExtension() const { - CHECK_LT(type_code_, kExtEnd); - return static_cast(value_.v_handle)[0]; - } template::value>::type> @@ -867,20 +833,8 @@ class TVMRetValue : public TVMPODValue_ { break; } default: { - if (other.type_code() < kExtBegin) { - SwitchToPOD(other.type_code()); - value_ = other.value_; - } else { -#if TVM_RUNTIME_HEADER_ONLY - LOG(FATAL) << "Header only mode do not support ext type"; -#else - this->Clear(); - type_code_ = other.type_code(); - value_.v_handle = - (*(ExtTypeVTable::Get(other.type_code())->clone))( - other.value().v_handle); -#endif - } + SwitchToPOD(other.type_code()); + value_ = other.value_; break; } } @@ -931,13 +885,6 @@ class TVMRetValue : public TVMPODValue_ { break; } } - if (type_code_ > kExtBegin) { -#if TVM_RUNTIME_HEADER_ONLY - LOG(FATAL) << "Header only mode do not support ext type"; -#else - (*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle); -#endif - } type_code_ = kNull; } }; @@ -1317,23 +1264,16 @@ inline R TypedPackedFunc::operator()(Args... args) const { // extension and node type handling namespace detail { -template +template struct TVMValueCast { static T Apply(const TSrc* self) { - static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions"); + static_assert(!is_nd, "The default case accepts only non-extensions"); return self->template AsObjectRef(); } }; template -struct TVMValueCast { - static T Apply(const TSrc* self) { - return self->template AsExtension(); - } -}; - -template -struct TVMValueCast { +struct TVMValueCast { static T Apply(const TSrc* self) { return self->template AsNDArray(); } @@ -1345,7 +1285,6 @@ template inline TVMArgValue::operator T() const { return detail:: TVMValueCast::code != 0), (array_type_info::code > 0)> ::Apply(this); } @@ -1354,19 +1293,10 @@ template inline TVMRetValue::operator T() const { return detail:: TVMValueCast::code != 0), (array_type_info::code > 0)> ::Apply(this); } -template -inline void TVMArgsSetter::operator()(size_t i, const T& value) const { - static_assert(extension_type_info::code != 0, - "Need to have extesion code"); - type_codes_[i] = extension_type_info::code; - values_[i].v_handle = const_cast(&value); -} - // PackedFunc support inline TVMRetValue& TVMRetValue::operator=(const DataType& t) { return this->operator=(t.operator DLDataType()); @@ -1385,28 +1315,6 @@ inline void TVMArgsSetter::operator()( this->operator()(i, t.operator DLDataType()); } -// extension type handling -template -struct ExtTypeInfo { - static void destroy(void* handle) { - delete static_cast(handle); - } - static void* clone(void* handle) { - return new T(*static_cast(handle)); - } -}; - -template -inline ExtTypeVTable* ExtTypeVTable::Register_() { - const int code = extension_type_info::code; - static_assert(code != 0, - "require extension_type_info traits to be declared with non-zero code"); - ExtTypeVTable vt; - vt.clone = ExtTypeInfo::clone; - vt.destroy = ExtTypeInfo::destroy; - return ExtTypeVTable::RegisterInternal(code, vt); -} - inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { return (*this)->GetFunction(name, query_imports); } diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index 56bf4a00080cb..1773d916722b5 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -299,20 +299,6 @@ def copyto(self, target): raise ValueError("Unsupported target type %s" % str(type(target))) -def free_extension_handle(handle, type_code): - """Free c++ extension type handle - - Parameters - ---------- - handle : ctypes.c_void_p - The handle to the extension type. - - type_code : int - The tyoe code - """ - check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code))) - - def register_extension(cls, fcreate=None): """Register a extension class to TVM. diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index 9a00f78eb77fa..2a7a532e660eb 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -26,7 +26,7 @@ from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase from ._ffi.ndarray import context, empty, from_dlpack from ._ffi.ndarray import _set_class_ndarray -from ._ffi.ndarray import register_extension, free_extension_handle +from ._ffi.ndarray import register_extension class NDArray(NDArrayBase): """Lightweight NDArray class of TVM runtime. diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index ce6a281a6ead3..4717d89e33c15 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -40,15 +40,10 @@ struct Registry::Manager { // and the resource can become invalid because of indeterminstic order of destruction. // The resources will only be recycled during program exit. std::unordered_map fmap; - // vtable for extension type - std::array ext_vtable; // mutex std::mutex mutex; Manager() { - for (auto& x : ext_vtable) { - x.destroy = nullptr; - } } static Manager* Global() { @@ -109,24 +104,6 @@ std::vector Registry::ListNames() { return keys; } -ExtTypeVTable* ExtTypeVTable::Get(int type_code) { - CHECK(type_code > kExtBegin && type_code < kExtEnd); - Registry::Manager* m = Registry::Manager::Global(); - ExtTypeVTable* vt = &(m->ext_vtable[type_code]); - CHECK(vt->destroy != nullptr) - << "Extension type not registered"; - return vt; -} - -ExtTypeVTable* ExtTypeVTable::RegisterInternal( - int type_code, const ExtTypeVTable& vt) { - CHECK(type_code > kExtBegin && type_code < kExtEnd); - Registry::Manager* m = Registry::Manager::Global(); - std::lock_guard lock(m->mutex); - ExtTypeVTable* pvt = &(m->ext_vtable[type_code]); - pvt[0] = vt; - return pvt; -} } // namespace runtime } // namespace tvm @@ -141,12 +118,6 @@ struct TVMFuncThreadLocalEntry { /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore TVMFuncThreadLocalStore; -int TVMExtTypeFree(void* handle, int type_code) { - API_BEGIN(); - tvm::runtime::ExtTypeVTable::Get(type_code)->destroy(handle); - API_END(); -} - int TVMFuncRegisterGlobal( const char* name, TVMFunctionHandle f, int override) { API_BEGIN();