Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime] Introduce runtime module property #14406

Merged
merged 8 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,31 @@
namespace tvm {
namespace runtime {

/*!
* \brief Property of runtime module
* We classify the property of runtime module into the following categories.
* - kBinarySerializable: we can serialize the module to the stream of bytes. CUDA/OpenCL/JSON
* runtime are representative examples. A binary exportable module can be integrated into final
* runtime artifact by being serialized as data into the artifact, then deserialzied at runtime.
* This class of modules must implement SaveToBinary, and have a matching deserializer registered as
* 'runtime.module.loadbinary_<type_key>'.
* - kRunnable: we can run the module directly. LLVM/CUDA/JSON runtime, executors (e.g,
sunggg marked this conversation as resolved.
Show resolved Hide resolved
* virtual machine) runtimes are runnable. Non-runnable modules, such as CSourceModule, requires a
* few extra steps (e.g,. compilation, link) to make it runnable.
* - kDSOExportable: we can export the module as DSO. A DSO exportable module (e.g., a
* CSourceModuleNode of type_key 'c') can be incorporated into the final runtime artifact (ie shared
* library) by compilation and/or linking using the external compiler (llvm, nvcc, etc). DSO
* exportable modules must implement SaveToFile.
*/
namespace property {
/*! \brief Binary serializable runtime module */
constexpr const uint8_t kBinarySerializable = 0b001;
Copy link
Member

@tqchen tqchen Mar 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us use int to keep things simple, make it an enum

enum class ModulePropertyMask : int

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used enum ModulePropertyMask : int instead for implicit casting.

/*! \brief Runnable runtime module */
constexpr const uint8_t kRunnable = 0b010;
/*! \brief DSO exportable runtime module */
constexpr const uint8_t kDSOExportable = 0b100;
}; // namespace property

class ModuleNode;
class PackedFunc;

Expand Down Expand Up @@ -193,20 +218,14 @@ class TVM_DLL ModuleNode : public Object {
const std::vector<Module>& imports() const { return imports_; }

/*!
* \brief Returns true if this module is 'DSO exportable'.
*
* A DSO exportable module (eg a CSourceModuleNode of type_key 'c') can be incorporated into the
* final runtime artifact (ie shared library) by compilation and/or linking using the external
* compiler (llvm, nvcc, etc). DSO exportable modules must implement SaveToFile.
*
* By contrast, non-DSO exportable modules (eg CUDAModuleNode of type_key 'cuda') typically must
* be incorporated into the final runtime artifact by being serialized as data into the
* artifact, then deserialized at runtime. Non-DSO exportable modules must implement SaveToBinary,
* and have a matching deserializer registered as 'runtime.module.loadbinary_<type_key>'.
*
* The default implementation returns false.
* \brief Returns bitmap of property.
* By default, none of the property is set. Derived class can override this function and set its
* own property.
*/
virtual bool IsDSOExportable() const;
virtual uint8_t GetProperty() const { return 0x000; }

/*! \brief Returns true if this module is 'DSO exportable'. */
bool IsDSOExportable() const { return GetProperty() == property::kDSOExportable; };

/*!
* \brief Returns true if this module has a definition for a function of \p name. If
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class TVM_DLL Executable : public ModuleNode {
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kBinarySerializable; };

/*!
* \brief Write the Executable to the binary stream in serialized form.
*
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ class TVM_DLL VirtualMachine : public runtime::ModuleNode {
*/
virtual void LoadExecutable(const ObjectPtr<Executable>& exec);

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kRunnable; }

protected:
/*! \brief Push a call frame on to the call stack. */
void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func);
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,8 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
}

const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; }
/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return runtime::property::kRunnable; }

private:
void init(void* mod, const Array<Target>& targets) {
Expand Down
3 changes: 3 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ class RelayBuildModule : public runtime::ModuleNode {
*/
const char* type_key() const final { return "RelayBuildModule"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return runtime::property::kRunnable; }

/*!
* \brief Build relay IRModule for graph executor
*
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/ethosu/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class EthosUModuleNode : public ModuleNode {
return Module(n);
}

bool IsDSOExportable() const final { return true; }
uint8_t GetProperty() const { return property::kDSOExportable; }

bool ImplementsFunction(const String& name, bool query_imports) final {
return std::find_if(compilation_artifacts_.begin(), compilation_artifacts_.end(),
Expand Down
3 changes: 3 additions & 0 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,9 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode {

const char* type_key() const final { return "RelayGraphExecutorCodegenModule"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return runtime::property::kRunnable; }

private:
std::shared_ptr<GraphExecutorCodegen> codegen_;
LoweredOutput output_;
Expand Down
3 changes: 3 additions & 0 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class VMCompiler : public runtime::ModuleNode {

const char* type_key() const final { return "VMCompiler"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kRunnable; }

/*!
* \brief Set the parameters
*
Expand Down
3 changes: 3 additions & 0 deletions src/relay/printer/model_library_format_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode {

const char* type_key() const final { return "model_library_format_printer"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return runtime::property::kRunnable; }

std::string Print(const ObjectRef& node) {
std::ostringstream oss;
oss << node;
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/aot_executor/aot_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class TVM_DLL AotExecutor : public ModuleNode {
*/
const char* type_key() const final { return "AotExecutor"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kRunnable; }

void Run();

/*!
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/aot_executor/aot_executor_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class TVM_DLL AotExecutorFactory : public runtime::ModuleNode {
*/
const char* type_key() const final { return "AotExecutorFactory"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kBinarySerializable; }

/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/const_loader_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class ConstLoaderModuleNode : public ModuleNode {

const char* type_key() const final { return "const_loader"; }

uint8_t GetProperty() const final { return property::kBinarySerializable; };

/*!
* \brief Get the list of constants that is required by the given module.
* \param symbol The symbol that is being queried.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/contrib/coreml/coreml_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class CoreMLRuntime : public ModuleNode {
*/
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kBinarySerializable | property::kRunnable; };

/*!
* \brief Serialize the content of the mlmodelc directory and save it to
* binary stream.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/contrib/json/json_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class JSONRuntimeBase : public ModuleNode {

const char* type_key() const override { return "json"; } // May be overridden

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const { return property::kBinarySerializable | property::kRunnable; }

/*! \brief Initialize a specific json runtime. */
virtual void Init(const Array<NDArray>& consts) = 0;

Expand Down
2 changes: 2 additions & 0 deletions src/runtime/contrib/libtorch/libtorch_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class TorchModuleNode : public ModuleNode {
: symbol_name_(symbol_name), module_(module) {}

const char* type_key() const { return "torch"; }
/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kBinarySerializable | property::kRunnable; };

/*!
* \brief Get a packed function.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/contrib/onnx/onnx_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class ONNXSourceModuleNode : public runtime::ModuleNode {
: code_(code), symbol_(symbol), const_vars_(const_vars) {}
const char* type_key() const { return "onnx"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kRunnable; };

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "get_symbol") {
return PackedFunc(
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/contrib/tensorrt/tensorrt_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class TensorRTRuntime : public JSONRuntimeBase {
*/
const char* type_key() const final { return "tensorrt"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kBinarySerializable | property::kRunnable; }

/*!
* \brief Initialize runtime. Create TensorRT layer from JSON
* representation.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/contrib/tflite/tflite_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class TFLiteRuntime : public ModuleNode {
*/
const char* type_key() const { return "TFLiteRuntime"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kRunnable; };
sunggg marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Invoke the internal tflite interpreter and run the whole model in
* dependency order.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/contrib/vitis_ai/vitis_ai_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class VitisAIRuntime : public ModuleNode {
*/
const char* type_key() const { return "VitisAIRuntime"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kBinarySerializable | property::kRunnable; };

/*!
* \brief Serialize the content of the pyxir directory and save it to
* binary stream.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class CUDAModuleNode : public runtime::ModuleNode {

const char* type_key() const final { return "cuda"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kBinarySerializable | property::kRunnable; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

void SaveToFile(const std::string& file_name, const std::string& format) final {
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/graph_executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class TVM_DLL GraphExecutor : public ModuleNode {
const char* type_key() const final { return "GraphExecutor"; }
void Run();

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kRunnable; }

/*!
* \brief Initialize the graph executor with graph and device.
* \param graph_json The execution graph.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/graph_executor/graph_executor_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class TVM_DLL GraphExecutorFactory : public runtime::ModuleNode {
*/
const char* type_key() const final { return "GraphExecutorFactory"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kBinarySerializable; }

/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/hexagon/hexagon_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class HexagonModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override;
std::string GetSource(const std::string& format) override;
const char* type_key() const final { return "hexagon"; }
virtual uint8_t GetProperty() const {
return property::kBinarySerializable | property::kRunnable;
};
void SaveToFile(const std::string& file_name, const std::string& format) override;
void SaveToBinary(dmlc::Stream* stream) override;

Expand Down
1 change: 1 addition & 0 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class LibraryModuleNode final : public ModuleNode {
: lib_(lib), packed_func_wrapper_(wrapper) {}

const char* type_key() const final { return "library"; }
uint8_t GetProperty() const final { return property::kBinarySerializable | property::kRunnable; };

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
TVMBackendPackedCFunc faddr;
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class MetadataModuleNode : public ::tvm::runtime::ModuleNode {

const char* type_key() const final { return "metadata_module"; }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kBinarySerializable; }

static Module LoadFromBinary() {
return Module(make_object<MetadataModuleNode>(runtime::metadata::Metadata()));
}
Expand Down
2 changes: 0 additions & 2 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ std::string ModuleNode::GetFormat() {
LOG(FATAL) << "Module[" << type_key() << "] does not support GetFormat";
}

bool ModuleNode::IsDSOExportable() const { return false; }

bool ModuleNode::ImplementsFunction(const String& name, bool query_imports) {
return GetFunction(name, query_imports) != nullptr;
}
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,9 @@ class OpenCLModuleNode : public ModuleNode {

const char* type_key() const final { return workspace_->type_key.c_str(); }

/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kBinarySerializable | property::kRunnable; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name, const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class RPCModuleNode final : public ModuleNode {
}

const char* type_key() const final { return "rpc"; }
/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return property::kRunnable; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "CloseRPCConnection") {
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/static_library.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class StaticLibraryNode final : public runtime::ModuleNode {
SaveBinaryToFile(file_name, data_);
}

bool IsDSOExportable() const final { return true; }
// TODO(tvm-team): Make this module serializable
uint8_t GetProperty() const { return property::kDSOExportable; }

bool ImplementsFunction(const String& name, bool query_imports) final {
return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
Expand Down
6 changes: 4 additions & 2 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,17 @@ class LLVMModuleNode final : public runtime::ModuleNode {
const char* type_key() const final { return "llvm"; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

virtual uint8_t GetProperty() const {
return runtime::property::kBinarySerializable | runtime::property::kRunnable |
runtime::property::kDSOExportable;
};
void SaveToFile(const std::string& file_name, const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
std::string GetSource(const std::string& format) final;

void Init(const IRModule& mod, const Target& target);
void Init(std::unique_ptr<llvm::Module> module, std::unique_ptr<LLVMInstance> llvm_instance);
void LoadIR(const std::string& file_name);
bool IsDSOExportable() const final { return true; }

bool ImplementsFunction(const String& name, bool query_imports) final;

Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ class WebGPUSourceModuleNode final : public runtime::ModuleNode {
: smap_(smap), fmap_(fmap) {}

const char* type_key() const final { return "webgpu"; }
/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return runtime::property::kBinarySerializable; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs";
Expand Down
6 changes: 4 additions & 2 deletions src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
}
}

bool IsDSOExportable() const final { return true; }
uint8_t GetProperty() const { return runtime::property::kDSOExportable; }

bool ImplementsFunction(const String& name, bool query_imports) final {
return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
Expand Down Expand Up @@ -200,7 +200,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
}
}

bool IsDSOExportable() const final { return true; }
uint8_t GetProperty() const { return runtime::property::kDSOExportable; }

bool ImplementsFunction(const String& name, bool query_imports) final {
return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
Expand Down Expand Up @@ -1009,6 +1009,8 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode {
}

const char* type_key() const final { return type_key_.c_str(); }
/*! \brief Get the property of the runtime module .*/
uint8_t GetProperty() const final { return runtime::property::kBinarySerializable; }

void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
Expand Down