Skip to content

Commit

Permalink
[RUNTIME][REFACTOR] Use object protocol to support runtime::Module
Browse files Browse the repository at this point in the history
Previously runtime::Module was supported using shared_ptr.
This PR refactors the codebase to use the Object protocol.

It will open doors to allow easier interpolation between
Object containers and module in the future.
  • Loading branch information
tqchen committed Nov 9, 2019
1 parent 281f643 commit 75ddf51
Show file tree
Hide file tree
Showing 58 changed files with 447 additions and 323 deletions.
6 changes: 3 additions & 3 deletions apps/android_deploy/app/src/main/jni/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -18,7 +18,6 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* \file tvm_runtime.h
* \brief Pack all tvm runtime source files
*/
Expand All @@ -35,6 +34,7 @@
#include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
#include "../src/runtime/thread_pool.cc"
#include "../src/runtime/object.cc"
#include "../src/runtime/threading_backend.cc"
#include "../src/runtime/ndarray.cc"

Expand Down
5 changes: 3 additions & 2 deletions apps/android_rpc/app/src/main/jni/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -55,6 +55,7 @@
#include "../src/runtime/threading_backend.cc"
#include "../src/runtime/graph/graph_runtime.cc"
#include "../src/runtime/ndarray.cc"
#include "../src/runtime/object.cc"

#ifdef TVM_OPENCL_RUNTIME
#include "../src/runtime/opencl/opencl_device_api.cc"
Expand Down
5 changes: 3 additions & 2 deletions apps/bundle_deploy/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -32,5 +32,6 @@
#include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"
#include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/graph/graph_runtime.cc"
1 change: 1 addition & 0 deletions apps/howto_deploy/tvm_runtime_pack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"

// NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use.
Expand Down
3 changes: 2 additions & 1 deletion apps/ios_rpc/tvmrpc/TVMRuntime.mm
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

/*!
* Copyright (c) 2017 by Contributors
* \file TVMRuntime.mm
*/
#include "TVMRuntime.h"
Expand All @@ -35,6 +34,8 @@
#include "../../../src/runtime/file_util.cc"
#include "../../../src/runtime/dso_module.cc"
#include "../../../src/runtime/ndarray.cc"
#include "../../../src/runtime/object.cc"

// RPC server
#include "../../../src/runtime/rpc/rpc_session.cc"
#include "../../../src/runtime/rpc/rpc_server_env.cc"
Expand Down
2 changes: 1 addition & 1 deletion golang/src/tvm_runtime_pack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* \brief This is an all in one TVM runtime file.
* \file tvm_runtime_pack.cc
*/
Expand All @@ -32,6 +31,7 @@
#include "src/runtime/threading_backend.cc"
#include "src/runtime/thread_pool.cc"
#include "src/runtime/ndarray.cc"
#include "src/runtime/object.cc"

// NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use.
Expand Down
98 changes: 76 additions & 22 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,31 @@
#define TVM_RUNTIME_MODULE_H_

#include <dmlc/io.h>

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>

#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "c_runtime_api.h"

namespace tvm {
namespace runtime {

// The internal container of module.
class ModuleNode;
class PackedFunc;

/*!
* \brief Module container of TVM.
*/
class Module {
class Module : public ObjectRef {
public:
Module() {}
// constructor from container.
explicit Module(std::shared_ptr<ModuleNode> n)
: node_(n) {}
explicit Module(ObjectPtr<Object> n)
: ObjectRef(n) {}
/*!
* \brief Get packed function from current module by name.
*
Expand All @@ -59,10 +62,6 @@ class Module {
* \note Implemented in packed_func.cc
*/
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
// The following functions requires link with runtime.
/*!
* \brief Import another module into this module.
Expand All @@ -71,7 +70,11 @@ class Module {
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
TVM_DLL void Import(Module other);
inline void Import(Module other);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
/*!
* \brief Load a module from file.
* \param file_name The name of the host function module.
Expand All @@ -81,20 +84,41 @@ class Module {
*/
TVM_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = "");

private:
std::shared_ptr<ModuleNode> node_;
// refer to the corresponding container.
using ContainerType = ModuleNode;
friend class ModuleNode;
};

/*!
* \brief Base node container of module.
* Do not create this directly, instead use Module.
* \brief Base container of module.
*
* Please subclass ModuleNode to create a specific runtime module.
*
* \code
*
* class MyModuleNode : public ModuleNode {
* public:
* // implement the interface
* };
*
* // use make_object to create a specific
* // instace of MyModuleNode.
* Module CreateMyModule() {
* ObjectPtr<MyModuleNode> n =
* tvm::runtime::make_object<MyModuleNode>();
* return Module(n);
* }
*
* \endcode
*/
class ModuleNode {
class ModuleNode : public Object {
public:
/*! \brief virtual destructor */
virtual ~ModuleNode() {}
/*! \return The module type key */
/*!
* \return The per module type key.
* \note This key is used to for serializing custom modules.
*/
virtual const char* type_key() const = 0;
/*!
* \brief Get a PackedFunc from module.
Expand All @@ -105,7 +129,7 @@ class ModuleNode {
* For benchmarking, use prepare to eliminate
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
* \param sptr_to_self The ObjectPtr that points to this module node.
*
* \return PackedFunc(nullptr) when it is not available.
*
Expand All @@ -115,7 +139,7 @@ class ModuleNode {
*/
virtual PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) = 0;
const ObjectPtr<Object>& sptr_to_self) = 0;
/*!
* \brief Save the module to file.
* \param file_name The file to be saved to.
Expand All @@ -137,6 +161,24 @@ class ModuleNode {
* \return Possible source code when available.
*/
TVM_DLL virtual std::string GetSource(const std::string& format = "");
/*!
* \brief Get packed function from current module by name.
*
* \param name The name of the function.
* \param query_imports Whether also query dependency modules.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
TVM_DLL PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
*
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
TVM_DLL void Import(Module other);
/*!
* \brief Get a function from current environment
* The environment includes all the imports as well as Global functions.
Expand All @@ -150,6 +192,13 @@ class ModuleNode {
return imports_;
}

// integration with the existing components.
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
static constexpr const char* _type_key = "runtime.Module";
// NOTE: ModuleNode can still be sub-classed
//
TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);

protected:
friend class Module;
/*! \brief The modules this module depend on */
Expand Down Expand Up @@ -180,16 +229,21 @@ constexpr const char* tvm_module_main = "__tvm_main__";
} // namespace symbol

// implementations of inline functions.

inline void Module::Import(Module other) {
return (*this)->Import(other);
}

inline ModuleNode* Module::operator->() {
return node_.get();
return static_cast<ModuleNode*>(get_mutable());
}

inline const ModuleNode* Module::operator->() const {
return node_.get();
return static_cast<const ModuleNode*>(get());
}

} // namespace runtime
} // namespace tvm

#include "packed_func.h"
#include <tvm/runtime/packed_func.h> // NOLINT(*)
#endif // TVM_RUNTIME_MODULE_H_
32 changes: 26 additions & 6 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ enum TypeIndex {
kVMTensor = 1,
kVMClosure = 2,
kVMADT = 3,
kRuntimeModule = 4,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
Expand Down Expand Up @@ -302,19 +303,19 @@ class Object {
template<typename>
friend class ObjectPtr;
friend class TVMRetValue;
friend class TVMObjectCAPI;
friend class FFIUtil;
};

/*!
* \brief Get a reference type from a raw object ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
* the object alive beyond the scope of the function.
*
* \param ptr The node pointer
* \param ptr The object pointer
* \tparam RefType The reference type
* \tparam ObjectType The node type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename RefType, typename ObjectType>
Expand Down Expand Up @@ -486,6 +487,8 @@ class ObjectPtr {
friend class TVMArgValue;
template <typename RefType, typename ObjType>
friend RefType GetRef(const ObjType* ptr);
template <typename BaseType, typename ObjType>
friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
};

/*! \brief Base class of all object reference */
Expand Down Expand Up @@ -513,7 +516,7 @@ class ObjectRef {
}
/*!
* \brief Comparator
* \param other Another node ref.
* \param other Another object ref.
* \return the compare result.
*/
bool operator!=(const ObjectRef& other) const {
Expand All @@ -535,7 +538,7 @@ class ObjectRef {
const Object* get() const {
return data_.get();
}
/*! \return the internal node pointer */
/*! \return the internal object pointer */
const Object* operator->() const {
return get();
}
Expand Down Expand Up @@ -595,6 +598,16 @@ class ObjectRef {
friend SubRef Downcast(BaseRef ref);
};

/*!
* \brief Get an object ptr type from a raw object ptr.
*
* \param ptr The object pointer
* \tparam BaseType The reference type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename BaseType, typename ObjectType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);

/*! \brief ObjectRef hash functor */
struct ObjectHash {
Expand Down Expand Up @@ -781,6 +794,13 @@ inline RefType GetRef(const ObjType* ptr) {
return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}

template <typename BaseType, typename ObjType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
static_assert(std::is_base_of<BaseType, ObjType>::value,
"Can only cast to the ref of same container type");
return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
}

template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
Expand Down
Loading

0 comments on commit 75ddf51

Please sign in to comment.