diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index a54f98a558f3..508b34b3517e 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -41,6 +41,35 @@ namespace tvm { namespace runtime { +/*! + * \brief Property of runtime module + * We classify the property of runtime module into the following categories. + */ +enum ModulePropertyMask : int { + /*! \brief kBinarySerializable + * we can serialize the module to the stream of bytes. CUDA/OpenCL/JSON + * runtime are representative examples. A binary exportable module can be integrated into final + * runtime artifact by being serialized as data into the artifact, then deserialized at runtime. + * This class of modules must implement SaveToBinary, and have a matching deserializer registered + * as 'runtime.module.loadbinary_'. + */ + kBinarySerializable = 0b001, + /*! \brief kRunnable + * we can run the module directly. LLVM/CUDA/JSON runtime, executors (e.g, + * virtual machine) runtimes are runnable. Non-runnable modules, such as CSourceModule, requires a + * few extra steps (e.g,. compilation, link) to make it runnable. + */ + kRunnable = 0b010, + /*! \brief kDSOExportable + * we can export the module as DSO. A DSO exportable module (e.g., a + * CSourceModuleNode of type_key 'c') can be incorporated into the final runtime artifact (ie + * shared library) by compilation and/or linking using the external compiler (llvm, nvcc, etc). + * DSO exportable modules must implement SaveToFile. In general, DSO exportable modules are not + * runnable unless there is a special support like JIT for `LLVMModule`. + */ + kDSOExportable = 0b100 +}; + class ModuleNode; class PackedFunc; @@ -193,20 +222,16 @@ class TVM_DLL ModuleNode : public Object { const std::vector& imports() const { return imports_; } /*! - * \brief Returns true if this module is 'DSO exportable'. - * - * A DSO exportable module (eg a CSourceModuleNode of type_key 'c') can be incorporated into the - * final runtime artifact (ie shared library) by compilation and/or linking using the external - * compiler (llvm, nvcc, etc). DSO exportable modules must implement SaveToFile. - * - * By contrast, non-DSO exportable modules (eg CUDAModuleNode of type_key 'cuda') typically must - * be incorporated into the final runtime artifact by being serialized as data into the - * artifact, then deserialized at runtime. Non-DSO exportable modules must implement SaveToBinary, - * and have a matching deserializer registered as 'runtime.module.loadbinary_'. - * - * The default implementation returns false. + * \brief Returns bitmap of property. + * By default, none of the property is set. Derived class can override this function and set its + * own property. */ - virtual bool IsDSOExportable() const; + virtual int GetPropertyMask() const { return 0b000; } + + /*! \brief Returns true if this module is 'DSO exportable'. */ + bool IsDSOExportable() const { + return (GetPropertyMask() & ModulePropertyMask::kDSOExportable) != 0; + } /*! * \brief Returns true if this module has a definition for a function of \p name. If diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index fdbc1769c353..4c24d7deadaa 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -66,6 +66,9 @@ class TVM_DLL Executable : public ModuleNode { */ PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; + /*! * \brief Write the Executable to the binary stream in serialized form. * diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 6fa91832a731..52d80d3fea48 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -178,6 +178,9 @@ class TVM_DLL VirtualMachine : public runtime::ModuleNode { */ virtual void LoadExecutable(const ObjectPtr& exec); + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } + protected: /*! \brief Push a call frame on to the call stack. */ void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func); diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 83b436939e9f..c78a6d9c3136 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -96,6 +96,14 @@ def __str__(self): ) +class ModulePropertyMask(object): + """Runtime Module Property Mask.""" + + BINARY_SERIALIZABLE = 0b001 + RUNNABLE = 0b010 + DSO_EXPORTABLE = 0b100 + + class Module(object): """Runtime Module.""" @@ -239,6 +247,40 @@ def imported_modules(self): nmod = _ffi_api.ModuleImportsSize(self) return [_ffi_api.ModuleGetImport(self, i) for i in range(nmod)] + def get_property_mask(self): + """Get the runtime module property mask. The mapping is stated in ModulePropertyMask. + + Returns + ------- + mask : int + Bitmask of runtime module property + """ + return _ffi_api.ModuleGetPropertyMask(self) + + @property + def is_binary_serializable(self): + """Returns true if module is 'binary serializable', ie can be serialzed into binary + stream and loaded back to the runtime module. + + Returns + ------- + b : Bool + True if the module is binary serializable. + """ + return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0 + + @property + def is_runnable(self): + """Returns true if module is 'runnable'. ie can be executed without any extra + compilation/linking steps. + + Returns + ------- + b : Bool + True if the module is runnable. + """ + return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 + @property def is_dso_exportable(self): """Returns true if module is 'DSO exportable', ie can be included in result of @@ -249,7 +291,7 @@ def is_dso_exportable(self): b : Bool True if the module is DSO exportable. """ - return _ffi_api.ModuleIsDSOExportable(self) + return (self.get_property_mask() & ModulePropertyMask.DSO_EXPORTABLE) != 0 def save(self, file_name, fmt=""): """Save the module to file. @@ -383,6 +425,10 @@ def _collect_from_import_tree(self, filter_func): stack.append(self) while stack: module = stack.pop() + assert ( + module.is_dso_exportable or module.is_binary_serializable + ), f"Module {module.type_key} should be either dso exportable or binary serializable." + if filter_func(module): dso_modules.append(module) for m in module.imported_modules: diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 6bbb43f50f21..8f7098c24aea 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1358,6 +1358,9 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kRunnable; } + private: void init(void* mod, const Array& targets) { codegen_ = diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 0642c0c67253..856a7700784a 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -283,6 +283,9 @@ class RelayBuildModule : public runtime::ModuleNode { */ const char* type_key() const final { return "RelayBuildModule"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kRunnable; } + /*! * \brief Build relay IRModule for graph executor * diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc index f66ebd5ed2b2..a2662b9018bf 100644 --- a/src/relay/backend/contrib/ethosu/source_module.cc +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -121,7 +121,8 @@ class EthosUModuleNode : public ModuleNode { return Module(n); } - bool IsDSOExportable() const final { return true; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const { return ModulePropertyMask::kDSOExportable; } bool ImplementsFunction(const String& name, bool query_imports) final { return std::find_if(compilation_artifacts_.begin(), compilation_artifacts_.end(), diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 78d4dde19a29..d8ce0e59b167 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -687,6 +687,9 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { const char* type_key() const final { return "RelayGraphExecutorCodegenModule"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kRunnable; } + private: std::shared_ptr codegen_; LoweredOutput output_; diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 9160ce0e2e42..8a5faa40b55a 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -93,6 +93,9 @@ class VMCompiler : public runtime::ModuleNode { const char* type_key() const final { return "VMCompiler"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } + /*! * \brief Set the parameters * diff --git a/src/relay/printer/model_library_format_printer.cc b/src/relay/printer/model_library_format_printer.cc index 76d0f1423d4f..994b3ae09c6e 100644 --- a/src/relay/printer/model_library_format_printer.cc +++ b/src/relay/printer/model_library_format_printer.cc @@ -37,6 +37,9 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { const char* type_key() const final { return "model_library_format_printer"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kRunnable; } + std::string Print(const ObjectRef& node) { std::ostringstream oss; oss << node; diff --git a/src/runtime/aot_executor/aot_executor.h b/src/runtime/aot_executor/aot_executor.h index cc86381624ce..ab30ab80269e 100644 --- a/src/runtime/aot_executor/aot_executor.h +++ b/src/runtime/aot_executor/aot_executor.h @@ -51,6 +51,9 @@ class TVM_DLL AotExecutor : public ModuleNode { */ const char* type_key() const final { return "AotExecutor"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } + void Run(); /*! diff --git a/src/runtime/aot_executor/aot_executor_factory.h b/src/runtime/aot_executor/aot_executor_factory.h index ada63f0ba8ee..4c6e36fc1186 100644 --- a/src/runtime/aot_executor/aot_executor_factory.h +++ b/src/runtime/aot_executor/aot_executor_factory.h @@ -65,6 +65,9 @@ class TVM_DLL AotExecutorFactory : public runtime::ModuleNode { */ const char* type_key() const final { return "AotExecutorFactory"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; } + /*! * \brief Save the module to binary stream. * \param stream The binary stream to save to. diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index f57c7d11d51e..75e094a63a6e 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -100,6 +100,9 @@ class ConstLoaderModuleNode : public ModuleNode { const char* type_key() const final { return "const_loader"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; + /*! * \brief Get the list of constants that is required by the given module. * \param symbol The symbol that is being queried. diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h index d1deb852c02b..80706425ba09 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.h +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -105,6 +105,11 @@ class CoreMLRuntime : public ModuleNode { */ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } + /*! * \brief Serialize the content of the mlmodelc directory and save it to * binary stream. diff --git a/src/runtime/contrib/ethosn/ethosn_runtime.h b/src/runtime/contrib/ethosn/ethosn_runtime.h index 57dc464ab2af..b887b7348079 100644 --- a/src/runtime/contrib/ethosn/ethosn_runtime.h +++ b/src/runtime/contrib/ethosn/ethosn_runtime.h @@ -104,6 +104,11 @@ class EthosnModule : public ModuleNode { const char* type_key() const override { return "ethos-n"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + }; + private: /*! \brief A map between ext_symbols (function names) and ordered compiled networks. */ std::map network_map_; diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index c84e659c6bb7..51ce2cffd780 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -58,6 +58,11 @@ class JSONRuntimeBase : public ModuleNode { const char* type_key() const override { return "json"; } // May be overridden + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } + /*! \brief Initialize a specific json runtime. */ virtual void Init(const Array& consts) = 0; diff --git a/src/runtime/contrib/libtorch/libtorch_runtime.cc b/src/runtime/contrib/libtorch/libtorch_runtime.cc index e76d04389ec7..48ccfc749674 100644 --- a/src/runtime/contrib/libtorch/libtorch_runtime.cc +++ b/src/runtime/contrib/libtorch/libtorch_runtime.cc @@ -88,6 +88,10 @@ class TorchModuleNode : public ModuleNode { : symbol_name_(symbol_name), module_(module) {} const char* type_key() const { return "torch"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } /*! * \brief Get a packed function. diff --git a/src/runtime/contrib/onnx/onnx_module.cc b/src/runtime/contrib/onnx/onnx_module.cc index 8732b700a218..384a368e287e 100644 --- a/src/runtime/contrib/onnx/onnx_module.cc +++ b/src/runtime/contrib/onnx/onnx_module.cc @@ -35,6 +35,9 @@ class ONNXSourceModuleNode : public runtime::ModuleNode { : code_(code), symbol_(symbol), const_vars_(const_vars) {} const char* type_key() const { return "onnx"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; }; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "get_symbol") { return PackedFunc( diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index b51684b95eb8..e1f205e22f10 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -97,6 +97,11 @@ class TensorRTRuntime : public JSONRuntimeBase { */ const char* type_key() const final { return "tensorrt"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } + /*! * \brief Initialize runtime. Create TensorRT layer from JSON * representation. diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 759be24b94ec..2a524479593a 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -60,6 +60,9 @@ class TFLiteRuntime : public ModuleNode { */ const char* type_key() const { return "TFLiteRuntime"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; }; + /*! * \brief Invoke the internal tflite interpreter and run the whole model in * dependency order. diff --git a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h index cad3b5e5a7ff..ccaa88c1ac42 100755 --- a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h +++ b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h @@ -86,6 +86,11 @@ class VitisAIRuntime : public ModuleNode { */ const char* type_key() const { return "VitisAIRuntime"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + }; + /*! * \brief Serialize the content of the pyxir directory and save it to * binary stream. diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 739875fe850f..240e1fe1aa7a 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -65,6 +65,11 @@ class CUDAModuleNode : public runtime::ModuleNode { const char* type_key() const final { return "cuda"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; void SaveToFile(const std::string& file_name, const std::string& format) final { diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index 9fce154870cd..0a7086c9f125 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -89,6 +89,9 @@ class TVM_DLL GraphExecutor : public ModuleNode { const char* type_key() const final { return "GraphExecutor"; } void Run(); + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } + /*! * \brief Initialize the graph executor with graph and device. * \param graph_json The execution graph. diff --git a/src/runtime/graph_executor/graph_executor_factory.h b/src/runtime/graph_executor/graph_executor_factory.h index d8ebe44bb972..2766dfafc29d 100644 --- a/src/runtime/graph_executor/graph_executor_factory.h +++ b/src/runtime/graph_executor/graph_executor_factory.h @@ -67,6 +67,9 @@ class TVM_DLL GraphExecutorFactory : public runtime::ModuleNode { */ const char* type_key() const final { return "GraphExecutorFactory"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; } + /*! * \brief Save the module to binary stream. * \param stream The binary stream to save to. diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index aac75002c258..f5d1d0200c03 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -62,6 +62,10 @@ class HexagonModuleNode : public runtime::ModuleNode { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; std::string GetSource(const std::string& format) override; const char* type_key() const final { return "hexagon"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } void SaveToFile(const std::string& file_name, const std::string& format) override; void SaveToBinary(dmlc::Stream* stream) override; diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index d6c2f791deb9..eed41dfc2b99 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -42,6 +42,11 @@ class LibraryModuleNode final : public ModuleNode { const char* type_key() const final { return "library"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + }; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index 2120ffe40d67..946ebf1232d2 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -88,6 +88,9 @@ class MetadataModuleNode : public ::tvm::runtime::ModuleNode { const char* type_key() const final { return "metadata_module"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; } + static Module LoadFromBinary() { return Module(make_object(runtime::metadata::Metadata())); } diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 9ef57e905324..298fd588d5e1 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -132,8 +132,6 @@ std::string ModuleNode::GetFormat() { LOG(FATAL) << "Module[" << type_key() << "] does not support GetFormat"; } -bool ModuleNode::IsDSOExportable() const { return false; } - bool ModuleNode::ImplementsFunction(const String& name, bool query_imports) { return GetFunction(name, query_imports) != nullptr; } @@ -199,8 +197,8 @@ TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFro TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") .set_body_typed([](Module mod, String name, tvm::String fmt) { mod->SaveToFile(name, fmt); }); -TVM_REGISTER_GLOBAL("runtime.ModuleIsDSOExportable").set_body_typed([](Module mod) { - return mod->IsDSOExportable(); +TVM_REGISTER_GLOBAL("runtime.ModuleGetPropertyMask").set_body_typed([](Module mod) { + return mod->GetPropertyMask(); }); TVM_REGISTER_GLOBAL("runtime.ModuleImplementsFunction") diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index cef49fdd50ca..a8a4cf3dc65c 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -431,6 +431,11 @@ class OpenCLModuleNode : public ModuleNode { const char* type_key() const final { return workspace_->type_key.c_str(); } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 968bd773e453..ed769d97ab36 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -175,6 +175,8 @@ class RPCModuleNode final : public ModuleNode { } const char* type_key() const final { return "rpc"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "CloseRPCConnection") { diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index e845d0fac225..46038c2defab 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -62,7 +62,9 @@ class StaticLibraryNode final : public runtime::ModuleNode { SaveBinaryToFile(file_name, data_); } - bool IsDSOExportable() const final { return true; } + // TODO(tvm-team): Make this module serializable + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const { return ModulePropertyMask::kDSOExportable; } bool ImplementsFunction(const String& name, bool query_imports) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 50dcd7402a47..b616737b0436 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -93,6 +93,12 @@ class LLVMModuleNode final : public runtime::ModuleNode { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + /*! \brief Get the property of the runtime module .*/ + // TODO(tvm-team): Make it serializable + int GetPropertyMask() const { + return runtime::ModulePropertyMask::kRunnable | runtime::ModulePropertyMask::kDSOExportable; + } + void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; std::string GetSource(const std::string& format) final; @@ -100,7 +106,6 @@ class LLVMModuleNode final : public runtime::ModuleNode { void Init(const IRModule& mod, const Target& target); void Init(std::unique_ptr module, std::unique_ptr llvm_instance); void LoadIR(const std::string& file_name); - bool IsDSOExportable() const final { return true; } bool ImplementsFunction(const String& name, bool query_imports) final; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index e4ccef88b62f..fd770007e243 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -489,6 +489,8 @@ class WebGPUSourceModuleNode final : public runtime::ModuleNode { : smap_(smap), fmap_(fmap) {} const char* type_key() const final { return "webgpu"; } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kBinarySerializable; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index ee5a7cd33de9..9ec9dbbbfedb 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -130,7 +130,7 @@ class CSourceModuleNode : public runtime::ModuleNode { } } - bool IsDSOExportable() const final { return true; } + int GetPropertyMask() const { return runtime::ModulePropertyMask::kDSOExportable; } bool ImplementsFunction(const String& name, bool query_imports) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); @@ -200,7 +200,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } } - bool IsDSOExportable() const final { return true; } + int GetPropertyMask() const { return runtime::ModulePropertyMask::kDSOExportable; } bool ImplementsFunction(const String& name, bool query_imports) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); @@ -1009,6 +1009,8 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { } const char* type_key() const final { return type_key_.c_str(); } + /*! \brief Get the property of the runtime module .*/ + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kBinarySerializable; } void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index d4886456d98b..e4f8a4fcd73e 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -414,7 +414,7 @@ def test_export_non_dso_exportable(): temp_dir = utils.tempdir() - with pytest.raises(micro.UnsupportedInModelLibraryFormatError) as exc: + with pytest.raises(AssertionError) as exc: model_library_format._populate_codegen_dir([module], temp_dir.relpath("codegen")) assert str(exc.exception) == ( diff --git a/tests/python/unittest/test_runtime_module_property.py b/tests/python/unittest/test_runtime_module_property.py new file mode 100644 index 000000000000..30af8d086a42 --- /dev/null +++ b/tests/python/unittest/test_runtime_module_property.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +import tvm.runtime._ffi_api +import tvm.target._ffi_api + + +def checker(mod, expected): + assert mod.is_binary_serializable == expected["is_binary_serializable"] + assert mod.is_runnable == expected["is_runnable"] + assert mod.is_dso_exportable == expected["is_dso_exportable"] + + +def create_csource_module(): + return tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], None) + + +def create_llvm_module(): + A = te.placeholder((1024,), name="A") + B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") + s = te.create_schedule(B.op) + return tvm.build(s, [A, B], "llvm", name="myadd0") + + +def create_aot_module(): + return tvm.get_global_func("relay.build_module._AOTExecutorCodegen")() + + +def test_property(): + checker( + create_csource_module(), + expected={"is_binary_serializable": False, "is_runnable": False, "is_dso_exportable": True}, + ) + + checker( + create_llvm_module(), + expected={"is_binary_serializable": False, "is_runnable": True, "is_dso_exportable": True}, + ) + + checker( + create_aot_module(), + expected={"is_binary_serializable": False, "is_runnable": True, "is_dso_exportable": False}, + ) + + +if __name__ == "__main__": + tvm.testing.main()