From 15bfdb4ba73e645571678f4d5f39bcf108465046 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 8 Nov 2019 14:57:54 -0800 Subject: [PATCH] [RUNTIME][REFACTOR] Use object protocol to support runtime::Module Previously runtime::Module was supported using shared_ptr. This PR refactors the codebase to use the Object protocol. It will open doors to allow easier interpolation between Object containers and module in the future. --- .../app/src/main/jni/tvm_runtime.h | 6 +- .../app/src/main/jni/tvm_runtime.h | 5 +- apps/bundle_deploy/runtime.cc | 5 +- apps/howto_deploy/tvm_runtime_pack.cc | 1 + apps/ios_rpc/tvmrpc/TVMRuntime.mm | 3 +- golang/src/tvm_runtime_pack.cc | 2 +- include/tvm/runtime/module.h | 98 ++++++++++++++----- include/tvm/runtime/object.h | 32 ++++-- include/tvm/runtime/packed_func.h | 75 +++++++------- include/tvm/runtime/vm.h | 4 +- python/tvm/relay/backend/vm.py | 2 +- src/codegen/llvm/llvm_module.cc | 7 +- src/codegen/source_module.cc | 20 ++-- src/relay/backend/build_module.cc | 4 +- src/relay/backend/graph_runtime_codegen.cc | 5 +- src/relay/backend/vm/compiler.cc | 5 +- src/relay/backend/vm/compiler.h | 6 +- src/relay/backend/vm/profiler/compiler.cc | 3 +- src/runtime/c_runtime_api.cc | 20 ++-- src/runtime/cuda/cuda_module.cc | 19 ++-- src/runtime/cuda/cuda_module.h | 5 +- src/runtime/dso_module.cc | 10 +- .../graph/debug/graph_runtime_debug.cc | 13 +-- src/runtime/graph/graph_runtime.cc | 16 +-- src/runtime/graph/graph_runtime.h | 4 +- src/runtime/metal/metal_module.mm | 12 +-- src/runtime/micro/micro_device_api.cc | 18 ++-- src/runtime/micro/micro_module.cc | 19 ++-- src/runtime/micro/micro_session.cc | 18 ++-- src/runtime/micro/micro_session.h | 9 +- src/runtime/micro/tcl_socket.h | 1 - src/runtime/module.cc | 30 ++++-- src/runtime/module_util.cc | 7 +- src/runtime/module_util.h | 7 +- src/runtime/object.cc | 17 +--- src/runtime/object_internal.h | 71 ++++++++++++++ src/runtime/opencl/aocl/aocl_common.h | 5 +- src/runtime/opencl/aocl/aocl_device_api.cc | 5 +- src/runtime/opencl/aocl/aocl_module.h | 5 +- src/runtime/opencl/opencl_common.h | 6 +- src/runtime/opencl/opencl_module.cc | 14 ++- src/runtime/opencl/opencl_module.h | 5 +- src/runtime/opengl/opengl_module.cc | 20 ++-- src/runtime/opengl/opengl_module.h | 4 +- src/runtime/rocm/rocm_module.cc | 16 ++- src/runtime/rpc/rpc_module.cc | 15 ++- src/runtime/rpc/rpc_session.cc | 23 +++-- src/runtime/rpc/rpc_session.h | 5 +- src/runtime/stackvm/stackvm.cc | 5 +- src/runtime/stackvm/stackvm_module.cc | 13 +-- src/runtime/system_lib_module.cc | 13 ++- src/runtime/vm/executable.cc | 5 +- src/runtime/vm/profiler/vm.cc | 5 +- src/runtime/vm/profiler/vm.h | 3 +- src/runtime/vm/vm.cc | 4 +- src/runtime/vulkan/vulkan.cc | 10 +- vta/src/dpi/module.cc | 5 +- web/web_runtime.cc | 5 +- 58 files changed, 447 insertions(+), 323 deletions(-) create mode 100644 src/runtime/object_internal.h diff --git a/apps/android_deploy/app/src/main/jni/tvm_runtime.h b/apps/android_deploy/app/src/main/jni/tvm_runtime.h index 3a909e0449ed..573612b93bc2 100644 --- a/apps/android_deploy/app/src/main/jni/tvm_runtime.h +++ b/apps/android_deploy/app/src/main/jni/tvm_runtime.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2018 by Contributors * \file tvm_runtime.h * \brief Pack all tvm runtime source files */ @@ -35,6 +34,7 @@ #include "../src/runtime/file_util.cc" #include "../src/runtime/dso_module.cc" #include "../src/runtime/thread_pool.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/threading_backend.cc" #include "../src/runtime/ndarray.cc" diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 73a8a484fd11..e30b31629e20 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -55,6 +55,7 @@ #include "../src/runtime/threading_backend.cc" #include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" diff --git a/apps/bundle_deploy/runtime.cc b/apps/bundle_deploy/runtime.cc index 968554b5b34a..f1c2ba2f54ec 100644 --- a/apps/bundle_deploy/runtime.cc +++ b/apps/bundle_deploy/runtime.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -32,5 +32,6 @@ #include "../../src/runtime/threading_backend.cc" #include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/ndarray.cc" +#include "../../src/runtime/object.cc" #include "../../src/runtime/system_lib_module.cc" #include "../../src/runtime/graph/graph_runtime.cc" diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc index 6ebad8177cd5..67c9a9d64716 100644 --- a/apps/howto_deploy/tvm_runtime_pack.cc +++ b/apps/howto_deploy/tvm_runtime_pack.cc @@ -47,6 +47,7 @@ #include "../../src/runtime/threading_backend.cc" #include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/ndarray.cc" +#include "../../src/runtime/object.cc" // NOTE: all the files after this are optional modules // that you can include remove, depending on how much feature you use. diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index 5d1d90e68b32..a98862abd94b 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file TVMRuntime.mm */ #include "TVMRuntime.h" @@ -35,6 +34,8 @@ #include "../../../src/runtime/file_util.cc" #include "../../../src/runtime/dso_module.cc" #include "../../../src/runtime/ndarray.cc" +#include "../../../src/runtime/object.cc" + // RPC server #include "../../../src/runtime/rpc/rpc_session.cc" #include "../../../src/runtime/rpc/rpc_server_env.cc" diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc index cfbe237fd31c..c8be428c2fcb 100644 --- a/golang/src/tvm_runtime_pack.cc +++ b/golang/src/tvm_runtime_pack.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2018 by Contributors * \brief This is an all in one TVM runtime file. * \file tvm_runtime_pack.cc */ @@ -32,6 +31,7 @@ #include "src/runtime/threading_backend.cc" #include "src/runtime/thread_pool.cc" #include "src/runtime/ndarray.cc" +#include "src/runtime/object.cc" // NOTE: all the files after this are optional modules // that you can include remove, depending on how much feature you use. diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 7bbfa4dc2d5a..ff096eec5a43 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -27,28 +27,31 @@ #define TVM_RUNTIME_MODULE_H_ #include + +#include +#include +#include + #include #include #include #include -#include "c_runtime_api.h" namespace tvm { namespace runtime { -// The internal container of module. class ModuleNode; class PackedFunc; /*! * \brief Module container of TVM. */ -class Module { +class Module : public ObjectRef { public: Module() {} // constructor from container. - explicit Module(std::shared_ptr n) - : node_(n) {} + explicit Module(ObjectPtr n) + : ObjectRef(n) {} /*! * \brief Get packed function from current module by name. * @@ -59,10 +62,6 @@ class Module { * \note Implemented in packed_func.cc */ inline PackedFunc GetFunction(const std::string& name, bool query_imports = false); - /*! \return internal container */ - inline ModuleNode* operator->(); - /*! \return internal container */ - inline const ModuleNode* operator->() const; // The following functions requires link with runtime. /*! * \brief Import another module into this module. @@ -71,7 +70,11 @@ class Module { * \note Cyclic dependency is not allowed among modules, * An error will be thrown when cyclic dependency is detected. */ - TVM_DLL void Import(Module other); + inline void Import(Module other); + /*! \return internal container */ + inline ModuleNode* operator->(); + /*! \return internal container */ + inline const ModuleNode* operator->() const; /*! * \brief Load a module from file. * \param file_name The name of the host function module. @@ -81,20 +84,41 @@ class Module { */ TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = ""); - - private: - std::shared_ptr node_; + // refer to the corresponding container. + using ContainerType = ModuleNode; + friend class ModuleNode; }; /*! - * \brief Base node container of module. - * Do not create this directly, instead use Module. + * \brief Base container of module. + * + * Please subclass ModuleNode to create a specific runtime module. + * + * \code + * + * class MyModuleNode : public ModuleNode { + * public: + * // implement the interface + * }; + * + * // use make_object to create a specific + * // instace of MyModuleNode. + * Module CreateMyModule() { + * ObjectPtr n = + * tvm::runtime::make_object(); + * return Module(n); + * } + * + * \endcode */ -class ModuleNode { +class ModuleNode : public Object { public: /*! \brief virtual destructor */ virtual ~ModuleNode() {} - /*! \return The module type key */ + /*! + * \return The per module type key. + * \note This key is used to for serializing custom modules. + */ virtual const char* type_key() const = 0; /*! * \brief Get a PackedFunc from module. @@ -105,7 +129,7 @@ class ModuleNode { * For benchmarking, use prepare to eliminate * * \param name the name of the function. - * \param sptr_to_self The shared_ptr that points to this module node. + * \param sptr_to_self The ObjectPtr that points to this module node. * * \return PackedFunc(nullptr) when it is not available. * @@ -115,7 +139,7 @@ class ModuleNode { */ virtual PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) = 0; + const ObjectPtr& sptr_to_self) = 0; /*! * \brief Save the module to file. * \param file_name The file to be saved to. @@ -137,6 +161,24 @@ class ModuleNode { * \return Possible source code when available. */ TVM_DLL virtual std::string GetSource(const std::string& format = ""); + /*! + * \brief Get packed function from current module by name. + * + * \param name The name of the function. + * \param query_imports Whether also query dependency modules. + * \return The result function. + * This function will return PackedFunc(nullptr) if function do not exist. + * \note Implemented in packed_func.cc + */ + TVM_DLL PackedFunc GetFunction(const std::string& name, bool query_imports = false); + /*! + * \brief Import another module into this module. + * \param other The module to be imported. + * + * \note Cyclic dependency is not allowed among modules, + * An error will be thrown when cyclic dependency is detected. + */ + TVM_DLL void Import(Module other); /*! * \brief Get a function from current environment * The environment includes all the imports as well as Global functions. @@ -150,6 +192,13 @@ class ModuleNode { return imports_; } + // integration with the existing components. + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule; + static constexpr const char* _type_key = "runtime.Module"; + // NOTE: ModuleNode can still be sub-classed + // + TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object); + protected: friend class Module; /*! \brief The modules this module depend on */ @@ -180,16 +229,21 @@ constexpr const char* tvm_module_main = "__tvm_main__"; } // namespace symbol // implementations of inline functions. + +inline void Module::Import(Module other) { + return (*this)->Import(other); +} + inline ModuleNode* Module::operator->() { - return node_.get(); + return static_cast(get_mutable()); } inline const ModuleNode* Module::operator->() const { - return node_.get(); + return static_cast(get()); } } // namespace runtime } // namespace tvm -#include "packed_func.h" +#include // NOLINT(*) #endif // TVM_RUNTIME_MODULE_H_ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 0aa78150bf2b..20e6b5a0fb63 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -53,6 +53,7 @@ enum TypeIndex { kVMTensor = 1, kVMClosure = 2, kVMADT = 3, + kRuntimeModule = 4, kStaticIndexEnd, /*! \brief Type index is allocated during runtime. */ kDynamic = kStaticIndexEnd @@ -302,7 +303,7 @@ class Object { template friend class ObjectPtr; friend class TVMRetValue; - friend class TVMObjectCAPI; + friend class ObjectInternal; }; /*! @@ -310,11 +311,11 @@ class Object { * * It is always important to get a reference type * if we want to return a value as reference or keep - * the node alive beyond the scope of the function. + * the object alive beyond the scope of the function. * - * \param ptr The node pointer + * \param ptr The object pointer * \tparam RefType The reference type - * \tparam ObjectType The node type + * \tparam ObjectType The object type * \return The corresponding RefType */ template @@ -486,6 +487,8 @@ class ObjectPtr { friend class TVMArgValue; template friend RefType GetRef(const ObjType* ptr); + template + friend ObjectPtr GetObjectPtr(ObjType* ptr); }; /*! \brief Base class of all object reference */ @@ -513,7 +516,7 @@ class ObjectRef { } /*! * \brief Comparator - * \param other Another node ref. + * \param other Another object ref. * \return the compare result. */ bool operator!=(const ObjectRef& other) const { @@ -535,7 +538,7 @@ class ObjectRef { const Object* get() const { return data_.get(); } - /*! \return the internal node pointer */ + /*! \return the internal object pointer */ const Object* operator->() const { return get(); } @@ -595,6 +598,16 @@ class ObjectRef { friend SubRef Downcast(BaseRef ref); }; +/*! + * \brief Get an object ptr type from a raw object ptr. + * + * \param ptr The object pointer + * \tparam BaseType The reference type + * \tparam ObjectType The object type + * \return The corresponding RefType + */ +template +inline ObjectPtr GetObjectPtr(ObjectType* ptr); /*! \brief ObjectRef hash functor */ struct ObjectHash { @@ -781,6 +794,13 @@ inline RefType GetRef(const ObjType* ptr) { return RefType(ObjectPtr(const_cast(static_cast(ptr)))); } +template +inline ObjectPtr GetObjectPtr(ObjType* ptr) { + static_assert(std::is_base_of::value, + "Can only cast to the ref of same container type"); + return ObjectPtr(static_cast(ptr)); +} + template inline SubRef Downcast(BaseRef ref) { CHECK(ref->template IsInstance()) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 645f49979ef7..57c4291907c0 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -496,6 +496,14 @@ class TVMPODValue_ { return ObjectRef( ObjectPtr(static_cast(value_.v_handle))); } + operator Module() const { + if (type_code_ == kNull) { + return Module(ObjectPtr(nullptr)); + } + TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle); + return Module( + ObjectPtr(static_cast(value_.v_handle))); + } operator TVMContext() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); return value_.v_ctx; @@ -574,6 +582,7 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator ObjectRef; + using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; // conversion operator. @@ -610,10 +619,6 @@ class TVMArgValue : public TVMPODValue_ { operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } - operator Module() const { - TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle); - return *ptr(); - } const TVMValue& value() const { return value_; } @@ -665,6 +670,7 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator ObjectRef; + using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { @@ -696,10 +702,6 @@ class TVMRetValue : public TVMPODValue_ { operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } - operator Module() const { - TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle); - return *ptr(); - } // Assign operators TVMRetValue& operator=(TVMRetValue&& other) { this->Clear(); @@ -766,17 +768,13 @@ class TVMRetValue : public TVMPODValue_ { TVMRetValue& operator=(ObjectRef other) { return operator=(std::move(other.data_)); } + TVMRetValue& operator=(Module m) { + SwitchToObject(kModuleHandle, std::move(m.data_)); + return *this; + } template TVMRetValue& operator=(ObjectPtr other) { - if (other.data_ != nullptr) { - this->Clear(); - type_code_ = kObjectHandle; - // move the handle out - value_.v_handle = other.data_; - other.data_ = nullptr; - } else { - SwitchToPOD(kNull); - } + SwitchToObject(kObjectHandle, std::move(other)); return *this; } TVMRetValue& operator=(PackedFunc f) { @@ -787,10 +785,6 @@ class TVMRetValue : public TVMPODValue_ { TVMRetValue& operator=(const TypedPackedFunc& f) { return operator=(f.packed()); } - TVMRetValue& operator=(Module m) { - this->SwitchToClass(kModuleHandle, m); - return *this; - } TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0 this->Assign(other); return *this; @@ -860,7 +854,7 @@ class TVMRetValue : public TVMPODValue_ { break; } case kModuleHandle: { - SwitchToClass(kModuleHandle, other); + *this = other.operator Module(); break; } case kNDArrayContainer: { @@ -907,16 +901,30 @@ class TVMRetValue : public TVMPODValue_ { *static_cast(value_.v_handle) = v; } } + void SwitchToObject(int type_code, ObjectPtr other) { + if (other.data_ != nullptr) { + this->Clear(); + type_code_ = type_code; + // move the handle out + value_.v_handle = other.data_; + other.data_ = nullptr; + } else { + SwitchToPOD(kNull); + } + } void Clear() { if (type_code_ == kNull) return; switch (type_code_) { case kStr: delete ptr(); break; case kFuncHandle: delete ptr(); break; - case kModuleHandle: delete ptr(); break; case kNDArrayContainer: { static_cast(value_.v_handle)->DecRef(); break; } + case kModuleHandle: { + static_cast(value_.v_handle)->DecRef(); + break; + } case kObjectHandle: { static_cast(value_.v_handle)->DecRef(); break; @@ -1156,8 +1164,12 @@ class TVMArgsSetter { operator()(i, value.packed()); } void operator()(size_t i, const Module& value) const { // NOLINT(*) - values_[i].v_handle = const_cast(&value); - type_codes_[i] = kModuleHandle; + if (value.defined()) { + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kModuleHandle; + } else { + type_codes_[i] = kNull; + } } void operator()(size_t i, const NDArray& value) const { // NOLINT(*) values_[i].v_handle = value.data_; @@ -1372,19 +1384,10 @@ inline ExtTypeVTable* ExtTypeVTable::Register_() { return ExtTypeVTable::RegisterInternal(code, vt); } -// Implement Module::GetFunction -// Put implementation in this file so we have seen the PackedFunc inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { - PackedFunc pf = node_->GetFunction(name, node_); - if (pf != nullptr) return pf; - if (query_imports) { - for (const Module& m : node_->imports_) { - pf = m.node_->GetFunction(name, m.node_); - if (pf != nullptr) return pf; - } - } - return pf; + return (*this)->GetFunction(name, query_imports); } + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_PACKED_FUNC_H_ diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index a196afdee2f3..317b53531c2d 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -480,7 +480,7 @@ class Executable : public ModuleNode { * \return PackedFunc or nullptr when it is not available. */ PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; + const ObjectPtr& sptr_to_self) final; /*! * \brief Serialize the executable into global section, constant section, and @@ -658,7 +658,7 @@ class VirtualMachine : public runtime::ModuleNode { * it should capture sptr_to_self. */ virtual PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self); + const ObjectPtr& sptr_to_self); /*! * \brief Invoke a PackedFunction diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index e190e3f1eb41..5a4c5f756054 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -148,7 +148,7 @@ def load_exec(bytecode, lib): raise TypeError("bytecode is expected to be the type of bytearray " + "or TVMByteArray, but received {}".format(type(code))) - if not isinstance(lib, tvm.module.Module): + if lib is not None and not isinstance(lib, tvm.module.Module): raise TypeError("lib is expected to be the type of tvm.module.Module" + ", but received {}".format(type(lib))) diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc index 554aec328900..b8a38f595985 100644 --- a/src/codegen/llvm/llvm_module.cc +++ b/src/codegen/llvm/llvm_module.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file llvm_module.cc * \brief LLVM runtime module for TVM */ @@ -54,7 +53,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { if (name == "__tvm_is_system_module") { bool flag = (mptr_->getFunction("__tvm_module_startup") != nullptr); @@ -325,7 +324,7 @@ TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id") TVM_REGISTER_API("codegen.build_llvm") .set_body([](TVMArgs args, TVMRetValue* rv) { - std::shared_ptr n = std::make_shared(); + auto n = make_object(); n->Init(args[0], args[1]); *rv = runtime::Module(n); }); @@ -339,7 +338,7 @@ TVM_REGISTER_API("codegen.llvm_version_major") TVM_REGISTER_API("module.loadfile_ll") .set_body([](TVMArgs args, TVMRetValue* rv) { - std::shared_ptr n = std::make_shared(); + auto n = make_object(); n->LoadIR(args[0]); *rv = runtime::Module(n); }); diff --git a/src/codegen/source_module.cc b/src/codegen/source_module.cc index 88be7fed448d..adbe7eaed451 100644 --- a/src/codegen/source_module.cc +++ b/src/codegen/source_module.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file source_module.cc * \brief Source code module, only for viewing */ @@ -51,7 +50,7 @@ class SourceModuleNode : public runtime::ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); @@ -67,8 +66,7 @@ class SourceModuleNode : public runtime::ModuleNode { }; runtime::Module SourceModuleCreate(std::string code, std::string fmt) { - std::shared_ptr n = - std::make_shared(code, fmt); + auto n = make_object(code, fmt); return runtime::Module(n); } @@ -84,7 +82,7 @@ class CSourceModuleNode : public runtime::ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "C Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); @@ -113,8 +111,7 @@ class CSourceModuleNode : public runtime::ModuleNode { }; runtime::Module CSourceModuleCreate(std::string code, std::string fmt) { - std::shared_ptr n = - std::make_shared(code, fmt); + auto n = make_object(code, fmt); return runtime::Module(n); } @@ -134,7 +131,7 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); @@ -182,8 +179,7 @@ runtime::Module DeviceSourceModuleCreate( std::unordered_map fmap, std::string type_key, std::function fget_source) { - std::shared_ptr n = - std::make_shared(data, fmt, fmap, type_key, fget_source); + auto n = make_object(data, fmt, fmap, type_key, fget_source); return runtime::Module(n); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 73cf6c27877d..9254c7e3e7b9 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -115,7 +115,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \return The corresponding member function. */ PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { if (name == "get_graph_json") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); @@ -489,7 +489,7 @@ class RelayBuildModule : public runtime::ModuleNode { }; runtime::Module RelayBuildCreate() { - std::shared_ptr exec = std::make_shared(); + auto exec = make_object(); return runtime::Module(exec); } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 0342aa6ab1ba..e2881785766c 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -593,7 +593,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { public: GraphRuntimeCodegenModule() {} virtual PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { if (name == "init") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 2) @@ -654,8 +654,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { }; runtime::Module CreateGraphCodegenMod() { - std::shared_ptr ptr = - std::make_shared(); + auto ptr = make_object(); return runtime::Module(ptr); } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 3cfea5c2e0db..7f828c473bbe 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file src/relay/backend/vm/compiler.cc * \brief A compiler from relay::Module to the VM byte code. */ @@ -745,7 +744,7 @@ class VMFunctionCompiler : ExprFunctor { PackedFunc VMCompiler::GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { if (name == "compile") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); @@ -974,7 +973,7 @@ void VMCompiler::LibraryCodegen() { } runtime::Module CreateVMCompiler() { - std::shared_ptr exec = std::make_shared(); + auto exec = make_object(); return runtime::Module(exec); } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 215cc12c4cdb..db319c49b2f3 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -86,14 +86,14 @@ class VMCompiler : public runtime::ModuleNode { virtual ~VMCompiler() {} virtual PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self); + const ObjectPtr& sptr_to_self); const char* type_key() const { return "VMCompiler"; } void InitVM() { - exec_ = std::make_shared(); + exec_ = make_object(); } /*! @@ -141,7 +141,7 @@ class VMCompiler : public runtime::ModuleNode { /*! \brief Global shared meta data */ VMCompilerContext context_; /*! \brief Compiled executable. */ - std::shared_ptr exec_; + ObjectPtr exec_; /*! \brief parameters */ std::unordered_map params_; }; diff --git a/src/relay/backend/vm/profiler/compiler.cc b/src/relay/backend/vm/profiler/compiler.cc index 60c441a60cf0..4727f151bd08 100644 --- a/src/relay/backend/vm/profiler/compiler.cc +++ b/src/relay/backend/vm/profiler/compiler.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file src/relay/backend/vm/profiler/compiler.cc * \brief A compiler from relay::Module to the VM byte code. */ @@ -37,7 +36,7 @@ class VMCompilerDebug : public VMCompiler { }; runtime::Module CreateVMCompilerDebug() { - std::shared_ptr exec = std::make_shared(); + auto exec = make_object(); return runtime::Module(exec); } diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 13181da7303a..3608fcea4aa1 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2016 by Contributors * \file c_runtime_api.cc * \brief Device specific implementations */ @@ -41,6 +40,7 @@ #include #include #include "runtime_base.h" +#include "object_internal.h" namespace tvm { namespace runtime { @@ -370,16 +370,20 @@ int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) { API_BEGIN(); - Module m = Module::LoadFromFile(file_name, format); - *out = new Module(m); + TVMRetValue ret; + ret = Module::LoadFromFile(file_name, format); + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + *out = val.v_handle; API_END(); } int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) { API_BEGIN(); - static_cast(mod)->Import( - *static_cast(dep)); + ObjectInternal::GetModuleNode(mod)->Import( + GetRef(ObjectInternal::GetModuleNode(dep))); API_END(); } @@ -388,7 +392,7 @@ int TVMModGetFunction(TVMModuleHandle mod, int query_imports, TVMFunctionHandle *func) { API_BEGIN(); - PackedFunc pf = static_cast(mod)->GetFunction( + PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction( func_name, query_imports != 0); if (pf != nullptr) { *func = new PackedFunc(pf); @@ -399,9 +403,7 @@ int TVMModGetFunction(TVMModuleHandle mod, } int TVMModFree(TVMModuleHandle mod) { - API_BEGIN(); - delete static_cast(mod); - API_END(); + return TVMObjectFree(mod); } int TVMBackendGetFuncFromEnv(void* mod_node, diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 55d9e648e154..e15356463bda 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -69,7 +69,7 @@ class CUDAModuleNode : public runtime::ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final; + const ObjectPtr& sptr_to_self) final; void SaveToFile(const std::string& file_name, const std::string& format) final { @@ -166,7 +166,7 @@ class CUDAWrappedFunc { public: // initialize the CUDA function. void Init(CUDAModuleNode* m, - std::shared_ptr sptr, + ObjectPtr sptr, const std::string& func_name, size_t num_void_args, const std::vector& thread_axis_tags) { @@ -220,7 +220,7 @@ class CUDAWrappedFunc { // internal module CUDAModuleNode* m_; // the resource holder - std::shared_ptr sptr_; + ObjectPtr sptr_; // The name of the function. std::string func_name_; // Device function cache per device. @@ -233,7 +233,7 @@ class CUDAWrappedFunc { class CUDAPrepGlobalBarrier { public: CUDAPrepGlobalBarrier(CUDAModuleNode* m, - std::shared_ptr sptr) + ObjectPtr sptr) : m_(m), sptr_(sptr) { std::fill(pcache_.begin(), pcache_.end(), 0); } @@ -252,14 +252,14 @@ class CUDAPrepGlobalBarrier { // internal module CUDAModuleNode* m_; // the resource holder - std::shared_ptr sptr_; + ObjectPtr sptr_; // mark as mutable, to enable lazy initialization mutable std::array pcache_; }; PackedFunc CUDAModuleNode::GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; @@ -279,8 +279,7 @@ Module CUDAModuleCreate( std::string fmt, std::unordered_map fmap, std::string cuda_source) { - std::shared_ptr n = - std::make_shared(data, fmt, fmap, cuda_source); + auto n = make_object(data, fmt, fmap, cuda_source); return Module(n); } diff --git a/src/runtime/cuda/cuda_module.h b/src/runtime/cuda/cuda_module.h index 54ff38da6f33..bce0d63e98a1 100644 --- a/src/runtime/cuda/cuda_module.h +++ b/src/runtime/cuda/cuda_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file cuda_module.h * \brief Execution handling of CUDA kernels */ diff --git a/src/runtime/dso_module.cc b/src/runtime/dso_module.cc index 4f69f2692d72..abbbe124a569 100644 --- a/src/runtime/dso_module.cc +++ b/src/runtime/dso_module.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,11 +18,11 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file dso_dll_module.cc * \brief Module to load from dynamic shared library. */ #include +#include #include #include #include "module_util.h" @@ -50,7 +50,7 @@ class DSOModuleNode final : public ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { BackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { const char* entry_name = reinterpret_cast( @@ -124,7 +124,7 @@ class DSOModuleNode final : public ModuleNode { TVM_REGISTER_GLOBAL("module.loadfile_so") .set_body([](TVMArgs args, TVMRetValue* rv) { - std::shared_ptr n = std::make_shared(); + auto n = make_object(); n->Init(args[0]); *rv = runtime::Module(n); }); diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 2b26ae541b5f..ab28cb662f2a 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2018 by Contributors * \file graph_runtime_debug.cc */ #include @@ -28,6 +27,7 @@ #include #include #include "../graph_runtime.h" +#include "../../object_internal.h" namespace tvm { namespace runtime { @@ -121,7 +121,7 @@ class GraphRuntimeDebug : public GraphRuntime { * \param sptr_to_self Packed function pointer. */ PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self); + const ObjectPtr& sptr_to_self); /*! * \brief Get the node index given the name of node. @@ -169,7 +169,7 @@ void DebugGetNodeOutput(int index, DLTensor* data_out) { */ PackedFunc GraphRuntimeDebug::GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { // return member functions during query. if (name == "get_output_by_layer") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -207,7 +207,7 @@ PackedFunc GraphRuntimeDebug::GetFunction( Module GraphRuntimeDebugCreate(const std::string& sym_json, const tvm::runtime::Module& m, const std::vector& ctxs) { - std::shared_ptr exec = std::make_shared(); + auto exec = make_object(); exec->Init(sym_json, m, ctxs); return Module(exec); } @@ -222,15 +222,16 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create") }); TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create") - .set_body([](TVMArgs args, TVMRetValue* rv) { +.set_body([](TVMArgs args, TVMRetValue* rv) { CHECK_GE(args.num_args, 4) << "The expected number of arguments for " "graph_runtime.remote_create is " "at least 4, but it has " << args.num_args; void* mhandle = args[1]; + ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle); const auto& contexts = GetAllContext(args); *rv = GraphRuntimeDebugCreate( - args[0], *static_cast(mhandle), contexts); + args[0], GetRef(mnode), contexts); }); } // namespace runtime diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 38016ab87cbb..9ad10c1232c3 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -18,11 +18,8 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file graph_runtime.cc */ -#include "graph_runtime.h" - #include #include #include @@ -38,6 +35,9 @@ #include #include +#include "graph_runtime.h" +#include "../object_internal.h" + namespace tvm { namespace runtime { namespace details { @@ -411,7 +411,7 @@ std::pair, std::shared_ptr > GraphRu PackedFunc GraphRuntime::GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -478,7 +478,7 @@ PackedFunc GraphRuntime::GetFunction( Module GraphRuntimeCreate(const std::string& sym_json, const tvm::runtime::Module& m, const std::vector& ctxs) { - std::shared_ptr exec = std::make_shared(); + auto exec = make_object(); exec->Init(sym_json, m, ctxs); return Module(exec); } @@ -513,15 +513,17 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create") }); TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create") - .set_body([](TVMArgs args, TVMRetValue* rv) { +.set_body([](TVMArgs args, TVMRetValue* rv) { CHECK_GE(args.num_args, 4) << "The expected number of arguments for " "graph_runtime.remote_create is " "at least 4, but it has " << args.num_args; void* mhandle = args[1]; + ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle); + const auto& contexts = GetAllContext(args); *rv = GraphRuntimeCreate( - args[0], *static_cast(mhandle), contexts); + args[0], GetRef(mnode), contexts); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index e8097a83b8dc..c83d68e08159 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -18,8 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors - * * \brief Tiny graph runtime that can run graph * containing only tvm PackedFunc. * \file graph_runtime.h @@ -83,7 +81,7 @@ class GraphRuntime : public ModuleNode { * \return The corresponding member function. */ virtual PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self); + const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index af809d7619bd..d9b23fc55086 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file metal_module.cc */ #include @@ -54,7 +53,7 @@ explicit MetalModuleNode(std::string data, PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final; + const ObjectPtr& sptr_to_self) final; void SaveToFile(const std::string& file_name, const std::string& format) final { @@ -187,7 +186,7 @@ void SaveToBinary(dmlc::Stream* stream) final { public: // initialize the METAL function. void Init(MetalModuleNode* m, - std::shared_ptr sptr, + ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, @@ -244,7 +243,7 @@ void operator()(TVMArgs args, // internal module MetalModuleNode* m_; // the resource holder - std::shared_ptr sptr_; + ObjectPtr sptr_; // The name of the function. std::string func_name_; // Number of buffer arguments @@ -260,7 +259,7 @@ void operator()(TVMArgs args, PackedFunc MetalModuleNode::GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; @@ -281,8 +280,7 @@ Module MetalModuleCreate( std::unordered_map fmap, std::string source) { metal::MetalWorkspace::Global()->Init(); - std::shared_ptr n = - std::make_shared(data, fmt, fmap, source); + auto n = make_object(data, fmt, fmap, source); return Module(n); } diff --git a/src/runtime/micro/micro_device_api.cc b/src/runtime/micro/micro_device_api.cc index 88328a2a4305..d1df67f00d9b 100644 --- a/src/runtime/micro/micro_device_api.cc +++ b/src/runtime/micro/micro_device_api.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file micro_device_api.cc */ @@ -50,7 +49,7 @@ class MicroDeviceAPI final : public DeviceAPI { size_t nbytes, size_t alignment, TVMType type_hint) final { - std::shared_ptr& session = MicroSession::Current(); + ObjectPtr& session = MicroSession::Current(); void* data = session->AllocateInSection(SectionKind::kHeap, nbytes).cast_to(); CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap"; MicroDevSpace* dev_space = new MicroDevSpace(); @@ -82,11 +81,12 @@ class MicroDeviceAPI final : public DeviceAPI { MicroDevSpace* from_space = static_cast(const_cast(from)); MicroDevSpace* to_space = static_cast(const_cast(to)); CHECK(from_space->session == to_space->session) - << "attempt to copy data between different micro sessions (" << from_space->session - << " != " << to_space->session << ")"; + << "attempt to copy data between different micro sessions (" + << from_space->session.get() + << " != " << to_space->session.get() << ")"; CHECK(ctx_from.device_id == ctx_to.device_id) << "can only copy between the same micro device"; - std::shared_ptr& session = from_space->session; + ObjectPtr& session = from_space->session; const std::shared_ptr& lld = session->low_level_device(); DevBaseOffset from_dev_offset = GetDevLoc(from_space, from_offset); @@ -99,7 +99,7 @@ class MicroDeviceAPI final : public DeviceAPI { // Reading from the device. MicroDevSpace* from_space = static_cast(const_cast(from)); - std::shared_ptr& session = from_space->session; + ObjectPtr& session = from_space->session; const std::shared_ptr& lld = session->low_level_device(); DevBaseOffset from_dev_offset = GetDevLoc(from_space, from_offset); @@ -109,7 +109,7 @@ class MicroDeviceAPI final : public DeviceAPI { // Writing to the device. MicroDevSpace* to_space = static_cast(const_cast(to)); - std::shared_ptr& session = to_space->session; + ObjectPtr& session = to_space->session; const std::shared_ptr& lld = session->low_level_device(); void* from_host_ptr = GetHostLoc(from, from_offset); @@ -124,7 +124,7 @@ class MicroDeviceAPI final : public DeviceAPI { } void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { - std::shared_ptr& session = MicroSession::Current(); + ObjectPtr& session = MicroSession::Current(); void* data = session->AllocateInSection(SectionKind::kWorkspace, size).cast_to(); CHECK(data != nullptr) << "unable to allocate " << size << " bytes on device workspace"; @@ -136,7 +136,7 @@ class MicroDeviceAPI final : public DeviceAPI { void FreeWorkspace(TVMContext ctx, void* data) final { MicroDevSpace* dev_space = static_cast(data); - std::shared_ptr& session = dev_space->session; + ObjectPtr& session = dev_space->session; session->FreeInSection(SectionKind::kWorkspace, DevBaseOffset(reinterpret_cast(dev_space->data))); delete dev_space; diff --git a/src/runtime/micro/micro_module.cc b/src/runtime/micro/micro_module.cc index 85cd35982138..e66c45b3f063 100644 --- a/src/runtime/micro/micro_module.cc +++ b/src/runtime/micro/micro_module.cc @@ -18,9 +18,8 @@ */ /*! -* Copyright (c) 2019 by Contributors -* \file micro_module.cc -*/ + * \file micro_module.cc + */ #include #include @@ -48,7 +47,7 @@ class MicroModuleNode final : public ModuleNode { } PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; + const ObjectPtr& sptr_to_self) final; /*! * \brief initializes module by establishing device connection and loads binary @@ -76,13 +75,13 @@ class MicroModuleNode final : public ModuleNode { /*! \brief path to module binary */ std::string binary_path_; /*! \brief global session pointer */ - std::shared_ptr session_; + ObjectPtr session_; }; class MicroWrappedFunc { public: MicroWrappedFunc(MicroModuleNode* m, - std::shared_ptr session, + ObjectPtr session, const std::string& func_name, DevBaseOffset func_offset) { m_ = m; @@ -99,7 +98,7 @@ class MicroWrappedFunc { /*! \brief internal module */ MicroModuleNode* m_; /*! \brief reference to the session for this function (to keep the session alive) */ - std::shared_ptr session_; + ObjectPtr session_; /*! \brief name of the function */ std::string func_name_; /*! \brief offset of the function to be called */ @@ -108,7 +107,7 @@ class MicroWrappedFunc { PackedFunc MicroModuleNode::GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { DevBaseOffset func_offset = session_->low_level_device()->ToDevOffset(binary_info_.symbol_map[name]); MicroWrappedFunc f(this, session_, name, func_offset); @@ -118,9 +117,9 @@ PackedFunc MicroModuleNode::GetFunction( // register loadfile function to load module from Python frontend TVM_REGISTER_GLOBAL("module.loadfile_micro_dev") .set_body([](TVMArgs args, TVMRetValue* rv) { - std::shared_ptr n = std::make_shared(); + auto n = make_object(); n->InitMicroModule(args[0]); *rv = runtime::Module(n); - }); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index 9790154cb6f3..febf726184d9 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -18,13 +18,11 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file micro_session.cc */ #include #include -#include #include #include #include @@ -36,18 +34,18 @@ namespace tvm { namespace runtime { struct TVMMicroSessionThreadLocalEntry { - std::stack> session_stack; + std::stack> session_stack; }; typedef dmlc::ThreadLocalStore TVMMicroSessionThreadLocalStore; -std::shared_ptr& MicroSession::Current() { +ObjectPtr& MicroSession::Current() { TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); CHECK_GT(entry->session_stack.size(), 0) << "No current session"; return entry->session_stack.top(); } -void MicroSession::EnterWithScope(std::shared_ptr session) { +void MicroSession::EnterWithScope(ObjectPtr session) { TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); entry->session_stack.push(session); } @@ -121,7 +119,7 @@ void MicroSession::CreateSession(const std::string& device_type, void MicroSession::PushToExecQueue(DevBaseOffset func, const TVMArgs& args) { int32_t (*func_dev_addr)(void*, void*, int32_t) = reinterpret_cast( - low_level_device()->ToDevPtr(func).value()); + low_level_device()->ToDevPtr(func).value()); // Create an allocator stream for the memory region after the most recent // allocation in the args section. @@ -355,10 +353,10 @@ void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, PackedFunc MicroSession::GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { if (name == "enter") { - return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { - MicroSession::EnterWithScope(std::dynamic_pointer_cast(sptr_to_self)); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + MicroSession::EnterWithScope(GetObjectPtr(this)); }); } else if (name == "exit") { return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { @@ -378,7 +376,7 @@ TVM_REGISTER_GLOBAL("micro._CreateSession") uint64_t base_addr = args[3]; const std::string& server_addr = args[4]; int port = args[5]; - std::shared_ptr session = std::make_shared(); + ObjectPtr session = make_object(); session->CreateSession( device_type, binary_path, toolchain_prefix, base_addr, server_addr, port); *rv = Module(session); diff --git a/src/runtime/micro/micro_session.h b/src/runtime/micro/micro_session.h index 1400f74c4346..65b64218313b 100644 --- a/src/runtime/micro/micro_session.h +++ b/src/runtime/micro/micro_session.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file micro_session.h * \brief session to manage multiple micro modules * @@ -66,7 +65,7 @@ class MicroSession : public ModuleNode { * \return The corresponding member function. */ virtual PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self); + const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. @@ -85,7 +84,7 @@ class MicroSession : public ModuleNode { */ ~MicroSession(); - static std::shared_ptr& Current(); + static ObjectPtr& Current(); /*! * \brief creates session by setting up a low-level device and initting allocators for it @@ -240,7 +239,7 @@ class MicroSession : public ModuleNode { * \brief Push a new session context onto the thread-local stack. * The session on top of the stack is used as the current global session. */ - static void EnterWithScope(std::shared_ptr session); + static void EnterWithScope(ObjectPtr session); /*! * \brief Pop a session off the thread-local context stack, * restoring the previous session as the current context. @@ -258,7 +257,7 @@ struct MicroDevSpace { /*! \brief data being wrapped */ void* data; /*! \brief shared ptr to session where this data is valid */ - std::shared_ptr session; + ObjectPtr session; }; } // namespace runtime diff --git a/src/runtime/micro/tcl_socket.h b/src/runtime/micro/tcl_socket.h index 80ce185f4696..2123312ff585 100644 --- a/src/runtime/micro/tcl_socket.h +++ b/src/runtime/micro/tcl_socket.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file tcl_socket.h * \brief TCP socket wrapper for communicating using Tcl commands */ diff --git a/src/runtime/module.cc b/src/runtime/module.cc index c0acb315a04f..161675c7ca0c 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file module.cc * \brief TVM module system */ @@ -34,33 +33,46 @@ namespace tvm { namespace runtime { -void Module::Import(Module other) { +void ModuleNode::Import(Module other) { // specially handle rpc - if (!std::strcmp((*this)->type_key(), "rpc")) { + if (!std::strcmp(this->type_key(), "rpc")) { static const PackedFunc* fimport_ = nullptr; if (fimport_ == nullptr) { fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule"); CHECK(fimport_ != nullptr); } - (*fimport_)(*this, other); + (*fimport_)(GetRef(this), other); return; } // cyclic detection. - std::unordered_set visited{other.node_.get()}; - std::vector stack{other.node_.get()}; + std::unordered_set visited{other.operator->()}; + std::vector stack{other.operator->()}; while (!stack.empty()) { const ModuleNode* n = stack.back(); stack.pop_back(); for (const Module& m : n->imports_) { - const ModuleNode* next = m.node_.get(); + const ModuleNode* next = m.operator->(); if (visited.count(next)) continue; visited.insert(next); stack.push_back(next); } } - CHECK(!visited.count(node_.get())) + CHECK(!visited.count(this)) << "Cyclic dependency detected during import"; - node_->imports_.emplace_back(std::move(other)); + this->imports_.emplace_back(std::move(other)); +} + +PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) { + ModuleNode* self = this; + PackedFunc pf = self->GetFunction(name, GetObjectPtr(this)); + if (pf != nullptr) return pf; + if (query_imports) { + for (Module& m : self->imports_) { + pf = m->GetFunction(name, m.data_); + if (pf != nullptr) return pf; + } + } + return pf; } Module Module::LoadFromFile(const std::string& file_name, diff --git a/src/runtime/module_util.cc b/src/runtime/module_util.cc index 456d28278abc..445bfd343653 100644 --- a/src/runtime/module_util.cc +++ b/src/runtime/module_util.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file module_util.cc * \brief Utilities for module. */ @@ -64,7 +63,7 @@ void ImportModuleBlob(const char* mblob, std::vector* mlist) { } PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { int ret = (*faddr)( const_cast(args.values), diff --git a/src/runtime/module_util.h b/src/runtime/module_util.h index e5bbfe32766a..5f56c150588a 100644 --- a/src/runtime/module_util.h +++ b/src/runtime/module_util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file module_util.h * \brief Helper utilities for module building */ @@ -45,7 +44,7 @@ namespace runtime { * \param faddr The function address * \param mptr The module pointer node. */ -PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr& mptr); +PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const ObjectPtr& mptr); /*! * \brief Load and append module blob to module list * \param mblob The module blob. diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 5d71c2fd2fa1..7a8aef8316f7 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -27,6 +27,7 @@ #include #include #include +#include "object_internal.h" #include "runtime_base.h" namespace tvm { @@ -200,18 +201,6 @@ uint32_t Object::TypeKey2Index(const std::string& key) { return TypeContext::Global()->TypeKey2Index(key); } -class TVMObjectCAPI { - public: - static void Free(TVMObjectHandle obj) { - if (obj != nullptr) { - static_cast(obj)->DecRef(); - } - } - - static uint32_t TypeKey2Index(const std::string& type_key) { - return Object::TypeKey2Index(type_key); - } -}; } // namespace runtime } // namespace tvm @@ -224,13 +213,13 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { int TVMObjectFree(TVMObjectHandle obj) { API_BEGIN(); - tvm::runtime::TVMObjectCAPI::Free(obj); + tvm::runtime::ObjectInternal::ObjectFree(obj); API_END(); } int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { API_BEGIN(); - out_tindex[0] = tvm::runtime::TVMObjectCAPI::TypeKey2Index( + out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index( type_key); API_END(); } diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h new file mode 100644 index 000000000000..79551309d67c --- /dev/null +++ b/src/runtime/object_internal.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/runtime/object_internal.h + * \brief Expose a few functions for CFFI purposes. + * This file is not intended to be used + */ +#ifndef TVM_RUNTIME_OBJECT_INTERNAL_H_ +#define TVM_RUNTIME_OBJECT_INTERNAL_H_ + +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Internal object namespace to expose + * certain util functions for FFI. + */ +class ObjectInternal { + public: + /*! + * \brief Free an object handle. + */ + static void ObjectFree(TVMObjectHandle obj) { + if (obj != nullptr) { + static_cast(obj)->DecRef(); + } + } + /*! + * \brief Expose TypeKey2Index + * \param type_key The original type key. + * \return the corresponding index. + */ + static uint32_t ObjectTypeKey2Index(const std::string& type_key) { + return Object::TypeKey2Index(type_key); + } + /*! + * \brief Convert ModuleHandle to module node pointer. + * \param handle The module handle. + * \return the corresponding module node pointer. + */ + static ModuleNode* GetModuleNode(TVMModuleHandle handle) { + // NOTE: we will need to convert to Object + // then to ModuleNode in order to get the correct + // address translation + return static_cast(static_cast(handle)); + } +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_ diff --git a/src/runtime/opencl/aocl/aocl_common.h b/src/runtime/opencl/aocl/aocl_common.h index 48a6b8eac0a1..d9251f8aaf53 100644 --- a/src/runtime/opencl/aocl/aocl_common.h +++ b/src/runtime/opencl/aocl/aocl_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2018 by Contributors * \file aocl_common.h * \brief AOCL common header */ diff --git a/src/runtime/opencl/aocl/aocl_device_api.cc b/src/runtime/opencl/aocl/aocl_device_api.cc index 2442c4d2f1e3..84c29eea33ec 100644 --- a/src/runtime/opencl/aocl/aocl_device_api.cc +++ b/src/runtime/opencl/aocl/aocl_device_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2018 by Contributors * \file aocl_device_api.cc */ #include diff --git a/src/runtime/opencl/aocl/aocl_module.h b/src/runtime/opencl/aocl/aocl_module.h index 2e8322f75943..70955cc65528 100644 --- a/src/runtime/opencl/aocl/aocl_module.h +++ b/src/runtime/opencl/aocl/aocl_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2018 by Contributors * \file aocl_module.h * \brief Execution handling of OpenCL kernels for AOCL */ diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index ab84eef9d764..bd934736f235 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -278,7 +278,7 @@ class OpenCLModuleNode : public ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final; + const ObjectPtr& sptr_to_self) final; void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 971ae3482014..24687db46ce6 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file opencl_module.cc */ #include @@ -36,7 +35,7 @@ class OpenCLWrappedFunc { public: // initialize the OpenCL function. void Init(OpenCLModuleNode* m, - std::shared_ptr sptr, + ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, std::string func_name, std::vector arg_size, @@ -88,7 +87,7 @@ class OpenCLWrappedFunc { // The module OpenCLModuleNode* m_; // resource handle - std::shared_ptr sptr_; + ObjectPtr sptr_; // global kernel id in the kernel table. OpenCLModuleNode::KTRefEntry entry_; // The name of the function. @@ -122,7 +121,7 @@ const std::shared_ptr& OpenCLModuleNode::GetGlobalWorkspace PackedFunc OpenCLModuleNode::GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; @@ -251,8 +250,7 @@ Module OpenCLModuleCreate( std::string fmt, std::unordered_map fmap, std::string source) { - std::shared_ptr n = - std::make_shared(data, fmt, fmap, source); + auto n = make_object(data, fmt, fmap, source); n->Init(); return Module(n); } diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index cd63382d8851..3b7ebb9c1659 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file opencl_module.h * \brief Execution handling of OPENCL kernels */ diff --git a/src/runtime/opengl/opengl_module.cc b/src/runtime/opengl/opengl_module.cc index 9a1f77430c11..0d3f953a96e1 100644 --- a/src/runtime/opengl/opengl_module.cc +++ b/src/runtime/opengl/opengl_module.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -44,7 +44,7 @@ class OpenGLModuleNode final : public ModuleNode { const char* type_key() const final { return "opengl"; } PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; + const ObjectPtr& sptr_to_self) final; std::string GetSource(const std::string& format) final; @@ -74,7 +74,7 @@ class OpenGLModuleNode final : public ModuleNode { class OpenGLWrappedFunc { public: OpenGLWrappedFunc(OpenGLModuleNode* m, - std::shared_ptr sptr, + ObjectPtr sptr, std::string func_name, std::vector arg_size, const std::vector& thread_axis_tags); @@ -85,7 +85,7 @@ class OpenGLWrappedFunc { // The module OpenGLModuleNode* m_; // resource handle - std::shared_ptr sptr_; + ObjectPtr sptr_; // The name of the function. std::string func_name_; // convert code for void argument @@ -111,7 +111,7 @@ OpenGLModuleNode::OpenGLModuleNode( PackedFunc OpenGLModuleNode::GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; @@ -191,7 +191,7 @@ const FunctionInfo& OpenGLModuleNode::GetFunctionInfo( OpenGLWrappedFunc::OpenGLWrappedFunc( OpenGLModuleNode* m, - std::shared_ptr sptr, + ObjectPtr sptr, std::string func_name, std::vector arg_size, const std::vector& thread_axis_tags) @@ -251,9 +251,9 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, Module OpenGLModuleCreate(std::unordered_map shaders, std::string fmt, std::unordered_map fmap) { - auto n = std::make_shared(std::move(shaders), - std::move(fmt), - std::move(fmap)); + auto n = make_object(std::move(shaders), + std::move(fmt), + std::move(fmap)); return Module(n); } diff --git a/src/runtime/opengl/opengl_module.h b/src/runtime/opengl/opengl_module.h index b4459ae5a952..f1b712e8b20a 100644 --- a/src/runtime/opengl/opengl_module.h +++ b/src/runtime/opengl/opengl_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 96d19483bde0..c2bea8a8d745 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file rocm_module.cc */ #include @@ -68,7 +67,7 @@ class ROCMModuleNode : public runtime::ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final; + const ObjectPtr& sptr_to_self) final; void SaveToFile(const std::string& file_name, @@ -158,7 +157,7 @@ class ROCMWrappedFunc { public: // initialize the ROCM function. void Init(ROCMModuleNode* m, - std::shared_ptr sptr, + ObjectPtr sptr, const std::string& func_name, size_t num_void_args, const std::vector& thread_axis_tags) { @@ -204,7 +203,7 @@ class ROCMWrappedFunc { // internal module ROCMModuleNode* m_; // the resource holder - std::shared_ptr sptr_; + ObjectPtr sptr_; // The name of the function. std::string func_name_; // Device function cache per device. @@ -217,7 +216,7 @@ class ROCMWrappedFunc { PackedFunc ROCMModuleNode::GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; @@ -235,8 +234,7 @@ Module ROCMModuleCreate( std::unordered_map fmap, std::string hip_source, std::string assembly) { - std::shared_ptr n = - std::make_shared(data, fmt, fmap, hip_source, assembly); + auto n = make_object(data, fmt, fmap, hip_source, assembly); return Module(n); } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 2e1e64a73283..8c4486e04bc5 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -123,7 +123,7 @@ class RPCModuleNode final : public ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { RPCFuncHandle handle = GetFuncHandle(name); return WrapRemote(handle); } @@ -195,8 +195,7 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr sess, return wf->operator()(args, rv); }); } else if (tcode == kModuleHandle) { - std::shared_ptr n = - std::make_shared(handle, sess); + auto n = make_object(handle, sess); *rv = Module(n); } else if (tcode == kArrayHandle || tcode == kNDArrayContainer) { CHECK_EQ(args.size(), 2); @@ -209,8 +208,7 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr sess, } Module CreateRPCModule(std::shared_ptr sess) { - std::shared_ptr n = - std::make_shared(nullptr, sess); + auto n = make_object(nullptr, sess); return Module(n); } @@ -237,8 +235,7 @@ TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule") CHECK_EQ(tkey, "rpc"); auto& sess = static_cast(m.operator->())->sess(); void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]); - std::shared_ptr n = - std::make_shared(mhandle, sess); + auto n = make_object(mhandle, sess); *rv = Module(n); }); diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 39db150bd3a0..b5fec104a677 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -35,6 +35,7 @@ #include #include #include "rpc_session.h" +#include "../object_internal.h" #include "../../common/ring_buffer.h" #include "../../common/socket.h" @@ -1119,25 +1120,29 @@ void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { } std::string file_name = args[0]; TVMRetValue ret = (*fsys_load_)(file_name); - Module m = ret; - *rv = static_cast(new Module(m)); + // pass via void* + TVMValue value; + int rcode; + ret.MoveToCHost(&value, &rcode); + CHECK_EQ(rcode, kModuleHandle); + *rv = static_cast(value.v_handle); } void RPCModuleImport(TVMArgs args, TVMRetValue *rv) { void* pmod = args[0]; void* cmod = args[1]; - static_cast(pmod)->Import( - *static_cast(cmod)); + ObjectInternal::GetModuleNode(pmod)->Import( + GetRef(ObjectInternal::GetModuleNode(cmod))); } void RPCModuleFree(TVMArgs args, TVMRetValue *rv) { void* mhandle = args[0]; - delete static_cast(mhandle); + ObjectInternal::ObjectFree(mhandle); } void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) { void* mhandle = args[0]; - PackedFunc pf = static_cast(mhandle)->GetFunction( + PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction( args[1], false); if (pf != nullptr) { *rv = static_cast(new PackedFunc(pf)); @@ -1149,7 +1154,7 @@ void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) { void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { void* mhandle = args[0]; std::string fmt = args[1]; - *rv = (*static_cast(mhandle))->GetSource(fmt); + *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt); } void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 3518455c83d1..ab5f16dadc46 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file rpc_session.h * \brief Base RPC session interface. */ diff --git a/src/runtime/stackvm/stackvm.cc b/src/runtime/stackvm/stackvm.cc index fe6913e6478d..07014a63110c 100644 --- a/src/runtime/stackvm/stackvm.cc +++ b/src/runtime/stackvm/stackvm.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * Implementation stack VM. * \file stackvm.cc */ diff --git a/src/runtime/stackvm/stackvm_module.cc b/src/runtime/stackvm/stackvm_module.cc index 4e7d42279001..4f86d0764ebf 100644 --- a/src/runtime/stackvm/stackvm_module.cc +++ b/src/runtime/stackvm/stackvm_module.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file stackvm_module.cc */ #include @@ -42,7 +41,7 @@ class StackVMModuleNode : public runtime::ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { if (name == runtime::symbol::tvm_module_main) { return GetFunction(entry_func_, sptr_to_self); } @@ -89,8 +88,7 @@ class StackVMModuleNode : public runtime::ModuleNode { static Module Create(std::unordered_map fmap, std::string entry_func) { - std::shared_ptr n = - std::make_shared(); + auto n = make_object(); n->fmap_ = std::move(fmap); n->entry_func_ = std::move(entry_func); return Module(n); @@ -101,8 +99,7 @@ class StackVMModuleNode : public runtime::ModuleNode { std::string entry_func, data; strm->Read(&fmap); strm->Read(&entry_func); - std::shared_ptr n = - std::make_shared(); + auto n = make_object(); n->fmap_ = std::move(fmap); n->entry_func_ = std::move(entry_func); uint64_t num_imports; diff --git a/src/runtime/system_lib_module.cc b/src/runtime/system_lib_module.cc index 247fae110747..8a75a36ca49f 100644 --- a/src/runtime/system_lib_module.cc +++ b/src/runtime/system_lib_module.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,11 +18,11 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file system_lib_module.cc * \brief SystemLib module. */ #include +#include #include #include #include "module_util.h" @@ -40,7 +40,7 @@ class SystemLibModuleNode : public ModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { std::lock_guard lock(mutex_); if (module_blob_ != nullptr) { @@ -83,9 +83,8 @@ class SystemLibModuleNode : public ModuleNode { } } - static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static const ObjectPtr& Global() { + static auto inst = make_object(); return inst; } diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 4c4554cc6c86..2aeecc5061bb 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file tvm/runtime/vm/executable.cc * \brief The implementation of a virtual machine executable APIs. */ @@ -51,7 +50,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr); Instruction DeserializeInstruction(const VMInstructionSerializer& instr); PackedFunc Executable::GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { if (name == "get_lib") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); @@ -440,7 +439,7 @@ void LoadHeader(dmlc::Stream* strm) { } runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { - std::shared_ptr exec = std::make_shared(); + auto exec = make_object(); exec->lib = lib; exec->code_ = code; dmlc::MemoryStringStream strm(&exec->code_); diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 821de0bda245..ed6cddb25471 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file src/runtime/vm/profiler/vm.cc * \brief The Relay debug virtual machine. */ @@ -41,7 +40,7 @@ namespace runtime { namespace vm { PackedFunc VirtualMachineDebug::GetFunction( - const std::string& name, const std::shared_ptr& sptr_to_self) { + const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "get_stat") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { double total_duration = 0.0; @@ -124,7 +123,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, } runtime::Module CreateVirtualMachineDebug(const Executable* exec) { - std::shared_ptr vm = std::make_shared(); + auto vm = make_object(); vm->LoadExecutable(exec); return runtime::Module(vm); } diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index ff3296cb6c16..2e95a0768000 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file src/runtime/vm/profiler/vm.h * \brief The Relay debug virtual machine. */ @@ -42,7 +41,7 @@ class VirtualMachineDebug : public VirtualMachine { VirtualMachineDebug() : VirtualMachine() {} PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; + const ObjectPtr& sptr_to_self) final; void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) final; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 05935b7833a5..463c5758ae02 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -627,7 +627,7 @@ ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { } PackedFunc VirtualMachine::GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) { + const ObjectPtr& sptr_to_self) { if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK(exec) << "The executable is not created yet."; @@ -1052,7 +1052,7 @@ void VirtualMachine::RunLoop() { } runtime::Module CreateVirtualMachine(const Executable* exec) { - std::shared_ptr vm = std::make_shared(); + auto vm = make_object(); vm->LoadExecutable(exec); return runtime::Module(vm); } diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index e3b2ac8988db..daf4ae7c55f7 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -663,7 +663,9 @@ class VulkanModuleNode; // a wrapped function class to get packed func. class VulkanWrappedFunc { public: - void Init(VulkanModuleNode* m, std::shared_ptr sptr, const std::string& func_name, + void Init(VulkanModuleNode* m, + ObjectPtr sptr, + const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, const std::vector& thread_axis_tags) { m_ = m; @@ -680,7 +682,7 @@ class VulkanWrappedFunc { // internal module VulkanModuleNode* m_; // the resource holder - std::shared_ptr sptr_; + ObjectPtr sptr_; // v The name of the function. std::string func_name_; // Number of buffer arguments @@ -705,7 +707,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { const char* type_key() const final { return "vulkan"; } PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); @@ -939,7 +941,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { Module VulkanModuleCreate(std::unordered_map smap, std::unordered_map fmap, std::string source) { - std::shared_ptr n = std::make_shared(smap, fmap, source); + auto n = make_object(smap, fmap, source); return Module(n); } diff --git a/vta/src/dpi/module.cc b/vta/src/dpi/module.cc index 6ef6af8fbcbd..27161c4b1bf8 100644 --- a/vta/src/dpi/module.cc +++ b/vta/src/dpi/module.cc @@ -226,7 +226,7 @@ class DPIModule final : public DPIModuleNode { PackedFunc GetFunction( const std::string& name, - const std::shared_ptr& sptr_to_self) final { + const ObjectPtr& sptr_to_self) final { if (name == "WriteReg") { return TypedPackedFunc( [this](int addr, int value){ @@ -413,8 +413,7 @@ class DPIModule final : public DPIModuleNode { }; Module DPIModuleNode::Load(std::string dll_name) { - std::shared_ptr n = - std::make_shared(); + auto n = make_object(); n->Init(dll_name); return Module(n); } diff --git a/web/web_runtime.cc b/web/web_runtime.cc index 12bc53cd3407..91f10c079088 100644 --- a/web/web_runtime.cc +++ b/web/web_runtime.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,6 +31,7 @@ #include "../src/runtime/system_lib_module.cc" #include "../src/runtime/module.cc" #include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" #include "../src/runtime/file_util.cc" #include "../src/runtime/dso_module.cc"