Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME][REFACTOR] Use object protocol to support runtime::Module #4289

Merged
merged 1 commit into from
Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions apps/android_deploy/app/src/main/jni/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -18,7 +18,6 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* \file tvm_runtime.h
* \brief Pack all tvm runtime source files
*/
Expand All @@ -35,6 +34,7 @@
#include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
#include "../src/runtime/thread_pool.cc"
#include "../src/runtime/object.cc"
#include "../src/runtime/threading_backend.cc"
#include "../src/runtime/ndarray.cc"

Expand Down
5 changes: 3 additions & 2 deletions apps/android_rpc/app/src/main/jni/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -55,6 +55,7 @@
#include "../src/runtime/threading_backend.cc"
#include "../src/runtime/graph/graph_runtime.cc"
#include "../src/runtime/ndarray.cc"
#include "../src/runtime/object.cc"

#ifdef TVM_OPENCL_RUNTIME
#include "../src/runtime/opencl/opencl_device_api.cc"
Expand Down
5 changes: 3 additions & 2 deletions apps/bundle_deploy/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -32,5 +32,6 @@
#include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"
#include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/graph/graph_runtime.cc"
1 change: 1 addition & 0 deletions apps/howto_deploy/tvm_runtime_pack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"

// NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use.
Expand Down
3 changes: 2 additions & 1 deletion apps/ios_rpc/tvmrpc/TVMRuntime.mm
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

/*!
* Copyright (c) 2017 by Contributors
* \file TVMRuntime.mm
*/
#include "TVMRuntime.h"
Expand All @@ -35,6 +34,8 @@
#include "../../../src/runtime/file_util.cc"
#include "../../../src/runtime/dso_module.cc"
#include "../../../src/runtime/ndarray.cc"
#include "../../../src/runtime/object.cc"

// RPC server
#include "../../../src/runtime/rpc/rpc_session.cc"
#include "../../../src/runtime/rpc/rpc_server_env.cc"
Expand Down
2 changes: 1 addition & 1 deletion golang/src/tvm_runtime_pack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* \brief This is an all in one TVM runtime file.
* \file tvm_runtime_pack.cc
*/
Expand All @@ -32,6 +31,7 @@
#include "src/runtime/threading_backend.cc"
#include "src/runtime/thread_pool.cc"
#include "src/runtime/ndarray.cc"
#include "src/runtime/object.cc"

// NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use.
Expand Down
98 changes: 76 additions & 22 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,31 @@
#define TVM_RUNTIME_MODULE_H_

#include <dmlc/io.h>

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>

#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "c_runtime_api.h"

namespace tvm {
namespace runtime {

// The internal container of module.
class ModuleNode;
class PackedFunc;

/*!
* \brief Module container of TVM.
*/
class Module {
class Module : public ObjectRef {
public:
Module() {}
// constructor from container.
explicit Module(std::shared_ptr<ModuleNode> n)
: node_(n) {}
explicit Module(ObjectPtr<Object> n)
: ObjectRef(n) {}
/*!
* \brief Get packed function from current module by name.
*
Expand All @@ -59,10 +62,6 @@ class Module {
* \note Implemented in packed_func.cc
*/
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
// The following functions requires link with runtime.
/*!
* \brief Import another module into this module.
Expand All @@ -71,7 +70,11 @@ class Module {
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
TVM_DLL void Import(Module other);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we intentionally remove the TVM_DLL here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DLL function was moved to the Node class

inline void Import(Module other);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
/*!
* \brief Load a module from file.
* \param file_name The name of the host function module.
Expand All @@ -81,20 +84,41 @@ class Module {
*/
TVM_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = "");

private:
std::shared_ptr<ModuleNode> node_;
// refer to the corresponding container.
using ContainerType = ModuleNode;
friend class ModuleNode;
};

/*!
* \brief Base node container of module.
* Do not create this directly, instead use Module.
* \brief Base container of module.
*
* Please subclass ModuleNode to create a specific runtime module.
*
* \code
*
* class MyModuleNode : public ModuleNode {
* public:
* // implement the interface
* };
*
* // use make_object to create a specific
* // instace of MyModuleNode.
* Module CreateMyModule() {
* ObjectPtr<MyModuleNode> n =
* tvm::runtime::make_object<MyModuleNode>();
* return Module(n);
* }
*
* \endcode
*/
class ModuleNode {
class ModuleNode : public Object {
public:
/*! \brief virtual destructor */
virtual ~ModuleNode() {}
/*! \return The module type key */
/*!
* \return The per module type key.
* \note This key is used to for serializing custom modules.
*/
virtual const char* type_key() const = 0;
/*!
* \brief Get a PackedFunc from module.
Expand All @@ -105,7 +129,7 @@ class ModuleNode {
* For benchmarking, use prepare to eliminate
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
* \param sptr_to_self The ObjectPtr that points to this module node.
*
* \return PackedFunc(nullptr) when it is not available.
*
Expand All @@ -115,7 +139,7 @@ class ModuleNode {
*/
virtual PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) = 0;
const ObjectPtr<Object>& sptr_to_self) = 0;
/*!
* \brief Save the module to file.
* \param file_name The file to be saved to.
Expand All @@ -137,6 +161,24 @@ class ModuleNode {
* \return Possible source code when available.
*/
TVM_DLL virtual std::string GetSource(const std::string& format = "");
/*!
* \brief Get packed function from current module by name.
*
* \param name The name of the function.
* \param query_imports Whether also query dependency modules.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
TVM_DLL PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
*
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
TVM_DLL void Import(Module other);
/*!
* \brief Get a function from current environment
* The environment includes all the imports as well as Global functions.
Expand All @@ -150,6 +192,13 @@ class ModuleNode {
return imports_;
}

// integration with the existing components.
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
static constexpr const char* _type_key = "runtime.Module";
// NOTE: ModuleNode can still be sub-classed
//
TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);

protected:
friend class Module;
/*! \brief The modules this module depend on */
Expand Down Expand Up @@ -180,16 +229,21 @@ constexpr const char* tvm_module_main = "__tvm_main__";
} // namespace symbol

// implementations of inline functions.

inline void Module::Import(Module other) {
return (*this)->Import(other);
}

inline ModuleNode* Module::operator->() {
return node_.get();
return static_cast<ModuleNode*>(get_mutable());
}

inline const ModuleNode* Module::operator->() const {
return node_.get();
return static_cast<const ModuleNode*>(get());
}

} // namespace runtime
} // namespace tvm

#include "packed_func.h"
#include <tvm/runtime/packed_func.h> // NOLINT(*)
#endif // TVM_RUNTIME_MODULE_H_
32 changes: 26 additions & 6 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ enum TypeIndex {
kVMTensor = 1,
kVMClosure = 2,
kVMADT = 3,
kRuntimeModule = 4,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
Expand Down Expand Up @@ -302,19 +303,19 @@ class Object {
template<typename>
friend class ObjectPtr;
friend class TVMRetValue;
friend class TVMObjectCAPI;
friend class ObjectInternal;
};

/*!
* \brief Get a reference type from a raw object ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
* the object alive beyond the scope of the function.
*
* \param ptr The node pointer
* \param ptr The object pointer
* \tparam RefType The reference type
* \tparam ObjectType The node type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename RefType, typename ObjectType>
Expand Down Expand Up @@ -486,6 +487,8 @@ class ObjectPtr {
friend class TVMArgValue;
template <typename RefType, typename ObjType>
friend RefType GetRef(const ObjType* ptr);
template <typename BaseType, typename ObjType>
friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
};

/*! \brief Base class of all object reference */
Expand Down Expand Up @@ -513,7 +516,7 @@ class ObjectRef {
}
/*!
* \brief Comparator
* \param other Another node ref.
* \param other Another object ref.
* \return the compare result.
*/
bool operator!=(const ObjectRef& other) const {
Expand All @@ -535,7 +538,7 @@ class ObjectRef {
const Object* get() const {
return data_.get();
}
/*! \return the internal node pointer */
/*! \return the internal object pointer */
const Object* operator->() const {
return get();
}
Expand Down Expand Up @@ -595,6 +598,16 @@ class ObjectRef {
friend SubRef Downcast(BaseRef ref);
};

/*!
* \brief Get an object ptr type from a raw object ptr.
*
* \param ptr The object pointer
* \tparam BaseType The reference type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename BaseType, typename ObjectType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);

/*! \brief ObjectRef hash functor */
struct ObjectHash {
Expand Down Expand Up @@ -781,6 +794,13 @@ inline RefType GetRef(const ObjType* ptr) {
return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}

template <typename BaseType, typename ObjType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
static_assert(std::is_base_of<BaseType, ObjType>::value,
"Can only cast to the ref of same container type");
return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
}

template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
Expand Down
Loading