From 8dc164e6bcd99dc1e0e781f83fdebbaaa5e73ec2 Mon Sep 17 00:00:00 2001 From: sung Date: Mon, 27 Mar 2023 00:33:46 -0700 Subject: [PATCH] reflect feedback --- include/tvm/runtime/module.h | 38 +++++++----- include/tvm/runtime/vm/executable.h | 2 +- include/tvm/runtime/vm/vm.h | 2 +- python/tvm/runtime/module.py | 46 +++++++++++++- src/relay/backend/aot_executor_codegen.cc | 4 +- src/relay/backend/build_module.cc | 2 +- .../backend/contrib/ethosu/source_module.cc | 2 +- src/relay/backend/graph_executor_codegen.cc | 2 +- src/relay/backend/vm/compiler.h | 2 +- .../printer/model_library_format_printer.cc | 2 +- src/runtime/aot_executor/aot_executor.h | 2 +- .../aot_executor/aot_executor_factory.h | 2 +- src/runtime/const_loader_module.cc | 2 +- src/runtime/contrib/coreml/coreml_runtime.h | 4 +- src/runtime/contrib/json/json_runtime.h | 4 +- .../contrib/libtorch/libtorch_runtime.cc | 4 +- src/runtime/contrib/onnx/onnx_module.cc | 2 +- .../contrib/tensorrt/tensorrt_runtime.cc | 2 +- src/runtime/contrib/tflite/tflite_runtime.h | 2 +- .../contrib/vitis_ai/vitis_ai_runtime.h | 2 +- src/runtime/cuda/cuda_module.cc | 2 +- src/runtime/graph_executor/graph_executor.h | 2 +- .../graph_executor/graph_executor_factory.h | 2 +- src/runtime/hexagon/hexagon_module.h | 4 +- src/runtime/library_module.cc | 4 +- src/runtime/metadata.cc | 2 +- src/runtime/module.cc | 4 +- src/runtime/opencl/opencl_common.h | 2 +- src/runtime/rpc/rpc_module.cc | 2 +- src/runtime/static_library.cc | 2 +- src/target/llvm/llvm_module.cc | 4 +- src/target/source/codegen_webgpu.cc | 2 +- src/target/source/source_module.cc | 6 +- .../unittest/test_runtime_module_property.py | 62 +++++++++++++++++++ 34 files changed, 174 insertions(+), 54 deletions(-) create mode 100644 tests/python/unittest/test_runtime_module_property.py diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index b629e6176498..508b34b3517e 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -44,23 +44,29 @@ namespace runtime { /*! * \brief Property of runtime module * We classify the property of runtime module into the following categories. - * - kBinarySerializable: we can serialize the module to the stream of bytes. CUDA/OpenCL/JSON - * runtime are representative examples. A binary exportable module can be integrated into final - * runtime artifact by being serialized as data into the artifact, then deserialzied at runtime. - * This class of modules must implement SaveToBinary, and have a matching deserializer registered as - * 'runtime.module.loadbinary_'. - * - kRunnable: we can run the module directly. LLVM/CUDA/JSON runtime, executors (e.g, - * virtual machine) runtimes are runnable. Non-runnable modules, such as CSourceModule, requires a - * few extra steps (e.g,. compilation, link) to make it runnable. - * - kDSOExportable: we can export the module as DSO. A DSO exportable module (e.g., a - * CSourceModuleNode of type_key 'c') can be incorporated into the final runtime artifact (ie shared - * library) by compilation and/or linking using the external compiler (llvm, nvcc, etc). DSO - * exportable modules must implement SaveToFile. */ - enum ModulePropertyMask : int { + /*! \brief kBinarySerializable + * we can serialize the module to the stream of bytes. CUDA/OpenCL/JSON + * runtime are representative examples. A binary exportable module can be integrated into final + * runtime artifact by being serialized as data into the artifact, then deserialized at runtime. + * This class of modules must implement SaveToBinary, and have a matching deserializer registered + * as 'runtime.module.loadbinary_'. + */ kBinarySerializable = 0b001, + /*! \brief kRunnable + * we can run the module directly. LLVM/CUDA/JSON runtime, executors (e.g, + * virtual machine) runtimes are runnable. Non-runnable modules, such as CSourceModule, requires a + * few extra steps (e.g,. compilation, link) to make it runnable. + */ kRunnable = 0b010, + /*! \brief kDSOExportable + * we can export the module as DSO. A DSO exportable module (e.g., a + * CSourceModuleNode of type_key 'c') can be incorporated into the final runtime artifact (ie + * shared library) by compilation and/or linking using the external compiler (llvm, nvcc, etc). + * DSO exportable modules must implement SaveToFile. In general, DSO exportable modules are not + * runnable unless there is a special support like JIT for `LLVMModule`. + */ kDSOExportable = 0b100 }; @@ -220,10 +226,12 @@ class TVM_DLL ModuleNode : public Object { * By default, none of the property is set. Derived class can override this function and set its * own property. */ - virtual int GetProperty() const { return 0b000; } + virtual int GetPropertyMask() const { return 0b000; } /*! \brief Returns true if this module is 'DSO exportable'. */ - bool IsDSOExportable() const { return GetProperty() & ModulePropertyMask::kDSOExportable; }; + bool IsDSOExportable() const { + return (GetPropertyMask() & ModulePropertyMask::kDSOExportable) != 0; + } /*! * \brief Returns true if this module has a definition for a function of \p name. If diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index b51a7a95f4b6..4c24d7deadaa 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -67,7 +67,7 @@ class TVM_DLL Executable : public ModuleNode { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kBinarySerializable; }; + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; /*! * \brief Write the Executable to the binary stream in serialized form. diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 72c43b25af59..52d80d3fea48 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -179,7 +179,7 @@ class TVM_DLL VirtualMachine : public runtime::ModuleNode { virtual void LoadExecutable(const ObjectPtr& exec); /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } protected: /*! \brief Push a call frame on to the call stack. */ diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 83b436939e9f..c643308ead24 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -96,6 +96,14 @@ def __str__(self): ) +class ModulePropertyMask(object): + """Runtime Module Property Mask.""" + + BINARY_SERIALIZABLE = 0b001 + RUNNABLE = 0b010 + DSO_EXPORTABLE = 0b100 + + class Module(object): """Runtime Module.""" @@ -239,6 +247,40 @@ def imported_modules(self): nmod = _ffi_api.ModuleImportsSize(self) return [_ffi_api.ModuleGetImport(self, i) for i in range(nmod)] + def get_property_mask(self): + """Get the runtime module property mask. The mapping is stated in ModulePropertyMask. + + Returns + ------- + mask : int + Bitmask of runtime module property + """ + return _ffi_api.ModuleGetPropertyMask(self) + + @property + def is_binary_serializable(self): + """Returns true if module is 'binary serializable', ie can be serialzed into binary + stream and loaded back to the runtime module. + + Returns + ------- + b : Bool + True if the module is binary serializable. + """ + return (self.get_property_mask() & ModulePropertyMask.BINARY_SERIALIZABLE) != 0 + + @property + def is_runnable(self): + """Returns true if module is 'runnable'. ie can be executed without any extra + compilation/linking steps. + + Returns + ------- + b : Bool + True if the module is runnable. + """ + return (self.get_property_mask() & ModulePropertyMask.RUNNABLE) != 0 + @property def is_dso_exportable(self): """Returns true if module is 'DSO exportable', ie can be included in result of @@ -249,7 +291,7 @@ def is_dso_exportable(self): b : Bool True if the module is DSO exportable. """ - return _ffi_api.ModuleIsDSOExportable(self) + return (self.get_property_mask() & ModulePropertyMask.DSO_EXPORTABLE) != 0 def save(self, file_name, fmt=""): """Save the module to file. @@ -383,6 +425,8 @@ def _collect_from_import_tree(self, filter_func): stack.append(self) while stack: module = stack.pop() + assert module.is_dso_exportable or module.is_binary_serializable + if filter_func(module): dso_modules.append(module) for m in module.imported_modules: diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index c565ed2248d7..8f7098c24aea 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1357,9 +1357,9 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { } const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } - + /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return runtime::ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kRunnable; } private: void init(void* mod, const Array& targets) { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 7420b133fdea..856a7700784a 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -284,7 +284,7 @@ class RelayBuildModule : public runtime::ModuleNode { const char* type_key() const final { return "RelayBuildModule"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return runtime::ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kRunnable; } /*! * \brief Build relay IRModule for graph executor diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc index e82b67417226..a2662b9018bf 100644 --- a/src/relay/backend/contrib/ethosu/source_module.cc +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -122,7 +122,7 @@ class EthosUModuleNode : public ModuleNode { } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const { return ModulePropertyMask::kDSOExportable; } + int GetPropertyMask() const { return ModulePropertyMask::kDSOExportable; } bool ImplementsFunction(const String& name, bool query_imports) final { return std::find_if(compilation_artifacts_.begin(), compilation_artifacts_.end(), diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 1da8e142ee98..d8ce0e59b167 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -688,7 +688,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { const char* type_key() const final { return "RelayGraphExecutorCodegenModule"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return runtime::ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kRunnable; } private: std::shared_ptr codegen_; diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 81666882772e..8a5faa40b55a 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -94,7 +94,7 @@ class VMCompiler : public runtime::ModuleNode { const char* type_key() const final { return "VMCompiler"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } /*! * \brief Set the parameters diff --git a/src/relay/printer/model_library_format_printer.cc b/src/relay/printer/model_library_format_printer.cc index 0620590988f6..994b3ae09c6e 100644 --- a/src/relay/printer/model_library_format_printer.cc +++ b/src/relay/printer/model_library_format_printer.cc @@ -38,7 +38,7 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { const char* type_key() const final { return "model_library_format_printer"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return runtime::ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kRunnable; } std::string Print(const ObjectRef& node) { std::ostringstream oss; diff --git a/src/runtime/aot_executor/aot_executor.h b/src/runtime/aot_executor/aot_executor.h index a72b219d77e4..ab30ab80269e 100644 --- a/src/runtime/aot_executor/aot_executor.h +++ b/src/runtime/aot_executor/aot_executor.h @@ -52,7 +52,7 @@ class TVM_DLL AotExecutor : public ModuleNode { const char* type_key() const final { return "AotExecutor"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } void Run(); diff --git a/src/runtime/aot_executor/aot_executor_factory.h b/src/runtime/aot_executor/aot_executor_factory.h index 15bfb2c7b7f4..4c6e36fc1186 100644 --- a/src/runtime/aot_executor/aot_executor_factory.h +++ b/src/runtime/aot_executor/aot_executor_factory.h @@ -66,7 +66,7 @@ class TVM_DLL AotExecutorFactory : public runtime::ModuleNode { const char* type_key() const final { return "AotExecutorFactory"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kBinarySerializable; } + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; } /*! * \brief Save the module to binary stream. diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index a54d217dd58e..75e094a63a6e 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -101,7 +101,7 @@ class ConstLoaderModuleNode : public ModuleNode { const char* type_key() const final { return "const_loader"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kBinarySerializable; }; + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; /*! * \brief Get the list of constants that is required by the given module. diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h index 4be33cc87d9f..80706425ba09 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.h +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -106,7 +106,9 @@ class CoreMLRuntime : public ModuleNode { virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; }; + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } /*! * \brief Serialize the content of the mlmodelc directory and save it to diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 8df010f64280..51ce2cffd780 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -59,7 +59,9 @@ class JSONRuntimeBase : public ModuleNode { const char* type_key() const override { return "json"; } // May be overridden /*! \brief Get the property of the runtime module .*/ - int GetProperty() const { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } + int GetPropertyMask() const { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } /*! \brief Initialize a specific json runtime. */ virtual void Init(const Array& consts) = 0; diff --git a/src/runtime/contrib/libtorch/libtorch_runtime.cc b/src/runtime/contrib/libtorch/libtorch_runtime.cc index ff6a1a778dd1..48ccfc749674 100644 --- a/src/runtime/contrib/libtorch/libtorch_runtime.cc +++ b/src/runtime/contrib/libtorch/libtorch_runtime.cc @@ -89,7 +89,9 @@ class TorchModuleNode : public ModuleNode { const char* type_key() const { return "torch"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; }; + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } /*! * \brief Get a packed function. diff --git a/src/runtime/contrib/onnx/onnx_module.cc b/src/runtime/contrib/onnx/onnx_module.cc index 231f97cf0f8b..384a368e287e 100644 --- a/src/runtime/contrib/onnx/onnx_module.cc +++ b/src/runtime/contrib/onnx/onnx_module.cc @@ -36,7 +36,7 @@ class ONNXSourceModuleNode : public runtime::ModuleNode { const char* type_key() const { return "onnx"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kRunnable; }; + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; }; PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "get_symbol") { diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index dcd9aeb93a9b..e1f205e22f10 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -98,7 +98,7 @@ class TensorRTRuntime : public JSONRuntimeBase { const char* type_key() const final { return "tensorrt"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index c6d9ef572607..2a524479593a 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -61,7 +61,7 @@ class TFLiteRuntime : public ModuleNode { const char* type_key() const { return "TFLiteRuntime"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kRunnable; }; + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; }; /*! * \brief Invoke the internal tflite interpreter and run the whole model in diff --git a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h index 7366b0bf6e96..ccaa88c1ac42 100755 --- a/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h +++ b/src/runtime/contrib/vitis_ai/vitis_ai_runtime.h @@ -87,7 +87,7 @@ class VitisAIRuntime : public ModuleNode { const char* type_key() const { return "VitisAIRuntime"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; }; diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 8a4efd56aad9..240e1fe1aa7a 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -66,7 +66,7 @@ class CUDAModuleNode : public runtime::ModuleNode { const char* type_key() const final { return "cuda"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index b41211bf5255..0a7086c9f125 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -90,7 +90,7 @@ class TVM_DLL GraphExecutor : public ModuleNode { void Run(); /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } /*! * \brief Initialize the graph executor with graph and device. diff --git a/src/runtime/graph_executor/graph_executor_factory.h b/src/runtime/graph_executor/graph_executor_factory.h index 937e13524c82..2766dfafc29d 100644 --- a/src/runtime/graph_executor/graph_executor_factory.h +++ b/src/runtime/graph_executor/graph_executor_factory.h @@ -68,7 +68,7 @@ class TVM_DLL GraphExecutorFactory : public runtime::ModuleNode { const char* type_key() const final { return "GraphExecutorFactory"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kBinarySerializable; } + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; } /*! * \brief Save the module to binary stream. diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index 2c8228cf1a53..f5d1d0200c03 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -63,9 +63,9 @@ class HexagonModuleNode : public runtime::ModuleNode { std::string GetSource(const std::string& format) override; const char* type_key() const final { return "hexagon"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const { + int GetPropertyMask() const { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; - }; + } void SaveToFile(const std::string& file_name, const std::string& format) override; void SaveToBinary(dmlc::Stream* stream) override; diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 47ee2a94d7d7..eed41dfc2b99 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -41,9 +41,9 @@ class LibraryModuleNode final : public ModuleNode { : lib_(lib), packed_func_wrapper_(wrapper) {} const char* type_key() const final { return "library"; } - + /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; }; diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc index d6bc80820267..946ebf1232d2 100644 --- a/src/runtime/metadata.cc +++ b/src/runtime/metadata.cc @@ -89,7 +89,7 @@ class MetadataModuleNode : public ::tvm::runtime::ModuleNode { const char* type_key() const final { return "metadata_module"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kBinarySerializable; } + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; } static Module LoadFromBinary() { return Module(make_object(runtime::metadata::Metadata())); diff --git a/src/runtime/module.cc b/src/runtime/module.cc index fd19ef5e26e1..298fd588d5e1 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -197,8 +197,8 @@ TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFro TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") .set_body_typed([](Module mod, String name, tvm::String fmt) { mod->SaveToFile(name, fmt); }); -TVM_REGISTER_GLOBAL("runtime.ModuleIsDSOExportable").set_body_typed([](Module mod) { - return mod->IsDSOExportable(); +TVM_REGISTER_GLOBAL("runtime.ModuleGetPropertyMask").set_body_typed([](Module mod) { + return mod->GetPropertyMask(); }); TVM_REGISTER_GLOBAL("runtime.ModuleImplementsFunction") diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 9f835fafcf78..a8a4cf3dc65c 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -432,7 +432,7 @@ class OpenCLModuleNode : public ModuleNode { const char* type_key() const final { return workspace_->type_key.c_str(); } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { + int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index f7ba301f822d..ed769d97ab36 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -176,7 +176,7 @@ class RPCModuleNode final : public ModuleNode { const char* type_key() const final { return "rpc"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return ModulePropertyMask::kRunnable; } + int GetPropertyMask() const final { return ModulePropertyMask::kRunnable; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "CloseRPCConnection") { diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index d0553534fdcc..46038c2defab 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -64,7 +64,7 @@ class StaticLibraryNode final : public runtime::ModuleNode { // TODO(tvm-team): Make this module serializable /*! \brief Get the property of the runtime module .*/ - int GetProperty() const { return ModulePropertyMask::kDSOExportable; } + int GetPropertyMask() const { return ModulePropertyMask::kDSOExportable; } bool ImplementsFunction(const String& name, bool query_imports) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 367989dc53d8..cd52694e2c50 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -95,9 +95,9 @@ class LLVMModuleNode final : public runtime::ModuleNode { /*! \brief Get the property of the runtime module .*/ // TODO(tvm-team): Make it serializable - int GetProperty() const { + int GetPropertyMask() const { return runtime::ModulePropertyMask::kRunnable | runtime::ModulePropertyMask::kDSOExportable; - }; + } void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 902077f69188..fd770007e243 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -490,7 +490,7 @@ class WebGPUSourceModuleNode final : public runtime::ModuleNode { const char* type_key() const final { return "webgpu"; } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return runtime::ModulePropertyMask::kBinarySerializable; } + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kBinarySerializable; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 86eb228be46c..9ec9dbbbfedb 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -130,7 +130,7 @@ class CSourceModuleNode : public runtime::ModuleNode { } } - int GetProperty() const { return runtime::ModulePropertyMask::kDSOExportable; } + int GetPropertyMask() const { return runtime::ModulePropertyMask::kDSOExportable; } bool ImplementsFunction(const String& name, bool query_imports) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); @@ -200,7 +200,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } } - int GetProperty() const { return runtime::ModulePropertyMask::kDSOExportable; } + int GetPropertyMask() const { return runtime::ModulePropertyMask::kDSOExportable; } bool ImplementsFunction(const String& name, bool query_imports) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); @@ -1010,7 +1010,7 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { const char* type_key() const final { return type_key_.c_str(); } /*! \brief Get the property of the runtime module .*/ - int GetProperty() const final { return runtime::ModulePropertyMask::kBinarySerializable; } + int GetPropertyMask() const final { return runtime::ModulePropertyMask::kBinarySerializable; } void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); diff --git a/tests/python/unittest/test_runtime_module_property.py b/tests/python/unittest/test_runtime_module_property.py new file mode 100644 index 000000000000..30af8d086a42 --- /dev/null +++ b/tests/python/unittest/test_runtime_module_property.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +import tvm.runtime._ffi_api +import tvm.target._ffi_api + + +def checker(mod, expected): + assert mod.is_binary_serializable == expected["is_binary_serializable"] + assert mod.is_runnable == expected["is_runnable"] + assert mod.is_dso_exportable == expected["is_dso_exportable"] + + +def create_csource_module(): + return tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], None) + + +def create_llvm_module(): + A = te.placeholder((1024,), name="A") + B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") + s = te.create_schedule(B.op) + return tvm.build(s, [A, B], "llvm", name="myadd0") + + +def create_aot_module(): + return tvm.get_global_func("relay.build_module._AOTExecutorCodegen")() + + +def test_property(): + checker( + create_csource_module(), + expected={"is_binary_serializable": False, "is_runnable": False, "is_dso_exportable": True}, + ) + + checker( + create_llvm_module(), + expected={"is_binary_serializable": False, "is_runnable": True, "is_dso_exportable": True}, + ) + + checker( + create_aot_module(), + expected={"is_binary_serializable": False, "is_runnable": True, "is_dso_exportable": False}, + ) + + +if __name__ == "__main__": + tvm.testing.main()