From a6f37a2bd543921ba3030a605e22f42980a33087 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Thu, 17 Oct 2019 13:25:08 -0700 Subject: [PATCH 01/59] [relay][vm] Separate VM runtime with executable (#4100) * [relay][vm] Separate VM runtime with executable * Address comments * move ctx back to vm * make only vm related fields and methods protected * integrate seriliaztion/deserialization to executable * create stream --- include/tvm/runtime/vm.h | 210 ++++- python/tvm/relay/__init__.py | 2 - python/tvm/relay/backend/deserializer.py | 81 -- python/tvm/relay/backend/profiler_vm.py | 12 +- python/tvm/relay/backend/serializer.py | 191 ----- python/tvm/relay/backend/vm.py | 232 +++++- src/relay/backend/vm/compiler.cc | 20 +- src/relay/backend/vm/compiler.h | 12 +- src/relay/backend/vm/deserializer.cc | 324 -------- src/relay/backend/vm/deserializer.h | 102 --- src/relay/backend/vm/profiler/compiler.cc | 1 - src/relay/backend/vm/serializer.cc | 439 ----------- src/relay/backend/vm/serializer.h | 202 ----- src/runtime/vm/executable.cc | 734 ++++++++++++++++++ src/runtime/vm/profiler/vm.cc | 29 +- src/runtime/vm/profiler/vm.h | 2 + .../backend => runtime}/vm/serialize_util.h | 12 +- src/runtime/vm/vm.cc | 92 +-- tests/python/relay/test_vm.py | 30 +- tests/python/relay/test_vm_serialization.py | 119 ++- .../unittest/test_runtime_vm_profiler.py | 4 +- 21 files changed, 1285 insertions(+), 1565 deletions(-) delete mode 100644 python/tvm/relay/backend/deserializer.py delete mode 100644 python/tvm/relay/backend/serializer.py delete mode 100644 src/relay/backend/vm/deserializer.cc delete mode 100644 src/relay/backend/vm/deserializer.h delete mode 100644 src/relay/backend/vm/serializer.cc delete mode 100644 src/relay/backend/vm/serializer.h create mode 100644 src/runtime/vm/executable.cc rename src/{relay/backend => runtime}/vm/serialize_util.h (95%) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index aa8543d569af..a276c658c496 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -430,15 +431,184 @@ struct VMFrame { caller_return_register(0) {} }; +/*! \brief The executable emitted by the VM compiler. + * + * The executable contains information (e.g. data in different memory regions) + * to run in a virtual machine. + * + * - Global section, containing all globals. + * - Constant section, storing the constant pool. + * - Primitive name section, containing the function name of the primitive ops + * used by the virtual machine. + * - Code section, handling the VM functions and bytecode. + */ +class Executable : public ModuleNode { + public: + /*! + * \brief Get a PackedFunc from an executable module. + * + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * + * \return PackedFunc or nullptr when it is not available. + */ + PackedFunc GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) final; + + /*! + * \brief Serialize the executable into global section, constant section, and + * code section. + * + * \return The binary representation of the VM. + */ + TVMByteArray Save(); + + /*! + * \brief Load the saved VM executable. + * + * \param code The bytecode in string. + * \param lib The compiled runtime library. + * + * \return exe The constructed executable. + */ + static runtime::Module Load(const std::string& code, const runtime::Module lib); + + /*! + * \brief Get the serialized form of the `functions`. This is + * essentially bytecode serialization. + * + * \return The serialized vm bytecode. + * + * \note The bytecode is in the following format: + * func_name reg_file_size num_instructions + * param1 param2 ... paramM + * instruction1 + * instruction2 + * ... + * instructionN + * + * Each instruction is printed in the following format: + * opcode num_fields field1 ... fieldX # The text format. + * + * Serializing an `Instruction` requires us to deal with the bytecode. Each line + * of the instructions could be serialized as the following format: + * hash, opcode, f1, f2, ..., fX, field with variable length + * 1. hash: the hash of the instruction. This number will be used to help us + * validate if an instruction is well-formed during deserialization. + * 2. opcode: the opcode code of the instruction. + * 3. f1, f2, ..., fX. These fields together represent the fixed fields in + * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For + * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). + * 4. The rest of the line indicates the field with variable length, e.g., + * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. + + * The field starting from # is only used for debugging. The serialized code + * doesn't contain it, therefore the deserializer doens't need to handle it. + */ + std::string GetBytecode() const; + +/*! + * \brief Print the detailed statistics of the given code, i.e. number of + * globls and constants, etc. + */ + std::string Stats() const; + + /*! \brief Get the `lib` module in an executable. Users have the flexibility to call + * `export_library` from the frontend to save the library to disk. + * + * \return The runtime module that contains the hardwre dependent code. + */ + runtime::Module GetLib() const { return lib; } + + virtual ~Executable() {} + + const char* type_key() const final { + return "VMExecutable"; + } + + /*! \brief The runtime module/library that contains both the host and also the device + * code when executing on non-CPU devices. */ + runtime::Module lib; + /*! \brief The global constant pool. */ + std::vector constants; + /*! \brief A map from globals (as strings) to their index in the function map. */ + std::unordered_map global_map; + /*! \brief A mapping from the packed function (as string) to the index that + * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object. + */ + std::unordered_map primitive_map; + /*! \brief The virtual machine's function table. */ + std::vector functions; + + private: + /*! + * \brief Save the globals. + * + * \param strm The input stream. + */ + void SaveGlobalSection(dmlc::Stream* strm); + + /*! + * \brief Save the constant pool. + * + * \param strm The input stream. + */ + void SaveConstantSection(dmlc::Stream* strm); + + /*! + * \brief Save primitive op names. + * + * \param strm The input stream. + */ + void SavePrimitiveOpNames(dmlc::Stream* strm); + + /*! + * \brief Save the vm functions. + * + * \param strm The input stream. + */ + void SaveCodeSection(dmlc::Stream* strm); + + /*! + * \brief Load the globals. + * + * \param strm The input stream. + */ + void LoadGlobalSection(dmlc::Stream* strm); + + /*! + * \brief Load the constant pool. + * + * \param strm The input stream. + */ + void LoadConstantSection(dmlc::Stream* strm); + + /*! + * \brief Load primitive op names. + * + * \param strm The input stream. + */ + void LoadPrimitiveOpNames(dmlc::Stream* strm); + + /*! + * \brief Load the vm functions. + * + * \param strm The input stream. + */ + void LoadCodeSection(dmlc::Stream* strm); + + /*! \brief The serialized bytecode. */ + std::string code_; +}; + /*! \brief The virtual machine. * * The virtual machine contains all the current execution state, - * as well as the global view of functions, the global constant - * table, the compiled operators. + * as well as the executable. * * The goal is to have a single self-contained object, * enabling one to easily pass around VMs, execute them on - * multiple threads, or serialized them to disk or over the + * multiple threads, or serialize them to disk or over the * wire. */ class VirtualMachine : public runtime::ModuleNode { @@ -486,16 +656,18 @@ class VirtualMachine : public runtime::ModuleNode { return "VirtualMachine"; } - /*! \brief The runtime module/library that contains generated code. */ - runtime::Module lib; + VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {} + + /*! \brief load the executable for the virtual machine. + * \param exec The executable. + */ + void LoadExecutable(const Executable* exec); + + protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs; - /*! \brief The virtual machine's function table. */ - std::vector functions; /*! \brief The current stack of call frames. */ std::vector frames; - /*! \brief The global constant pool. */ - std::vector constants; /*! \brief The fuction table index of the current function. */ Index func_index; /*! \brief The current pointer to the code section. */ @@ -506,6 +678,9 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief The special return register. */ ObjectRef return_register; + /*! \brief The executable the VM will operate on. */ + const Executable* exec; + /*! \brief The set of TVM contexts the VM is currently executing on. */ std::vector ctxs; @@ -550,8 +725,6 @@ class VirtualMachine : public runtime::ModuleNode { */ ObjectRef Invoke(const std::string& name, const std::vector& args); - VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {} - /*! \brief Initialize the virtual machine for a set of contexts. * \param contexts The set of TVM contexts. */ @@ -565,21 +738,6 @@ class VirtualMachine : public runtime::ModuleNode { */ TVMContext GetParamsContext() const; - /*! - * \brief Load parameters from the parameter bytearray. - * \param params The binary file that contains parameters. - */ - void LoadParams(const std::string& params); - - /*! \brief A map from globals (as strings) to their index in the function map. - */ - std::unordered_map global_map; - - /*! \brief A mapping from the packed function (as string) to the index that - * corresponds to the position of the `packed_funcs` list. - */ - std::unordered_map primitive_map; - private: /*! \brief Invoke a global setting up the VM state to execute. * diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index ceb98c4d251e..fff9c99e5007 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -37,8 +37,6 @@ from . import feature from .backend import vm from .backend import profiler_vm -from .backend import serializer -from .backend import deserializer from .backend import vmobj # Root operators diff --git a/python/tvm/relay/backend/deserializer.py b/python/tvm/relay/backend/deserializer.py deleted file mode 100644 index fde702b1cd04..000000000000 --- a/python/tvm/relay/backend/deserializer.py +++ /dev/null @@ -1,81 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -""" -The Relay Virtual Machine deserializer. - -Python interface for deserializing a Relay VM. -""" -from tvm import module -from tvm._ffi.runtime_ctypes import TVMByteArray -from . import _vm -from . import vm as rly_vm - -def _create_deserializer(code, lib): - """Create a deserializer object. - - Parameters - ---------- - code : bytearray - The serialized virtual machine code. - - lib : :py:class:`~tvm.module.Module` - The serialized runtime module/library that contains the hardware - dependent binary code. - - Returns - ------- - ret : Deserializer - The created virtual machine deserializer. - """ - if isinstance(code, (bytes, str)): - code = bytearray(code) - elif not isinstance(code, (bytearray, TVMByteArray)): - raise TypeError("vm is expected to be the type of bytearray or " + - "TVMByteArray, but received {}".format(type(code))) - - if not isinstance(lib, module.Module): - raise TypeError("lib is expected to be the type of tvm.module.Module" + - ", but received {}".format(type(lib))) - return _vm._Deserializer(code, lib) - - -class Deserializer: - """Relay VM deserializer. - - Parameters - ---------- - code : bytearray - The serialized virtual machine code. - - lib : :py:class:`~tvm.module.Module` - The serialized runtime module/library that contains the hardware - dependent binary code. - """ - def __init__(self, code, lib): - self.mod = _create_deserializer(code, lib) - self._deserialize = self.mod["deserialize"] - - def deserialize(self): - """Deserialize the serialized bytecode into a Relay VM. - - Returns - ------- - ret : VirtualMachine - The deserialized Relay VM. - """ - return rly_vm.VirtualMachine(self._deserialize()) diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index 8ae3161e0b83..b36715249f0a 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None): Returns ------- - vm : VirtualMachineProfiler - The profile VM runtime. + exec : Executable + The executable with profiling code. """ compiler = VMCompilerProfiler() target = compiler.update_target(target) @@ -60,7 +60,7 @@ def compile(mod, target=None, target_host=None, params=None): tophub_context = compiler.tophub_context(target) with tophub_context: compiler._compile(mod, target, target_host) - return VirtualMachineProfiler(compiler._get_vm()) + return vm.Executable(compiler._get_exec()) class VMCompilerProfiler(vm.VMCompiler): """Build Relay module to run on VM runtime.""" @@ -68,13 +68,17 @@ def __init__(self): super().__init__() self.mod = _vm._VMCompilerProfiler() self._compile = self.mod["compile"] - self._get_vm = self.mod["get_vm"] + self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] class VirtualMachineProfiler(vm.VirtualMachine): """Relay profile VM runtime.""" def __init__(self, mod): super().__init__(mod) + m = mod.module if isinstance(mod, vm.Executable) else mod + self.mod = _vm._VirtualMachineDebug(m) + self._init = self.mod["init"] + self._invoke = self.mod["invoke"] self._get_stat = self.mod["get_stat"] def get_stat(self): diff --git a/python/tvm/relay/backend/serializer.py b/python/tvm/relay/backend/serializer.py deleted file mode 100644 index b45ba9116a15..000000000000 --- a/python/tvm/relay/backend/serializer.py +++ /dev/null @@ -1,191 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -""" -The Relay Virtual Machine serializer. - -Python interface for serializing a Relay VM. -""" -import tvm -from . import _vm -from . import vm as rly_vm - -def _create_serializer(vm): - """Create a VM serializer. - - Parameters - ---------- - vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`] - The virtual machine to be serialized. - - Returns - ------- - ret : Serializer - The created virtual machine serializer. - """ - if isinstance(vm, rly_vm.VirtualMachine): - vm = vm.module - elif not isinstance(vm, tvm.module.Module): - raise TypeError("vm is expected to be the type of VirtualMachine or " + - "tvm.Module, but received {}".format(type(vm))) - - return _vm._Serializer(vm) - - -class Serializer: - """Relay VM serializer.""" - def __init__(self, vm): - self.mod = _create_serializer(vm) - self._get_lib = self.mod["get_lib"] - self._get_bytecode = self.mod["get_bytecode"] - self._get_globals = self.mod["get_globals"] - self._get_stats = self.mod["get_stats"] - self._get_primitive_ops = self.mod["get_primitive_ops"] - self._serialize = self.mod["serialize"] - - @property - def stats(self): - """Get the statistics of the Relay VM. - - Returns - ------- - ret : String - The serialized statistic information. - """ - return self._get_stats() - - @property - def primitive_ops(self): - """Get the name of the primitive ops that are executed in the VM. - - Returns - ------- - ret : List[:py:class:`~tvm.expr.StringImm`] - The list of primitive ops. - """ - return [prim_op.value for prim_op in self._get_primitive_ops()] - - @property - def bytecode(self): - """Get the bytecode of the Relay VM. - - Returns - ------- - ret : String - The serialized bytecode. - - Notes - ----- - The bytecode is in the following format: - func_name reg_file_size num_instructions - param1 param2 ... paramM - instruction1 - instruction2 - ... - instructionN - - Each instruction is printed in the following format: - hash opcode field1 ... fieldX # The text format. - - The part starting from # is only used for visualization and debugging. - The real serialized code doesn't contain it, therefore the deserializer - doesn't need to deal with it as well. - """ - return self._get_bytecode() - - @property - def globals(self): - """Get the globals used by the Relay VM. - - Returns - ------- - ret : List[:py:class:`~tvm.expr.StringImm`] - The serialized globals. - """ - return [glb.value for glb in self._get_globals()] - - def serialize(self): - """Serialize the Relay VM. - - Returns - ------- - code : bytearray - The binary blob representing a serialized Relay VM. It can then be - saved to disk and later deserialized into a new VM. - - lib : :py:class:`~tvm.module.Module` - The runtime module that contains the generated code. It is - basically a library that is composed of hardware dependent code. - - Notes - ----- - The returned code is organized with the following sections in order. - - Global section. This section contains the globals used by the - virtual machine. - - Constant section. This section is used to store the constant pool of - a virtual machine. - - Primitive name section. This section is introduced to accommodate - the list of primitive operator names that will be invoked by the - virtual machine. - - Code section. The VM functions, including bytecode, are sitting in - this section. - - Examples - -------- - .. code-block:: python - - import numpy as np - import tvm - from tvm import relay - - # define a simple network. - x = relay.var('x', shape=(10, 10)) - f = relay.Function([x], x + x) - mod = relay.Module({"main": f}) - - # create a Relay VM. - ctx = tvm.cpu() - target = "llvm" - compiler = relay.vm.VMCompiler() - vm = compiler.compile(mod, target) - vm.init(ctx) - - # serialize. - ser = relay.serializer.Serializer(vm) - code, lib = ser.serialize() - - # save and load the code and lib file. - tmp = tvm.contrib.util.tempdir() - path_lib = tmp.relpath("lib.so") - lib.export_library(path_lib) - with open(tmp.relpath("code.bc"), "wb") as fo: - fo.write(code) - - loaded_lib = tvm.module.load(path_lib) - loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) - - # deserialize. - deser = relay.deserializer.Deserializer(loaded_code, loaded_lib) - des_vm = deser.deserialize() - - # execute the deserialized vm. - des_vm.init(ctx) - x_data = np.random.rand(10, 10).astype('float32') - res = des_vm.run(x_data) - print(res.asnumpy()) - """ - return self._serialize(), self._get_lib() diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index c24b16ca6437..942c93b866f4 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -24,8 +24,8 @@ import tvm from tvm import autotvm -from tvm._ffi.runtime_ctypes import TVMByteArray from tvm.relay import expr as _expr +from tvm._ffi.runtime_ctypes import TVMByteArray from . import _vm from . import vmobj as _obj from .interpreter import Executor @@ -44,6 +44,7 @@ def _convert(arg, cargs): else: raise "unsupported type" + def convert(args): cargs = [] for arg in args: @@ -52,12 +53,202 @@ def convert(args): return cargs +class Executable(object): + """Relay VM executable""" + def __init__(self, mod): + self.mod = mod + self._save = self.mod["save"] + self._get_lib = self.mod["get_lib"] + self._get_bytecode = self.mod["get_bytecode"] + self._get_stats = self.mod["get_stats"] + + def save(self): + """Save the Relay VM Executable. + + Returns + ------- + code : bytearray + The binary blob representing a serialized Relay VM executable. It + can then be saved to disk and later deserialized into a new + Executable. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. It is + basically a library that is composed of hardware dependent code. + + Notes + ----- + The returned code is organized with the following sections in order. + - Global section. This section contains the globals used by the + virtual machine. + - Constant section. This section is used to store the constant pool of + a virtual machine. + - Primitive name section. This section is introduced to accommodate + the list of primitive operator names that will be invoked by the + virtual machine. + - Code section. The VM functions, including bytecode, are sitting in + this section. + + Examples + -------- + + .. code-block:: python + + import numpy as np + import tvm + from tvm import relay + # define a simple network. + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x + x) + mod = relay.Module({"main": f}) + # create a Relay VM. + ctx = tvm.cpu() + target = "llvm" + executable = relay.vm.compile(mod, target) + code, lib = executable.save() + # save and load the code and lib file. + tmp = tvm.contrib.util.tempdir() + path_lib = tmp.relpath("lib.so") + lib.export_library(path_lib) + with open(tmp.relpath("code.ro"), "wb") as fo: + fo.write(code) + loaded_lib = tvm.module.load(path_lib) + loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) + # deserialize. + des_exec = relay.vm.Executable.load_exec(loaded_code, loaded_code) + # execute the deserialized executable. + x_data = np.random.rand(10, 10).astype('float32') + des_vm = relay.vm.VirtualMachine(des_exec) + des_vm.init(ctx) + res = des_vm.run(x_data) + print(res.asnumpy()) + """ + return self._save(), self._get_lib() + + @staticmethod + def load_exec(bytecode, lib): + """Construct an executable from saved artifacts. + + Parameters + ---------- + bytecode : bytearray + The binary blob representing a the Relay VM bytecode. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. + + Returns + ------- + exec: Executable + An executable constructed using the provided artifacts. + """ + if isinstance(bytecode, (bytes, str)): + code = bytearray(bytecode) + elif not isinstance(bytecode, (bytearray, TVMByteArray)): + raise TypeError("bytecode is expected to be the type of bytearray " + + "or TVMByteArray, but received {}".format(type(code))) + + if not isinstance(lib, tvm.module.Module): + raise TypeError("lib is expected to be the type of tvm.module.Module" + + ", but received {}".format(type(lib))) + + return Executable(_vm.Load_Executable(bytecode, lib)) + + @property + def lib(self): + """Get the library that contains hardware dependent code. + + Returns + ------- + ret : :py:class:`~tvm.Module` + The runtime module that contains hardware dependent code. + """ + return self._get_lib() + + @property + def stats(self): + """Get the statistics of the Relay VM executable. + + Returns + ------- + ret : String + The statistic information of the VM executable. + """ + return self._get_stats() + + @property + def primitive_ops(self): + """Get the name of the primitive ops contained in the executable. + + Returns + ------- + ret : List[String] + The list of primitive ops. + """ + ret = [] + num_primitives = _vm.GetNumOfPrimitives(self.module) + for i in range(num_primitives): + ret.append(_vm.GetPrimitiveFields(self.module, i)) + return ret + + @property + def bytecode(self): + """Get the bytecode of the Relay VM executable. + + Returns + ------- + ret : String + The bytecode of the executable. + + Notes + ----- + The bytecode is in the following format: + func_name reg_file_size num_instructions + param1 param2 ... paramM + instruction1 + instruction2 + ... + instructionN + + Each instruction is printed in the following format: + hash opcode field1 ... fieldX # The text format. + + The part starting from # is only used for visualization and debugging. + The real serialized code doesn't contain it, therefore the deserializer + doesn't need to deal with it as well. + """ + return self._get_bytecode() + + @property + def globals(self): + """Get the globals used by the Relay VM executable. + + Returns + ------- + ret : List[String] + The globals contained in the executable. + """ + ret = [] + num_globals = _vm.GetNumOfGlobals(self.module) + for i in range(num_globals): + ret.append(_vm.GetGlobalFields(self.module, i)) + return ret + + @property + def module(self): + """Return the runtime module contained in a virtual machine executable.""" + return self.mod + + class VirtualMachine(object): """Relay VM runtime.""" def __init__(self, mod): - self.mod = mod + if not isinstance(mod, (Executable, tvm.module.Module)): + raise TypeError("mod is expected to be the type of Executable or " + + "tvm.Module, but received {}".format(type(mod))) + m = mod.module if isinstance(mod, Executable) else mod + self.mod = _vm._VirtualMachine(m) self._init = self.mod["init"] - self._load_params = self.mod["load_params"] self._invoke = self.mod["invoke"] def init(self, ctx): @@ -71,23 +262,6 @@ def init(self, ctx): args = [ctx.device_type, ctx.device_id] self._init(*args) - def load_params(self, params): - """Load parameters for the VM. - - Parameters - ---------- - params : Union[bytearray, Dict] - The dictionary that contains serialized parameters. - """ - if isinstance(params, dict): - params = tvm.relay.save_param_dict(params) - elif isinstance(params, (bytes, str)): - params = bytearray(params) - if not isinstance(params, (bytearray, TVMByteArray)): - raise TypeError("params must be a bytearray") - - self._load_params(bytearray(params)) - def invoke(self, func_name, *args): """Invoke a function. @@ -122,11 +296,6 @@ def run(self, *args): """ return self.invoke("main", *args) - @property - def module(self): - """Return the runtime module contained in a virtual machine.""" - return self.mod - def compile(mod, target=None, target_host=None, params=None): """ @@ -155,8 +324,8 @@ def compile(mod, target=None, target_host=None, params=None): Returns ------- - vm : VirtualMachine - The VM runtime. + exec : Executable + The VM executable that contains both library code and bytecode. """ compiler = VMCompiler() @@ -167,14 +336,14 @@ def compile(mod, target=None, target_host=None, params=None): tophub_context = compiler.tophub_context(target) with tophub_context: compiler._compile(mod, target, target_host) - return VirtualMachine(compiler._get_vm()) + return Executable(compiler._get_exec()) class VMCompiler(object): """Build Relay module to run on VM runtime.""" def __init__(self): self.mod = _vm._VMCompiler() self._compile = self.mod["compile"] - self._get_vm = self.mod["get_vm"] + self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] def set_params(self, params): @@ -240,7 +409,7 @@ class VMExecutor(Executor): mod : :py:class:`~tvm.relay.module.Module` The module to support the execution. - ctx : :py:class:`TVMContext` + ctx : :py:class:`~tvm.TVMContext` The runtime context to run the code on. target : :py:class:`Target` @@ -252,7 +421,8 @@ def __init__(self, mod, ctx, target): self.mod = mod self.ctx = ctx self.target = target - self.vm = compile(mod, target) + self.executable = compile(mod, target) + self.vm = VirtualMachine(self.executable) self.vm.init(ctx) def _make_executor(self, expr=None): diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 0cfae374ab2c..f295ccd7a555 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -783,9 +783,9 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, Module mod = args[0]; this->Compile(mod, args[1], args[2]); }); - } else if (name == "get_vm") { + } else if (name == "get_executable") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = runtime::Module(vm_); + *rv = runtime::Module(exec_); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -864,7 +864,7 @@ void VMCompiler::Compile(Module mod, // Next we get ready by allocating space for // the global state. - vm_->functions.resize(context_.module->functions.size()); + exec_->functions.resize(context_.module->functions.size()); for (auto named_func : context_.module->functions) { auto gvar = named_func.first; @@ -873,25 +873,25 @@ void VMCompiler::Compile(Module mod, auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); - CHECK(func_index < vm_->functions.size()); - vm_->functions[func_index] = vm_func; + CHECK(func_index < exec_->functions.size()); + exec_->functions[func_index] = vm_func; } #if USE_RELAY_DEBUG - for (auto vm_func : vm_->functions) { + for (auto vm_func : exec_->functions) { DLOG(INFO) << vm_func << "-------------"; } #endif // USE_RELAY_DEBUG // populate constants for (auto data : context_.constants) { - vm_->constants.push_back(runtime::vm::Tensor(data)); + exec_->constants.push_back(runtime::vm::Tensor(data)); } LibraryCodegen(); for (auto gv : context_.global_map) { - vm_->global_map.insert({gv.first->name_hint, gv.second}); + exec_->global_map.insert({gv.first->name_hint, gv.second}); } } @@ -987,13 +987,13 @@ void VMCompiler::LibraryCodegen() { // therefore target won't be used in the build function runtime::Module mod = (*f)(funcs, Target(), target_host_); CHECK(mod.operator->()); - vm_->lib = mod; + exec_->lib = mod; } else { LOG(FATAL) << "relay.backend.build is not registered"; } size_t primitive_index = 0; for (auto cfunc : cached_funcs) { - vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); + exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); } } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index dff1ef7f4569..215cc12c4cdb 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -92,12 +92,8 @@ class VMCompiler : public runtime::ModuleNode { return "VMCompiler"; } - std::shared_ptr GetVirtualMachine() const { - return vm_; - } - - virtual void InitVM() { - vm_ = std::make_shared(); + void InitVM() { + exec_ = std::make_shared(); } /*! @@ -144,8 +140,8 @@ class VMCompiler : public runtime::ModuleNode { tvm::Target target_host_; /*! \brief Global shared meta data */ VMCompilerContext context_; - /*! \brief Compiled virtual machine. */ - std::shared_ptr vm_; + /*! \brief Compiled executable. */ + std::shared_ptr exec_; /*! \brief parameters */ std::unordered_map params_; }; diff --git a/src/relay/backend/vm/deserializer.cc b/src/relay/backend/vm/deserializer.cc deleted file mode 100644 index 777282782e99..000000000000 --- a/src/relay/backend/vm/deserializer.cc +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/deserializer.cc - * \brief Implementation of APIs to deserialize the serialized VM bytecode. - */ - -#include "deserializer.h" - -#include -#include -#include - -#include "serialize_util.h" - -namespace tvm { -namespace relay { -namespace vm { - -#define STREAM_CHECK(val, section) \ - CHECK(val) << "Invalid VM file format in the " << section << " section." \ - << "\n"; - -void Deserializer::Init(const std::string& code, const runtime::Module& lib) { - code_ = code; - vm_ = std::make_shared(); - vm_->lib = lib; - strm_ = new dmlc::MemoryStringStream(&code_); -} - -runtime::PackedFunc Deserializer::GetFunction( - const std::string& name, - const std::shared_ptr& sptr_to_self) { - if (name == "deserialize") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Deserialize(); - *rv = runtime::Module(vm_); - }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); - } -} - -void Deserializer::Deserialize() { - // Check header. - uint64_t header; - STREAM_CHECK(strm_->Read(&header), "header"); - STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); - - // Check version. - std::string version; - STREAM_CHECK(strm_->Read(&version), "version"); - STREAM_CHECK(version == TVM_VERSION, "version"); - - // Global section. - DeserializeGlobalSection(); - - // Constant section. - DeserializeConstantSection(); - - // Primitive names that will be invoked by `InvokePacked` instructions. - DeserializePrimitiveOpNames(); - - // Code section. - DeserializeCodeSection(); -} - -void Deserializer::DeserializeGlobalSection() { - std::vector globals; - STREAM_CHECK(strm_->Read(&globals), "global"); - for (size_t i = 0; i < globals.size(); i++) { - vm_->global_map.insert({globals[i], i}); - } -} - -void Deserializer::DeserializeConstantSection() { - uint64_t sz; - // Load the number of constants. - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant"); - - size_t size = static_cast(sz); - // Load each of the constants. - for (size_t i = 0; i < size; i++) { - runtime::NDArray constant; - STREAM_CHECK(constant.Load(strm_), "constant"); - runtime::ObjectRef obj = runtime::vm::Tensor(constant); - vm_->constants.push_back(obj); - } -} - -void Deserializer::DeserializePrimitiveOpNames() { - std::vector primitive_names; - STREAM_CHECK(strm_->Read(&primitive_names), "primitive name"); - for (size_t i = 0; i < primitive_names.size(); i++) { - vm_->primitive_map.insert({primitive_names[i], i}); - } -} - -// Extract the `cnt` number of fields started at `start` from the list -// `instr_fields`. -inline std::vector ExtractFields(const std::vector& instr_fields, - Index start, - Index cnt) { - CHECK_LE(static_cast(start + cnt), instr_fields.size()); - std::vector ret; - for (auto i = start; i < start + cnt; i++) { - ret.push_back(instr_fields[i]); - } - return ret; -} - -Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { - Opcode opcode = static_cast(instr.opcode); - switch (opcode) { - case Opcode::Move: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::Move(instr.fields[0], instr.fields[1]); - } - case Opcode::Ret: { - // Number of fields = 1 - DCHECK_EQ(instr.fields.size(), 1U); - return Instruction::Ret(instr.fields[0]); - } - case Opcode::Fatal: { - // Number of fields = 0 - DCHECK(instr.fields.empty()); - return Instruction::Fatal(); - } - case Opcode::InvokePacked: { - // Number of fields = 3 + instr.arity - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index packed_index = instr.fields[0]; - Index arity = instr.fields[1]; - Index output_size = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, arity); - return Instruction::InvokePacked(packed_index, arity, output_size, args); - } - case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim - DCHECK_GE(instr.fields.size(), 5U); - DCHECK_EQ(instr.fields.size(), 5U + static_cast(instr.fields[3])); - - DLDataType dtype; - dtype.code = instr.fields[0]; - dtype.bits = instr.fields[1]; - dtype.lanes = instr.fields[2]; - - Index ndim = instr.fields[3]; - RegName dst = instr.fields[4]; - - std::vector shape = ExtractFields(instr.fields, 5, ndim); - - return Instruction::AllocTensor(shape, dtype, dst); - } - case Opcode::AllocTensorReg: { - // Number of fields = 5 - DCHECK_EQ(instr.fields.size(), 5U); - Index shape_register = instr.fields[0]; - - DLDataType dtype; - dtype.code = instr.fields[1]; - dtype.bits = instr.fields[2]; - dtype.lanes = instr.fields[3]; - - RegName dst = instr.fields[4]; - - return Instruction::AllocTensorReg(shape_register, dtype, dst); - } - case Opcode::AllocDatatype: { - // Number of fields = 3 + instr.num_fields - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index constructor_tag = instr.fields[0]; - Index num_fields = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector fields = ExtractFields(instr.fields, 3, num_fields); - - return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); - } - case Opcode::AllocClosure: { - // Number of fields = 3 + instr.num_freevar - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index clo_index = instr.fields[0]; - Index num_freevar = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector free_vars = ExtractFields(instr.fields, 3, num_freevar); - - return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); - } - case Opcode::If: { - // Number of fields = 4 - DCHECK_EQ(instr.fields.size(), 4U); - Index test = instr.fields[0]; - Index target = instr.fields[1]; - Index true_offset = instr.fields[2]; - Index false_offset = instr.fields[3]; - - return Instruction::If(test, target, true_offset, false_offset); - } - case Opcode::Invoke: { - // Number of fields = 3 + instr.num_args - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index func_index = instr.fields[0]; - Index num_args = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, num_args); - - return Instruction::Invoke(func_index, args, dst); - } - case Opcode::InvokeClosure: { - // Number of fields = 3 + instr.num_closure_args - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index closure = instr.fields[0]; - Index num_closure_args = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, num_closure_args); - - return Instruction::InvokeClosure(closure, args, dst); - } - case Opcode::LoadConst: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::LoadConst(instr.fields[0], instr.fields[1]); - } - case Opcode::LoadConsti: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::LoadConsti(instr.fields[0], instr.fields[1]); - } - case Opcode::GetField: { - // Number of fields = 3 - DCHECK_EQ(instr.fields.size(), 3U); - return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]); - } - case Opcode::GetTag: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::GetTag(instr.fields[0], instr.fields[1]); - } - case Opcode::Goto: { - // Number of fields = 1 - DCHECK_EQ(instr.fields.size(), 1U); - return Instruction::Goto(instr.fields[0]); - } - default: - LOG(FATAL) << "Invalid opcode" << instr.opcode; - return Instruction(); - } -} - -void Deserializer::DeserializeCodeSection() { - // Load the number of functions. - uint64_t sz; - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code"); - - size_t num_funcs = static_cast(sz); - vm_->functions.resize(num_funcs); - for (size_t i = 0; i < num_funcs; i++) { - // Load the function info. - VMFunctionSerializer loaded_func; - STREAM_CHECK(loaded_func.Load(strm_), "code/function"); - - // Load the instructions. - std::vector instructions; - for (size_t j = 0; j < loaded_func.num_instructions; j++) { - VMInstructionSerializer instr; - std::vector instr_fields; - STREAM_CHECK(instr.Load(strm_), "code/instruction"); - instructions.push_back(DeserializeInstruction(instr)); - } - - // Create the VM function. - VMFunction vm_func = VMFunction(loaded_func.name, - loaded_func.params, - instructions, - loaded_func.register_file_size); - auto it = vm_->global_map.find(loaded_func.name); - CHECK(it != vm_->global_map.end()); - CHECK_LE(it->second, vm_->global_map.size()); - vm_->functions[it->second] = vm_func; - } -} - -runtime::Module CreateDeserializer(const std::string& code, const runtime::Module lib) { - std::shared_ptr exec = std::make_shared(); - exec->Init(code, lib); - return runtime::Module(exec); -} - -TVM_REGISTER_GLOBAL("relay._vm._Deserializer") -.set_body_typed(CreateDeserializer); - -} // namespace vm -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/vm/deserializer.h b/src/relay/backend/vm/deserializer.h deleted file mode 100644 index 0caf72bee92c..000000000000 --- a/src/relay/backend/vm/deserializer.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/deserializer.h - * \brief Define a deserializer for the serialized Relay VM. - */ - -#ifndef TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ -#define TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace relay { -namespace vm { - -using namespace tvm::runtime::vm; -namespace runtime = tvm::runtime; - -class Deserializer : public runtime::ModuleNode { - public: - /*! - * \brief Initialize the deserializer for creating a virtual machine object. - * - * \param code The serialized code. - * \param lib The serialized runtime module/library that contains the - * hardware dependent code. - */ - inline void Init(const std::string& code, const runtime::Module& lib); - - /*! - * \brief Return the member function to the frontend. - * - * \param name The name of the function. - * \param sptr_to_self The pointer to the module node. - * - * \return The corresponding member function. - */ - PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; - - const char* type_key() const final { return "Deserializer"; } - - /*! \brief Deserialize the serialized VM. */ - void Deserialize(); - - virtual ~Deserializer() { delete strm_; } - - private: - /*! \brief Deserialize the globals in `vm_`. */ - void DeserializeGlobalSection(); - - /*! \brief Deserialize the constant pool in `vm_`. */ - void DeserializeConstantSection(); - - /*! \brief Deserialize primitive op names in `vm_`. */ - void DeserializePrimitiveOpNames(); - - /*! \brief Deserialize the vm functions in `vm_`. */ - void DeserializeCodeSection(); - - /*! \brief The code to be serialized. */ - std::string code_; - - /*! \brief The stream used for serialization. */ - dmlc::Stream* strm_; - - /*! \brief The VM to be created. */ - std::shared_ptr vm_; -}; - -} // namespace vm -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ diff --git a/src/relay/backend/vm/profiler/compiler.cc b/src/relay/backend/vm/profiler/compiler.cc index 9fd28e8c7f46..60c441a60cf0 100644 --- a/src/relay/backend/vm/profiler/compiler.cc +++ b/src/relay/backend/vm/profiler/compiler.cc @@ -33,7 +33,6 @@ namespace vm { class VMCompilerDebug : public VMCompiler { public: VMCompilerDebug() {} - void InitVM() override { vm_ = std::make_shared(); } virtual ~VMCompilerDebug() {} }; diff --git a/src/relay/backend/vm/serializer.cc b/src/relay/backend/vm/serializer.cc deleted file mode 100644 index 0040ef9db470..000000000000 --- a/src/relay/backend/vm/serializer.cc +++ /dev/null @@ -1,439 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serializer.cc - * \brief Implementation of serializing APIs for the Relay VM. - */ -#include "serializer.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "serialize_util.h" - -namespace tvm { -namespace relay { -namespace vm { - -void Serializer::Init(const VirtualMachine* vm) { - vm_ = vm; - // Initialize the stream object. - strm_ = new dmlc::MemoryStringStream(&code_); -} - -runtime::PackedFunc Serializer::GetFunction( - const std::string& name, - const std::shared_ptr& sptr_to_self) { - if (name == "get_lib") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLib(); - }); - } else if (name == "get_primitive_ops") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetPrimitiveOps(); - }); - } else if (name == "get_bytecode") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetBytecode(); - }); - } else if (name == "get_globals") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetGlobals(); - }); - } else if (name == "get_stats") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Stats(); - }); - } else if (name == "serialize") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Serialize(); - }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); - } -} - -tvm::Array Serializer::GetPrimitiveOps() const { - std::vector ret; - for (const auto& it : vm_->primitive_map) { - auto packed_name = tvm::ir::StringImm::make(it.first); - auto packed_index = static_cast(it.second); - if (ret.size() <= packed_index) { - ret.resize(packed_index + 1); - } - ret[packed_index] = packed_name; - } - return ret; -} - -std::string Serializer::Stats() const { - std::ostringstream oss; - oss << "Relay VM statistics:" << std::endl; - - // Get the number of constants and the shape of each of them. - oss << " Constant shapes (# " << vm_->constants.size() << "): ["; - for (const auto& it : vm_->constants) { - auto* cell = it.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - const auto& shape = data.Shape(); - - // Scalar - if (shape.empty()) { - oss << "scalar, "; - continue; - } - - oss << "["; - for (auto s : shape) { - oss << s << ", "; - } - oss.seekp(-2, oss.cur); - oss << "], " << std::endl; - } - if (!vm_->constants.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - // Get the number of globals and the name of each of them. - oss << " Globals (#" << vm_->global_map.size() << "): ["; - for (const auto& it : vm_->global_map) { - oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; - } - if (!vm_->global_map.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - // Get the number of primitive ops and the name of each of them. - oss << " Primitive ops (#" << vm_->primitive_map.size() << "): ["; - const auto& prim_ops = GetPrimitiveOps(); - for (const auto& it : prim_ops) { - oss << it << ", "; - } - if (!prim_ops.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - return oss.str(); -} - -TVMByteArray Serializer::Serialize() { - uint64_t header = kTVMVMBytecodeMagic; - strm_->Write(header); - std::string version = TVM_VERSION; - strm_->Write(version); - - // Global section. - SerializeGlobalSection(); - - // Constant section. - SerializeConstantSection(); - - // Primitive names. - SerializePrimitiveOpNames(); - - // Code section. - SerializeCodeSection(); - - TVMByteArray arr; - arr.data = code_.c_str(); - arr.size = code_.length(); - return arr; -} - -void Serializer::SerializeGlobalSection() { - auto globals = GetGlobals(); - std::vector glbs; - for (const auto& it : globals) { - glbs.push_back(it.as()->value); - } - strm_->Write(glbs); -} - -void Serializer::SerializeConstantSection() { - std::vector arrays; - for (const auto& obj : vm_->constants) { - const auto* cell = obj.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - arrays.push_back(const_cast(data.operator->())); - } - strm_->Write(static_cast(vm_->constants.size())); - for (const auto& it : arrays) { - runtime::SaveDLTensor(strm_, it); - } -} - -void Serializer::SerializePrimitiveOpNames() { - auto names = GetPrimitiveOps(); - std::vector primitive_names; - for (const auto& it : names) { - primitive_names.push_back(it.as()->value); - } - strm_->Write(primitive_names); -} - -// Serialize a virtual machine instruction. It creates a list that contains the -// hash, opcode, and all fields of an instruction. -// -// For example, the function signature used to create an `AllocTensor` -// instruction is: -// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) -// -// The serialized form will be: -// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` -// -// where hash is the hash of serialized instruction that is computed internally -// by the `VMInstructionSerializer`. It is used for sanity check before decoding. -// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` -// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` -// is the destination register, and the rest of it together indicates the shape -// of the tensor to be allocated. -VMInstructionSerializer SerializeInstruction(const Instruction& instr) { - std::vector fields; - // Save the opcode. - DLOG(INFO) << "Serializing: " << instr << std::endl; - switch (instr.op) { - case Opcode::Move: { - // Number of fields = 2 - fields.assign({instr.from, instr.dst}); - break; - } - case Opcode::Ret: { - // Number of fields = 1 - fields.push_back(instr.result); - break; - } - case Opcode::Fatal: { - // Number of fields = 0 - break; - } - case Opcode::InvokePacked: { - // Number of fields = 3 + instr.arity - // Note that arity includes both input arguments and outputs. We will - // put all the `arity` number of fields in the end for serialization. - fields.assign({instr.packed_index, instr.arity, instr.output_size}); - // Save the args. - fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); - break; - } - case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim - // Save `DLDataType` and the dst register. - const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); - - // The number of dimensions is not needed for constructing an - // `AllocTensor` instruction as it equals to the length of the `shape` - // vector. However, we save it to conveniently deserialize the instruction - // because we will know how many fields are needed by the `shape` argument. - fields.push_back(instr.alloc_tensor.ndim); - fields.push_back(instr.dst); - - // Save the shape of the tensor. - // Note that this field is rotated to the end of the list. - fields.insert(fields.end(), instr.alloc_tensor.shape, - instr.alloc_tensor.shape + instr.alloc_tensor.ndim); - break; - } - case Opcode::AllocTensorReg: { - // Number of fields = 5 - fields.push_back(instr.alloc_tensor_reg.shape_register); - // Save `DLDataType` and the dst register. - const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); - fields.push_back(instr.dst); - break; - } - case Opcode::AllocDatatype: { - // Number of fields = 3 + instr.num_fields - fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); - - // Save the fields. - fields.insert(fields.end(), instr.datatype_fields, - instr.datatype_fields + instr.num_fields); - break; - } - case Opcode::AllocClosure: { - // Number of fields = 3 + instr.num_freevar - fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); - - // Save the free vars. - fields.insert(fields.end(), instr.free_vars, - instr.free_vars + instr.num_freevar); - break; - } - case Opcode::If: { - // Number of fields = 4 - fields.assign({instr.if_op.test, - instr.if_op.target, - instr.if_op.true_offset, - instr.if_op.false_offset}); - break; - } - case Opcode::Invoke: { - // Number of fields = 3 + instr.num_args - fields.assign({instr.func_index, instr.num_args, instr.dst}); - - // Save the args. - fields.insert(fields.end(), instr.invoke_args_registers, - instr.invoke_args_registers + instr.num_args); - break; - } - case Opcode::InvokeClosure: { - // Number of fields = 3 + instr.num_closure_args - fields.assign({instr.closure, instr.num_closure_args, instr.dst}); - - // Save the args. - fields.insert(fields.end(), instr.closure_args, - instr.closure_args + instr.num_closure_args); - break; - } - case Opcode::LoadConst: { - // Number of fields = 2 - fields.assign({instr.const_index, instr.dst}); - break; - } - case Opcode::LoadConsti: { - // Number of fields = 2 - fields.assign({instr.load_consti.val, instr.dst}); - break; - } - case Opcode::GetField: { - // Number of fields = 3 - fields.assign({instr.object, instr.field_index, instr.dst}); - break; - } - case Opcode::GetTag: { - // Number of fields = 2 - fields.assign({instr.get_tag.object, instr.dst}); - break; - } - case Opcode::Goto: { - // Number of fields = 1 - fields.push_back(instr.pc_offset); - break; - } - default: - LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); - break; - } - - return VMInstructionSerializer(static_cast(instr.op), fields); -} - -void Serializer::SerializeCodeSection() { - // Save the number of functions. - strm_->Write(static_cast(vm_->functions.size())); - for (const auto& func : vm_->functions) { - // Serialize the function info. - VMFunctionSerializer func_format(func.name, - func.register_file_size, - func.instructions.size(), - func.params); - func_format.Save(strm_); - - // Serialize each instruction. - for (const auto& instr : func.instructions) { - const auto& serialized_instr = SerializeInstruction(instr); - serialized_instr.Save(strm_); - } - } -} - -tvm::Array Serializer::GetGlobals() const { - tvm::Array ret; - std::vector > globals(vm_->global_map.begin(), - vm_->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { - return a.second < b.second; - }; - std::sort(globals.begin(), globals.end(), comp); - for (const auto& it : globals) { - ret.push_back(tvm::ir::StringImm::make(it.first)); - } - return ret; -} - -std::string Serializer::GetBytecode() const { - std::ostringstream oss; - - for (const auto& func : vm_->functions) { - // Print the header of the function format. - oss << "# func name, reg file size, param count, inst count:" - << std::endl; - oss << func.name << " " - << func.register_file_size << " " - << func.params.size() << " " - << func.instructions.size() << std::endl; - - // Print pramams of a `VMFunction`. - oss << "# Parameters:"<< std::endl; - for (const auto& param : func.params) { - oss << param << " "; - } - oss << std::endl; - - // Print the instructions of a `VMFunction`. - // The part after ";" is the instruction in text format. - oss << "hash, opcode, fields # inst(text):"<< std::endl; - for (const auto& instr : func.instructions) { - const auto& serialized_instr = SerializeInstruction(instr); - oss << std::hex << "0x" << serialized_instr.Hash() << " " - << std::dec << serialized_instr.opcode << " "; - for (auto it : serialized_instr.fields) { - oss << it << " "; - } - oss << " # " << instr; - if (oss.str().back() != '\n') oss << std::endl; - } - } - - return oss.str(); -} - -runtime::Module Serializer::GetLib() const { - return vm_->lib; -} - -runtime::Module CreateSerializer(const VirtualMachine* vm) { - std::shared_ptr exec = std::make_shared(); - exec->Init(vm); - return runtime::Module(exec); -} - -TVM_REGISTER_GLOBAL("relay._vm._Serializer") -.set_body([](TVMArgs args, TVMRetValue* rv) { - runtime::Module mod = args[0]; - const auto* vm = dynamic_cast(mod.operator->()); - CHECK(vm) << "Virtual machine has not been defined yet." - << "\n"; - *rv = CreateSerializer(vm); -}); - -} // namespace vm -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/vm/serializer.h b/src/relay/backend/vm/serializer.h deleted file mode 100644 index 2371bb4c94f5..000000000000 --- a/src/relay/backend/vm/serializer.h +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serializer.h - * \brief Define a serializer for the Relay VM. - * - * The following components of a Relay VM will be serialized: - * - The `constants`, e.g., the constant pool, that contains the - * constants used in a Relay program. - * - The `packed_funcs` that essentially contains the generated code for - * a specific target. We return it as a runtime module that can be exported as - * a library file (e.g., .so, .o, or .tar). - * - The `global_map` that contains the globals. - * - The `primitive_map` that contains the name of individual primitive operators. - * - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of - * a list of instructions/bytecode. - * - * Note that only the library is returned as a separate module. All othere parts - * are stored in a single serialized code that is organized with the following - * sections in order. - * - Global section, containing all globals. - * - Constant section, storing the constant pool. - * - Primitive name section, containing the function name of the primitive ops - * used by the virtual machine. - * - Code section, handling the VM functions and bytecode. - * - * The code section is again organized as follows for each VM function: - * func_name, register_file_size, num_instructions (N) - * param1, param2, ..., paramM - * instruction1 - * instruction2 - * ... - * instructionN - * - * Serializing an `Instruction` requires us to deal with the bytecode. Each line - * of the instructions could be serialized as the following format: - * hash, opcode, f1, f2, ..., fX, field with variable length - * 1. hash: the hash of the instruction. This number will be used to help us - * validate if an instruction is well-formed during deserialization. - * 2. opcode: the opcode code of the instruction. - * 3. f1, f2, ..., fX. These fields together represent the fixed fields in - * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For - * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). - * 4. The rest of the line indicates the field with variable length, e.g., - * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. - */ - -#ifndef TVM_RELAY_BACKEND_VM_SERIALIZER_H_ -#define TVM_RELAY_BACKEND_VM_SERIALIZER_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace relay { -namespace vm { - -using namespace tvm::runtime; -using namespace tvm::runtime::vm; - -/*! - * \brief The Relay VM serializer. - */ -class Serializer : public runtime::ModuleNode { - public: - /*! - * \brief Initialize the serializer for a virtual machine. - * - * \param vm The Relay virtual machine. - */ - inline void Init(const VirtualMachine* vm); - - /*! - * \brief Return the member function to the frontend. - * - * \param name The name of the function. - * \param sptr_to_self The pointer to the module node. - * - * \return The corresponding member function. - */ - PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; - - const char* type_key() const final { return "Serializer"; } - - /*! - * \brief Print the detailed statistics of the given code, i.e. number of - * globls and constants, etc. - */ - std::string Stats() const; - - /*! - * \brief Serialize the `vm_` into global section, constant section, and code - * section. - * - * \return The binary representation of the VM. - */ - TVMByteArray Serialize(); - - /*! - * \brief Get a list of the globals used by the `_vm`. - * - * \return The global map in the form a list. - */ - tvm::Array GetGlobals() const; - - /*! - * \brief Get the primitive operators that are contained in the Relay VM. - * - * \return The list of primitve operators. - */ - tvm::Array GetPrimitiveOps() const; - - /*! - * \brief Get the serialized form of the `functions` in `vm_`. This is - * essentially bytecode serialization. - * - * \return The serialized vm bytecode. - * - * \note The bytecode is in the following format: - * func_name reg_file_size num_instructions - * param1 param2 ... paramM - * instruction1 - * instruction2 - * ... - * instructionN - * - * Each instruction is printed in the following format: - * opcode num_fields field1 ... fieldX # The text format. - * - * The field starting from # is only used for debugging. The serialized code - * doesn't contain it, therefore the deserializer doens't need to handle it. - */ - std::string GetBytecode() const; - - /*! \brief Get the `lib` module in vm_. Serialization of `runtime::module` - * has already been supported by TVM. Therefore, we only return the runtime - * module and let users have the flexibility to call `export_library` from - * the frontend to save the library to disk. - * - * \return The runtime module that contains the hardwre dependent code. - */ - inline runtime::Module GetLib() const; - - virtual ~Serializer() { delete strm_; } - - private: - /*! \brief Serialize the globals in vm_. */ - void SerializeGlobalSection(); - - /*! \brief Serialize the constant pool in vm_. */ - void SerializeConstantSection(); - - /*! \brief Serialize primitive op names in vm_. */ - void SerializePrimitiveOpNames(); - - /*! \brief Serialize the vm functions in vm_. */ - void SerializeCodeSection(); - - /*! \brief The Relay virtual machine for to be serialized. */ - const VirtualMachine* vm_; - - /*! \brief The stream used for serialization. */ - dmlc::Stream* strm_; - - /*! \brief The serialized code. */ - std::string code_; -}; - -} // namespace vm -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_VM_SERIALIZER_H_ diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc new file mode 100644 index 000000000000..21f71af4eb8c --- /dev/null +++ b/src/runtime/vm/executable.cc @@ -0,0 +1,734 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/runtime/vm/executable.cc + * \brief The implementation of a virtual machine executable APIs. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "serialize_util.h" + +namespace tvm { +namespace runtime { +namespace vm { + +#define STREAM_CHECK(val, section) \ + CHECK(val) << "Invalid VM file format in the " << section << " section." \ + << "\n"; + +// Helper to serialize a vm instruction. +VMInstructionSerializer SerializeInstruction(const Instruction& instr); +// Helper to deserialize a serialized vm instruction. +Instruction DeserializeInstruction(const VMInstructionSerializer& instr); + +PackedFunc Executable::GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) { + if (name == "get_lib") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetLib(); + }); + } else if (name == "get_bytecode") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetBytecode(); + }); + } else if (name == "get_stats") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Stats(); + }); + } else if (name == "save") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Save(); + }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc(nullptr); + } +} + +std::string Executable::GetBytecode() const { + std::ostringstream oss; + + for (const auto& func : functions) { + // Print the header of the function format. + oss << "# func name, reg file size, param count, inst count:" + << std::endl; + oss << func.name << " " + << func.register_file_size << " " + << func.params.size() << " " + << func.instructions.size() << std::endl; + + // Print pramams of a `VMFunction`. + oss << "# Parameters: "<< std::endl; + for (const auto& param : func.params) { + oss << param << " "; + } + oss << std::endl; + + // Print the instructions of a `VMFunction`. + // The part after ";" is the instruction in text format. + oss << "hash, opcode, fields # inst(text):"<< std::endl; + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + oss << std::hex << "0x" << serialized_instr.Hash() << " " + << std::dec << serialized_instr.opcode << " "; + for (auto it : serialized_instr.fields) { + oss << it << " "; + } + oss << " # " << instr; + if (oss.str().back() != '\n') oss << std::endl; + } + } + + return oss.str(); +} + +std::string Executable::Stats() const { + std::ostringstream oss; + oss << "Relay VM executable statistics:" << std::endl; + + // Get the number of constants and the shape of each of them. + oss << " Constant shapes (# " << constants.size() << "): ["; + for (const auto& it : constants) { + const auto* cell = it.as(); + CHECK(cell); + runtime::NDArray data = cell->data; + const auto& shape = data.Shape(); + + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], " << std::endl; + } + if (!constants.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of globals and the name of each of them. + oss << " Globals (#" << global_map.size() << "): ["; + for (const auto& it : global_map) { + oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; + } + if (!global_map.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of primitive ops and the name of each of them. + oss << " Primitive ops (#" << primitive_map.size() << "): ["; + std::vector prim_ops; + for (const auto& it : primitive_map) { + auto packed_index = static_cast(it.second); + if (prim_ops.size() <= packed_index) { + prim_ops.resize(packed_index + 1); + } + prim_ops[packed_index] = it.first; + } + for (const auto& it : prim_ops) { + oss << it << ", "; + } + if (!prim_ops.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + return oss.str(); +} + +void SaveHeader(dmlc::Stream* strm) { + uint64_t header = kTVMVMBytecodeMagic; + strm->Write(header); + std::string version = TVM_VERSION; + strm->Write(version); +} + +TVMByteArray Executable::Save() { + // Initialize the stream object. + code_.clear(); + dmlc::MemoryStringStream strm(&code_); + + // Save header + SaveHeader(&strm); + + // Global section. + SaveGlobalSection(&strm); + + // Constant section. + SaveConstantSection(&strm); + + // Primitive names. + SavePrimitiveOpNames(&strm); + + // Code section. + SaveCodeSection(&strm); + + TVMByteArray arr; + arr.data = code_.c_str(); + arr.size = code_.length(); + return arr; +} + +void Executable::SaveGlobalSection(dmlc::Stream* strm) { + std::vector > globals(this->global_map.begin(), + this->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + + std::vector glbs; + for (const auto& it : globals) { + glbs.push_back(it.first); + } + strm->Write(glbs); +} + +void Executable::SaveConstantSection(dmlc::Stream* strm) { + std::vector arrays; + for (const auto& obj : this->constants) { + const auto* cell = obj.as(); + CHECK(cell != nullptr); + runtime::NDArray data = cell->data; + arrays.push_back(const_cast(data.operator->())); + } + strm->Write(static_cast(this->constants.size())); + for (const auto& it : arrays) { + runtime::SaveDLTensor(strm, it); + } +} + +void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { + std::vector primitive_names; + for (const auto& it : this->primitive_map) { + auto packed_index = static_cast(it.second); + if (primitive_names.size() <= packed_index) { + primitive_names.resize(packed_index + 1); + } + primitive_names[packed_index] = it.first; + } + strm->Write(primitive_names); +} + +// Serialize a virtual machine instruction. It creates a list that contains the +// hash, opcode, and all fields of an instruction. +// +// For example, the function signature used to create an `AllocTensor` +// instruction is: +// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) +// +// The serialized form will be: +// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` +// +// where hash is the hash of serialized instruction that is computed internally +// by the `VMInstructionExecutable`. It is used for sanity check before decoding. +// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` +// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` +// is the destination register, and the rest of it together indicates the shape +// of the tensor to be allocated. +VMInstructionSerializer SerializeInstruction(const Instruction& instr) { + std::vector fields; + // Save the opcode. + DLOG(INFO) << "Serializing: " << instr << std::endl; + switch (instr.op) { + case Opcode::Move: { + // Number of fields = 2 + fields.assign({instr.from, instr.dst}); + break; + } + case Opcode::Ret: { + // Number of fields = 1 + fields.push_back(instr.result); + break; + } + case Opcode::Fatal: { + // Number of fields = 0 + break; + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + // Note that arity includes both input arguments and outputs. We will + // put all the `arity` number of fields in the end for serialization. + fields.assign({instr.packed_index, instr.arity, instr.output_size}); + // Save the args. + fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); + break; + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + + // The number of dimensions is not needed for constructing an + // `AllocTensor` instruction as it equals to the length of the `shape` + // vector. However, we save it to conveniently deserialize the instruction + // because we will know how many fields are needed by the `shape` argument. + fields.push_back(instr.alloc_tensor.ndim); + fields.push_back(instr.dst); + + // Save the shape of the tensor. + // Note that this field is rotated to the end of the list. + fields.insert(fields.end(), instr.alloc_tensor.shape, + instr.alloc_tensor.shape + instr.alloc_tensor.ndim); + break; + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + fields.push_back(instr.alloc_tensor_reg.shape_register); + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + fields.push_back(instr.dst); + break; + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); + + // Save the fields. + fields.insert(fields.end(), instr.datatype_fields, + instr.datatype_fields + instr.num_fields); + break; + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); + + // Save the free vars. + fields.insert(fields.end(), instr.free_vars, + instr.free_vars + instr.num_freevar); + break; + } + case Opcode::If: { + // Number of fields = 4 + fields.assign({instr.if_op.test, + instr.if_op.target, + instr.if_op.true_offset, + instr.if_op.false_offset}); + break; + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + fields.assign({instr.func_index, instr.num_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.invoke_args_registers, + instr.invoke_args_registers + instr.num_args); + break; + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + fields.assign({instr.closure, instr.num_closure_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.closure_args, + instr.closure_args + instr.num_closure_args); + break; + } + case Opcode::LoadConst: { + // Number of fields = 2 + fields.assign({instr.const_index, instr.dst}); + break; + } + case Opcode::LoadConsti: { + // Number of fields = 2 + fields.assign({instr.load_consti.val, instr.dst}); + break; + } + case Opcode::GetField: { + // Number of fields = 3 + fields.assign({instr.object, instr.field_index, instr.dst}); + break; + } + case Opcode::GetTag: { + // Number of fields = 2 + fields.assign({instr.get_tag.object, instr.dst}); + break; + } + case Opcode::Goto: { + // Number of fields = 1 + fields.push_back(instr.pc_offset); + break; + } + default: + LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); + break; + } + + return VMInstructionSerializer(static_cast(instr.op), fields); +} + +void Executable::SaveCodeSection(dmlc::Stream* strm) { + // Save the number of functions. + strm->Write(static_cast(this->functions.size())); + for (const auto& func : this->functions) { + // Save the function info. + VMFunctionSerializer func_format(func.name, + func.register_file_size, + func.instructions.size(), + func.params); + func_format.Save(strm); + + // Serialize each instruction. + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + serialized_instr.Save(strm); + } + } +} + +void LoadHeader(dmlc::Stream* strm) { + // Check header. + uint64_t header; + STREAM_CHECK(strm->Read(&header), "header"); + STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); + + // Check version. + std::string version; + STREAM_CHECK(strm->Read(&version), "version"); + STREAM_CHECK(version == TVM_VERSION, "version"); +} + +runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { + std::shared_ptr exec = std::make_shared(); + exec->lib = lib; + exec->code_ = code; + dmlc::MemoryStringStream strm(&exec->code_); + + // Load header. + LoadHeader(&strm); + + // Global section. + exec->LoadGlobalSection(&strm); + + // Constant section. + exec->LoadConstantSection(&strm); + + // Primitive names that will be invoked by `InvokePacked` instructions. + exec->LoadPrimitiveOpNames(&strm); + + // Code section. + exec->LoadCodeSection(&strm); + + return runtime::Module(exec); +} + +void Executable::LoadGlobalSection(dmlc::Stream* strm) { + std::vector globals; + STREAM_CHECK(strm->Read(&globals), "global"); + for (size_t i = 0; i < globals.size(); i++) { + this->global_map.insert({globals[i], i}); + } +} + +void Executable::LoadConstantSection(dmlc::Stream* strm) { + uint64_t sz; + // Load the number of constants. + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); + + size_t size = static_cast(sz); + // Load each of the constants. + for (size_t i = 0; i < size; i++) { + runtime::NDArray constant; + STREAM_CHECK(constant.Load(strm), "constant"); + runtime::ObjectRef obj = runtime::vm::Tensor(constant); + this->constants.push_back(obj); + } +} + +void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { + std::vector primitive_names; + STREAM_CHECK(strm->Read(&primitive_names), "primitive name"); + for (size_t i = 0; i < primitive_names.size(); i++) { + this->primitive_map.insert({primitive_names[i], i}); + } +} + +// Extract the `cnt` number of fields started at `start` from the list +// `instr_fields`. +inline std::vector ExtractFields(const std::vector& instr_fields, + Index start, + Index cnt) { + CHECK_LE(static_cast(start + cnt), instr_fields.size()); + std::vector ret; + for (auto i = start; i < start + cnt; i++) { + ret.push_back(instr_fields[i]); + } + return ret; +} + +Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { + Opcode opcode = static_cast(instr.opcode); + switch (opcode) { + case Opcode::Move: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::Move(instr.fields[0], instr.fields[1]); + } + case Opcode::Ret: { + // Number of fields = 1 + DCHECK_EQ(instr.fields.size(), 1U); + return Instruction::Ret(instr.fields[0]); + } + case Opcode::Fatal: { + // Number of fields = 0 + DCHECK(instr.fields.empty()); + return Instruction::Fatal(); + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index packed_index = instr.fields[0]; + Index arity = instr.fields[1]; + Index output_size = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, arity); + return Instruction::InvokePacked(packed_index, arity, output_size, args); + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + DCHECK_GE(instr.fields.size(), 5U); + DCHECK_EQ(instr.fields.size(), 5U + static_cast(instr.fields[3])); + + DLDataType dtype; + dtype.code = instr.fields[0]; + dtype.bits = instr.fields[1]; + dtype.lanes = instr.fields[2]; + + Index ndim = instr.fields[3]; + RegName dst = instr.fields[4]; + + std::vector shape = ExtractFields(instr.fields, 5, ndim); + + return Instruction::AllocTensor(shape, dtype, dst); + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + DCHECK_EQ(instr.fields.size(), 5U); + Index shape_register = instr.fields[0]; + + DLDataType dtype; + dtype.code = instr.fields[1]; + dtype.bits = instr.fields[2]; + dtype.lanes = instr.fields[3]; + + RegName dst = instr.fields[4]; + + return Instruction::AllocTensorReg(shape_register, dtype, dst); + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index constructor_tag = instr.fields[0]; + Index num_fields = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector fields = ExtractFields(instr.fields, 3, num_fields); + + return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index clo_index = instr.fields[0]; + Index num_freevar = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector free_vars = ExtractFields(instr.fields, 3, num_freevar); + + return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); + } + case Opcode::If: { + // Number of fields = 4 + DCHECK_EQ(instr.fields.size(), 4U); + Index test = instr.fields[0]; + Index target = instr.fields[1]; + Index true_offset = instr.fields[2]; + Index false_offset = instr.fields[3]; + + return Instruction::If(test, target, true_offset, false_offset); + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index func_index = instr.fields[0]; + Index num_args = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, num_args); + + return Instruction::Invoke(func_index, args, dst); + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index closure = instr.fields[0]; + Index num_closure_args = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, num_closure_args); + + return Instruction::InvokeClosure(closure, args, dst); + } + case Opcode::LoadConst: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::LoadConst(instr.fields[0], instr.fields[1]); + } + case Opcode::LoadConsti: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::LoadConsti(instr.fields[0], instr.fields[1]); + } + case Opcode::GetField: { + // Number of fields = 3 + DCHECK_EQ(instr.fields.size(), 3U); + return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]); + } + case Opcode::GetTag: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::GetTag(instr.fields[0], instr.fields[1]); + } + case Opcode::Goto: { + // Number of fields = 1 + DCHECK_EQ(instr.fields.size(), 1U); + return Instruction::Goto(instr.fields[0]); + } + default: + LOG(FATAL) << "Invalid opcode" << instr.opcode; + return Instruction(); + } +} + +void Executable::LoadCodeSection(dmlc::Stream* strm) { + // Load the number of functions. + uint64_t sz; + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "code"); + + size_t num_funcs = static_cast(sz); + this->functions.resize(num_funcs); + for (size_t i = 0; i < num_funcs; i++) { + // Load the function info. + VMFunctionSerializer loaded_func; + STREAM_CHECK(loaded_func.Load(strm), "code/function"); + + // Load the instructions. + std::vector instructions; + for (size_t j = 0; j < loaded_func.num_instructions; j++) { + VMInstructionSerializer instr; + std::vector instr_fields; + STREAM_CHECK(instr.Load(strm), "code/instruction"); + instructions.push_back(DeserializeInstruction(instr)); + } + + // Create the VM function. + VMFunction vm_func = VMFunction(loaded_func.name, + loaded_func.params, + instructions, + loaded_func.register_file_size); + auto it = this->global_map.find(loaded_func.name); + CHECK(it != this->global_map.end()); + CHECK_LE(it->second, this->global_map.size()); + this->functions[it->second] = vm_func; + } +} + +TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + *rv = static_cast(exec->global_map.size()); +}); + +TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + int idx = args[1]; + std::vector > globals(exec->global_map.begin(), + exec->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + CHECK_LT(idx, globals.size()); + *rv = globals[idx].first; +}); + +TVM_REGISTER_GLOBAL("relay._vm.GetNumOfPrimitives") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + *rv = static_cast(exec->primitive_map.size()); +}); + + +TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + int idx = args[1]; + CHECK_GE(idx, 0); + CHECK_LT(idx, exec->primitive_map.size()); + + for (const auto& it : exec->primitive_map) { + if (idx == static_cast(it.second)) { + *rv = it.first; + break; + } + } +}); + +TVM_REGISTER_GLOBAL("relay._vm.Load_Executable") +.set_body_typed([]( + std::string code, + runtime::Module lib) { + return Executable::Load(code, lib); +}); + +} // namespace vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 80e0ce57a8ae..821de0bda245 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -85,19 +85,25 @@ PackedFunc VirtualMachineDebug::GetFunction( } } -void VirtualMachineDebug::Init(const std::vector& ctxs) { - VirtualMachine::Init(ctxs); - for (auto kv : primitive_map) { +void VirtualMachineDebug::LoadExecutable(const Executable* exec) { + VirtualMachine::LoadExecutable(exec); + CHECK(this->exec); + for (auto kv : this->exec->primitive_map) { packed_index_map[kv.second] = kv.first; op_invokes[kv.second] = 0; } } +void VirtualMachineDebug::Init(const std::vector& ctxs) { + VirtualMachine::Init(ctxs); +} + void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { - auto ctx = VirtualMachine::GetParamsContext(); + CHECK(this->exec); + auto ctx = this->GetParamsContext(); // warmup VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); @@ -117,6 +123,21 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, op_invokes[packed_index] += 1; } +runtime::Module CreateVirtualMachineDebug(const Executable* exec) { + std::shared_ptr vm = std::make_shared(); + vm->LoadExecutable(exec); + return runtime::Module(vm); +} + +TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec) << "Virtual machine has not been defined yet." + << "\n"; + *rv = CreateVirtualMachineDebug(exec); +}); + } // namespace vm } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 447967cafeb0..ff3296cb6c16 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -47,6 +47,8 @@ class VirtualMachineDebug : public VirtualMachine { void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) final; + void LoadExecutable(const Executable* exec); + ~VirtualMachineDebug() {} private: diff --git a/src/relay/backend/vm/serialize_util.h b/src/runtime/vm/serialize_util.h similarity index 95% rename from src/relay/backend/vm/serialize_util.h rename to src/runtime/vm/serialize_util.h index 3e7508ebee9b..3931f2f0e023 100644 --- a/src/relay/backend/vm/serialize_util.h +++ b/src/runtime/vm/serialize_util.h @@ -19,11 +19,11 @@ /*! * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serialize_util.h + * \file src/runtime/vm/serialize_util.h * \brief Definitions of helpers for serializing and deserializing a Relay VM. */ -#ifndef TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ -#define TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ +#ifndef TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ +#define TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ #include #include @@ -34,7 +34,7 @@ #include namespace tvm { -namespace relay { +namespace runtime { namespace vm { /*! \brief The magic number for the serialized VM bytecode file */ @@ -158,7 +158,7 @@ struct VMInstructionSerializer { }; } // namespace vm -} // namespace relay +} // namespace runtime } // namespace tvm -#endif // TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ +#endif // TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 7dea9bdb95ea..78b74768b930 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -575,11 +575,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, const std::shared_ptr& sptr_to_self) { if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(exec) << "The executable is not created yet."; std::string func_name = args[0]; - auto gvit = this->global_map.find(func_name); - CHECK(gvit != this->global_map.end()) << "Cannot find function " << func_name; + auto gvit = exec->global_map.find(func_name); + CHECK(gvit != exec->global_map.end()) << "Cannot find function " << func_name; auto func_index = gvit->second; - const auto& vm_func = this->functions[func_index]; + const auto& vm_func = exec->functions[func_index]; const auto& param_names = vm_func.params; auto ctx = this->GetParamsContext(); @@ -617,10 +618,6 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } this->Init(contexts); }); - } else if (name == "load_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->LoadParams(args[0].operator std::string()); - }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); @@ -628,43 +625,20 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } TVMContext VirtualMachine::GetParamsContext() const { + CHECK(!ctxs.empty()) << "Context has not been initialized yet." + << "\n"; + // Use the fallback device if no device index is available. int fallback_device_type = static_cast(ctxs[0].device_type); // TODO(wweic): For heterogeneous execution, get device information from byte const auto& cit = - std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { - return fallback_device_type == static_cast(c.device_type); - }); + std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { + return fallback_device_type == static_cast(c.device_type); + }); return (cit == ctxs.end() ? ctxs[0] : *cit); } -void VirtualMachine::LoadParams(const std::string& params) { - dmlc::MemoryStringStream mss(const_cast(¶ms)); - dmlc::Stream* strm = &mss; - uint64_t header, reserved; - CHECK(strm->Read(&header)) << "Invalid parameter file"; - CHECK(header == kTVMNDArrayListMagic) << "Invalid parameter file"; - CHECK(strm->Read(&reserved)) << "Invalid parameter file"; - - std::vector names; - CHECK(strm->Read(&names)) << "Invalid parameter file"; - - uint64_t sz; - strm->Read(&sz); - size_t size = static_cast(sz); - CHECK(size == names.size()) << "Invalid parameter file"; - - auto ctx = GetParamsContext(); - for (size_t i = 0; i < size; i++) { - NDArray arr; - CHECK(arr.Load(strm)) << "Invalid parameter file"; - ObjectRef obj = Tensor(arr); - auto copy = CopyTo(obj, ctx); - params_.emplace(std::make_pair(names[i], copy)); - } -} - void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); frames.push_back(frame); @@ -699,15 +673,17 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vectorGetAllocator(ctxs[0]); DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B"; return return_register; } ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector& args) { - auto func_index = this->global_map[name]; + CHECK(exec) << "The executable has not been created yet."; + auto func_index = exec->global_map.at(name); DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; - return Invoke(this->functions[func_index], args); + return Invoke(exec->functions[func_index], args); } void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, @@ -744,14 +720,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); } -void VirtualMachine::Init(const std::vector& ctxs) { - this->ctxs = ctxs; +void VirtualMachine::LoadExecutable(const Executable* exec) { + CHECK(exec) << "The executable is not created yet."; + this->exec = exec; + runtime::Module lib = this->exec->lib; // Get the list of packed functions. - CHECK(primitive_map.empty() || lib.operator->()) + CHECK(exec->primitive_map.empty() || lib.operator->()) << "runtime module should have been built for primitive functions" << "\n"; - for (const auto& it : primitive_map) { + for (const auto& it : this->exec->primitive_map) { const auto& packed_name = it.first; auto packed_index = static_cast(it.second); if (packed_funcs.size() <= packed_index) { @@ -761,6 +739,11 @@ void VirtualMachine::Init(const std::vector& ctxs) { } } + +void VirtualMachine::Init(const std::vector& ctxs) { + this->ctxs = ctxs; +} + inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { frames.back().register_file[r] = val; } @@ -788,6 +771,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const { void VirtualMachine::RunLoop() { CHECK(this->code); + CHECK(this->exec); this->pc = 0; Index frame_start = frames.size(); while (true) { @@ -810,7 +794,8 @@ void VirtualMachine::RunLoop() { throw std::runtime_error("VM encountered fatal error"); } case Opcode::LoadConst: { - auto constant_obj = this->constants[instr.const_index]; + auto constant_obj = exec->constants[instr.const_index]; + // TODO(wweic) ctx could be obtained from the ctxs list. auto device_obj = CopyTo(constant_obj, ctxs[0]); WriteRegister(instr.dst, device_obj); pc++; @@ -828,7 +813,7 @@ void VirtualMachine::RunLoop() { for (Index i = 0; i < instr.num_args; ++i) { args.push_back(ReadRegister(instr.invoke_args_registers[i])); } - InvokeGlobal(this->functions[instr.func_index], args); + InvokeGlobal(exec->functions[instr.func_index], args); frames.back().caller_return_register = instr.dst; goto main_loop; } @@ -858,7 +843,7 @@ void VirtualMachine::RunLoop() { for (Index i = 0; i < instr.num_closure_args; ++i) { args.push_back(ReadRegister(instr.closure_args[i])); } - InvokeGlobal(this->functions[closure->func_index], args); + InvokeGlobal(exec->functions[closure->func_index], args); frames.back().caller_return_register = instr.dst; goto main_loop; } @@ -910,6 +895,7 @@ void VirtualMachine::RunLoop() { for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { shape[i] = instr.alloc_tensor.shape[i]; } + // TODO(wweic) ctx could be obtained from the ctxs list. auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); auto obj = Tensor(data); @@ -931,6 +917,7 @@ void VirtualMachine::RunLoop() { auto num_dims = shape_tensor->shape[0]; auto shape = std::vector(shape_tensor->shape[0]); shape.assign(dims, dims + num_dims); + // TODO(wweic) ctx could be obtained from the ctxs list. auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); auto obj = Tensor(data); @@ -976,6 +963,21 @@ void VirtualMachine::RunLoop() { } } +runtime::Module CreateVirtualMachine(const Executable* exec) { + std::shared_ptr vm = std::make_shared(); + vm->LoadExecutable(exec); + return runtime::Module(vm); +} + +TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec) << "The virtual machine executable has not been defined yet." + << "\n"; + *rv = CreateVirtualMachine(exec); +}); + } // namespace vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index cedbc4f71859..1b40f894db08 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -47,14 +47,16 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): if isinstance(f, relay.Expr): mod = relay.Module() mod["main"] = f - vm = relay.vm.compile(mod, target) - vm.init(tvm.cpu()) + exe = relay.vm.compile(mod, target) + vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) return vm.invoke("main", *args) else: assert isinstance(f, relay.Module), "expected expression or module" mod = f - vm = relay.vm.compile(mod, target) - vm.init(tvm.cpu()) + exe = relay.vm.compile(mod, target) + vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) ret = vm.invoke("main", *args) return ret @@ -573,25 +575,6 @@ def test_add_op_broadcast(): mod["main"] = func check_result([x_data, y_data], x_data + y_data, mod=mod) -def test_set_params(): - mod = relay.Module() - x = relay.var('x', shape=(10, 5)) - w = relay.var('w', shape=(6, 5)) - b = relay.var('b', shape=(6,)) - y = relay.nn.bias_add(relay.nn.dense(x, w), b) - mod["main"] = relay.Function([x, w, b], y) - vm = relay.vm.compile(mod, 'llvm') - vm.init(tvm.cpu()) - - x_np = np.random.uniform(size=(10, 5)).astype('float32') - w_np = np.random.uniform(size=(6, 5)).astype('float32') - b_np = np.random.uniform(size=(6,)).astype('float32') - ref_np = np.dot(x_np, w_np.T) + b_np - params = {'w': w_np} - vm.load_params(params) - out = vm.run(x_np, b_np) - tvm.testing.assert_allclose(out.asnumpy(), ref_np) - if __name__ == "__main__": test_id() @@ -626,4 +609,3 @@ def test_set_params(): test_add_op_scalar() test_add_op_tensor() test_add_op_broadcast() - test_set_params() diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 3a317fc2d111..014648099aeb 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -22,29 +22,25 @@ from tvm import relay from tvm.relay.module import Module as rly_module from tvm.relay import vm as _vm -from tvm.relay import serializer, deserializer from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.prelude import Prelude from tvm.contrib import util from tvm.relay import testing -def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None): +def create_exec(f, target="llvm", params=None): if isinstance(f, relay.Expr): mod = relay.Module() mod["main"] = f - vm = _vm.compile(mod, target=target, params=params) - vm.init(ctx) - return vm + executable = _vm.compile(mod, target=target, params=params) + return executable else: assert isinstance(f, relay.Module), "expected mod as relay.Module" - vm = _vm.compile(f, target=target, params=params) - vm.init(ctx) - return vm + executable = _vm.compile(f, target=target, params=params) + return executable def veval(vm, *args, ctx=tvm.cpu()): assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine" - vm.init(ctx) ret = vm.run(*args) return ret @@ -59,13 +55,11 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32'): return result.asnumpy().astype(dtype) def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): - vm = create_vm(mod, ctx, target, params=params) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod, target, params=params) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) des_vm.init(ctx) - des_vm.load_params(params) result = des_vm.run(data) return result.asnumpy().astype(dtype) @@ -99,26 +93,25 @@ def test_serializer(): main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1)) mod["main"] = main - vm = create_vm(mod) - ser = serializer.Serializer(vm) + exe = create_exec(mod) - glbs = ser.globals + glbs = exe.globals assert len(glbs) == 3 assert "f1" in glbs assert "f2" in glbs assert "main" in glbs - prim_ops = ser.primitive_ops + prim_ops = exe.primitive_ops assert any(item.startswith('fused_add') for item in prim_ops) assert any(item.startswith('fused_subtract') for item in prim_ops) assert any(item.startswith('fused_multiply') for item in prim_ops) - code = ser.bytecode + code = exe.bytecode assert "main 5 2 5" in code assert "f1 2 1 3" in code assert "f2 2 1 3" in code - code, lib = ser.serialize() + code, lib = exe.save() assert isinstance(code, bytearray) assert isinstance(lib, tvm.module.Module) @@ -129,24 +122,24 @@ def test_save_load(): x_data = np.random.rand(10, 10).astype('float32') # serialize. - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() + vm = create_exec(f) + code, lib = vm.save() assert isinstance(code, bytearray) # save and load the code and lib file. tmp = util.tempdir() path_lib = tmp.relpath("lib.so") lib.export_library(path_lib) - with open(tmp.relpath("code.bc"), "wb") as fo: + with open(tmp.relpath("code.ro"), "wb") as fo: fo.write(code) loaded_lib = tvm.module.load(path_lib) - loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) + loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) # deserialize. - deser = deserializer.Deserializer(loaded_code, loaded_lib) - des_vm = deser.deserialize() + des_exec = _vm.Executable.load_exec(loaded_code, loaded_lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) res = veval(des_vm, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) @@ -156,12 +149,12 @@ def test_const(): c = relay.const(1.0, "float32") x = relay.var('x', shape=(10, 10), dtype='float32') f = relay.Function([x], x + c) - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() + exe = create_exec(f) + code, lib = exe.save() assert isinstance(code, bytearray) - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) x_data = np.random.rand(10, 10).astype('float32') res = veval(des_vm, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + 1) @@ -177,11 +170,11 @@ def test_if(): x_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32') - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(f) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) # same res = veval(des_vm, x_data, x_data) @@ -213,11 +206,11 @@ def test_loop(): aarg = relay.var('accum', shape=[], dtype='int32') mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) - vm = create_vm(mod) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm, i_data, accum_data) tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) @@ -230,11 +223,11 @@ def test_tuple(): i_data = np.random.rand(41).astype('float32') j_data = np.random.rand(10).astype('float32') - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(f) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm, (i_data, j_data)) tvm.testing.assert_allclose(result.asnumpy(), j_data) @@ -251,11 +244,11 @@ def test_adt_list(): f = relay.Function([], l321) mod["main"] = f - vm = create_vm(mod) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm) assert len(result) == 2 @@ -297,11 +290,11 @@ def test_adt_compose(): f = relay.Function([y], add_two_body) mod["main"] = f - vm = create_vm(mod) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) x_data = np.array(np.random.rand()).astype('float32') result = veval(des_vm, x_data) @@ -317,11 +310,11 @@ def test_closure(): clo = ff(relay.const(1.0)) main = clo(relay.const(2.0)) - vm = create_vm(main) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(main) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) res = veval(des_vm) tvm.testing.assert_allclose(res.asnumpy(), 3.0) diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index b5ce0ec70e51..53f573730576 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -26,9 +26,9 @@ def test_basic(): mod, params = resnet.get_workload() target = 'llvm' ctx = tvm.cpu() - vm = relay.profiler_vm.compile(mod, target) + exe = relay.profiler_vm.compile(mod, target, params=params) + vm = relay.profiler_vm.VirtualMachineProfiler(exe) vm.init(ctx) - vm.load_params(params) data = np.random.rand(1, 3, 224, 224).astype('float32') res = vm.invoke("main", [data]) From 5faa6f70d7d70e56cb3f44ff3f4e5699be287a3e Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 17 Oct 2019 22:41:34 -0700 Subject: [PATCH 02/59] [Relay][Frontend][TF] Add tensor array ops (#3798) * [Relay][Frontend][TF] Add tensor array ops * rename * delete test * Move utility function * Refactor * fix tensor array ops * fix test * fix rebase * Fix serializer bug * Improve tf convert name lookup to use prelude api * Fix lint * Fix test --- python/tvm/relay/frontend/tensorflow.py | 82 ++- python/tvm/relay/op/_tensor.py | 26 + python/tvm/relay/prelude.py | 520 ++++++++++++++++++ python/tvm/relay/testing/py_converter.py | 8 +- src/runtime/vm/executable.cc | 4 +- .../frontend/tensorflow/test_forward.py | 118 +++- tests/python/relay/test_adt.py | 148 +++++ tests/python/relay/test_feature.py | 3 +- 8 files changed, 899 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 38f9c523e0b1..eb67cf24b81e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -22,10 +22,14 @@ import warnings from collections import defaultdict + # Numpy support import numpy as np import tvm + +from tvm.relay.prelude import Prelude + from .. import analysis from .. import expr as _expr from .. import op as _op @@ -508,6 +512,69 @@ def _impl(inputs, attr, params): return _op.concatenate(inputs_reshaped, axis) return _impl +def _tensor_array(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get('dtype').name + tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) + return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0))) + return _impl + +def _tensor_array_scatter(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get('T').name + values_rank = len(inputs[2].type_annotation.shape) + unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) + unstack_function = prelude.get_var(unstack_name, dtype_str) + values = unstack_function(inputs[2]) + tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) + return tensor_array_scatter_func(inputs[0], inputs[1], values) + return _impl + +def _tensor_array_gather(): + def _impl(inputs, attr, params, prelude): + return prelude.tensor_array_gather(inputs[2], inputs[1]) + return _impl + +def _tensor_array_size(): + def _impl(inputs, attr, params, prelude): + return prelude.length(inputs[0]) + return _impl + +def _tensor_array_write(): + def _impl(inputs, attr, params, prelude): + input_rank = len(inputs[2].type_annotation.shape) + dtype = attr.get('T').name + + tensor_name = 'tensor{}'.format(input_rank) + tensor_func = prelude.get_var(tensor_name, dtype) + v = tensor_func(inputs[2]) + write_func = prelude.get_var('tensor_array_write', dtype) + + return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) + return _impl + +def _tensor_array_read(): + def _impl(inputs, attr, params, prelude): + read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name) + return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) + return _impl + +def _tensor_array_split(): + def _impl(inputs, attr, params, prelude): + input_rank = len(inputs[1].type_annotation.shape) + dtype_str = attr.get('T').name + v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) + lengths = _op.cast(inputs[2], 'int32') + split_var = prelude.get_var('tensor_array_split', dtype_str) + return split_var(inputs[0], v, lengths) + return _impl + +def _tensor_array_concat(): + def _impl(inputs, attr, params, prelude): + concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name) + return concat_func(inputs[1]) + return _impl + def _tile(): def _impl(inputs, attr, params): reps = _get_list_param(params, inputs.pop()) @@ -1313,6 +1380,14 @@ def _impl(inputs, attr, params): 'NotEqual' : _broadcast('not_equal'), 'OneHot' : _one_hot(), 'Pack' : _pack(), + 'TensorArrayV3' : _tensor_array(), + 'TensorArrayScatterV3' : _tensor_array_scatter(), + 'TensorArrayGatherV3' : _tensor_array_gather(), + 'TensorArraySizeV3' : _tensor_array_size(), + 'TensorArrayWriteV3' : _tensor_array_write(), + 'TensorArrayReadV3' : _tensor_array_read(), + 'TensorArraySplitV3' : _tensor_array_split(), + 'TensorArrayConcatV3' : _tensor_array_concat(), 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), 'Pow' : _elemwise('power'), @@ -1860,6 +1935,7 @@ def __init__(self): self._loops = {} self._branches = {} self._mod = _module.Module({}) + self._prelude = Prelude(self._mod) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -2335,7 +2411,11 @@ def _convert_operator(self, op_name, inputs, attrs, if op_name in identity_list: sym = get_relay_op(op_name)(*inputs, **attrs) elif op_name in convert_map: - sym = convert_map[op_name](inputs, attrs, self._params) + if 'TensorArray' in op_name: + sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) + else: + sym = convert_map[op_name](inputs, attrs, self._params) + elif op_name in convert_map_rnn: sym = self._convert_rnn_operator(op_name, inputs, attrs, self._params, graph, diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index da5804906269..188b3bb15956 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target): register_schedule("clip", schedule_elemwise) +@script +def _cast_shape_function(x): + out_ndim = len(x) + out = output_tensor((out_ndim,), "int64") + for i in const_range(out_ndim): + out[i] = x[i] + return out + +def cast_shape_func(attrs, inputs, out_ndims): + return [_cast_shape_function(*inputs)] + +@script +def _expand_dims_shape_func(x): + ndim = len(x.shape) + out = output_tensor((ndim+1,), "int64") + out[0] = int64(1) + for i in const_range(0, ndim): + out[i+1] = int64(x.shape[i]) + return out + +def expand_dims_shape_func(attrs, inputs, out_ndims): + return [_expand_dims_shape_func(*inputs)] + # shape func @script def _broadcast_shape_func(x, y, ndim): @@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim): def broadcast_shape_func(attrs, inputs, out_ndims): return [_broadcast_shape_func(*inputs, out_ndims[0])] +register_shape_func("expand_dims", False, expand_dims_shape_func) +register_shape_func("cast", False, cast_shape_func) + register_shape_func("add", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func) register_shape_func("multiply", False, broadcast_shape_func) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 803d8ef50db5..d27ffe512617 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -16,8 +16,513 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" +from .ty import GlobalTypeVar, TensorType, Any, scalar_type +from .expr import Var, Function, GlobalVar, If, const +from .op.tensor import add, subtract, equal +from .adt import Constructor, TypeData, Clause, Match +from .adt import PatternConstructor, PatternVar, PatternWildcard +from . import op from .module import Module +class TensorArrayOps(object): + """Contains tensor array related ops""" + + def __init__(self, prelude, dtype): + """Create tensor array ops registry""" + self.prelude = prelude + self.dtype = dtype + + def get_name(self, canonical): + """Get name corresponding to the caninical name""" + return self.prelude.get_name(canonical, self.dtype) + + def get_var(self, canonical): + """Get var corresponding to the caninical name""" + return self.prelude.get_var(canonical, self.dtype) + + def define_tensor_adt(self): + """Defines the dynamic tensor ADT, which is the container for tensors + with variable shapes.""" + tensor_type_name = self.get_name('tensor_t') + tensor_type_var = GlobalTypeVar(tensor_type_name) + setattr(self.prelude, tensor_type_name, tensor_type_var) + tensor0_type = TensorType([], self.dtype) + tensor1_type = TensorType([Any()], self.dtype) + tensor2_type = TensorType([Any(), Any()], self.dtype) + tensor3_type = TensorType([Any(), Any(), Any()], self.dtype) + tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype) + tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype) + tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype) + tensor_nil_name = self.get_name('tensor_nil') + tensor0_name = self.get_name('tensor0') + tensor1_name = self.get_name('tensor1') + tensor2_name = self.get_name('tensor2') + tensor3_name = self.get_name('tensor3') + tensor4_name = self.get_name('tensor4') + tensor5_name = self.get_name('tensor5') + tensor6_name = self.get_name('tensor6') + tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var) + tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var) + tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var) + tensor2_case = Constructor(tensor2_name, [tensor2_type], tensor_type_var) + tensor3_case = Constructor(tensor3_name, [tensor3_type], tensor_type_var) + tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var) + tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var) + tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var) + setattr(self.prelude, tensor_nil_name, tensor_nil_case) + setattr(self.prelude, tensor0_name, tensor0_case) + setattr(self.prelude, tensor1_name, tensor1_case) + setattr(self.prelude, tensor2_name, tensor2_case) + setattr(self.prelude, tensor3_name, tensor3_case) + setattr(self.prelude, tensor4_name, tensor4_case) + setattr(self.prelude, tensor5_name, tensor5_case) + setattr(self.prelude, tensor6_name, tensor6_case) + self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var, [], [tensor_nil_case, + tensor0_case, + tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case, + tensor6_case]) + + def define_tensor_take(self): + """Defines a function to return a range of tensor_t on axis 0. + tensor_take(t, lower, upper) : + tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t + """ + take_name = self.get_name("tensor_take") + take_var = GlobalVar(take_name) + setattr(self.prelude, take_name, take_var) + tensor_t = self.get_var('tensor_t') + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + tensor5_var = self.get_var('tensor5') + tensor6_var = self.get_var('tensor6') + t = Var('tensor', tensor_t()) + lower = Var('lower', scalar_type('int32')) + upper = Var('upper', scalar_type('int32')) + t1 = Var('t1') + t2 = Var('t2') + t3 = Var('t3') + t4 = Var('t4') + t5 = Var('t5') + t6 = Var('t6') + tensor1_case =\ + Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]), + tensor1_var(op.take(t1, op.arange(lower, upper, dtype='int32')))) + tensor2_case =\ + Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]), + tensor2_var(op.take(t2, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor3_case =\ + Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]), + tensor3_var(op.take(t3, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor4_case =\ + Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]), + tensor4_var(op.take(t4, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor5_case =\ + Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]), + tensor5_var(op.take(t5, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor6_case =\ + Clause(PatternConstructor(tensor6_var, [PatternVar(t6)]), + tensor6_var(op.take(t6, op.arange(lower, upper, dtype='int32'), axis=0))) + self.prelude.mod[take_var] =\ + Function([t, lower, upper], + Match(t, [tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case, + tensor6_case], False), + tensor_t(), []) + + def define_tensor_expand_dims(self): + """Defines a function to grow a tensor_t's rank by adding one dimension in front + of the original tensor_t. + tensor_expand_dims(t) : tensor_t -> tensor_t + """ + expand_dims_name = self.get_name("tensor_expand_dims") + expand_dims_var = GlobalVar(expand_dims_name) + setattr(self.prelude, expand_dims_name, expand_dims_var) + tensor_type_var = self.get_var('tensor_t') + x = Var("x", tensor_type_var()) + t0 = Var("t0") + t1 = Var("t1") + t2 = Var("t2") + t3 = Var("t3") + t4 = Var("t4") + t5 = Var("t5") + tensor0_var = self.get_var('tensor0') + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + tensor5_var = self.get_var('tensor5') + tensor6_var = self.get_var('tensor6') + tensor0_case = Clause(PatternConstructor(tensor0_var, [PatternVar(t0)]), + tensor1_var(op.expand_dims(t0, 0, 1))) + tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]), + tensor2_var(op.expand_dims(t1, 0, 1))) + tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]), + tensor3_var(op.expand_dims(t2, 0, 1))) + tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]), + tensor4_var(op.expand_dims(t3, 0, 1))) + tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]), + tensor5_var(op.expand_dims(t4, 0, 1))) + tensor5_case = Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]), + tensor6_var(op.expand_dims(t5, 0, 1))) + self.prelude.mod[expand_dims_var] =\ + Function([x], + Match(x, [tensor0_case, + tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case], False)) + + def define_tensor_concat(self): + """Defines a function to concatenate two tensor_t on the first axis + + tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t + """ + concat_name = self.get_name("tensor_concatenate") + concat_var = GlobalVar(concat_name) + setattr(self.prelude, concat_name, concat_var) + tensor_type_var = self.get_var('tensor_t') + x = Var("x", tensor_type_var()) + y = Var("y", tensor_type_var()) + + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + t11 = Var("t11") + t12 = Var("t12") + t21 = Var("t21") + t22 = Var("t22") + t31 = Var("t31") + t32 = Var("t32") + t41 = Var("t41") + t42 = Var("t42") + tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t11)]), + Match(y, [Clause(PatternConstructor(tensor1_var, [PatternVar(t12)]), + tensor1_var(op.concatenate([t11, t12], axis=0)))], + False)) + tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t21)]), + Match(y, [Clause(PatternConstructor(tensor2_var, [PatternVar(t22)]), + tensor2_var(op.concatenate([t21, t22], axis=0)))], + False)) + tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t31)]), + Match(y, [Clause(PatternConstructor(tensor3_var, [PatternVar(t32)]), + tensor3_var(op.concatenate([t31, t32], axis=0)))], + False)) + tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t41)]), + Match(y, [Clause(PatternConstructor(tensor4_var, [PatternVar(t42)]), + tensor4_var(op.concatenate([t41, t42], axis=0)))], + False)) + # op.concatenate does not support tensor with rank higher than 4 + self.prelude.mod[concat_var] =\ + Function([x, y], Match(x, [tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case], False)) + + def define_tensor_array(self): + """Defines a function to create a tensor array with size n. + tensor_array(n) : Tensor[(), int32] -> list[tensor_t] + """ + tensor_array_constructor_name = self.get_name("tensor_array") + tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name) + setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var) + tensor_nil_var = self.get_var('tensor_nil') + tensor_type_var = self.get_var('tensor_t') + n = Var("x", scalar_type('int32')) + body = If(equal(n, const(0)), + self.prelude.nil(), + self.prelude.cons(tensor_nil_var(), + tensor_array_constructor_var(subtract(n, const(1))))) + self.prelude.mod[tensor_array_constructor_var] = \ + Function([n], body, self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_read(self): + """Defines a function to get the head of a list. Assume the list has at least one + element. + + tensor_array_read(ta, n) : list[tensor_t] -> Tensor[(), int32] -> tensor_t + """ + read_name = self.get_name("tensor_array_read") + read_var = GlobalVar(read_name) + setattr(self.prelude, read_name, read_var) + tensor_type_var = self.get_var('tensor_t') + + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + n = Var("x", scalar_type('int32')) + self.prelude.mod[read_var] =\ + Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), []) + + def define_tensor_array_write(self): + """Defines a function to update a tensor array at index n with value v. + tensor_array_write(ta, n, v) : + list[tensor_t] -> Tensor[(), int32] -> tensor_t -> list[tensor_t] + """ + write_name = self.get_name("tensor_array_write") + write_var = GlobalVar(write_name) + setattr(self.prelude, write_name, write_var) + tensor_type_var = self.get_var('tensor_t') + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + n = Var("x", scalar_type('int32')) + v = Var("v", tensor_type_var()) + self.prelude.mod[write_var] =\ + Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v), + self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_unstack_tensor1(self): + """Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array. + tensor_array_unstack_tensor1(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor1_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + tensor_type_var = self.get_var('tensor_t') + tensor0_var = self.get_var('tensor0') + helper_body =\ + If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(tensor0_var(op.take(tensor, i)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), []) + unstack_name = self.get_name("tensor_array_unstack_tensor1") + unstack_var = GlobalVar(unstack_name) + setattr(self.prelude, unstack_name, unstack_var) + tensor1 = Var("tensor", TensorType([Any()], self.dtype)) + shape = op.shape_of(tensor1) + ndim = op.take(shape, const(0)) + self.prelude.mod[unstack_var] =\ + Function([tensor1], helper_var(const(0), ndim, tensor1), + self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_unstack_tensor2(self): + """Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array. + + tensor_array_unstack_tensor2(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor2_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any(), Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + + helper_body = If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(self.get_var('tensor1')(op.take(tensor, i, axis=0)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), []) + + tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2") + tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name) + setattr(self.prelude, tensor_array_unstack_tensor2_name, tensor_array_unstack_tensor2_var) + tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype)) + shape = op.shape_of(tensor2) + ndim = op.take(shape, const(0)) + self.prelude.mod[tensor_array_unstack_tensor2_var] =\ + Function([tensor2], helper_var(const(0), ndim, tensor2), + self.prelude.l(self.get_var('tensor_t')()), []) + + def define_tensor_array_scatter(self): + """Defines a function to scatter the values of a tensor_t in indices of a tensor array. + tensor_array_scatter(ta, indices, value) : + list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] + """ + tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") + tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name) + tensor_t = self.get_var('tensor_t') + ta = Var("ta", self.prelude.l(tensor_t())) + current = Var("current", scalar_type('int32')) + limit = Var("limit", scalar_type('int32')) + indices_ = Var('indices_', TensorType([Any()], 'int32')) + values_ = Var('values_', self.prelude.l(tensor_t())) + write_var = self.get_var('tensor_array_write') + read_var = self.get_var('tensor_array_read') + helper_body = If(equal(current, limit), + ta, + tensor_array_scatter_helper_var( + write_var(ta, op.take(indices_, current), + read_var(values_, current)), + add(current, const(1)), + limit, indices_, values_)) + self.prelude.mod[tensor_array_scatter_helper_var] =\ + Function([ta, current, limit, indices_, values_], + helper_body, self.prelude.l(tensor_t()), []) + tensor_array_scatter_name = self.get_name("tensor_array_scatter") + tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name) + setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + indices = Var('indices', TensorType([Any()], 'int32')) + values = Var('values', self.prelude.l(tensor_t())) + indices_shape = op.shape_of(indices) + limit = op.take(indices_shape, const(0)) + body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) + self.prelude.mod[tensor_array_scatter_var] =\ + Function([tensor_array, indices, values], body, self.prelude.l(tensor_t()), []) + + def define_tensor_array_split(self): + """Defines a function to split the values of a tensor_t into a tensor array. + tensor_array_split(ta, value, lengths) : + list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] + """ + tensor_t = self.get_var('tensor_t') + tensor_array_split_helper_name = self.get_name("ta_split_helper") + tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name) + setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var) + ta1 = Var("tensor_array", self.prelude.l(tensor_t())) + value1 = Var('value1', tensor_t()) + offset1 = Var('offset1', scalar_type('int32')) + current1 = Var('current1', scalar_type('int32')) + limit1 = Var('limit1', scalar_type('int32')) + lengths1 = Var('lengths', TensorType([Any()], 'int32')) + write_var = self.get_var('tensor_array_write') + take_var = self.get_var('tensor_take') + helper1_body = If(equal(current1, limit1), + ta1, + write_var( + tensor_array_split_helper_var( + ta1, + value1, + add(offset1, op.take(lengths1, current1)), + add(current1, const(1)), + limit1, + lengths1 + ), + current1, + take_var(value1, + offset1, + add(op.take(lengths1, current1), offset1)))) + self.prelude.mod[tensor_array_split_helper_var] = \ + Function([ta1, value1, offset1, current1, limit1, lengths1], + helper1_body, self.prelude.l(tensor_t()), []) + split_name = self.get_name("tensor_array_split") + split_var = GlobalVar(split_name) + setattr(self.prelude, split_name, split_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + value = Var('value', tensor_t()) + lengths = Var('lengths', TensorType([Any()], 'int32')) + lengths_shape = op.shape_of(lengths) + lengths_limit = op.take(lengths_shape, const(0)) + body = tensor_array_split_helper_var( + tensor_array, + value, + const(0), + const(0), + lengths_limit, + lengths) + self.prelude.mod[split_var] =\ + Function([tensor_array, value, lengths], body, self.prelude.l(tensor_t()), []) + + def define_tensor_array_concat(self): + """Defines a function to return the values in the tensor array as concatenated tensor_t. + tensor_array_concat(ta) : list[tensor_t] -> tensor_t + """ + concat_name = self.get_name("tensor_array_concat") + concat_var = GlobalVar(concat_name) + setattr(self.prelude, concat_name, concat_var) + tensor_concat_var = self.get_var('tensor_concatenate') + tensor_t = self.get_var('tensor_t') + tensor_nil_var = self.get_var('tensor_nil') + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + hd = Var("hd") + tl = Var("tl") + nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) + cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), + Match(tl, [ + Clause(PatternConstructor(self.prelude.nil), hd), + Clause(PatternWildcard(), + tensor_concat_var(hd, concat_var(tl))) + ], False)) + self.prelude.mod[concat_var] =\ + Function([tensor_array], + Match(tensor_array, [nil_case, cons_case], False), tensor_t(), []) + + def define_tensor_array_gather(self): + """Defines a function to return the selected values in a tensor array as tensor_t. + tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t + """ + helper_name = self.get_name("tensor_array_gather_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor_type_var = self.get_var('tensor_t') + stack_var = self.get_var('tensor_array_stack') + read_var = self.get_var('tensor_array_read') + ta = Var("ta", self.prelude.l(tensor_type_var())) + accu = Var("accu", self.prelude.l(tensor_type_var())) + current = Var("current", scalar_type('int32')) + limit = Var("limit", scalar_type('int32')) + indices_ = Var('indices_', TensorType([Any()], 'int32')) + helper_body =\ + If(equal(current, const(0)), + stack_var(accu), + helper_var( + ta, + self.prelude.cons( + read_var( + ta, op.take(indices_, subtract(current, const(1)))), accu), + subtract(current, const(1)), + limit, indices_)) + self.prelude.mod[helper_var] = \ + Function([ta, accu, current, limit, indices_], helper_body, tensor_type_var(), []) + gather_name = self.get_name("tensor_array_gather") + gather_var = GlobalVar(gather_name) + setattr(self.prelude, gather_name, gather_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + indices = Var('indices', TensorType([Any()], 'int32')) + indices_shape = op.shape_of(indices) + limit = op.take(indices_shape, const(0)) + body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) + self.prelude.mod[gather_var] =\ + Function([tensor_array, indices], body, tensor_type_var(), []) + + def define_tensor_array_stack(self): + """Defines a function to get the values in the tensor array as a stack tensor_t. + tensor_array_stack(l) : list[tensor_t] -> tensor_t + """ + stack_name = self.get_name("tensor_array_stack") + stack_var = GlobalVar(stack_name) + setattr(self.prelude, stack_name, stack_var) + tensor_type_var = self.get_var('tensor_t') + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + expand_dims_var = self.get_var('tensor_expand_dims') + concat_var = self.get_var('tensor_concatenate') + tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) + tensors = self.prelude.foldl(concat_var, + self.prelude.hd(tensor_array_expand_dims), + self.prelude.tl(tensor_array_expand_dims)) + self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), []) + + def register(self): + """Register all tensor array ops in Prelude""" + self.define_tensor_adt() + self.define_tensor_take() + self.define_tensor_expand_dims() + self.define_tensor_concat() + self.define_tensor_array() + self.define_tensor_array_read() + self.define_tensor_array_write() + self.define_tensor_array_unstack_tensor1() + self.define_tensor_array_unstack_tensor2() + self.define_tensor_array_scatter() + self.define_tensor_array_split() + self.define_tensor_array_concat() + self.define_tensor_array_stack() + # TODO(wweic): Gather fails in PartialEvaluate + # self.define_tensor_array_gather() + class Prelude: """Contains standard definitions.""" @@ -27,6 +532,17 @@ def __init__(self, mod=None): self.mod = mod self.load_prelude() + def get_name(self, canonical, dtype): + """Get name corresponding to the canonical name""" + if canonical == 'tensor_t': + return 'tensor_{}_t'.format(dtype) + return "{}_{}".format(canonical, dtype) + + def get_var(self, canonical, dtype): + """Get var corresponding to the canonical name""" + name = self.get_name(canonical, dtype) + return getattr(self, name) + def load_prelude(self): """Parses the Prelude from Relay's text format into a module.""" # TODO(@jroesch): we should remove this helper when we port over prelude @@ -74,3 +590,7 @@ def load_prelude(self): ] for global_def in GLOBAL_DEFS: setattr(self, global_def, self.mod.get_global_var(global_def)) + + for dtype in ['float32', 'int32']: + tensor_array_ops = TensorArrayOps(self, dtype) + tensor_array_ops.register() diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index d661be73ad02..d7b59922b89d 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -203,8 +203,12 @@ def convert_module(self): for var, func in self.mod.functions.items(): # optimize the definition so any operators used are lowered opt_func = self.optimize(func) - converted_func, _ = self.convert_func_node(opt_func, var) - defs.append(converted_func) + try: + converted_func, _ = self.convert_func_node(opt_func, var) + defs.append(converted_func) + except TypeError: + # TODO(wweic): fix conversion for Any + pass return defs diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 21f71af4eb8c..f85283094e91 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -309,7 +309,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.push_back(instr.alloc_tensor_reg.shape_register); // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); + fields.push_back(dtype.code); + fields.push_back(dtype.bits); + fields.push_back(dtype.lanes); fields.push_back(instr.dst); break; } diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c2cbbff24173..3321d71a2cb8 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -60,13 +60,19 @@ def vmobj_to_list(o): result.append(vmobj_to_list(f)) return result elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): - if o.constructor.name_hint == 'cons': + if o.constructor.name_hint == 'Cons': tl = vmobj_to_list(o.fields[1]) hd = vmobj_to_list(o.fields[0]) hd.extend(tl) return hd - elif o.constructor.name_hint == 'nil': + elif o.constructor.name_hint == 'Nil': return [] + elif 'tensor_nil' in o.constructor.name_hint: + return [0] + elif 'tensor' in o.constructor.name_hint: + return [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): return [o.data.asnumpy()] else: @@ -77,14 +83,11 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, """ Generic function to compile on relay and execute on tvm """ input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) - layout = None if target == "cuda": layout = "NCHW" target_host = None - shape_dict = {e: i.shape for e, i in zip(input_node, input_data)} - mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, @@ -581,6 +584,111 @@ def test_forward_squeeze(): _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5]) _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1]) +def test_tensor_array_constructor(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) + t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) + ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) + ta2 = ta1.write(0, t) + ta3 = ta2.write(1, t2) + out = ta3.read(0) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') + run('float32') + run('int32') + +def test_tensor_array_scatter(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) + indices = tf.constant([2, 1, 0]) + ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=False, dynamic_size=False) + ta2 = ta1.scatter(indices, t) + out0 = ta2.read(0) + out1 = ta2.read(1) + out2 = ta2.read(2) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') + run('float32') + run('int32') + +# TODO(wweic): Fix gather issue with PartialEvaluate +# def test_tensor_array_gather(): +# with tf.Graph().as_default(): +# dtype = 'float32' +# t = tf.constant([[1.0], [2.0], [3.0]]) +# scatter_indices = tf.constant([2, 1, 0]) +# gather_indices = tf.constant([1, 2]) +# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False) +# ta2 = ta1.scatter(scatter_indices, t) +# t1 = ta2.gather(gather_indices) +# g = tf.get_default_graph() +# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug') + +def test_tensor_array_split(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) + split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) + ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False) + ta2 = ta1.split(t, split_length) + out0 = ta2.read(0) + out1 = ta2.read(1) + out2 = ta2.read(2) + out3 = ta2.read(3) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug') + run('float32') + run('int32') + +def test_tensor_array_concat(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) + split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) + ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False) + ta2 = ta1.split(t, split_length) + t = ta2.concat() + compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug') + run('float32') + run('int32') + +def test_tensor_array_size(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) + out = ta1.size() + g = tf.get_default_graph() + compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') + run('float32') + run('int32') + ####################################################################### # ConcatV2 # -------- diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 7be7c75dfe64..390d3cd9f3c4 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -21,6 +21,8 @@ from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr +import numpy as np + mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) @@ -683,6 +685,146 @@ def test_iterate(): res = intrp.evaluate(relay.Function([], expr)()) assert count(res) == 12 +def test_tensor_expand_dims(): + def run(dtype): + x = relay.var('x') + mod = relay.Module() + p = Prelude(mod) + expand_dims_func = p.get_var('tensor_expand_dims', dtype) + tensor1 = p.get_var('tensor1', dtype) + mod["main"] = relay.Function([x], expand_dims_func(tensor1(x))) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + x_np = np.random.uniform(size=(1,)).astype(dtype) + result = ex.evaluate()(x_np) + got = vmobj_to_list(result) + expected = [np.expand_dims(x_np, axis=0)] + tvm.testing.assert_allclose(expected, got) + run('float32') + run('int32') + +def test_tensor_array_constructor(): + def run(dtype): + x = relay.var('x') + mod = relay.Module() + p = Prelude(mod) + tensor_array = p.get_var('tensor_array', dtype) + mod["main"] = relay.Function([x], tensor_array(x)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(5) + got = vmobj_to_list(result) + expected = np.array([0, 0, 0, 0, 0]) + tvm.testing.assert_allclose(expected, got) + run('float32') + run('int32') + +def test_tensor_array_read(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + l = relay.var('l') + i = relay.var('i') + read_func = p.get_var('tensor_array_read', dtype) + tensor_array = p.get_var('tensor_array', dtype) + mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(10, 5) + got = vmobj_to_list(result) + expected = [0] + tvm.testing.assert_allclose(expected, got) + run('float32') + run('int32') + +def vmobj_to_list(o): + if isinstance(o, tvm.relay.backend.vmobj.Tensor): + return [o.asnumpy().tolist()] + elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): + return [o.asnumpy()] + elif isinstance(o, tvm.relay.backend.vmobj.Datatype): + result = [] + for f in o: + result.extend(vmobj_to_list(f)) + return result + elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): + if o.constructor.name_hint == 'Cons': + tl = vmobj_to_list(o.fields[1]) + hd = vmobj_to_list(o.fields[0]) + hd.extend(tl) + return hd + elif o.constructor.name_hint == 'Nil': + return [] + elif 'tensor_nil' in o.constructor.name_hint: + return [0] + elif 'tensor' in o.constructor.name_hint: + return [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + +def test_tensor_array_stack(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + tensor_array = p.get_var('tensor_array', dtype) + tensor1 = p.get_var('tensor1', dtype) + write = p.get_var('tensor_array_write', dtype) + stack = p.get_var('tensor_array_stack', dtype) + l = relay.var('l') + v = relay.var('v') + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor1(v)) + tensor_array4 = stack(tensor_array3) + mod["main"] = relay.Function([v], tensor_array4) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + t = np.random.uniform(size=(1,)).astype(dtype) + result = ex.evaluate()(t) + res = vmobj_to_list(result) + expected = [np.stack([t, t, t])] + tvm.testing.assert_allclose(expected, res) + run('float32') + run('int32') + +def test_tensor_array_unstack(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype) + v = relay.var('v') + mod["main"] = relay.Function([v], unstack_tensor1(v)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + t = np.random.uniform(size=(1,)).astype(dtype) + result = ex.evaluate()(t) + res = vmobj_to_list(result) + tvm.testing.assert_allclose(t, res) + run('float32') + run('int32') + +def test_tensor_take(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + take = p.get_var('tensor_take', dtype) + tensor2 = p.get_var('tensor2', dtype) + v = relay.var('v') + lower = relay.var('lower') + upper = relay.var('upper') + mod["main"] = relay.Function([v, lower, upper], take(tensor2(v), lower, upper)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + t = np.random.uniform(size=(10, 10)).astype(dtype) + result = ex.evaluate()(t, 2, 5) + res = vmobj_to_list(result) + expected = [np.take(t, range(2, 5), axis=0)] + tvm.testing.assert_allclose(expected, res) + run('float32') + run('int32') if __name__ == "__main__": test_nat_constructor() @@ -707,3 +849,9 @@ def test_iterate(): test_size() test_compose() test_iterate() + + test_tensor_expand_dims() + test_tensor_array_constructor() + test_tensor_array_read() + test_tensor_array_stack() + test_tensor_array_unstack() diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_feature.py index 8f0e90de0315..64eda9d04e7c 100644 --- a/tests/python/relay/test_feature.py +++ b/tests/python/relay/test_feature.py @@ -38,7 +38,8 @@ def test_prelude(): Feature.fLet, Feature.fIf, Feature.fConstructor, - Feature.fMatch + Feature.fMatch, + Feature.fGraph ]) From cb5277f979e59f7a29bf9d1987a381a1c5143a3e Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Fri, 18 Oct 2019 08:19:32 -0700 Subject: [PATCH 03/59] Fix typo (#4144) --- src/pass/lower_tvm_builtin.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index 69618985d50c..79329cbe717f 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -230,7 +230,7 @@ class BuiltinLower : public IRMutator { cast(Int(32), device_type_))); return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr); } - // call packled. + // call packed. Expr MakeCallPacked(const Call* op, const Expr& e) { size_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; From 86d445a9ccfb107ec33fd0d6d84b847473b3e038 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 18 Oct 2019 09:49:37 -0700 Subject: [PATCH 04/59] [CI] Pin NNPack pthreadtools version (#4152) --- docker/install/ubuntu_install_nnpack.sh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docker/install/ubuntu_install_nnpack.sh b/docker/install/ubuntu_install_nnpack.sh index 4f45f130e2e5..dc51fc28d492 100755 --- a/docker/install/ubuntu_install_nnpack.sh +++ b/docker/install/ubuntu_install_nnpack.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,11 +22,14 @@ set -o pipefail apt-get update && apt-get install -y --no-install-recommends git cmake -# TODO: specific tag? git clone https://github.com/Maratyszcza/NNPACK NNPACK +git clone https://github.com/Maratyszcza/pthreadpool NNPACK/pthreadpool + +# Use specific versioning tag. (cd NNPACK && git checkout 1e005b0c2) +(cd NNPACK/pthreadpool && git checkout 13da0b4c) mkdir -p NNPACK/build cd NNPACK/build -cmake -DCMAKE_INSTALL_PREFIX:PATH=. -DNNPACK_INFERENCE_ONLY=OFF -DNNPACK_CONVOLUTION_ONLY=OFF -DNNPACK_BUILD_TESTS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && make -j4 && make install +cmake -DCMAKE_INSTALL_PREFIX:PATH=. -DNNPACK_INFERENCE_ONLY=OFF -DNNPACK_CONVOLUTION_ONLY=OFF -DNNPACK_BUILD_TESTS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DPTHREADPOOL_SOURCE_DIR=pthreadpool .. && make -j4 && make install cd - From c67bb94c210690cfc50019e19a807bfb72f71b82 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 18 Oct 2019 10:51:15 -0700 Subject: [PATCH 05/59] [QNN][TFLite] Parsing QNN Add op. Adding MobilenetV2. (#4142) --- python/tvm/relay/frontend/tflite.py | 66 +++++++++++++++++++- tests/python/frontend/tflite/test_forward.py | 22 ++++++- 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 35bc85e09fdd..b08dd6bf94e0 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -224,6 +224,18 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor): return lhs_tensor.qnn_params['scale'] == rhs_tensor.qnn_params['scale'] and \ lhs_tensor.qnn_params['zero_point'] == rhs_tensor.qnn_params['zero_point'] + def is_quantized(self, op): + """Check if an input tensor is quantized.""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + first_tensor = input_tensors[0] + return first_tensor.qnn_params is not None + def convert_conv2d(self, op): """Convert TFLite conv2d""" return self.convert_conv(op, "conv2d") @@ -498,7 +510,25 @@ def _convert_elemwise(self, relay_op, op): rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), dtype=rhs_type_str) - out = relay_op(lhs_expr, rhs_expr) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + # If quantized, extracts qnn params and call QNN add operator. + if lhs_tensor.qnn_params: + assert rhs_tensor.qnn_params, "Both tensors should be quantized." + assert output_tensor.qnn_params, "Output tensor should be quantized." + out = relay_op(lhs=lhs_expr, + rhs=rhs_expr, + lhs_scale=lhs_tensor.qnn_params['scale'], + lhs_zero_point=lhs_tensor.qnn_params['zero_point'], + rhs_scale=rhs_tensor.qnn_params['scale'], + rhs_zero_point=rhs_tensor.qnn_params['zero_point'], + output_scale=output_tensor.qnn_params['scale'], + output_zero_point=output_tensor.qnn_params['zero_point']) + else: + out = relay_op(lhs_expr, rhs_expr) # Options (fused_activation_function) options = None @@ -517,36 +547,70 @@ def _convert_elemwise(self, relay_op, op): fused_activation_fn = options.FusedActivationFunction() # if we have activation fn if fused_activation_fn != ActivationFunctionType.NONE: + if output_tensor.qnn_params: + raise tvm.error.OpNotImplemented( + 'Elemwise operators with fused activation are not supported yet.') out = self.convert_fused_activation_function(out, fused_activation_fn) return out def convert_add(self, op): """Convert TFLite ADD""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + return self._convert_elemwise(_qnn.op.add, op) return self._convert_elemwise(_op.add, op) def convert_sub(self, op): """Convert TFLite SUB""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized sub operator is not supported yet.') return self._convert_elemwise(_op.subtract, op) def convert_mul(self, op): """Convert TFLite MUL""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized mul operator is not supported yet.') return self._convert_elemwise(_op.multiply, op) def convert_div(self, op): """Convert TFLite DIV""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized div operator is not supported yet.') return self._convert_elemwise(_op.divide, op) def convert_pow(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized pow operator is not supported yet.') return self._convert_elemwise(_op.power, op) def convert_maximum(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized maximum operator is not supported yet.') return self._convert_elemwise(_op.maximum, op) def convert_minimum(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized minimum operator is not supported yet.') return self._convert_elemwise(_op.minimum, op) def convert_greater(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized greater operator is not supported yet.') return self._convert_elemwise(_op.greater, op) def convert_zeros_like(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index a71a24ee0a4f..29b0c87c5b32 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1037,6 +1037,26 @@ def test_forward_qnn_mobilenet_v1_net(): tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) +def test_forward_qnn_mobilenet_v2_net(): + """Test the Quantized TFLite Mobilenet V2 model.""" + # MobilenetV2 + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz", + "mobilenet_v2_1.0_224_quant.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + # Checking the labels because the requantize implementation is different between TFLite and + # Relay. This cause final output numbers to mismatch. So, testing accuracy via labels. + np.random.seed(0) + data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8') + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + ####################################################################### # SSD Mobilenet # ------------- @@ -1111,6 +1131,6 @@ def test_forward_ssd_mobilenet_v1(): test_forward_ssd_mobilenet_v1() # End to End quantized - # TODO - MobilenetV2 fails for now. Remove when fixed. test_forward_qnn_inception_v1_net() test_forward_qnn_mobilenet_v1_net() + test_forward_qnn_mobilenet_v2_net() From 7aae836007d8c1e0f8ef3f3a291f97d8231579c1 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 18 Oct 2019 15:22:37 -0700 Subject: [PATCH 06/59] Add lift_if_then_else pass (#3865) * Add LiftIfThenElse pass * Add more comments * Rename and refactor * Add description for internal data structure * Rename a test * Minor change * Address comments * Improve update_for --- include/tvm/ir_pass.h | 7 + src/api/api_pass.cc | 1 + src/pass/hoist_if_then_else.cc | 424 ++++++++++++++++++++ tests/python/unittest/test_pass_hoist_if.py | 185 +++++++++ 4 files changed, 617 insertions(+) create mode 100644 src/pass/hoist_if_then_else.cc create mode 100644 tests/python/unittest/test_pass_hoist_if.py diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 5ac71fdce47b..03078b8be41f 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -377,6 +377,13 @@ Stmt LowerStorageAccessInfo(Stmt stmt); */ Stmt DecorateDeviceScope(Stmt stmt); +/*! + * \brief Loop invariant code motion which locates and hoists if statements. + * \param stmt The stmt to do if statement hoisting. + * \return Transformed stmt. + */ +Stmt HoistIfThenElse(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 25cd5838385f..d2352496c2b4 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -160,5 +160,6 @@ REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); +REGISTER_PASS(HoistIfThenElse); } // namespace ir } // namespace tvm diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc new file mode 100644 index 000000000000..bbdb609e9a08 --- /dev/null +++ b/src/pass/hoist_if_then_else.cc @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file hoist_if_then_else.cc + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../arithmetic/int_set.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +using HoistMap = std::unordered_map>; +using VarMap = std::unordered_map>; + +/* + * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant. + * For example, given the following block: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt. + * Then we hoist IfThenElse stmt by one For stmt each step: + * + * Step 1: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * if (likely(i*2 < 4)) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Step 2: + * for (i = 0; i < 3; i++) + * if (likely(i*2 < 4)) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * In this pass, we only continue detecting possible hoisting chance when visiting For, + * IfThenElse or AttrStmt Node. For example, for the following block: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * A[i + j] = A[i + j] - 1 + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Only the For with k variable will be considered and the resulting stmt would be: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * A[i + j] = A[i + j] - 1 + * if (likely(i*2 < 4)) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following + * block won't be optimized: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * if (likely(j > 2)) + * A[i+j+k] = B[i+j+k] + * + */ +class IfThenElseHoist { + public: + Stmt VisitAndMutate(const Stmt& stmt) { + SelectCandidates(stmt); + LocateTopFor(); + return PostOrderMutate(stmt); + } + + private: + void SelectCandidates(const Stmt& stmt); + void LocateTopFor(); + Stmt PostOrderMutate(const Stmt& stmt); + size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt); + Stmt HoistIf(const Stmt& if_stmt); + + // Map of all For nodes to all child IfThenElse nodes. + HoistMap for2if_map_; + // Map of all IfThenElse nodes to all For nodes which are loop invariant. + HoistMap if2for_map_; + // Map of highest loop invariant For to child IfThenElse. + HoistMap top_for_var_map_; + // Map of original For to list of update For nodes. + HoistMap for_tracking_map_; + // Map of all IfThenElse nodes to condition variable nodes. + VarMap cond_var_map_; + // List of For nodes added in post order DFS visiting. + std::vector ordered_for_list_; +}; + +// Check whether a given IfThenElse stmt is the first one appearing +// in a For stmt. +bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { + std::vector if_node_list; + const For* for_node = for_stmt.as(); + CHECK(for_node); + CHECK(if_stmt.as()); + + PostOrderVisit(for_node->body, [&](const NodeRef& node) { + if (node.as()) { + if_node_list.push_back(node.get()); + } + }); + return if_node_list.empty() ? false : if_stmt.get() == if_node_list.back(); +} + +// Update upper level For node when current For node is modified. +// With this function we only need to visit and mutate top level For node +// in the main VisitAndMutate function. +Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { + const Node* top_for_node; + const For* parent_for_node = parent_for_stmt.as(); + CHECK(parent_for_node); + CHECK(new_if_stmt.as()); + + PostOrderVisit(parent_for_node->body, [&](const NodeRef& node) { + if (node.as()) { + top_for_node = node.get(); + } + }); + + PackedFunc replace_target_for = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& current_for = args[0]; + if (current_for.get() == top_for_node) { + *ret = new_if_stmt; + } + }); + + return IRTransform(parent_for_stmt, nullptr, replace_target_for, + {Expr("For")}); +} + +// Remove IfThenElse node from a For node. +// A pair of For nodes will be generated. +std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { + Stmt then_for; + Stmt else_for; + CHECK(if_stmt.as()); + + PackedFunc replace_then_case = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->then_case; + } + }); + + PackedFunc replace_else_case = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->else_case; + } + }); + + then_for = IRTransform(for_stmt, nullptr, replace_then_case, + {Expr("IfThenElse")}); + if (if_stmt.as()->else_case) { + else_for = IRTransform(for_stmt, nullptr, replace_else_case, + {Expr("IfThenElse")}); + } + + return std::make_pair(then_for, else_for); +} + +// Locate all For nodes and capture child IfThenElse nodes. +void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { + PostOrderVisit(stmt, [&](const NodeRef& node){ + const For* for_node = node.as(); + if (!for_node) return; + + std::queue tracker; + tracker.push(for_node->body); + Stmt for_stmt = Downcast(node); + for2if_map_.insert({for_stmt.get(), std::vector()}); + while (!tracker.empty()) { + Stmt head = tracker.front(); + tracker.pop(); + if (head->is_type()) { + for (const auto& if_stmt : for2if_map_.at(head.get())) { + for2if_map_[for_stmt.get()].push_back(if_stmt); + } + } else if (head->is_type()) { + const AttrStmt* attr_node = head.as(); + tracker.push(attr_node->body); + } else if (head->is_type()) { + for2if_map_[for_stmt.get()].push_back(head); + const IfThenElse* if_node = head.as(); + tracker.push(if_node->then_case); + if (if_node->else_case) { + tracker.push(if_node->else_case); + } + + // Record condition variables. + if (!cond_var_map_.count(head.get())) { + std::unordered_set new_var_set; + cond_var_map_.insert({head.get(), new_var_set}); + PostOrderVisit(if_node->condition, [&](const NodeRef& cond_node) { + if (cond_node.as()) { + cond_var_map_[head.get()].insert(cond_node.get()); + } + }); + } + } else { + continue; + } + } + ordered_for_list_.emplace_back(Downcast(node)); + }); +} + +// For each IfThenElse node, find the highest For node which +// meets loop invariant condition. +void IfThenElseHoist::LocateTopFor() { + std::unordered_map if_position_map; + std::unordered_set top_for_var_set; + + // Create IfThenElse -> For map. + for (const Stmt& for_stmt : ordered_for_list_) { + std::vector if_list = for2if_map_[for_stmt.get()]; + const For* for_node = for_stmt.as(); + CHECK(for_node); + top_for_var_map_.insert({for_node->loop_var.get(), if_list}); + for (const Stmt& if_stmt : if_list) { + const Node* if_node = if_stmt.get(); + if2for_map_[if_node].push_back(for_stmt); + } + } + + // Locate the highest For node which is loop invariant. + for (const auto& item : if2for_map_) { + Stmt top_for; + const Node* if_stmt = item.first; + std::vector for_list = item.second; + for (size_t i = 0; i < for_list.size(); ++i) { + const Stmt& for_stmt = for_list.at(i); + const For* for_node = for_stmt.as(); + CHECK(for_node); + std::vector new_for_list{for_stmt}; + for_tracking_map_.insert({for_stmt.get(), new_for_list}); + if (cond_var_map_[if_stmt] + .count(for_node->loop_var.get())) { + std::vector updated_for_list(for_list.begin(), + for_list.begin() + i); + if2for_map_[if_stmt] = updated_for_list; + break; + } else { + top_for = for_stmt; + } + } + if (top_for.as()) { + if_position_map.insert({if_stmt, top_for}); + } + } + + for (const auto& item : if_position_map) { + top_for_var_set.insert(item.second.as()->loop_var.get()); + } + + std::vector removed_for_var_list; + for (const auto& item : top_for_var_map_) { + const Node* top_for_var = item.first; + std::vector if_list = item.second; + if (!top_for_var_set.count(top_for_var)) { + removed_for_var_list.push_back(top_for_var); + } else { + std::vector actual_if_list; + for (const Stmt& if_stmt : if_list) { + if (if_position_map.count(if_stmt.get())) { + actual_if_list.push_back(if_stmt); + } + } + top_for_var_map_[top_for_var] = actual_if_list; + } + } + for (const Node* top_for_var : removed_for_var_list) { + top_for_var_map_.erase(top_for_var); + } +} + +// When we try to mutate a For node, some child For nodes can have already +// been mutated. This function is to get the updated For node and further +// hoisting can be done based on this new node. +// We keep all For nodes tracing in for_tracking_map_. When we get a +// hoisted IfThenElse, we match it with tracing For nodes to pick +// the updated one. +size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, + const Stmt& if_stmt) { + std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; + size_t updated_for_idx = 0; + for (size_t i = 0; i < tracked_for_list.size(); ++i) { + const Stmt& current_for = + tracked_for_list.at(tracked_for_list.size() - 1 - i); + if (is_first_if(current_for, if_stmt)) { + updated_for_idx = tracked_for_list.size() - 1 - i; + break; + } + } + return updated_for_idx; +} + +// Hoist an IfThenElse node as high as possible. +// This function iterates on all candidate For nodes. For each For node, +// it first removes IfThenElse nodes. Then it generates a new IfThenElse +// node using mutated For nodes. +Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { + Stmt new_if = if_stmt; + + for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { + const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i); + size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if); + const Stmt& updated_for_node = + for_tracking_map_[for_stmt.get()].at(updated_for_idx); + auto generated_for_pair = RemoveIf(updated_for_node, new_if); + const Stmt& then_for = generated_for_pair.first; + const Stmt& else_for = generated_for_pair.second;; + for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for; + + if (else_for.get()) { + for_tracking_map_[for_stmt.get()].push_back(else_for); + } + + const IfThenElse* new_if_node = new_if.as(); + CHECK(new_if_node); + new_if = IfThenElse::make(new_if_node->condition, then_for, else_for); + if (i < if2for_map_[if_stmt.get()].size() - 1) { + const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); + const Stmt& actual_next_for = + for_tracking_map_[original_next_for.get()].at(updated_for_idx); + Stmt update_for_stmt = update_for(actual_next_for, new_if); + + for_tracking_map_[original_next_for.get()]. + at(updated_for_idx) = update_for_stmt; + } + } + return new_if; +} + +// Mutate For nodes in post order DFS manner. +Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { + PackedFunc replace_top_for = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& current_for = args[0]; + const For* for_node = current_for.as(); + if (!for_node) return; + + if (top_for_var_map_.count(for_node->loop_var.get())) { + std::vector new_if_list; + for (const Stmt& if_stmt : + top_for_var_map_[for_node->loop_var.get()]) { + new_if_list.emplace_back(HoistIf(if_stmt)); + } + + const IfThenElse* next_if_node; + const IfThenElse* current_if_node = + new_if_list.back().as(); + Stmt new_for = Stmt(); + for (size_t i = new_if_list.size() - 1; i > 0; --i) { + CHECK(current_if_node); + const Stmt current_if_stmt = + IfThenElse::make(current_if_node->condition, + current_if_node->then_case, + current_if_node->else_case); + next_if_node = new_if_list[i - 1].as(); + CHECK(next_if_node); + new_for = IfThenElse::make(next_if_node->condition, current_if_stmt, + next_if_node->else_case); + current_if_node = new_for.as(); + } + + if (!new_for.get()) { + const IfThenElse* first_if_node = new_if_list[0].as(); + CHECK(first_if_node); + new_for = IfThenElse::make(first_if_node->condition, + first_if_node->then_case, + first_if_node->else_case); + } + *ret = new_for; + } + }); + return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")}); +} + +Stmt HoistIfThenElse(Stmt stmt) { + return IfThenElseHoist().VisitAndMutate(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/tests/python/unittest/test_pass_hoist_if.py b/tests/python/unittest/test_pass_hoist_if.py new file mode 100644 index 000000000000..4a28cf6b318a --- /dev/null +++ b/tests/python/unittest/test_pass_hoist_if.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm + + +var_list = [] + +def verify_structure(stmt, expected_struct): + node_dict = {} + struct = {} + def _extract_vars(op): + global var_list + if isinstance(op, tvm.expr.Var): + var_list.append(op.name) + + def _visit(op): + key = op + if isinstance(op, tvm.stmt.IfThenElse): + global var_list + tvm.ir_pass.PostOrderVisit(op.condition, _extract_vars) + val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))] + var_list.clear() + elif isinstance(op, tvm.stmt.For): + val = [(op.body,), ("For", op.loop_var.name)] + elif isinstance(op, tvm.stmt.AttrStmt): + val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))] + else: + return + node_dict[key] = val + + tvm.ir_pass.PostOrderVisit(stmt, _visit) + for key, val in node_dict.items(): + struct[val[1]] = tuple(node_dict[child][1] if child in node_dict + else None for child in val[0]) + + assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \ + % (expected_struct, struct) + var_list.clear() + +def test_basic(): + ib = tvm.ir_builder.create() + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.make.Evaluate(m)) + with ib.else_scope(): + ib.emit(tvm.make.Evaluate(n)) + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), + ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_no_else(): + ib = tvm.ir_builder.create() + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.make.Evaluate(m)) + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), + ('IfThenElse', ('i',)): (('For', 'j'), None), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_attr_stmt(): + ib = tvm.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.5 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0 + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')), + ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),), + ('AttrStmt', 'thread_extent', 64): (('For', 'i'),), + ('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)} + verify_structure(new_stmt, expected_struct) + +def test_nested_for(): + ib = tvm.ir_builder.create() + data = ib.pointer("float32", name="data") + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.if_scope(i >= 3): + data[i * 3 + j] = data[i * 3 + j] + 0.5 + with ib.for_range(0, 15, "k") as k: + with ib.for_range(0, 20, "l") as l: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 + with ib.else_scope(): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),), + ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_if_block(): + ib = tvm.ir_builder.create() + data = ib.pointer("float32", name="data") + n = tvm.var("n") + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.if_scope(i >= 3): + data[i * 3 + j] = data[i * 3 + j] + 0.5 + with ib.for_range(0, 15, "k") as k: + with ib.for_range(0, 20, "l") as l: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 + with ib.else_scope(): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 + with ib.if_scope(j <5): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1 + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 15, "k") as k: + with ib.if_scope(n >= 3): + data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6 + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None), + ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),), + ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),), + ('IfThenElse', ('n',)): (('For', 'j'), None)} + verify_structure(new_stmt, expected_struct) + + +if __name__ == "__main__": + test_basic() + test_no_else() + test_attr_stmt() + test_nested_for() + test_if_block() From 909900fc3ae3b366aa9ad3e4d33031240ac01883 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 18 Oct 2019 16:15:52 -0700 Subject: [PATCH 07/59] [CI] Update cpu docker (#4153) --- Jenkinsfile | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 4b9ae9cafd88..c140d9c58ad2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -38,9 +38,15 @@ // - Tag the new version as the lates // - Periodically cleanup the old versions on local workers // + +// Hashtag in the source to build current CI docker builds +// +// - ci-cpu:v0.54: e7c88a99f830de30814df14eaa980547ecbd61c1 +// + ci_lint = "tvmai/ci-lint:v0.51" ci_gpu = "tvmai/ci-gpu:v0.54" -ci_cpu = "tvmai/ci-cpu:v0.52" +ci_cpu = "tvmai/ci-cpu:v0.54" ci_i386 = "tvmai/ci-i386:v0.52" // tvm libraries From 6f5d9f206a2d6865efc48d1d61911e20b80a6c94 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Sat, 19 Oct 2019 21:57:50 -0700 Subject: [PATCH 08/59] [Refactor] Rename Datatype to ADT (#4156) We think it will reduce the confusion with the meaning. https://discuss.tvm.ai/t/discuss-consider-rename-vm-datatype/4339 --- docs/dev/virtual_machine.rst | 10 +++---- include/tvm/runtime/object.h | 2 +- include/tvm/runtime/vm.h | 24 ++++++++-------- python/tvm/relay/backend/vm.py | 2 +- python/tvm/relay/backend/vmobj.py | 20 ++++++------- src/relay/backend/vm/compiler.cc | 8 +++--- src/runtime/vm/executable.cc | 6 ++-- src/runtime/vm/object.cc | 28 +++++++++---------- src/runtime/vm/vm.cc | 24 ++++++++-------- .../frontend/tensorflow/test_forward.py | 2 +- tests/python/relay/test_adt.py | 2 +- tests/python/relay/test_vm.py | 2 +- tests/python/relay/test_vm_object.py | 8 +++--- 13 files changed, 69 insertions(+), 69 deletions(-) diff --git a/docs/dev/virtual_machine.rst b/docs/dev/virtual_machine.rst index 2791ee71177e..cb08cc14e56e 100644 --- a/docs/dev/virtual_machine.rst +++ b/docs/dev/virtual_machine.rst @@ -121,7 +121,7 @@ AllocTensor Allocate a tensor value of the appropriate shape (stored in `shape_register`) and `dtype`. The result is saved to register `dst`. -AllocDatatype +AllocADT ^^^^^^^^^^^^^ **Arguments**: :: @@ -176,7 +176,7 @@ GetTagi RegName object RegName dst -Get the object tag for Datatype object in register `object`. And saves the reult to register `dst`. +Get the object tag for ADT object in register `object`. And saves the reult to register `dst`. Fatal ^^^^^ @@ -251,9 +251,9 @@ Currently, we support 3 types of objects: tensors, data types, and closures. :: - VMObject VMTensor(const tvm::runtime::NDArray& data); - VMObject VMDatatype(size_t tag, const std::vector& fields); - VMObject VMClosure(size_t func_index, std::vector free_vars); + Object Tensor(const tvm::runtime::NDArray& data); + Object ADT(size_t tag, const std::vector& fields); + Object Closure(size_t func_index, std::vector free_vars); Stack and State diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 0693b1f47b3c..7291510c16df 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -51,7 +51,7 @@ enum TypeIndex { kRoot = 0, kVMTensor = 1, kVMClosure = 2, - kVMDatatype = 3, + kVMADT = 3, kStaticIndexEnd, /*! \brief Type index is allocated during runtime. */ kDynamic = kStaticIndexEnd diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index a276c658c496..7d2df0b285b1 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -57,31 +57,31 @@ class Tensor : public ObjectRef { /*! \brief An object representing a structure or enumeration. */ -class DatatypeObj : public Object { +class ADTObj : public Object { public: /*! \brief The tag representing the constructor used. */ size_t tag; /*! \brief The fields of the structure. */ std::vector fields; - static constexpr const uint32_t _type_index = TypeIndex::kVMDatatype; - static constexpr const char* _type_key = "vm.Datatype"; - TVM_DECLARE_FINAL_OBJECT_INFO(DatatypeObj, Object); + static constexpr const uint32_t _type_index = TypeIndex::kVMADT; + static constexpr const char* _type_key = "vm.ADT"; + TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); }; -/*! \brief reference to data type. */ -class Datatype : public ObjectRef { +/*! \brief reference to algebraic data type objects. */ +class ADT : public ObjectRef { public: - Datatype(size_t tag, std::vector fields); + ADT(size_t tag, std::vector fields); /*! * \brief construct a tuple object. * \param fields The fields of the tuple. * \return The constructed tuple type. */ - static Datatype Tuple(std::vector fields); + static ADT Tuple(std::vector fields); - TVM_DEFINE_OBJECT_REF_METHODS(Datatype, ObjectRef, DatatypeObj); + TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); }; /*! \brief An object representing a closure. */ @@ -129,7 +129,7 @@ enum class Opcode { InvokePacked = 4U, AllocTensor = 5U, AllocTensorReg = 6U, - AllocDatatype = 7U, + AllocADT = 7U, AllocClosure = 8U, GetField = 9U, If = 10U, @@ -237,7 +237,7 @@ struct Instruction { /*! \brief The register to project from. */ RegName object; } get_tag; - struct /* AllocDatatype Operands */ { + struct /* AllocADT Operands */ { /*! \brief The datatype's constructor tag. */ Index constructor_tag; /*! \brief The number of fields to store in the datatype. */ @@ -294,7 +294,7 @@ struct Instruction { * \param dst The register name of the destination. * \return The allocate instruction tensor. */ - static Instruction AllocDatatype(Index tag, Index num_fields, const std::vector& fields, + static Instruction AllocADT(Index tag, Index num_fields, const std::vector& fields, RegName dst); /*! \brief Construct an allocate closure instruction. * \param func_index The index of the function table. diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 942c93b866f4..e190e3f1eb41 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -31,7 +31,7 @@ from .interpreter import Executor Tensor = _obj.Tensor -Datatype = _obj.Datatype +ADT = _obj.ADT def _convert(arg, cargs): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): diff --git a/python/tvm/relay/backend/vmobj.py b/python/tvm/relay/backend/vmobj.py index 939b122bf510..f3fdb763209d 100644 --- a/python/tvm/relay/backend/vmobj.py +++ b/python/tvm/relay/backend/vmobj.py @@ -61,14 +61,14 @@ def asnumpy(self): return self.data.asnumpy() -@register_object("vm.Datatype") -class Datatype(Object): - """Datatype object. +@register_object("vm.ADT") +class ADT(Object): + """Algebatic data type(ADT) object. Parameters ---------- tag : int - The tag of datatype. + The tag of ADT. fields : list[Object] or tuple[Object] The source tuple. @@ -77,22 +77,22 @@ def __init__(self, tag, fields): for f in fields: assert isinstance(f, Object) self.__init_handle_by_constructor__( - _vmobj.Datatype, tag, *fields) + _vmobj.ADT, tag, *fields) @property def tag(self): - return _vmobj.GetDatatypeTag(self) + return _vmobj.GetADTTag(self) def __getitem__(self, idx): return getitem_helper( - self, _vmobj.GetDatatypeFields, len(self), idx) + self, _vmobj.GetADTFields, len(self), idx) def __len__(self): - return _vmobj.GetDatatypeNumberOfFields(self) + return _vmobj.GetADTNumberOfFields(self) def tuple_object(fields): - """Create a datatype object from source tuple. + """Create a ADT object from source tuple. Parameters ---------- @@ -101,7 +101,7 @@ def tuple_object(fields): Returns ------- - ret : Datatype + ret : ADT The created object. """ for f in fields: diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index f295ccd7a555..fab01bd40423 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -239,7 +239,7 @@ class VMFunctionCompiler : ExprFunctor { DLOG(INFO) << "VMCompiler::Emit: instr=" << instr; CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op; switch (instr.op) { - case Opcode::AllocDatatype: + case Opcode::AllocADT: case Opcode::AllocTensor: case Opcode::AllocTensorReg: case Opcode::GetField: @@ -287,7 +287,7 @@ class VMFunctionCompiler : ExprFunctor { } // TODO(@jroesch): use correct tag - Emit(Instruction::AllocDatatype( + Emit(Instruction::AllocADT( 0, tuple->fields.size(), fields_registers, @@ -626,7 +626,7 @@ class VMFunctionCompiler : ExprFunctor { for (size_t i = arity - return_count; i < arity; ++i) { fields_registers.push_back(unpacked_arg_regs[i]); } - Emit(Instruction::AllocDatatype(0, return_count, fields_registers, NewRegister())); + Emit(Instruction::AllocADT(0, return_count, fields_registers, NewRegister())); } } @@ -659,7 +659,7 @@ class VMFunctionCompiler : ExprFunctor { } } else if (auto constructor_node = op.as()) { auto constructor = GetRef(constructor_node); - Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers, + Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers, NewRegister())); } else if (auto var_node = op.as()) { VisitExpr(GetRef(var_node)); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index f85283094e91..32032b5a1e64 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -315,7 +315,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.push_back(instr.dst); break; } - case Opcode::AllocDatatype: { + case Opcode::AllocADT: { // Number of fields = 3 + instr.num_fields fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); @@ -551,7 +551,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { return Instruction::AllocTensorReg(shape_register, dtype, dst); } - case Opcode::AllocDatatype: { + case Opcode::AllocADT: { // Number of fields = 3 + instr.num_fields DCHECK_GE(instr.fields.size(), 3U); DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); @@ -561,7 +561,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { RegName dst = instr.fields[2]; std::vector fields = ExtractFields(instr.fields, 3, num_fields); - return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); + return Instruction::AllocADT(constructor_tag, num_fields, fields, dst); } case Opcode::AllocClosure: { // Number of fields = 3 + instr.num_freevar diff --git a/src/runtime/vm/object.cc b/src/runtime/vm/object.cc index c20a1ce9de27..12edf511db66 100644 --- a/src/runtime/vm/object.cc +++ b/src/runtime/vm/object.cc @@ -39,15 +39,15 @@ Tensor::Tensor(NDArray data) { data_ = std::move(ptr); } -Datatype::Datatype(size_t tag, std::vector fields) { - auto ptr = make_object(); +ADT::ADT(size_t tag, std::vector fields) { + auto ptr = make_object(); ptr->tag = tag; ptr->fields = std::move(fields); data_ = std::move(ptr); } -Datatype Datatype::Tuple(std::vector fields) { - return Datatype(0, fields); +ADT ADT::Tuple(std::vector fields) { + return ADT(0, fields); } Closure::Closure(size_t func_index, std::vector free_vars) { @@ -66,28 +66,28 @@ TVM_REGISTER_GLOBAL("_vmobj.GetTensorData") *rv = cell->data; }); -TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag") +TVM_REGISTER_GLOBAL("_vmobj.GetADTTag") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; - const auto* cell = obj.as(); + const auto* cell = obj.as(); CHECK(cell != nullptr); *rv = static_cast(cell->tag); }); -TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields") +TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; - const auto* cell = obj.as(); + const auto* cell = obj.as(); CHECK(cell != nullptr); *rv = static_cast(cell->fields.size()); }); -TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields") +TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; int idx = args[1]; - const auto* cell = obj.as(); + const auto* cell = obj.as(); CHECK(cell != nullptr); CHECK_LT(idx, cell->fields.size()); *rv = cell->fields[idx]; @@ -104,10 +104,10 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple") for (auto i = 0; i < args.size(); ++i) { fields.push_back(args[i]); } - *rv = Datatype::Tuple(fields); + *rv = ADT::Tuple(fields); }); -TVM_REGISTER_GLOBAL("_vmobj.Datatype") +TVM_REGISTER_GLOBAL("_vmobj.ADT") .set_body([](TVMArgs args, TVMRetValue* rv) { int itag = args[0]; size_t tag = static_cast(itag); @@ -115,11 +115,11 @@ TVM_REGISTER_GLOBAL("_vmobj.Datatype") for (int i = 1; i < args.size(); i++) { fields.push_back(args[i]); } - *rv = Datatype(tag, fields); + *rv = ADT(tag, fields); }); TVM_REGISTER_OBJECT_TYPE(TensorObj); -TVM_REGISTER_OBJECT_TYPE(DatatypeObj); +TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 78b74768b930..fd5ff64d5812 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -74,7 +74,7 @@ Instruction::Instruction(const Instruction& instr) { this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return; - case Opcode::AllocDatatype: + case Opcode::AllocADT: this->constructor_tag = instr.constructor_tag; this->num_fields = instr.num_fields; this->datatype_fields = Duplicate(instr.datatype_fields, instr.num_fields); @@ -159,7 +159,7 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return *this; - case Opcode::AllocDatatype: + case Opcode::AllocADT: this->constructor_tag = instr.constructor_tag; this->num_fields = instr.num_fields; FreeIf(this->datatype_fields); @@ -229,7 +229,7 @@ Instruction::~Instruction() { case Opcode::AllocTensor: delete this->alloc_tensor.shape; return; - case Opcode::AllocDatatype: + case Opcode::AllocADT: delete this->datatype_fields; return; case Opcode::AllocClosure: @@ -301,10 +301,10 @@ Instruction Instruction::AllocTensorReg(RegName shape_register, DLDataType dtype return instr; } -Instruction Instruction::AllocDatatype(Index tag, Index num_fields, +Instruction Instruction::AllocADT(Index tag, Index num_fields, const std::vector& datatype_fields, Index dst) { Instruction instr; - instr.op = Opcode::AllocDatatype; + instr.op = Opcode::AllocADT; instr.dst = dst; instr.constructor_tag = tag; instr.num_fields = num_fields; @@ -485,7 +485,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); break; } - case Opcode::AllocDatatype: { + case Opcode::AllocADT: { os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$" << StrJoin(instr.datatype_fields, 0, instr.num_fields, ",$") << "]"; break; @@ -691,7 +691,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, const std::vector& args) { size_t arity = 0; for (Index i = 0; i < arg_count; i++) { - if (const auto* obj = args[i].as()) { + if (const auto* obj = args[i].as()) { arity += obj->fields.size(); } else { ++arity; @@ -703,7 +703,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, runtime::TVMArgsSetter setter(values.data(), codes.data()); int idx = 0; for (Index i = 0; i < arg_count; i++) { - if (const auto* dt_cell = args[i].as()) { + if (const auto* dt_cell = args[i].as()) { for (auto obj : dt_cell->fields) { const auto* tensor = obj.as(); CHECK(tensor != nullptr); @@ -849,7 +849,7 @@ void VirtualMachine::RunLoop() { } case Opcode::GetField: { auto object = ReadRegister(instr.object); - const auto* tuple = object.as(); + const auto* tuple = object.as(); CHECK(tuple != nullptr) << "Object is not data type object, register " << instr.object << ", Object tag " << object->type_index(); @@ -860,7 +860,7 @@ void VirtualMachine::RunLoop() { } case Opcode::GetTag: { auto object = ReadRegister(instr.get_tag.object); - const auto* data = object.as(); + const auto* data = object.as(); CHECK(data != nullptr) << "Object is not data type object, register " << instr.get_tag.object << ", Object tag " @@ -925,12 +925,12 @@ void VirtualMachine::RunLoop() { pc++; goto main_loop; } - case Opcode::AllocDatatype: { + case Opcode::AllocADT: { std::vector fields; for (Index i = 0; i < instr.num_fields; ++i) { fields.push_back(ReadRegister(instr.datatype_fields[i])); } - ObjectRef obj = Datatype(instr.constructor_tag, fields); + ObjectRef obj = ADT(instr.constructor_tag, fields); WriteRegister(instr.dst, obj); pc++; goto main_loop; diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 3321d71a2cb8..420bcb72a4a2 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -49,7 +49,7 @@ def convert_to_list(x): def vmobj_to_list(o): if isinstance(o, tvm.relay.backend.vmobj.Tensor): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vmobj.Datatype): + elif isinstance(o, tvm.relay.backend.vmobj.ADT): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 390d3cd9f3c4..32bc22f9031a 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -742,7 +742,7 @@ def vmobj_to_list(o): return [o.asnumpy().tolist()] elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): return [o.asnumpy()] - elif isinstance(o, tvm.relay.backend.vmobj.Datatype): + elif isinstance(o, tvm.relay.backend.vmobj.ADT): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 1b40f894db08..a3b251c38e00 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -63,7 +63,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): def vmobj_to_list(o): if isinstance(o, tvm.relay.backend.vm.Tensor): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vm.Datatype): + elif isinstance(o, tvm.relay.backend.vm.ADT): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_vm_object.py b/tests/python/relay/test_vm_object.py index ad21fff8e185..12d263d1125b 100644 --- a/tests/python/relay/test_vm_object.py +++ b/tests/python/relay/test_vm_object.py @@ -28,13 +28,13 @@ def test_tensor(): assert isinstance(x.data, tvm.nd.NDArray) -def test_datatype(): +def test_adt(): arr = tvm.nd.array([1,2,3]) x = vm.Tensor(arr) - y = vm.Datatype(0, [x, x]) + y = vm.ADT(0, [x, x]) assert len(y) == 2 - assert isinstance(y, vm.Datatype) + assert isinstance(y, vm.ADT) y[0:1][-1].data == x.data assert y.tag == 0 assert isinstance(x.data, tvm.nd.NDArray) @@ -43,4 +43,4 @@ def test_datatype(): if __name__ == "__main__": test_tensor() - test_datatype() + test_adt() From ffc11b758d9e786a18a47a8715356d0cad1a24d9 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sun, 20 Oct 2019 10:40:10 -0700 Subject: [PATCH 09/59] [Runtime] Enable option to use OpenMP thread pool (#4089) --- CMakeLists.txt | 4 ++++ cmake/config.cmake | 4 ++++ cmake/modules/OpenMP.cmake | 48 ++++++++++++++++++++++++++++++++++++++ src/runtime/thread_pool.cc | 26 +++++++++++++++++++++ 4 files changed, 82 insertions(+) create mode 100644 cmake/modules/OpenMP.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index f44dd502e5de..248b39130e36 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,7 @@ tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" O tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF) +tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) tvm_option(USE_SGX "Build with SGX" OFF) tvm_option(USE_RTTI "Build with RTTI" ON) @@ -155,6 +156,7 @@ list(APPEND COMPILER_SRCS ${RELAY_BACKEND_SRCS}) list(APPEND COMPILER_SRCS ${RELAY_IR_SRCS}) list(APPEND COMPILER_SRCS ${RELAY_QNN_SRCS}) + if(USE_VM_PROFILER) message(STATUS "Build compiler with Relay VM profiler support...") file(GLOB BACKEND_VM_PROFILER_SRCS src/relay/backend/vm/profiler/*.cc) @@ -234,6 +236,7 @@ include(cmake/modules/VTA.cmake) include(cmake/modules/CUDA.cmake) include(cmake/modules/OpenCL.cmake) include(cmake/modules/OpenGL.cmake) +include(cmake/modules/OpenMP.cmake) include(cmake/modules/Vulkan.cmake) include(cmake/modules/Metal.cmake) include(cmake/modules/ROCM.cmake) @@ -267,6 +270,7 @@ add_library(tvm_topi SHARED ${TOPI_SRCS}) add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) add_library(tvm_runtime_static STATIC ${RUNTIME_SRCS}) + if(USE_RELAY_DEBUG) message(STATUS "Building Relay in debug mode...") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG") diff --git a/cmake/config.cmake b/cmake/config.cmake index b88d25b68700..f87dc8ab1d8f 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -115,6 +115,10 @@ set(USE_BLAS none) # set(USE_MKL_PATH ) if using `pip install mkl` set(USE_MKL_PATH none) +# Whether use OpenMP thread pool, choices: gnu, intel +# Note: "gnu" uses gomp library, "intel" uses iomp5 library +set(USE_OPENMP none) + # Whether use contrib.random in runtime set(USE_RANDOM OFF) diff --git a/cmake/modules/OpenMP.cmake b/cmake/modules/OpenMP.cmake new file mode 100644 index 000000000000..5dd9be508342 --- /dev/null +++ b/cmake/modules/OpenMP.cmake @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# OpenMP Module +if(USE_OPENMP STREQUAL "gnu") + find_package(OpenMP) + if(OPENMP_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenMP_CXX_LIBRARIES}) + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=1) + message(STATUS "Build with OpenMP ${OpenMP_CXX_LIBRARIES}") + else() + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=0) + message(WARNING "OpenMP cannot be found, use TVM threadpool instead.") + endif() +elseif(USE_OPENMP STREQUAL "intel") + find_package(OpenMP) + if(OPENMP_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + if (MSVC) + find_library(OMP_LIBRARY NAMES libiomp5md) + else() + find_library(OMP_LIBRARY NAMES iomp5) + endif() + list(APPEND TVM_RUNTIME_LINKER_LIBS ${OMP_LIBRARY}) + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=1) + message(STATUS "Build with OpenMP " ${OMP_LIBRARY}) + else() + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=0) + message(WARNING "OpenMP cannot be found, use TVM threadpool instead.") + endif() +else() + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=0) +endif() diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 2e101364db2a..e9e6d03243e3 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -29,6 +29,9 @@ #include #include #include +#if TVM_THREADPOOL_USE_OPENMP +#include +#endif #include #include #include @@ -394,12 +397,34 @@ int TVMBackendParallelLaunch( FTVMParallelLambda flambda, void* cdata, int num_task) { +#if !TVM_THREADPOOL_USE_OPENMP int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch( flambda, cdata, num_task, 1); return res; +#else + int num_workers = tvm::runtime::threading::MaxConcurrency(); + if (num_task == 0) num_task = num_workers; + omp_set_num_threads(num_workers); + #pragma omp parallel num_threads(num_workers) + { + TVMParallelGroupEnv env; + env.num_task = num_task; + std::atomic* sync_counter = new std::atomic[num_task * tvm::runtime::kSyncStride]; + for (int i = 0; i < num_task; ++i) { + sync_counter[i * tvm::runtime::kSyncStride].store( + 0, std::memory_order_relaxed); + } + env.sync_handle = sync_counter; + (*flambda)(omp_get_thread_num(), &env, cdata); + } + return 0; +#endif } int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { +#if TVM_THREADPOOL_USE_OPENMP + #pragma omp barrier +#else using tvm::runtime::kSyncStride; int num_task = penv->num_task; std::atomic* sync_counter = @@ -415,5 +440,6 @@ int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { } } std::atomic_thread_fence(std::memory_order_acquire); +#endif return 0; } From 824e1d8182b97c399a975d0583e78b021d03f0a6 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 20 Oct 2019 18:30:41 -0700 Subject: [PATCH 10/59] [REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol. (#4161) * [REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol. This PR removes the original node system, and make node as a subclass of Object. This is a major refactor towards a better unified runtime object system. List of changes in the refactor: - We now hide data_ field, use Downcast explicitly to get a sub-class object. - Removed the node system FFI in python. - Removed the node C API, instead use PackedFunc for list and get attrs. - Change relay::Op::set_attr_type_key(attr_key_name) to relay::Op::set_attr_type(). - This change was necessary because of the new Object registration mechanism. - Subsequent changes to the op registrations - The change revealed a few previous problems that is now fixed. - Patched up a few missing node type registration. - Now we will raise an error if we register object that is not registered. - The original node.h and container.h are kept in the same location. - Calling convention: kObjectHandle now equals the old kNodeHandle, kNodeHandle is removed. - IRFunctor now dispatches on ObjectRef. - Update to the new type checking API: is_type, derived_from are replaced by IsInstance. - Removed .hash member function, instead use C++ convention hasher functors. * Address review comments --- golang/src/value.go | 4 +- include/tvm/api_registry.h | 8 +- include/tvm/arithmetic.h | 4 +- include/tvm/attrs.h | 24 +- include/tvm/base.h | 16 +- include/tvm/buffer.h | 4 +- include/tvm/build_module.h | 16 +- include/tvm/c_dsl_api.h | 98 ------ include/tvm/channel.h | 4 +- include/tvm/data_layout.h | 8 +- include/tvm/expr.h | 27 +- include/tvm/ir.h | 6 +- include/tvm/ir_functor_ext.h | 18 +- include/tvm/ir_mutator.h | 4 +- include/tvm/ir_visitor.h | 4 +- include/tvm/lowered_func.h | 9 +- include/tvm/node/container.h | 219 ++++++------- include/tvm/node/ir_functor.h | 50 +-- include/tvm/node/memory.h | 77 ----- include/tvm/node/node.h | 300 +++--------------- include/tvm/operation.h | 2 +- include/tvm/packed_func_ext.h | 183 +++++------ include/tvm/relay/adt.h | 2 +- include/tvm/relay/base.h | 10 +- include/tvm/relay/expr.h | 11 +- include/tvm/relay/expr_functor.h | 8 +- include/tvm/relay/interpreter.h | 4 +- include/tvm/relay/module.h | 6 +- include/tvm/relay/op.h | 19 +- include/tvm/relay/pattern_functor.h | 8 +- include/tvm/relay/transform.h | 8 +- include/tvm/relay/type.h | 7 +- include/tvm/runtime/c_runtime_api.h | 3 +- include/tvm/runtime/memory.h | 2 +- include/tvm/runtime/node_base.h | 259 --------------- include/tvm/runtime/object.h | 248 +++++++++++++-- include/tvm/runtime/packed_func.h | 65 ++-- include/tvm/schedule.h | 20 +- include/tvm/tensor.h | 16 +- include/tvm/tensor_intrin.h | 4 +- .../main/native/ml_dmlc_tvm_native_c_api.cc | 6 +- nnvm/include/nnvm/compiler/util.h | 6 +- nnvm/src/compiler/compile_engine.cc | 7 +- nnvm/src/compiler/compile_engine.h | 6 +- nnvm/src/compiler/graph_runtime.h | 5 +- nnvm/src/compiler/packed_func_ext.cc | 6 +- nnvm/src/top/tensor/transform.cc | 6 +- python/tvm/_ffi/_ctypes/function.py | 17 +- python/tvm/_ffi/_ctypes/node.py | 102 ------ python/tvm/_ffi/_ctypes/object.py | 13 +- python/tvm/_ffi/_cython/base.pxi | 17 +- python/tvm/_ffi/_cython/core.pyx | 2 +- python/tvm/_ffi/_cython/function.pxi | 23 +- python/tvm/_ffi/_cython/node.pxi | 110 ------- python/tvm/_ffi/_cython/object.pxi | 12 +- python/tvm/_ffi/node.py | 59 +--- python/tvm/_ffi/object.py | 23 +- python/tvm/_ffi/runtime_ctypes.py | 3 +- python/tvm/error.py | 1 + python/tvm/relay/backend/profiler_vm.py | 4 + python/tvm/relay/debug.py | 4 - rust/common/src/packed_func.rs | 6 +- rust/frontend/src/function.rs | 2 +- src/api/api_arith.cc | 3 +- src/api/api_base.cc | 11 +- src/api/api_codegen.cc | 6 +- src/api/api_ir.cc | 1 - src/api/api_lang.cc | 93 +++--- src/api/api_pass.cc | 8 +- src/api/api_schedule.cc | 5 +- src/api/dsl_api.cc | 134 +++----- src/arithmetic/analyzer.cc | 7 +- src/arithmetic/canonical_simplify.cc | 6 +- src/arithmetic/const_int_bound.cc | 2 +- src/arithmetic/detect_linear_equation.cc | 2 +- src/arithmetic/int_set.cc | 4 +- src/arithmetic/ir_mutator_with_analyzer.cc | 2 +- src/arithmetic/ir_visitor_with_analyzer.h | 2 +- src/arithmetic/modular_set.cc | 2 +- src/codegen/build_module.cc | 24 +- src/codegen/codegen_c.cc | 2 +- src/codegen/llvm/codegen_llvm.cc | 2 +- src/codegen/spirv/codegen_spirv.cc | 2 +- src/contrib/hybrid/codegen_hybrid.cc | 4 +- src/contrib/hybrid/codegen_hybrid.h | 1 - src/lang/attr_functor.h | 80 ++--- src/lang/attrs.cc | 52 +-- src/lang/data_layout.cc | 8 +- src/lang/expr.cc | 4 +- src/lang/ir.cc | 8 +- src/lang/reflection.cc | 105 +++--- src/node/node.cc | 76 ----- src/op/compute_op.cc | 8 +- src/op/hybrid_op.cc | 4 +- src/op/op_util.cc | 2 +- src/op/tensorize.cc | 2 +- src/pass/arg_binder.cc | 2 +- src/pass/coproc_sync.cc | 6 +- src/pass/hoist_if_then_else.cc | 7 +- src/pass/inject_copy_intrin.cc | 10 +- src/pass/inject_double_buffer.cc | 2 +- src/pass/inject_prefetch.cc | 2 +- src/pass/inject_virtual_thread.cc | 5 +- src/pass/ir_mutator.cc | 2 +- src/pass/lift_attr_scope.cc | 6 +- src/pass/lower_thread_allreduce.cc | 6 +- src/pass/lower_warp_memory.cc | 6 +- src/pass/make_api.cc | 4 +- src/pass/narrow_channel_access.cc | 2 +- src/pass/remap_thread_axis.cc | 6 +- src/pass/split_host_device.cc | 10 +- src/pass/split_pipeline.cc | 8 +- src/pass/storage_access.cc | 12 +- src/pass/storage_flatten.cc | 16 +- src/pass/storage_rewrite.cc | 13 +- src/pass/storage_sync.cc | 6 +- src/pass/unroll_loop.cc | 3 +- src/pass/vectorize_loop.cc | 3 +- src/pass/verify_memory.cc | 6 +- src/relay/backend/compile_engine.cc | 8 +- src/relay/backend/compile_engine.h | 18 +- src/relay/backend/graph_runtime_codegen.cc | 6 +- src/relay/ir/alpha_equal.cc | 15 +- src/relay/ir/expr_functor.cc | 7 +- src/relay/ir/hash.cc | 21 +- src/relay/ir/module.cc | 9 +- src/relay/ir/op.cc | 12 +- src/relay/ir/pretty_printer.cc | 18 +- src/relay/ir/type_functor.h | 14 +- src/relay/op/algorithm/argsort.cc | 2 +- src/relay/op/algorithm/topk.cc | 2 +- src/relay/op/debug.cc | 13 +- src/relay/op/image/resize.cc | 6 +- src/relay/op/nn/bitserial.cc | 38 +-- src/relay/op/nn/convolution.cc | 20 +- src/relay/op/nn/nn.cc | 26 +- src/relay/op/nn/pad.cc | 4 +- src/relay/op/nn/pooling.cc | 16 +- src/relay/op/nn/sparse.cc | 4 +- src/relay/op/nn/upsampling.cc | 2 +- src/relay/op/tensor/reduce.cc | 18 +- src/relay/op/tensor/transform.cc | 75 +++-- src/relay/op/tensor/unary.cc | 10 +- src/relay/op/vision/multibox_op.cc | 8 +- src/relay/op/vision/yolo.cc | 2 +- src/relay/pass/alter_op_layout.cc | 15 +- src/relay/pass/device_annotation.cc | 10 +- src/relay/pass/eta_expand.cc | 4 +- src/relay/pass/fold_constant.cc | 2 +- src/relay/pass/fold_scale_axis.cc | 10 +- src/relay/pass/partial_eval.cc | 14 +- src/relay/pass/pass_manager.cc | 7 +- src/relay/pass/quantize/annotate.cc | 4 +- src/relay/pass/quantize/partition.cc | 3 + src/relay/pass/quantize/quantize.cc | 2 +- src/relay/pass/quantize/quantize.h | 8 +- src/relay/pass/quantize/realize.cc | 22 +- src/relay/pass/type_infer.cc | 21 +- src/relay/pass/type_solver.cc | 2 +- src/relay/qnn/op/concatenate.cc | 2 +- src/relay/qnn/op/convolution.cc | 2 +- src/relay/qnn/op/dense.cc | 2 +- src/relay/qnn/op/dequantize.cc | 2 +- src/relay/qnn/op/quantize.cc | 2 +- src/relay/qnn/op/requantize.cc | 2 +- src/runtime/c_dsl_api.cc | 91 ------ src/runtime/c_runtime_api.cc | 2 +- src/runtime/dsl_api.h | 59 ---- src/runtime/object.cc | 21 +- src/schedule/graph.cc | 2 +- src/schedule/schedule_dataflow_rewrite.cc | 18 +- src/schedule/schedule_lang.cc | 24 +- src/schedule/schedule_ops.cc | 6 +- tests/cpp/expr_test.cc | 4 +- tests/cpp/ir_functor_test.cc | 2 +- tests/cpp/object_protocol_test.cc | 6 +- tests/cpp/packed_func_test.cc | 2 +- tests/python/unittest/test_lang_schedule.py | 6 +- .../unittest/test_runtime_vm_profiler.py | 2 + topi/include/topi/cuda/pooling.h | 2 +- topi/include/topi/cuda/reduction.h | 2 +- topi/include/topi/detail/constant_utils.h | 15 +- topi/include/topi/generic/extern.h | 2 +- topi/src/topi.cc | 5 +- web/tvm_runtime.js | 8 +- 185 files changed, 1442 insertions(+), 2387 deletions(-) delete mode 100644 include/tvm/c_dsl_api.h delete mode 100644 include/tvm/node/memory.h delete mode 100644 include/tvm/runtime/node_base.h delete mode 100644 python/tvm/_ffi/_ctypes/node.py delete mode 100644 python/tvm/_ffi/_cython/node.pxi delete mode 100644 src/node/node.cc delete mode 100644 src/runtime/c_dsl_api.cc delete mode 100644 src/runtime/dsl_api.h diff --git a/golang/src/value.go b/golang/src/value.go index 576331a8cfa0..5e0f78270eaa 100644 --- a/golang/src/value.go +++ b/golang/src/value.go @@ -44,8 +44,8 @@ var KTVMType = int32(C.kTVMType) var KTVMContext = int32(C.kTVMContext) // KArrayHandle is golang type code for TVM kArrayHandle. var KArrayHandle = int32(C.kArrayHandle) -// KNodeHandle is golang type code for TVM kNodeHandle. -var KNodeHandle = int32(C.kNodeHandle) +// KObjectHandle is golang type code for TVM kObjectHandle. +var KObjectHandle = int32(C.kObjectHandle) // KModuleHandle is gonag type code for TVM kModuleHandle. var KModuleHandle = int32(C.kModuleHandle) // KFuncHandle is gonalg type code for TVM kFuncHandle. diff --git a/include/tvm/api_registry.h b/include/tvm/api_registry.h index e12d841519ca..dbd097293593 100644 --- a/include/tvm/api_registry.h +++ b/include/tvm/api_registry.h @@ -79,7 +79,7 @@ class EnvFunc : public NodeRef { explicit EnvFunc(NodePtr n) : NodeRef(n) {} /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief Invoke the function. @@ -124,19 +124,19 @@ class TypedEnvFunc : public NodeRef { /*! \brief short hand for this function type */ using TSelf = TypedEnvFunc; TypedEnvFunc() {} - explicit TypedEnvFunc(NodePtr n) : NodeRef(n) {} + explicit TypedEnvFunc(ObjectPtr n) : NodeRef(n) {} /*! * \brief Assign global function to a TypedEnvFunc * \param other Another global function. * \return reference to self. */ TSelf& operator=(const EnvFunc& other) { - this->node_ = other.node_; + ObjectRef::operator=(other); return *this; } /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief Invoke the function. diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 8be1c3604813..e81fa0afd254 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -362,7 +362,7 @@ class IntSet : public NodeRef { /*! \brief constructor */ IntSet() {} // constructor from not container. - explicit IntSet(NodePtr n) : NodeRef(n) {} + explicit IntSet(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -692,7 +692,7 @@ Array DetectClipBound(const Expr& e, // implementation inline const IntSetNode* IntSet::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace arith } // namespace tvm diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 3b64d1f961e2..fb8927a75613 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -163,7 +163,7 @@ class AttrsEqual { return lhs == rhs; } // node comparator - TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const; + TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; protected: friend class AttrsEqualHandler; @@ -203,7 +203,7 @@ class AttrsHash { (static_cast(value.bits()) << 8) | (static_cast(value.lanes()) << 16)); } - TVM_DLL size_t operator()(const NodeRef& value) const; + TVM_DLL size_t operator()(const ObjectRef& value) const; private: friend class AttrsHashHandler; @@ -260,7 +260,7 @@ class BaseAttrsNode : public Node { * \return The comparison result. */ TVM_DLL virtual bool ContentEqual( - const Node* other, AttrsEqual equal) const = 0; + const Object* other, AttrsEqual equal) const = 0; /*! * \brief Content aware hash. * \param hasher The hasher to run the hash. @@ -290,7 +290,7 @@ class Attrs : public NodeRef { private: /*! \return the internal attribute node */ const BaseAttrsNode* ptr() const { - return static_cast(node_.get()); + return static_cast(get()); } }; @@ -315,7 +315,7 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; - bool ContentEqual(const Node* other, AttrsEqual equal) const final; + bool ContentEqual(const Object* other, AttrsEqual equal) const final; size_t ContentHash(AttrsHash hasher) const final; // type info static constexpr const char* _type_key = "DictAttrs"; @@ -369,7 +369,7 @@ class AttrsEqualVisitor { public: bool result_{true}; // constructor - AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal) + AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal) : lhs_(lhs), rhs_(rhs), equal_(equal) { } template @@ -387,8 +387,8 @@ class AttrsEqualVisitor { } private: - const Node* lhs_; - const Node* rhs_; + const Object* lhs_; + const Object* rhs_; const AttrsEqual& equal_; }; @@ -488,7 +488,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { } else if (const ir::UIntImm* op = expr.as()) { *ptr = static_cast(op->value); } else { - LOG(FATAL) << "Expect int value, but get " << expr->type_key(); + LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); } } } @@ -521,7 +521,7 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { } else if (const ir::UIntImm* op = expr.as()) { *ptr = static_cast(op->value); } else { - LOG(FATAL) << "Expect float value, but get " << expr->type_key(); + LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); } } } @@ -827,7 +827,7 @@ class AttrsNode : public BaseAttrsNode { return visitor.fields_; } - bool ContentEqual(const Node* other, AttrsEqual equal) const final { + bool ContentEqual(const Object* other, AttrsEqual equal) const final { DerivedType* pself = self(); if (pself == other) return true; if (other == nullptr) return false; @@ -839,7 +839,7 @@ class AttrsNode : public BaseAttrsNode { size_t ContentHash(AttrsHash hasher) const final { ::tvm::detail::AttrsHashVisitor visitor(hasher); - visitor.result_ = std::hash()(this->type_key()); + visitor.result_ = this->GetTypeKeyHash(); self()->__VisitAttrs__(visitor); return visitor.result_; } diff --git a/include/tvm/base.h b/include/tvm/base.h index f358f7f5d447..a42de10abef2 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -47,9 +47,10 @@ using ::tvm::AttrVisitor; */ #define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \ + explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ + : BaseTypeName(n) {} \ const NodeName* operator->() const { \ - return static_cast(node_.get()); \ + return static_cast(data_.get()); \ } \ operator bool() const { return this->defined(); } \ using ContainerType = NodeName; @@ -75,12 +76,12 @@ using ::tvm::AttrVisitor; */ #define TVM_DEFINE_NODE_REF_COW(NodeName) \ NodeName* CopyOnWrite() { \ - CHECK(node_ != nullptr); \ - if (!node_.unique()) { \ + CHECK(data_ != nullptr); \ + if (!data_.unique()) { \ NodePtr n = make_node(*(operator->())); \ - NodePtr(std::move(n)).swap(node_); \ + ObjectPtr(std::move(n)).swap(data_); \ } \ - return static_cast(node_.get()); \ + return static_cast(data_.get()); \ } /*! \brief Macro to make it easy to define node ref type given node */ @@ -160,7 +161,7 @@ std::string SaveJSON(const NodeRef& node); * * \return The shared_ptr of the Node. */ -NodePtr LoadJSON_(std::string json_str); +ObjectPtr LoadJSON_(std::string json_str); /*! * \brief Load the node from json string. @@ -233,6 +234,7 @@ struct NodeFactoryReg { * \note This is necessary to enable serialization of the Node. */ #define TVM_REGISTER_NODE_TYPE(TypeName) \ + TVM_REGISTER_OBJECT_TYPE(TypeName); \ static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ ::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \ .set_creator([](const std::string&) { return ::tvm::make_node(); }) diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 1233e9b0b89b..f18ed9206db3 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -51,7 +51,7 @@ enum BufferType : int { class Buffer : public NodeRef { public: Buffer() {} - explicit Buffer(NodePtr n) : NodeRef(n) {} + explicit Buffer(ObjectPtr n) : NodeRef(n) {} /*! * \brief Return a new buffer that is equivalent with current one * but always add stride field. @@ -171,7 +171,7 @@ class BufferNode : public Node { }; inline const BufferNode* Buffer::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 1d57d82e66c6..c985fbe17546 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -93,7 +93,7 @@ class TargetNode : public Node { class Target : public NodeRef { public: Target() {} - explicit Target(NodePtr n) : NodeRef(n) {} + explicit Target(ObjectPtr n) : NodeRef(n) {} /*! * \brief Create a Target given a string * \param target_str the string to parse @@ -110,7 +110,7 @@ class Target : public NodeRef { TVM_DLL static tvm::Target Current(bool allow_not_defined = true); const TargetNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } using ContainerType = TargetNode; @@ -256,12 +256,12 @@ class BuildConfigNode : public Node { class BuildConfig : public ::tvm::NodeRef { public: BuildConfig() {} - explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {} + explicit BuildConfig(ObjectPtr n) : NodeRef(n) {} const BuildConfigNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } BuildConfigNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } /*! * \brief Construct a BuildConfig containing a empty build config node. @@ -371,7 +371,7 @@ class GenericFuncNode; class GenericFunc : public NodeRef { public: GenericFunc() {} - explicit GenericFunc(NodePtr n) : NodeRef(n) {} + explicit GenericFunc(ObjectPtr n) : NodeRef(n) {} /*! * \brief Set the default function implementaiton. @@ -478,10 +478,10 @@ class GenericFuncNode : public Node { }; inline GenericFuncNode* GenericFunc::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } -#define TVM_GENERIC_FUNC_REG_VAR_DEF \ +#define TVM_GENERIC_FUNC_REG_VAR_DEF \ static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM /*! diff --git a/include/tvm/c_dsl_api.h b/include/tvm/c_dsl_api.h deleted file mode 100644 index bbbb84926e8e..000000000000 --- a/include/tvm/c_dsl_api.h +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/c_dsl_api.h - * - * \brief TVM DSL Node C API, used to interact to DSL compilation. - * - * These are only a few functions needed for DSL construction time. - * These function are only available when link libtvm. - * If only TVM runtime is linked, calling these function will trigger error. - * - * \note Most API functions are registerd as PackedFunc and - * can be grabbed via TVMFuncGetGlobal - */ -#ifndef TVM_C_DSL_API_H_ -#define TVM_C_DSL_API_H_ - -#include "runtime/c_runtime_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/*! \brief handle to node */ -typedef void* NodeHandle; - -/*! - * \brief free the node handle - * \param handle The node handle to be freed. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeFree(NodeHandle handle); - -/*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_index the corresponding type index. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeTypeKey2Index(const char* type_key, - int* out_index); - -/*! - * \brief Get runtime type index of the node. - * \param handle the node handle. - * \param out_index the corresponding type index. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeGetTypeIndex(NodeHandle handle, - int* out_index); - -/*! - * \brief get attributes given key - * \param handle The node handle - * \param key The attribute name - * \param out_value The attribute value - * \param out_type_code The type code of the attribute. - * \param out_success Whether get is successful. - * \return 0 when success, -1 when failure happens - * \note API calls always exchanges with type bits=64, lanes=1 - */ -TVM_DLL int TVMNodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* out_value, - int* out_type_code, - int* out_success); - -/*! - * \brief get attributes names in the node. - * \param handle The node handle - * \param out_size The number of functions - * \param out_array The array of function names. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeListAttrNames(NodeHandle handle, - int *out_size, - const char*** out_array); -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif -#endif // TVM_C_DSL_API_H_ diff --git a/include/tvm/channel.h b/include/tvm/channel.h index 143d4295f3e3..346291a6b06a 100644 --- a/include/tvm/channel.h +++ b/include/tvm/channel.h @@ -35,7 +35,7 @@ class Channel : public NodeRef { public: /*! \brief default constructor */ Channel() {} - explicit Channel(NodePtr n) : NodeRef(n) {} + explicit Channel(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -67,7 +67,7 @@ struct ChannelNode : public Node { // Inline implementations inline const ChannelNode* Channel::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_CHANNEL_H_ diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index c2ae572de818..ad3da6b347af 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -127,7 +127,7 @@ class LayoutNode : public Node { */ class Layout : public NodeRef { public: - explicit Layout(NodePtr n) : NodeRef(n) {} + explicit Layout(ObjectPtr n) : NodeRef(n) {} /*! \brief default constructor */ Layout() = default; @@ -152,7 +152,7 @@ class Layout : public NodeRef { * \return the pointer to the internal node container */ const LayoutNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! @@ -160,7 +160,7 @@ class Layout : public NodeRef { * \return the pointer to the internal node container */ LayoutNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } /*! @@ -369,7 +369,7 @@ class BijectiveLayout : public NodeRef { }; inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 201a2b485aa6..d884a4d61748 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -49,7 +49,7 @@ class ExprNode : public Node { class Expr : public NodeRef { public: Expr() {} - explicit Expr(NodePtr ptr) : NodeRef(ptr) {} + explicit Expr(ObjectPtr ptr) : NodeRef(ptr) {} /*! * \brief construct from integer. * \param value The value to be constructed. @@ -122,7 +122,7 @@ class Variable : public ExprNode { /*! \brief a named variable in TVM */ class Var : public Expr { public: - explicit Var(NodePtr n) : Expr(n) {} + explicit Var(ObjectPtr n) : Expr(n) {} TVM_DLL explicit Var(std::string name_hint = "v", Type t = Int(32)); /*! @@ -145,7 +145,7 @@ class Var : public Expr { * \return the corresponding Variable. */ const Variable* get() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = Variable; @@ -187,7 +187,7 @@ class Integer : public Expr { /*! * \brief constructor from node. */ - explicit Integer(NodePtr node) : Expr(node) {} + explicit Integer(ObjectPtr node) : Expr(node) {} /*! * \brief Construct integer from int value. */ @@ -197,7 +197,7 @@ class Integer : public Expr { * \param other another expression. */ Integer& operator=(const Integer& other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -205,13 +205,13 @@ class Integer : public Expr { * \return the content of the integer. */ const IntImm* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief convert to int64_t */ operator int64_t() const { - CHECK(node_ != nullptr) + CHECK(data_ != nullptr) << " Trying to reference a null Integer"; return (*this)->value; } @@ -346,7 +346,7 @@ class IterVar : public NodeRef { // construct a new iter var without a domain IterVar() {} // construct from shared ptr. - explicit IterVar(NodePtr n) : NodeRef(n) {} + explicit IterVar(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -423,7 +423,7 @@ class IterVarNode : public Node { // inline implementations inline const IterVarNode* IterVar::operator->() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } inline IterVar::operator Expr() const { @@ -481,11 +481,11 @@ class IRPrinter { : stream(stream) {} /*! \brief The node to be printed. */ - TVM_DLL void Print(const NodeRef& node); + TVM_DLL void Print(const ObjectRef& node); /*! \brief Print indent to the stream */ TVM_DLL void PrintIndent(); // Allow registration to be printer. - using FType = IRFunctor; + using FType = IRFunctor; TVM_DLL static FType& vtable(); }; @@ -498,10 +498,7 @@ inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT namespace std { template <> -struct hash<::tvm::IterVar> { - std::size_t operator()(const ::tvm::IterVar& k) const { - return k.hash(); - } +struct hash<::tvm::IterVar> : public ::tvm::NodeHash { }; } #endif // TVM_EXPR_H_ diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 079f05f5a7f2..b90804983cfb 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -664,10 +664,10 @@ class CommReducerNode : public Node { }; inline const CommReducerNode* CommReducer::get() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } inline const CommReducerNode* CommReducer::operator->() const { - return static_cast(node_.get()); + return get(); } /*! \brief Reduction operator operator */ @@ -1576,7 +1576,7 @@ namespace std { template <> struct hash<::tvm::ir::TensorKey> { std::size_t operator()(const ::tvm::ir::TensorKey& k) const { - size_t lhs = k.f.hash(); + size_t lhs = ::tvm::NodeHash()(k.f); size_t rhs = static_cast(k.value_index); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); return lhs; diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index a7d91eacf851..54a5eff6846b 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -84,19 +84,19 @@ class StmtFunctor; } #define STMT_FUNCTOR_DEFAULT { \ return VisitStmtDefault_(op, std::forward(args)...); \ -} + } #define IR_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), \ std::forward(args)...); \ }); \ #define IR_STMT_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitStmt_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStmt_(static_cast(n.get()), \ std::forward(args)...); \ }); \ @@ -104,7 +104,7 @@ template class ExprFunctor { private: using TSelf = ExprFunctor; - using FType = IRFunctor; + using FType = IRFunctor; public: /*! \brief the result type of this functor */ @@ -164,7 +164,7 @@ class ExprFunctor { virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args ...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -213,7 +213,7 @@ template class StmtFunctor { private: using TSelf = StmtFunctor; - using FType = IRFunctor; + using FType = IRFunctor; public: /*! \brief the result type of this functor */ @@ -255,7 +255,7 @@ class StmtFunctor { virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Node* op, Args ...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index b82a19d4689c..c910a48620c8 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -65,9 +65,9 @@ class TVM_DLL IRMutator { /*! \brief destructor */ virtual ~IRMutator() {} /*! \brief functor type of expr mutation */ - using FMutateExpr = IRFunctor; + using FMutateExpr = IRFunctor; /*! \brief functor type of stmt mutation */ - using FMutateStmt = IRFunctor; + using FMutateStmt = IRFunctor; /*! \return internal vtable of expr */ static FMutateExpr& vtable_expr(); // NOLINT(*) /*! \return internal stmt of expr */ diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index f20b91368587..bebf94585ed6 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -49,7 +49,7 @@ namespace ir { * // The use case is to count number of Variables in the ir tree. * class MyCounter : public IRVisitor { * public: - * int Count(const NodeRef& n) { + * int Count(const ObjectRef& n) { * ret_ = 0; * this->Visit(n); * return ret_; @@ -94,7 +94,7 @@ class TVM_DLL IRVisitor { /*! \brief destructor */ virtual ~IRVisitor() {} /*! \brief functor type of visitor */ - using FVisit = IRFunctor; + using FVisit = IRFunctor; /*! \return internal vtable*/ static FVisit& vtable(); // overloadable visit function. diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index 4da93b80c2ab..e2147d036587 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -44,7 +44,7 @@ class LoweredFuncNode; class LoweredFunc : public ir::FunctionRef { public: LoweredFunc() {} - explicit LoweredFunc(NodePtr n) : FunctionRef(n) {} + explicit LoweredFunc(ObjectPtr n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -136,17 +136,14 @@ class LoweredFuncNode : public ir::FunctionBaseNode { // Implementations of inline functions inline const LoweredFuncNode* LoweredFunc::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm namespace std { template <> -struct hash<::tvm::LoweredFunc> { - std::size_t operator()(const ::tvm::LoweredFunc& k) const { - return k.hash(); - } +struct hash<::tvm::LoweredFunc> : public tvm::NodeHash { }; } diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index c2c639e374f5..2e1a978f4806 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -38,14 +38,14 @@ namespace tvm { class ArrayNode : public Node { public: /*! \brief the data content */ - std::vector > data; + std::vector data; void VisitAttrs(AttrVisitor* visitor) final { // Visitor to array have no effect. } static constexpr const char* _type_key = "Array"; - TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Node); }; /*! \brief map node content */ @@ -54,32 +54,17 @@ class MapNode : public Node { void VisitAttrs(AttrVisitor* visitor) final { // Visitor to map have no effect. } - // hash function - struct Hash { - size_t operator()(const NodePtr& n) const { - return std::hash()(n.get()); - } - }; - // comparator - struct Equal { - bool operator()( - const NodePtr& a, - const NodePtr& b) const { - return a.get() == b.get(); - } - }; - /*! \brief The corresponding conatiner type */ using ContainerType = std::unordered_map< - NodePtr, - NodePtr, - Hash, Equal>; + ObjectRef, + ObjectRef, + ObjectHash, ObjectEqual>; /*! \brief the data content */ ContainerType data; static constexpr const char* _type_key = "Map"; - TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Node); }; @@ -90,15 +75,13 @@ class StrMapNode : public Node { // Visitor to map have no effect. } /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map< - std::string, - NodePtr >; + using ContainerType = std::unordered_map; /*! \brief the data content */ ContainerType data; static constexpr const char* _type_key = "StrMap"; - TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Node); }; /*! @@ -111,9 +94,9 @@ template::difference_type; - using value_type = typename std::iterator_traits::value_type; - using pointer = typename std::iterator_traits::pointer; - using reference = typename std::iterator_traits::reference; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) using iterator_category = typename std::iterator_traits::iterator_category; explicit IterAdapter(TIter iter) : iter_(iter) {} @@ -138,7 +121,7 @@ class IterAdapter { inline bool operator!=(IterAdapter other) const { return !(*this == other); } - inline const typename Converter::ResultType operator*() const { + inline const value_type operator*() const { return Converter::convert(*iter_); } @@ -162,26 +145,27 @@ class Array : public NodeRef { * \brief default constructor */ Array() { - node_ = make_node(); + data_ = make_node(); } /*! * \brief move constructor * \param other source */ Array(Array && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Array(const Array &other) : NodeRef(other.node_) { // NOLINT(*) + Array(const Array &other) { // NOLINT(*) + data_ = std::move(other.data_); } /*! * \brief constructor from pointer * \param n the container pointer */ - explicit Array(NodePtr n) : NodeRef(n) {} + explicit Array(ObjectPtr n) : NodeRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -214,9 +198,9 @@ class Array : public NodeRef { explicit Array(size_t n, const T& val) { auto tmp_node = make_node(); for (size_t i = 0; i < n; ++i) { - tmp_node->data.push_back(val.node_); + tmp_node->data.push_back(val); } - node_ = std::move(tmp_node); + data_ = std::move(tmp_node); } /*! * \brief move assign operator @@ -224,7 +208,7 @@ class Array : public NodeRef { * \return reference to self. */ Array& operator=(Array && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } /*! @@ -233,7 +217,7 @@ class Array : public NodeRef { * \return reference to self. */ Array& operator=(const Array & other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -246,9 +230,9 @@ class Array : public NodeRef { void assign(IterType begin, IterType end) { auto n = make_node(); for (IterType it = begin; it != end; ++it) { - n->data.push_back((*it).node_); + n->data.push_back(T(*it)); } - node_ = std::move(n); + data_ = std::move(n); } /*! * \brief Read i-th element from array. @@ -256,12 +240,13 @@ class Array : public NodeRef { * \return the i-th element. */ inline const T operator[](size_t i) const { - return T(static_cast(node_.get())->data[i]); + return DowncastNoCheck( + static_cast(data_.get())->data[i]); } /*! \return The size of the array */ inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } /*! * \brief copy on write semantics @@ -272,12 +257,12 @@ class Array : public NodeRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline ArrayNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! * \brief push a new item to the back of the list @@ -285,7 +270,7 @@ class Array : public NodeRef { */ inline void push_back(const T& item) { ArrayNode* n = this->CopyOnWrite(); - n->data.push_back(item.node_); + n->data.push_back(item); } /*! * \brief set i-th element of the array. @@ -294,7 +279,7 @@ class Array : public NodeRef { */ inline void Set(size_t i, const T& value) { ArrayNode* n = this->CopyOnWrite(); - n->data[i] = value.node_; + n->data[i] = value; } /*! \return whether array is empty */ inline bool empty() const { @@ -303,34 +288,34 @@ class Array : public NodeRef { /*! \brief specify container node */ using ContainerType = ArrayNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = T; - static inline T convert(const NodePtr& n) { - return T(n); + static inline T convert(const ObjectRef& n) { + return DowncastNoCheck(n); } }; - using iterator = IterAdapter >::const_iterator>; + using iterator = IterAdapter::const_iterator>; using reverse_iterator = IterAdapter< - Ptr2NodeRef, - std::vector >::const_reverse_iterator>; + ValueConverter, + std::vector::const_reverse_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return rbegin iterator */ inline reverse_iterator rbegin() const { - return reverse_iterator(static_cast(node_.get())->data.rbegin()); + return reverse_iterator(static_cast(data_.get())->data.rbegin()); } /*! \return rend iterator */ inline reverse_iterator rend() const { - return reverse_iterator(static_cast(node_.get())->data.rend()); + return reverse_iterator(static_cast(data_.get())->data.rend()); } }; @@ -355,26 +340,26 @@ class Map : public NodeRef { * \brief default constructor */ Map() { - node_ = make_node(); + data_ = make_node(); } /*! * \brief move constructor * \param other source */ Map(Map && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Map(const Map &other) : NodeRef(other.node_) { // NOLINT(*) + Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) } /*! * \brief constructor from pointer * \param n the container pointer */ - explicit Map(NodePtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : NodeRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -406,7 +391,7 @@ class Map : public NodeRef { * \return reference to self. */ Map& operator=(Map && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } /*! @@ -415,7 +400,7 @@ class Map : public NodeRef { * \return reference to self. */ Map& operator=(const Map & other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -428,10 +413,9 @@ class Map : public NodeRef { void assign(IterType begin, IterType end) { NodePtr n = make_node(); for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first.node_, - i->second.node_)); + n->data.emplace(std::make_pair(i->first, i->second)); } - node_ = std::move(n); + data_ = std::move(n); } /*! * \brief Read element from map. @@ -439,7 +423,8 @@ class Map : public NodeRef { * \return the corresonding element. */ inline const V operator[](const K& key) const { - return V(static_cast(node_.get())->data.at(key.node_)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } /*! * \brief Read element from map. @@ -447,17 +432,18 @@ class Map : public NodeRef { * \return the corresonding element. */ inline const V at(const K& key) const { - return V(static_cast(node_.get())->data.at(key.node_)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } /*! \return The size of the array */ inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } /*! \return The number of elements of the key */ inline size_t count(const K& key) const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.count(key.node_); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.count(key); } /*! * \brief copy on write semantics @@ -468,12 +454,12 @@ class Map : public NodeRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline MapNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! * \brief set the Map. @@ -482,7 +468,7 @@ class Map : public NodeRef { */ inline void Set(const K& key, const V& value) { MapNode* n = this->CopyOnWrite(); - n->data[key.node_] = value.node_; + n->data[key] = value; } /*! \return whether array is empty */ @@ -492,29 +478,31 @@ class Map : public NodeRef { /*! \brief specify container node */ using ContainerType = MapNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = std::pair; static inline ResultType convert(const std::pair< - NodePtr, - NodePtr >& n) { - return std::make_pair(K(n.first), V(n.second)); + ObjectRef, + ObjectRef>& n) { + return std::make_pair(DowncastNoCheck(n.first), + DowncastNoCheck(n.second)); } }; using iterator = IterAdapter< - Ptr2NodeRef, MapNode::ContainerType::const_iterator>; + ValueConverter, MapNode::ContainerType::const_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return begin iterator */ inline iterator find(const K& key) const { - return iterator(static_cast(node_.get())->data.find(key.node_)); + return iterator( + static_cast(data_.get())->data.find(key)); } }; @@ -524,14 +512,14 @@ class Map : public NodeRef { public: // for code reuse Map() { - node_ = make_node(); + data_ = make_node(); } Map(Map && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } - Map(const Map &other) : NodeRef(other.node_) { // NOLINT(*) + Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) } - explicit Map(NodePtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : NodeRef(n) {} template Map(IterType begin, IterType end) { assign(begin, end); @@ -545,76 +533,77 @@ class Map : public NodeRef { assign(init.begin(), init.end()); } Map& operator=(Map && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } Map& operator=(const Map & other) { - node_ = other.node_; + data_ = other.data_; return *this; } template void assign(IterType begin, IterType end) { auto n = make_node(); for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first, - i->second.node_)); + n->data.emplace(std::make_pair(i->first, i->second)); } - node_ = std::move(n); + data_ = std::move(n); } inline const V operator[](const std::string& key) const { - return V(static_cast(node_.get())->data.at(key)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } inline const V at(const std::string& key) const { - return V(static_cast(node_.get())->data.at(key)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } inline size_t count(const std::string& key) const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.count(key); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.count(key); } inline StrMapNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } inline void Set(const std::string& key, const V& value) { StrMapNode* n = this->CopyOnWrite(); - n->data[key] = value.node_; + n->data[key] = value; } inline bool empty() const { return size() == 0; } using ContainerType = StrMapNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = std::pair; static inline ResultType convert(const std::pair< - std::string, - NodePtr >& n) { - return std::make_pair(n.first, V(n.second)); + std::string, + ObjectRef>& n) { + return std::make_pair(n.first, DowncastNoCheck(n.second)); } }; using iterator = IterAdapter< - Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>; + ValueConverter, StrMapNode::ContainerType::const_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return begin iterator */ inline iterator find(const std::string& key) const { - return iterator(static_cast(node_.get())->data.find(key)); + return iterator(static_cast(data_.get())->data.find(key)); } }; diff --git a/include/tvm/node/ir_functor.h b/include/tvm/node/ir_functor.h index 23c5a3fafdab..e902e8fb6d44 100644 --- a/include/tvm/node/ir_functor.h +++ b/include/tvm/node/ir_functor.h @@ -34,10 +34,10 @@ namespace tvm { /*! - * \brief A dynamically dispatched functor on NodeRef in the first argument. + * \brief A dynamically dispatched functor on ObjectRef in the first argument. * * \code - * IRFunctor tostr; + * IRFunctor tostr; * tostr.set_dispatch([](const Add* op, std::string prefix) { * return prefix + "Add"; * }); @@ -60,10 +60,10 @@ template class IRFunctor; template -class IRFunctor { +class IRFunctor { private: - using Function = std::function; - using TSelf = IRFunctor; + using Function = std::function; + using TSelf = IRFunctor; /*! \brief internal function table */ std::vector func_; @@ -75,8 +75,8 @@ class IRFunctor { * \param n The node to be dispatched * \return Whether dispatching function is registered for n's type. */ - inline bool can_dispatch(const NodeRef& n) const { - uint32_t type_index = n.type_index(); + inline bool can_dispatch(const ObjectRef& n) const { + uint32_t type_index = n->type_index(); return type_index < func_.size() && func_[type_index] != nullptr; } /*! @@ -85,12 +85,12 @@ class IRFunctor { * \param args The additional arguments * \return The result. */ - inline R operator()(const NodeRef& n, Args... args) const { - uint32_t type_index = n.type_index(); + inline R operator()(const ObjectRef& n, Args... args) const { + uint32_t type_index = n->type_index(); CHECK(type_index < func_.size() && func_[type_index] != nullptr) << "IRFunctor calls un-registered function on type " - << Node::TypeIndex2Key(type_index); + << n->GetTypeKey(); return func_[type_index](n, std::forward(args)...); } /*! @@ -101,19 +101,19 @@ class IRFunctor { */ template inline TSelf& set_dispatch(Function f) { // NOLINT(*) - uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + uint32_t tindex = TNode::RuntimeTypeIndex(); if (func_.size() <= tindex) { func_.resize(tindex + 1, nullptr); } CHECK(func_[tindex] == nullptr) - << "Dispatch for " << Node::TypeIndex2Key(tindex) + << "Dispatch for " << TNode::_type_key << " is already set"; func_[tindex] = f; return *this; } /*! * \brief set the dispacher for type TNode - * This allows f to used detailed const Node pointer to replace NodeRef + * This allows f to used detailed const Node pointer to replace ObjectRef * * \param f The function to be set. * \tparam TNode the type of Node to be dispatched. @@ -121,8 +121,8 @@ class IRFunctor { */ template inline TSelf& set_dispatch(std::function f) { // NOLINT(*) - Function fun = [f](const NodeRef& n, Args... args) { - return f(static_cast(n.node_.get()), + Function fun = [f](const ObjectRef& n, Args... args) { + return f(static_cast(n.get()), std::forward(args)...); }; return this->set_dispatch(fun); @@ -135,7 +135,7 @@ class IRFunctor { */ template inline TSelf& clear_dispatch() { // NOLINT(*) - uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + uint32_t tindex = TNode::RuntimeTypeIndex(); CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; func_[tindex] = nullptr; return *this; @@ -172,7 +172,7 @@ class IRFunctor { * f(e, this); * } * - * using FType = IRFunctor; + * using FType = IRFunctor; * // function to return global function table * static FType& vtable(); * }; @@ -232,15 +232,15 @@ template class IRFunctorStaticRegistry; template -class IRFunctorStaticRegistry { +class IRFunctorStaticRegistry { private: - IRFunctor *irf_; + IRFunctor *irf_; std::shared_ptr free_list; - using TSelf = IRFunctorStaticRegistry; + using TSelf = IRFunctorStaticRegistry; public: - IRFunctorStaticRegistry(IRFunctor *irf) { + IRFunctorStaticRegistry(IRFunctor *irf) { irf_ = irf; free_list = std::make_shared(); } @@ -261,12 +261,12 @@ class IRFunctorStaticRegistry { * the compiler to deduce the template types. */ template -IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( - IRFunctor *irf) { - return IRFunctorStaticRegistry(irf); +IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( + IRFunctor *irf) { + return IRFunctorStaticRegistry(irf); } -#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ +#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName /*! diff --git a/include/tvm/node/memory.h b/include/tvm/node/memory.h deleted file mode 100644 index 1bba57144e19..000000000000 --- a/include/tvm/node/memory.h +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/node/memory.h - * \brief Node memory management. - */ -#ifndef TVM_NODE_MEMORY_H_ -#define TVM_NODE_MEMORY_H_ - -#include -#include "node.h" - -namespace tvm { -/*! - * \brief Allocate a node object. - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The NodePtr to the allocated object. - */ -template -inline NodePtr make_node(Args&&... args); - -// Detail implementations after this -// -// The current design allows swapping the -// allocator pattern when necessary. -// -// Possible future allocator optimizations: -// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr) -// - Thread-local object pools: one pool per size and alignment requirement. -// - Can specialize by type of object to give the specific allocator to each object. -// -template -class SimpleNodeAllocator { - public: - template - static T* New(Args&&... args) { - return new T(std::forward(args)...); - } - static NodeBase::FDeleter Deleter() { - return Deleter_; - } - - private: - static void Deleter_(NodeBase* ptr) { - delete static_cast(ptr); - } -}; - -template -inline NodePtr make_node(Args&&... args) { - using Allocator = SimpleNodeAllocator; - static_assert(std::is_base_of::value, - "make_node can only be used to create NodeBase"); - T* node = Allocator::New(std::forward(args)...); - node->deleter_ = Allocator::Deleter(); - return NodePtr(node); -} - -} // namespace tvm -#endif // TVM_NODE_MEMORY_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index cb18e46e9a5c..8203ee69f686 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -25,7 +25,9 @@ #include #include -#include +#include +#include +#include #include #include #include @@ -38,13 +40,6 @@ class DataType; class Node; class NodeRef; -namespace runtime { -// forward declaration -class NDArray; -// forward declaration -class ObjectRef; -} // namespace runtime - /*! * \brief Visitor class to each node content. * The content is going to be called for each field. @@ -74,15 +69,17 @@ class TVM_DLL AttrVisitor { //! \endcond }; +/*! \brief Reuse the type index in he runtime. */ +using TypeIndex = runtime::TypeIndex; + /*! * \brief base class of node container in DSL AST. */ -class TVM_DLL Node : public NodeBase { +class Node : public runtime::Object { public: /*! \brief virtual destructor */ virtual ~Node() {} - /*! \return The unique type key of the node */ - virtual const char* type_key() const = 0; + /*! * \brief Apply visitor to each field of the Node * Visitor could mutate the content of the node. @@ -90,272 +87,79 @@ class TVM_DLL Node : public NodeBase { * \param visitor The visitor */ virtual void VisitAttrs(AttrVisitor* visitor) {} - /*! \return the type index of the node */ - virtual uint32_t type_index() const = 0; - /*! - * \brief Whether this node derives from node with type_index=tid. - * Implemented by TVM_DECLARE_NODE_TYPE_INFO - * - * \param tid The type index. - * \return the check result. - */ - virtual bool _DerivedFrom(uint32_t tid) const; - /*! - * \brief get a runtime unique type index given a type key - * \param type_key Type key of a type. - * \return the corresponding type index. - */ - static uint32_t TypeKey2Index(const char* type_key); - /*! - * \brief get type key from type index. - * \param index The type index - * \return the corresponding type key. - */ - static const char* TypeIndex2Key(uint32_t index); - /*! - * \return whether the type is derived from - */ - template - inline bool derived_from() const; - /*! - * \return whether the node is of type T - * \tparam The type to be checked. - */ - template - inline bool is_type() const; - /*! - * \brief Get a NodePtr that holds reference to this Node. - * \return the NodePtr - */ - inline NodePtr GetNodePtr() const; - // node ref can see this - friend class NodeRef; + static constexpr const char* _type_key = "Node"; + static constexpr uint32_t _type_index = TypeIndex::kDynamic; + + TVM_DECLARE_BASE_OBJECT_INFO(Node, runtime::Object); }; -/*! \brief Base class of all node reference object */ -class NodeRef { + +/*! + * \brief Base class of all node reference object + * NodeRef is just a alias of ObjectRef. + */ +class NodeRef : public runtime::ObjectRef { public: /*! \brief type indicate the container type */ using ContainerType = Node; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator==(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool same_as(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator<(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator!=(const NodeRef& other) const; - /*! \return the hash function for NodeRef */ - inline size_t hash() const; - /*! \return whether the expression is null */ - inline bool defined() const; - /*! \return the internal type index of IRNode */ - inline uint32_t type_index() const; + /*! \return the internal node pointer */ - inline const Node* get() const; + const Node* get() const { + return static_cast(ObjectRef::get()); + } /*! \return the internal node pointer */ - inline const Node* operator->() const; - /*! - * \brief Downcast this ir node to its actual type (e.g. Add, or - * Select). This returns nullptr if the node is not of the requested - * type. Example usage: - * - * if (const Add *add = node->as()) { - * // This is an add node - * } - * \tparam T the target type, must be subtype of IRNode - */ - template - inline const T *as() const; + const Node* operator->() const { + return get(); + } /*! * \brief A more powerful version of as that also works with * intermediate base types. * \tparam T the target type, must be subtype of IRNode */ template - inline const T *as_derived() const; + const T *as_derived() const { + return as(); + } /*! \brief default constructor */ NodeRef() = default; - explicit NodeRef(NodePtr node) : node_(node) {} - /*! \brief the internal node object, do not touch */ - NodePtr node_; + explicit NodeRef(runtime::ObjectPtr ptr) : ObjectRef(ptr) {} }; -/*! - * \brief Get a reference type from a Node ptr type - * - * It is always important to get a reference type - * if we want to return a value as reference or keep - * the node alive beyond the scope of the function. - * - * \param ptr The node pointer - * \tparam RefType The reference type - * \tparam NodeType The node type - * \return The corresponding RefType - */ -template -inline RefType GetRef(const NodeType* ptr); - -/*! - * \brief Downcast a base reference type to a more specific type. - * - * \param ref The inptut reference - * \return The corresponding SubRef. - * \tparam SubRef The target specific reference type. - * \tparam BaseRef the current reference type. - */ -template -inline SubRef Downcast(BaseRef ref); - /*! * \brief helper macro to declare type information in a base node. */ -#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ - bool _DerivedFrom(uint32_t tid) const override { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - if (tidx == tid) return true; \ - return Parent::_DerivedFrom(tid); \ - } +#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ + TVM_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) /*! * \brief helper macro to declare type information in a terminal node */ -#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ - const char* type_key() const final { \ - return TypeName::_type_key; \ - } \ - uint32_t type_index() const final { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - return tidx; \ - } \ - bool _DerivedFrom(uint32_t tid) const final { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - if (tidx == tid) return true; \ - return Parent::_DerivedFrom(tid); \ - } - -// implementations of inline functions after this -template -inline bool Node::derived_from() const { - // use static field so query only happens once. - static uint32_t type_id = Node::TypeKey2Index(T::_type_key); - return this->_DerivedFrom(type_id); -} - - -template -inline bool Node::is_type() const { - // use static field so query only happens once. - static uint32_t type_id = Node::TypeKey2Index(T::_type_key); - return type_id == this->type_index(); -} +#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ + TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent); -inline NodePtr Node::GetNodePtr() const { - return NodePtr(const_cast(this)); -} +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; +using runtime::GetRef; +using runtime::Downcast; +using runtime::make_object; +using runtime::ObjectHash; +using runtime::ObjectEqual; -template -inline RefType GetRef(const NodeType* ptr) { - static_assert(std::is_base_of::value, - "Can only cast to the ref of same container type"); - return RefType(ptr->GetNodePtr()); -} - -template -inline SubRef Downcast(BaseRef ref) { - CHECK(ref->template is_type() || - ref->template derived_from()) - << "Downcast from " << ref->type_key() << " to " - << SubRef::ContainerType::_type_key << " failed."; - return SubRef(std::move(ref.node_)); -} - -inline const Node* NodeRef::get() const { - return node_.get(); -} - -inline const Node* NodeRef::operator->() const { - return node_.get(); -} - -inline bool NodeRef::defined() const { - return node_.get() != nullptr; -} - -inline bool NodeRef::operator==(const NodeRef& other) const { - return node_.get() == other.node_.get(); -} +using NodeHash = ObjectHash; +using NodeEqual = ObjectEqual; -inline bool NodeRef::same_as(const NodeRef& other) const { - return node_.get() == other.node_.get(); -} - -inline bool NodeRef::operator<(const NodeRef& other) const { - return node_.get() < other.node_.get(); -} - -inline bool NodeRef::operator!=(const NodeRef& other) const { - return node_.get() != other.node_.get(); -} - -inline size_t NodeRef::hash() const { - return std::hash()(node_.get()); -} - -inline uint32_t NodeRef::type_index() const { - CHECK(node_.get() != nullptr) - << "null type"; - return get()->type_index(); -} - -template -inline const T* NodeRef::as() const { - const Node* ptr = static_cast(get()); - if (ptr && ptr->is_type()) { - return static_cast(ptr); - } - return nullptr; -} - -template -inline const T* NodeRef::as_derived() const { - const Node* ptr = static_cast(get()); - if (ptr && (ptr->is_type() || ptr->derived_from())) { - return static_cast(ptr); - } - return nullptr; +/*! + * \brief Allocate a node object. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The NodePtr to the allocated object. + */ +template +inline NodePtr make_node(Args&&... args) { + return runtime::make_object(std::forward(args)...); } - -/*! \brief The hash function for nodes */ -struct NodeHash { - size_t operator()(const NodeRef& a) const { - return a.hash(); - } -}; - -/*! \brief The equal comparator for nodes */ -struct NodeEqual { - bool operator()(const NodeRef& a, const NodeRef& b) const { - return a.get() == b.get(); - } -}; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index b950aa952f04..b942464d4907 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -651,7 +651,7 @@ inline Tensor compute(Array shape, // inline function. inline const OperationNode* Operation::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_OPERATION_H_ diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 5951594b873c..48d46fdf2fc6 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -37,6 +37,7 @@ #include "runtime/packed_func.h" namespace tvm { + using runtime::TVMArgs; using runtime::TVMRetValue; using runtime::PackedFunc; @@ -47,86 +48,82 @@ namespace runtime { * \tparam T the type to be checked. */ template -struct NodeTypeChecker { - static inline bool Check(Node* sptr) { - // This is the only place in the project where RTTI is used - // It can be turned off, but will make non strict checking. - // TODO(tqchen) possibly find alternative to turn of RTTI +struct ObjectTypeChecker { + static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; - // always allow nullptr. - if (sptr == nullptr) return true; - return sptr->derived_from(); + if (ptr == nullptr) return true; + return ptr->IsInstance(); } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + static void PrintName(std::ostream& os) { // NOLINT(*) using ContainerType = typename T::ContainerType; os << ContainerType::_type_key; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - ArrayNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const ArrayNode* n = static_cast(ptr); for (const auto& p : n->data) { - if (!NodeTypeChecker::Check(p.get())) { + if (!ObjectTypeChecker::Check(p.get())) { return false; } } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "array<"; - NodeTypeChecker::PrintName(os); - os << ">"; + static void PrintName(std::ostream& os) { // NOLINT(*) + os << "List["; + ObjectTypeChecker::PrintName(os); + os << "]"; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - StrMapNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const StrMapNode* n = static_cast(ptr); for (const auto& kv : n->data) { - if (!NodeTypeChecker::Check(kv.second.get())) return false; + if (!ObjectTypeChecker::Check(kv.second.get())) return false; } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "map::PrintName(os); - os << '>'; + ObjectTypeChecker::PrintName(os); + os << ']'; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - MapNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const MapNode* n = static_cast(ptr); for (const auto& kv : n->data) { - if (!NodeTypeChecker::Check(kv.first.get())) return false; - if (!NodeTypeChecker::Check(kv.second.get())) return false; + if (!ObjectTypeChecker::Check(kv.first.get())) return false; + if (!ObjectTypeChecker::Check(kv.second.get())) return false; } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "map<"; - NodeTypeChecker::PrintName(os); + static void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "Map["; + ObjectTypeChecker::PrintName(os); os << ','; - NodeTypeChecker::PrintName(os); - os << '>'; + ObjectTypeChecker::PrintName(os); + os << ']'; } }; template -inline std::string NodeTypeName() { +inline std::string ObjectTypeName() { std::ostringstream os; - NodeTypeChecker::PrintName(os); + ObjectTypeChecker::PrintName(os); return os.str(); } @@ -138,12 +135,12 @@ inline TNodeRef TVMArgValue::AsNodeRef() const { std::is_base_of::value, "Conversion only works for NodeRef"); if (type_code_ == kNull) return TNodeRef(NodePtr(nullptr)); - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return TNodeRef(sptr); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return TNodeRef(ObjectPtr(ptr)); } inline TVMArgValue::operator tvm::Expr() const { @@ -156,18 +153,20 @@ inline TVMArgValue::operator tvm::Expr() const { if (type_code_ == kDLFloat) { return Expr(static_cast(value_.v_float64)); } - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - if (sptr->is_type()) { - return IterVar(sptr)->var; + + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + + if (ptr->IsInstance()) { + return IterVar(ObjectPtr(ptr))->var; } - if (sptr->is_type()) { - return Tensor(sptr)(); + if (ptr->IsInstance()) { + return Tensor(ObjectPtr(ptr))(); } - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return Expr(sptr); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return Expr(ObjectPtr(ptr)); } inline TVMArgValue::operator tvm::Integer() const { @@ -177,68 +176,36 @@ inline TVMArgValue::operator tvm::Integer() const { CHECK_GE(value_.v_int64, std::numeric_limits::min()); return Integer(static_cast(value_.v_int64)); } - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return Integer(sptr); -} - -inline NodePtr& TVMArgValue::node_sptr() { - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - return *ptr >(); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return Integer(ObjectPtr(ptr)); } - template -inline bool TVMArgValue::IsNodeType() const { - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = - *ptr >(); - return NodeTypeChecker::Check(sptr.get()); +inline bool TVMPODValue_::IsObjectRef() const { + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + return ObjectTypeChecker::Check(ptr); } // extensions for TVMRetValue -inline TVMRetValue& TVMRetValue::operator=( - const NodePtr& other) { - if (other.get() == nullptr) { - SwitchToPOD(kNull); - } else { - SwitchToClass >(kNodeHandle, other); - } - return *this; -} - -inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { - if (!other.defined()) { - SwitchToPOD(kNull); - } else { - SwitchToClass >(kNodeHandle, other.node_); - } - return *this; -} - template inline TNodeRef TVMRetValue::AsNodeRef() const { static_assert( std::is_base_of::value, "Conversion only works for NodeRef"); if (type_code_ == kNull) return TNodeRef(); - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return TNodeRef(sptr); -} + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); -inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*) - if (other.defined()) { - values_[i].v_handle = const_cast*>(&(other.node_)); - type_codes_[i] = kNodeHandle; - } else { - type_codes_[i] = kNull; - } + Object* ptr = static_cast(value_.v_handle); + + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return TNodeRef(ObjectPtr(ptr)); } // type related stuffs diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 4329c438e8a0..e54d88d5a393 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -52,7 +52,7 @@ class PatternNode : public RelayNode { class Pattern : public NodeRef { public: Pattern() {} - explicit Pattern(NodePtr p) : NodeRef(p) {} + explicit Pattern(ObjectPtr p) : NodeRef(p) {} using ContainerType = PatternNode; }; diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index f94ba5e26068..15330b00e961 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -83,10 +83,12 @@ using NodeEqual = ::tvm::NodeEqual; #define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ class TypeName : public NodeRefBase { \ public: \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRefBase(n) {} \ + TypeName() {} \ + explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ + : NodeRefBase(n) { \ + } \ const NodeName* operator->() const { \ - return static_cast(node_.get()); \ + return static_cast(get()); \ } \ operator bool() { return this->defined(); } \ using ContainerType = NodeName; \ @@ -127,7 +129,7 @@ class SourceName : public NodeRef { * \return the pointer to the internal node container */ inline const SourceNameNode* operator->() const { - return static_cast(this->node_.get()); + return static_cast(get()); } /*! diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index b1b8d6a7154e..281b99297e78 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -541,10 +541,11 @@ RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); // implementataions inline const Type& ExprNode::checked_type() const { - CHECK(checked_type_.defined()) << "internal error: the type checker has " - "not populated the checked_type " - "field for " - << GetRef(this); + CHECK(checked_type_.defined()) + << "internal error: the type checker has " + << "not populated the checked_type " + << "field for " + << GetRef(this); return this->checked_type_; } @@ -557,7 +558,7 @@ inline const TTypeNode* ExprNode::type_as() const { const TTypeNode* node = checked_type_.as(); CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key - << ", but get " << checked_type_->type_key(); + << ", but get " << checked_type_->GetTypeKey(); return node; } diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e0d940c5d1a5..8bc87a27f66f 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -57,8 +57,8 @@ class ExprFunctor; #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), \ std::forward(args)...); \ }); @@ -66,7 +66,7 @@ template class ExprFunctor { private: using TSelf = ExprFunctor; - using FType = tvm::IRFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -117,7 +117,7 @@ class ExprFunctor { virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index d05099f781ac..a0422fa7f446 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -78,9 +78,9 @@ class ValueNode : public RelayNode { class Value : public NodeRef { public: Value() {} - explicit Value(NodePtr n) : NodeRef(n) {} + explicit Value(ObjectPtr n) : NodeRef(n) {} const ValueNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } using ContainerType = ValueNode; diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 8b17020a1132..10d72349d0f5 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -281,10 +281,10 @@ class ModuleNode : public RelayNode { struct Module : public NodeRef { Module() {} - explicit Module(NodePtr p) : NodeRef(p) {} + explicit Module(ObjectPtr<::tvm::Object> p) : NodeRef(p) {} - inline ModuleNode* operator->() const { - return static_cast(node_.get()); + ModuleNode* operator->() const { + return static_cast(get_mutable()); } using ContainerType = ModuleNode; diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 0a6d3725655f..572c194bc269 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -138,7 +138,7 @@ class Op : public relay::Expr { /*! \brief default constructor */ Op() {} /*! \brief constructor from node pointer */ - explicit Op(NodePtr n) : Expr(n) {} + explicit Op(ObjectPtr n) : Expr(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -221,11 +221,12 @@ class OpRegistry { const Attrs&, const TypeReporter&)> type_rel_func); /*! - * \brief Set the type key of attributes. - * \param type_key The type of of the attrs field. + * \brief Set the the attrs type key and index to be AttrsType. + * \tparam AttrsType the attribute type to b set. * \return reference to self. */ - inline OpRegistry& set_attrs_type_key(const std::string& type_key); + template + inline OpRegistry& set_attrs_type(); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -397,7 +398,7 @@ class OpMap { // implementations inline const OpNode* Op::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } template @@ -496,10 +497,10 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) return *this; } -inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) - const std::string& type_key) { - get()->attrs_type_key = type_key; - get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str()); +template +inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) + get()->attrs_type_key = AttrsType::_type_key; + get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); return *this; } diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 7f1c47e03592..c15523cb25de 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -57,8 +57,8 @@ class PatternFunctor; #define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitPattern_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitPattern_(static_cast(n.get()), \ std::forward(args)...); \ }); @@ -66,7 +66,7 @@ template class PatternFunctor { private: using TSelf = PatternFunctor; - using FType = tvm::IRFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -103,7 +103,7 @@ class PatternFunctor { virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPatternDefault_(const Node* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index a2119c90f750..08ea3075cb83 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -134,16 +134,16 @@ class PassContext : public NodeRef { * \return const access pointer. */ const PassContextNode* operator->() const { - CHECK(node_.get() != nullptr); - return static_cast(node_.get()); + CHECK(get() != nullptr); + return static_cast(get()); } /*! * \brief mutable accessor. * \return mutable access pointer. */ PassContextNode* operator->() { - CHECK(node_.get() != nullptr); - return static_cast(node_.get()); + CHECK(get() != nullptr); + return static_cast(get_mutable()); } /*! * \brief Construct a PassContext containing the default configurations. diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 16e36785c533..a5cc3c83383e 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -58,7 +58,7 @@ class TypeNode : public RelayNode { class Type : public NodeRef { public: Type() {} - explicit Type(NodePtr p) : NodeRef(p) {} + explicit Type(ObjectPtr p) : NodeRef(p) {} using ContainerType = TypeNode; }; @@ -430,10 +430,11 @@ class TypeReporterNode : public Node { class TypeReporter : public NodeRef { public: TypeReporter() {} - explicit TypeReporter(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) { + explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : NodeRef(n) { } TypeReporterNode* operator->() const { - return static_cast(node_.get()); + return const_cast( + static_cast(get())); } using ContainerType = TypeReporterNode; }; diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index b058fd63a2f5..267504beb11a 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -98,13 +98,12 @@ typedef enum { kTVMType = 5U, kTVMContext = 6U, kArrayHandle = 7U, - kNodeHandle = 8U, + kObjectHandle = 8U, kModuleHandle = 9U, kFuncHandle = 10U, kStr = 11U, kBytes = 12U, kNDArrayContainer = 13U, - kObjectHandle = 14U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 6b4f01e4ac9b..01c08d324fcb 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -69,7 +69,7 @@ class ObjAllocatorBase { "make_node can only be used to create NodeBase"); T* ptr = Handler::New(static_cast(this), std::forward(args)...); - ptr->type_index_ = T::type_index(); + ptr->type_index_ = T::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); } diff --git a/include/tvm/runtime/node_base.h b/include/tvm/runtime/node_base.h deleted file mode 100644 index 8b47c18a09a7..000000000000 --- a/include/tvm/runtime/node_base.h +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/node_base.h - * \brief Base data structure for Node. - * - * \note Node is not a runtime feature. - * This file only exposes the signature of NodePtr for PackedFunc. - */ -#ifndef TVM_RUNTIME_NODE_BASE_H_ -#define TVM_RUNTIME_NODE_BASE_H_ - -#include -#include - -namespace tvm { - -// forward declarations -template -class NodePtr; -class Node; -class NodeRef; - -/*! - * \brief Base class of Node for runtime destructor purposes. - * - * Node is a reference counted object which is used to construct AST. - * Each node is backed by a custom deleter, which deletes the object. - * Do not call create raw Node pointer, always use tvm::make_node. - * - * \note In most cases, please inheritate tvm::Node. - * \sa Node, NodePtr, make_node - */ -class NodeBase { - public: - /*! - * \brief type of NodeBase deleter - * \param self pointer to the NodeBase. - */ - typedef void (*FDeleter)(NodeBase* self); - - protected: - // default constructor and copy constructor - NodeBase() {} - // override the copy and assign constructors to do nothing. - // This is to make sure only contents, but not deleter and ref_counter - // are copied when a child class copies itself. - NodeBase(const NodeBase& other) { // NOLINT(*) - } - NodeBase(NodeBase&& other) { // NOLINT(*) - } - NodeBase& operator=(const NodeBase& other) { //NOLINT(*) - return *this; - } - NodeBase& operator=(NodeBase&& other) { //NOLINT(*) - return *this; - } - - private: - /*! \brief Internal reference counter */ - std::atomic ref_counter_{0}; - /*! - * \brief deleter of this object to enable customized allocation. - * If the deleter is nullptr, no deletion will be performed. - * The creator of the Node must always set the deleter field properly. - */ - FDeleter deleter_ = nullptr; - // reference counting functions - void IncRef() { - ref_counter_.fetch_add(1, std::memory_order_relaxed); - } - void DecRef() { - if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - if (this->deleter_ != nullptr) { - (*this->deleter_)(this); - } - } - } - int use_count() const { - return ref_counter_.load(std::memory_order_relaxed); - } - // friend declaration - template - friend class NodePtr; - template - friend NodePtr make_node(Args&&...); -}; - -/*! - * \brief Smart pointer for Node containers, - * must be subclass of NodeBase - * \tparam T the content data type. - */ -template -class NodePtr { - public: - /*! \brief default constructor */ - NodePtr() {} - /*! \brief default constructor */ - NodePtr(std::nullptr_t) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other The value to be moved - */ - NodePtr(const NodePtr& other) // NOLINT(*) - : NodePtr(other.data_) { - } - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - NodePtr(const NodePtr& other) // NOLINT(*) - : NodePtr(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class NodePtr to parent"); - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - NodePtr(NodePtr&& other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - template - NodePtr(NodePtr&& other) // NOLINT(*) - : data_(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class NodePtr to parent"); - other.data_ = nullptr; - } - /*! \brief destructor */ - ~NodePtr() { - this->reset(); - } - /*! - * \brief Swap this array with another NDArray - * \param other The other NDArray - */ - void swap(NodePtr& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - /*! - * \return Get the content of the pointer - */ - T* get() const { - return static_cast(data_); - } - /*! - * \return The pointer - */ - T* operator->() const { - return get(); - } - /*! - * \return The reference - */ - T& operator*() const { // NOLINT(*) - return *get(); - } - /*! - * \brief copy assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NodePtr& operator=(const NodePtr& other) { // NOLINT(*) - // takes in plane operator to enable copy elison. - // copy-and-swap idiom - NodePtr(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NodePtr& operator=(NodePtr&& other) { // NOLINT(*) - // copy-and-swap idiom - NodePtr(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! \brief reset the content of ptr to be nullptr */ - void reset() { - if (data_ != nullptr) { - data_->DecRef(); - data_ = nullptr; - } - } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { - return data_ != nullptr ? data_->use_count() : 0; - } - /*! \return whether the reference is unique */ - bool unique() const { - return data_ != nullptr && data_->use_count() == 1; - } - /*! \return Whether two NodePtr do not equals each other */ - bool operator==(const NodePtr& other) const { - return data_ == other.data_; - } - /*! \return Whether two NodePtr equals each other */ - bool operator!=(const NodePtr& other) const { - return data_ != other.data_; - } - /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t null) const { - return data_ == nullptr; - } - /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return data_ != nullptr; - } - - private: - /*! \brief internal pointer field */ - NodeBase* data_{nullptr}; - /*! - * \brief constructor from NodeBase - * \param data The node base pointer - */ - explicit NodePtr(NodeBase* data) - : data_(data) { - if (data != nullptr) { - data_->IncRef(); - } - } - // friend declaration - friend class Node; - template - friend class NodePtr; - template - friend NodePtr make_node(Args&&...); -}; -} // namespace tvm - -#endif // TVM_RUNTIME_NODE_BASE_H_ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 7291510c16df..143f3bb35220 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -65,7 +65,7 @@ enum TypeIndex { * - _type_index: * Static type index of the object, if assigned to TypeIndex::kDynamic * the type index will be assigned during runtime. - * Runtime type index can be accessed by ObjectType::type_index(); + * Runtime type index can be accessed by ObjectType::TypeIndex(); * - _type_key: * The unique string identifier of tyep type. * - _type_final: @@ -147,10 +147,23 @@ class Object { * \param self pointer to the Object. */ typedef void (*FDeleter)(Object* self); - /*! \return The internal type index of the object. */ + /*! \return The internal runtime type index of the object. */ uint32_t type_index() const { return type_index_; } + /*! + * \return the type key of the object. + * \note this operation is expensive, can be used for error reporting. + */ + std::string GetTypeKey() const { + return TypeIndex2Key(type_index_); + } + /*! + * \return A hash value of the return of GetTypeKey. + */ + size_t GetTypeKeyHash() const { + return TypeIndex2KeyHash(type_index_); + } /*! * Check if the object is an instance of TargetType. * \tparam TargetType The target type to be checked. @@ -159,6 +172,25 @@ class Object { template inline bool IsInstance() const; + /*! + * \brief Get the type key of the corresponding index from runtime. + * \param tindex The type index. + * \return the result. + */ + TVM_DLL static std::string TypeIndex2Key(uint32_t tindex); + /*! + * \brief Get the type key hash of the corresponding index from runtime. + * \param tindex The type index. + * \return the related key-hash. + */ + TVM_DLL static size_t TypeIndex2KeyHash(uint32_t tindex); + /*! + * \brief Get the type index of the corresponding key from runtime. + * \param key The type key. + * \return the result. + */ + TVM_DLL static uint32_t TypeKey2Index(const char* key); + #if TVM_OBJECT_ATOMIC_REF_COUNTER using RefCounterType = std::atomic; #else @@ -170,9 +202,30 @@ class Object { static constexpr bool _type_final = false; static constexpr uint32_t _type_child_slots = 0; static constexpr bool _type_child_slots_can_overflow = true; - static const uint32_t _GetOrAllocRuntimeTypeIndex() { + static uint32_t _GetOrAllocRuntimeTypeIndex() { return 0; } + static uint32_t RuntimeTypeIndex() { + return 0; + } + + // Default constructor and copy constructor + Object() {} + // Override the copy and assign constructors to do nothing. + // This is to make sure only contents, but not deleter and ref_counter + // are copied when a child class copies itself. + // This will enable us to use make_object(*obj_ptr) + // to copy an existing object. + Object(const Object& other) { // NOLINT(*) + } + Object(Object&& other) { // NOLINT(*) + } + Object& operator=(const Object& other) { //NOLINT(*) + return *this; + } + Object& operator=(Object&& other) { //NOLINT(*) + return *this; + } protected: // The fields of the base object cell. @@ -215,18 +268,6 @@ class Object { uint32_t type_child_slots, bool type_child_slots_can_overflow); - /*! - * \brief Get the type key of the corresponding index from runtime. - * \param tindex The type index. - */ - TVM_DLL static std::string TypeIndex2Key(uint32_t tindex); - - /*! - * \brief Get the type index of the corresponding key from runtime. - * \param key The type key. - */ - TVM_DLL static uint32_t TypeKey2Index(const char* key); - private: // reference counter related operations /*! \brief developer function, increases reference counter. */ @@ -256,6 +297,32 @@ class Object { friend class TVMObjectCAPI; }; +/*! + * \brief Get a reference type from a raw object ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam ObjectType The node type + * \return The corresponding RefType + */ +template +inline RefType GetRef(const ObjectType* ptr); + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The inptut reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template +inline SubRef Downcast(BaseRef ref); + /*! * \brief A custom smart pointer for Object. * \tparam T the content data type. @@ -389,7 +456,7 @@ class ObjectPtr { /*! \brief internal pointer field */ Object* data_{nullptr}; /*! - * \brief constructor from NodeBase + * \brief constructor from Object * \param data The data pointer */ explicit ObjectPtr(Object* data) : data_(data) { @@ -400,6 +467,7 @@ class ObjectPtr { // friend classes friend class Object; friend class ObjectRef; + friend struct ObjectHash; template friend class ObjectPtr; template @@ -407,6 +475,9 @@ class ObjectPtr { friend class TVMPODValue_; friend class TVMArgsSetter; friend class TVMRetValue; + friend class TVMArgValue; + template + friend RefType GetRef(const ObjType* ptr); }; /*! \brief Base class of all object reference */ @@ -416,10 +487,54 @@ class ObjectRef { ObjectRef() = default; /*! \brief Constructor from existing object ptr */ explicit ObjectRef(ObjectPtr data) : data_(data) {} + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool same_as(const ObjectRef& other) const { + return data_ == other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator==(const ObjectRef& other) const { + return data_ == other.data_; + } + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + bool operator!=(const ObjectRef& other) const { + return data_ != other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref by address. + * \return the compare result. + */ + bool operator<(const ObjectRef& other) const { + return data_.get() < other.data_.get(); + } + /*! \return whether the expression is null */ + bool defined() const { + return data_ != nullptr; + } /*! \return the internal object pointer */ - inline const Object* get() const; + const Object* get() const { + return data_.get(); + } /*! \return the internal node pointer */ - inline const Object* operator->() const; + const Object* operator->() const { + return get(); + } + /*! \return whether the reference is unique */ + bool unique() const { + return data_.unique(); + } /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. @@ -434,25 +549,81 @@ class ObjectRef { template inline const ObjectType* as() const; - /*! \brief type indicate the container type */ + /*! \brief type indicate the container type. */ using ContainerType = Object; protected: /*! \brief Internal pointer that backs the reference. */ ObjectPtr data_; + /*! \return return a mutable internal ptr, can be used by sub-classes. */ + Object* get_mutable() const { + return data_.get(); + } + /*! + * \brief Internal helper function downcast a ref without check. + * \note Only used for internal dev purposes. + * \tparam T The target reference type. + * \return The casted result. + */ + template + static T DowncastNoCheck(ObjectRef ref) { + return T(std::move(ref.data_)); + } + /*! + * \brief Internal helper function get data_ as ObjectPtr of ObjectType. + * \note only used for internal dev purpose. + * \tparam ObjectType The corresponding object type. + * \return the corresponding type. + */ + template + static ObjectPtr GetDataPtr(const ObjectRef& ref) { + return ObjectPtr(ref.data_.data_); + } // friend classes. + friend struct ObjectHash; friend class TVMRetValue; friend class TVMArgsSetter; + template + friend SubRef Downcast(BaseRef ref); }; + +/*! \brief ObjectRef hash functor */ +struct ObjectHash { + size_t operator()(const ObjectRef& a) const { + return operator()(a.data_); + } + + template + size_t operator()(const ObjectPtr& a) const { + return std::hash()(a.get()); + } +}; + + +/*! \brief ObjectRef equal functor */ +struct ObjectEqual { + bool operator()(const ObjectRef& a, const ObjectRef& b) const { + return a.same_as(b); + } + + template + size_t operator()(const ObjectPtr& a, const ObjectPtr& b) const { + return a == b; + } +}; + + /*! * \brief helper macro to declare a base object type that can be inheritated. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static const uint32_t type_index() { \ - if (_type_index != TypeIndex::kDynamic) return _type_index; \ + static const uint32_t RuntimeTypeIndex() { \ + if (_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ + return _type_index; \ + } \ return _GetOrAllocRuntimeTypeIndex(); \ } \ static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ @@ -551,11 +722,11 @@ inline bool Object::IsInstance() const { if (TargetType::_type_final) { // if the target type is a final type // then we only need to check the equivalence. - return self->type_index_ == TargetType::type_index(); + return self->type_index_ == TargetType::RuntimeTypeIndex(); } else { // if target type is a non-leaf type // Check if type index falls into the range of reserved slots. - uint32_t begin = TargetType::type_index(); + uint32_t begin = TargetType::RuntimeTypeIndex(); // The condition will be optimized by constant-folding. if (TargetType::_type_child_slots != 0) { uint32_t end = begin + TargetType::_type_child_slots; @@ -565,22 +736,15 @@ inline bool Object::IsInstance() const { } if (!TargetType::_type_child_slots_can_overflow) return false; // Invariance: parent index is always smaller than the child. - if (self->type_index_ < TargetType::type_index()) return false; + if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false; // The rare slower-path, check type hierachy. - return self->DerivedFrom(TargetType::type_index()); + return self->DerivedFrom(TargetType::RuntimeTypeIndex()); } } else { return false; } } -inline const Object* ObjectRef::get() const { - return data_.data_; -} - -inline const Object* ObjectRef::operator->() const { - return get(); -} template inline const ObjectType* ObjectRef::as() const { @@ -591,7 +755,27 @@ inline const ObjectType* ObjectRef::as() const { return nullptr; } } + +template +inline RefType GetRef(const ObjType* ptr) { + static_assert(std::is_base_of::value, + "Can only cast to the ref of same container type"); + return RefType(ObjectPtr(const_cast(static_cast(ptr)))); +} + +template +inline SubRef Downcast(BaseRef ref) { + CHECK(ref->template IsInstance()) + << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + return SubRef(std::move(ref.data_)); +} + } // namespace runtime + +template +using NodePtr = runtime::ObjectPtr; + } // namespace tvm #endif // TVM_RUNTIME_OBJECT_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 2bfa3323e4f1..649a5058a9a5 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -40,7 +40,6 @@ #include "module.h" #include "ndarray.h" #include "object.h" -#include "node_base.h" // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY @@ -52,6 +51,8 @@ namespace tvm { class Integer; class DataType; class Expr; +class Node; +class NodeRef; namespace runtime { @@ -490,9 +491,12 @@ class TVMPODValue_ { return NDArray(static_cast(value_.v_handle)); } operator ObjectRef() const { - if (type_code_ == kNull) return ObjectRef(ObjectPtr(nullptr)); + if (type_code_ == kNull) { + return ObjectRef(ObjectPtr(nullptr)); + } TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - return ObjectRef(ObjectPtr(static_cast(value_.v_handle))); + return ObjectRef( + ObjectPtr(static_cast(value_.v_handle))); } operator TVMContext() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); @@ -512,9 +516,14 @@ class TVMPODValue_ { CHECK_LT(type_code_, kExtEnd); return static_cast(value_.v_handle)[0]; } + template::value>::type> + inline bool IsObjectRef() const; int type_code() const { return type_code_; } + /*! * \brief return handle as specific pointer type. * \tparam T the data type. @@ -567,6 +576,7 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator ObjectRef; + using TVMPODValue_::IsObjectRef; // conversion operator. operator std::string() const { @@ -616,15 +626,9 @@ class TVMArgValue : public TVMPODValue_ { typename = typename std::enable_if< std::is_class::value>::type> inline operator T() const; - template::value>::type> - inline bool IsNodeType() const; inline operator tvm::DataType() const; inline operator tvm::Expr() const; inline operator tvm::Integer() const; - // get internal node ptr, if it is node - inline NodePtr& node_sptr(); }; /*! @@ -663,6 +667,8 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator ObjectRef; + using TVMPODValue_::IsObjectRef; + TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } @@ -760,11 +766,19 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(ObjectRef other) { - this->Clear(); - type_code_ = kObjectHandle; - // move the handle out - value_.v_handle = other.data_.data_; - other.data_.data_ = nullptr; + return operator=(std::move(other.data_)); + } + template + TVMRetValue& operator=(ObjectPtr other) { + if (other.data_ != nullptr) { + this->Clear(); + type_code_ = kObjectHandle; + // move the handle out + value_.v_handle = other.data_; + other.data_ = nullptr; + } else { + SwitchToPOD(kNull); + } return *this; } TVMRetValue& operator=(PackedFunc f) { @@ -814,7 +828,7 @@ class TVMRetValue : public TVMPODValue_ { } /*! \return The value field, if the data is POD */ const TVMValue& value() const { - CHECK(type_code_ != kNodeHandle && + CHECK(type_code_ != kObjectHandle && type_code_ != kFuncHandle && type_code_ != kModuleHandle && type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; @@ -827,8 +841,6 @@ class TVMRetValue : public TVMPODValue_ { inline operator T() const; template inline TNodeRef AsNodeRef() const; - inline TVMRetValue& operator=(const NodeRef& other); - inline TVMRetValue& operator=(const NodePtr& other); // type related inline operator tvm::DataType() const; inline TVMRetValue& operator=(const tvm::DataType& other); @@ -857,11 +869,6 @@ class TVMRetValue : public TVMPODValue_ { *this = other.operator NDArray(); break; } - case kNodeHandle: { - SwitchToClass >( - kNodeHandle, *other.template ptr >()); - break; - } case kObjectHandle: { *this = other.operator ObjectRef(); break; @@ -908,7 +915,6 @@ class TVMRetValue : public TVMPODValue_ { case kStr: delete ptr(); break; case kFuncHandle: delete ptr(); break; case kModuleHandle: delete ptr(); break; - case kNodeHandle: delete ptr >(); break; case kNDArrayContainer: { static_cast(value_.v_handle)->DecRef(); break; @@ -939,7 +945,6 @@ inline const char* TypeCode2Str(int type_code) { case kBytes: return "bytes"; case kHandle: return "handle"; case kNull: return "NULL"; - case kNodeHandle: return "NodeHandle"; case kArrayHandle: return "ArrayHandle"; case kTVMType: return "TVMType"; case kTVMContext: return "TVMContext"; @@ -1057,8 +1062,6 @@ inline PackedFunc::FType PackedFunc::body() const { return body_; } - - // internal namespace namespace detail { @@ -1163,8 +1166,12 @@ class TVMArgsSetter { type_codes_[i] = kNDArrayContainer; } void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kObjectHandle; + if (value.defined()) { + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kObjectHandle; + } else { + type_codes_[i] = kNull; + } } void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) if (value.type_code() == kStr) { @@ -1181,8 +1188,6 @@ class TVMArgsSetter { typename = typename std::enable_if< extension_type_info::code != 0>::type> inline void operator()(size_t i, const T& value) const; - // NodeRef related extenstions: in tvm/packed_func_ext.h - inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) inline void operator()(size_t i, const tvm::DataType& t) const; private: diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index af3e929ac3fa..36265667e5b6 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -56,7 +56,7 @@ enum AttachType : int { class Stage : public NodeRef { public: Stage() {} - explicit Stage(NodePtr n) : NodeRef(n) {} + explicit Stage(ObjectPtr n) : NodeRef(n) {} /*! * \brief create a new schedule for op. * \param op The operator in the schedule @@ -280,7 +280,7 @@ class Stage : public NodeRef { class Schedule : public NodeRef { public: Schedule() {} - explicit Schedule(NodePtr n) : NodeRef(n) {} + explicit Schedule(ObjectPtr n) : NodeRef(n) {} /*! * \brief Get a copy of current schedule. * \return The copied schedule. @@ -403,7 +403,7 @@ class Schedule : public NodeRef { class IterVarRelation : public NodeRef { public: IterVarRelation() {} - explicit IterVarRelation(NodePtr n) : NodeRef(n) {} + explicit IterVarRelation(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -417,7 +417,7 @@ class IterVarRelation : public NodeRef { class IterVarAttr : public NodeRef { public: IterVarAttr() {} - explicit IterVarAttr(NodePtr n) : NodeRef(n) {} + explicit IterVarAttr(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -745,25 +745,25 @@ class SingletonNode : public IterVarRelationNode { // implementations inline const StageNode* Stage::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline StageNode* Stage::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } inline const ScheduleNode* Schedule::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline ScheduleNode* Schedule::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } inline const IterVarRelationNode* IterVarRelation::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline const IterVarAttrNode* IterVarAttr::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_SCHEDULE_H_ diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index f37cc7bed7d1..6471c9c69a62 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -50,7 +50,7 @@ class Tensor : public NodeRef { public: /*! \brief default constructor, used internally */ Tensor() {} - explicit Tensor(NodePtr n) : NodeRef(n) {} + explicit Tensor(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -141,7 +141,7 @@ class Operation : public ir::FunctionRef { public: /*! \brief default constructor */ Operation() {} - explicit Operation(NodePtr n) : FunctionRef(n) {} + explicit Operation(ObjectPtr n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -189,7 +189,7 @@ class TensorNode : public Node { // Implementations of inline functions inline const TensorNode* Tensor::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline size_t Tensor::ndim() const { @@ -250,19 +250,17 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) namespace std { template <> -struct hash<::tvm::Operation> { - std::size_t operator()(const ::tvm::Operation& k) const { - return k.hash(); - } +struct hash<::tvm::Operation> : public ::tvm::NodeHash { }; template <> struct hash<::tvm::Tensor> { std::size_t operator()(const ::tvm::Tensor& k) const { + ::tvm::NodeHash hasher; if (k.defined() && k->op.defined()) { - return k->op.hash(); + return hasher(k->op); } else{ - return k.hash(); + return hasher(k); } } }; diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index b5ca6eb4358b..152a27f6e2a9 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -112,7 +112,7 @@ class TensorIntrinNode : public Node { }; inline const TensorIntrinNode* TensorIntrin::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } // Internal node container of tensor intrinsic calling. @@ -170,7 +170,7 @@ class TensorIntrinCallNode : public Node { }; inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm diff --git a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc index 1eff6c45e1fc..b4bfd4270775 100644 --- a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc @@ -242,7 +242,7 @@ extern "C" int funcInvokeCallback(TVMValue *args, for (int i = 0; i < numArgs; ++i) { TVMValue arg = args[i]; int tcode = typeCodes[i]; - if (tcode == kNodeHandle || tcode == kFuncHandle || tcode == kModuleHandle) { + if (tcode == kObjectHandle || tcode == kFuncHandle || tcode == kModuleHandle) { TVMCbArgToReturn(&arg, tcode); } jobject jarg = tvmRetValueToJava(env, arg, tcode); @@ -259,8 +259,8 @@ extern "C" int funcInvokeCallback(TVMValue *args, reinterpret_cast(resourceHandle), jargs); TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); - const int prevNumStrArg = e->tvmFuncArgPushedStrs.size(); - const int prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); + const size_t prevNumStrArg = e->tvmFuncArgPushedStrs.size(); + const size_t prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); // convert returned (java) TVMValue to (C) TVMValue env->CallStaticVoidMethod(clsFunc, pushArgToStack, jretValue); diff --git a/nnvm/include/nnvm/compiler/util.h b/nnvm/include/nnvm/compiler/util.h index fa8b69f9b70a..9555c0e7b3ea 100644 --- a/nnvm/include/nnvm/compiler/util.h +++ b/nnvm/include/nnvm/compiler/util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -56,7 +56,7 @@ inline tvm::Array ShapeToArray(TShape shape) { * \return An Array of Expr, where each element is a constant int32 */ inline tvm::Array ShapeToIntArray(TShape shape) { - return tvm::Array(ShapeToArray(shape).node_); + return tvm::Downcast >(ShapeToArray(shape)); } } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index 3da95e879fa7..c9cdaef63935 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -392,6 +392,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs") *rv = ret; }); +TVM_REGISTER_NODE_TYPE(GraphFuncNode); +TVM_REGISTER_NODE_TYPE(GraphCacheEntryNode); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const GraphFuncNode *op, IRPrinter *p) { p->stream << "GraphFunc(name=" << op->func_name diff --git a/nnvm/src/compiler/compile_engine.h b/nnvm/src/compiler/compile_engine.h index 35287f5a9358..e8d33cb4be7e 100644 --- a/nnvm/src/compiler/compile_engine.h +++ b/nnvm/src/compiler/compile_engine.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -92,7 +92,7 @@ class GraphCacheEntry : public ::tvm::NodeRef { GraphCacheEntry() {} explicit GraphCacheEntry(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} GraphCacheEntryNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } using ContainerType = GraphCacheEntryNode; }; diff --git a/nnvm/src/compiler/graph_runtime.h b/nnvm/src/compiler/graph_runtime.h index 3a847de83d9f..7b324ba100ad 100644 --- a/nnvm/src/compiler/graph_runtime.h +++ b/nnvm/src/compiler/graph_runtime.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,6 @@ #include #include #include -#include #include #include #include diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index bbcc62a99ad8..45f1451663e6 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -115,7 +115,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute") const Array& out_info) -> Array { TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info); - if ((*ret.ptr<::tvm::NodePtr >())->derived_from()) { + if (ret.IsObjectRef()) { return {ret.operator Tensor()}; } else { return ret; diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index cafb99926bfa..ab18c2d7337a 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -1242,7 +1242,7 @@ Array GetIntArray(Array arr) { CHECK(!arr[i].defined() || arr[i].as()) << "Expect an int array"; } - return Array(arr.node_); + return Downcast >(arr); } NNVM_REGISTER_OP(slice_like) diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 22fb6c335dcc..2f0b5babda4d 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement +# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import """Function configuration API.""" from __future__ import absolute_import @@ -32,9 +32,8 @@ from .types import TVMValue, TypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 -from .node import NodeBase +from .object import ObjectBase, _set_class_node from . import object as _object -from . import node as _node FunctionHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p @@ -108,9 +107,9 @@ def _make_tvm_args(args, temp_args): values = (TVMValue * num_args)() type_codes = (ctypes.c_int * num_args)() for i, arg in enumerate(args): - if isinstance(arg, NodeBase): + if isinstance(arg, ObjectBase): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.NODE_HANDLE + type_codes[i] = TypeCode.OBJECT_HANDLE elif arg is None: values[i].v_handle = None type_codes[i] = TypeCode.NULL @@ -148,7 +147,7 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, (list, tuple, dict, NodeGeneric)): arg = convert_to_node(arg) values[i].v_handle = arg.handle - type_codes[i] = TypeCode.NODE_HANDLE + type_codes[i] = TypeCode.OBJECT_HANDLE temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): values[i].v_handle = arg.handle @@ -164,9 +163,6 @@ def _make_tvm_args(args, temp_args): values[i].v_handle = arg.handle type_codes[i] = TypeCode.FUNC_HANDLE temp_args.append(arg) - elif isinstance(arg, _CLASS_OBJECT): - values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_HANDLE else: raise TypeError("Don't know how to handle type %s" % type(arg)) return values, type_codes, num_args @@ -226,7 +222,7 @@ def __init_handle_by_constructor__(fconstructor, args): raise get_last_ffi_error() _ = temp_args _ = args - assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE) + assert ret_tcode.value == TypeCode.OBJECT_HANDLE handle = ret_val.v_handle return handle @@ -247,7 +243,6 @@ def _handle_return_func(x): return _CLASS_FUNCTION(handle, False) # setup return handle for function type -_node.__init_by_constructor__ = __init_handle_by_constructor__ _object.__init_by_constructor__ = __init_handle_by_constructor__ RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module diff --git a/python/tvm/_ffi/_ctypes/node.py b/python/tvm/_ffi/_ctypes/node.py deleted file mode 100644 index 39fe0ef35525..000000000000 --- a/python/tvm/_ffi/_ctypes/node.py +++ /dev/null @@ -1,102 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, protected-access -# pylint: disable=no-member, missing-docstring, not-callable -from __future__ import absolute_import - -import ctypes -from ..base import _LIB, check_call, c_str -from ..node_generic import _set_class_node_base -from .types import TVMValue, TypeCode -from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func - -NodeHandle = ctypes.c_void_p -__init_by_constructor__ = None - -"""Maps node type to its constructor""" -NODE_TYPE = {} - -def _register_node(index, cls): - """register node class""" - NODE_TYPE[index] = cls - -def _return_node(x): - """Return node function""" - handle = x.v_handle - if not isinstance(handle, NodeHandle): - handle = NodeHandle(handle) - tindex = ctypes.c_int() - check_call(_LIB.TVMNodeGetTypeIndex(handle, ctypes.byref(tindex))) - cls = NODE_TYPE.get(tindex.value, NodeBase) - # Avoid calling __init__ of cls, instead directly call __new__ - # This allows child class to implement their own __init__ - node = cls.__new__(cls) - node.handle = handle - return node - - -RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node -C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func( - _return_node, TypeCode.NODE_HANDLE) - - -class NodeBase(object): - __slots__ = ["handle"] - # pylint: disable=no-member - def __del__(self): - if _LIB is not None: - check_call(_LIB.TVMNodeFree(self.handle)) - - def __getattr__(self, name): - ret_val = TVMValue() - ret_type_code = ctypes.c_int() - ret_success = ctypes.c_int() - check_call(_LIB.TVMNodeGetAttr( - self.handle, c_str(name), - ctypes.byref(ret_val), - ctypes.byref(ret_type_code), - ctypes.byref(ret_success))) - if not ret_success.value: - raise AttributeError( - "'%s' object has no attribute '%s'" % (str(type(self)), name)) - return RETURN_SWITCH[ret_type_code.value](ret_val) - - def __init_handle_by_constructor__(self, fconstructor, *args): - """Initialize the handle by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return handle is directly set into the Node object - instead of creating a new Node. - """ - # assign handle first to avoid error raising - self.handle = None - handle = __init_by_constructor__(fconstructor, args) - if not isinstance(handle, NodeHandle): - handle = NodeHandle(handle) - self.handle = handle - -_set_class_node_base(NodeBase) diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 5ddceb166677..c3ae56822198 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -21,6 +21,7 @@ import ctypes from ..base import _LIB, check_call from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func +from ..node_generic import _set_class_node_base ObjectHandle = ctypes.c_void_p @@ -29,6 +30,13 @@ """Maps object type to its constructor""" OBJECT_TYPE = {} +_CLASS_NODE = None + +def _set_class_node(node_class): + global _CLASS_NODE + _CLASS_NODE = node_class + + def _register_object(index, cls): """register object class""" OBJECT_TYPE[index] = cls @@ -40,7 +48,7 @@ def _return_object(x): handle = ObjectHandle(handle) tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) - cls = OBJECT_TYPE.get(tindex.value, ObjectBase) + cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE) # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) @@ -83,3 +91,6 @@ def __init_handle_by_constructor__(self, fconstructor, *args): if not isinstance(handle, ObjectHandle): handle = ObjectHandle(handle) self.handle = handle + + +_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 76fa96376b47..4b7b2c88ffa5 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -31,13 +31,12 @@ cdef enum TVMTypeCode: kTVMType = 5 kTVMContext = 6 kArrayHandle = 7 - kNodeHandle = 8 + kObjectHandle = 8 kModuleHandle = 9 kFuncHandle = 10 kStr = 11 kBytes = 12 kNDArrayContainer = 13 - kObjectHandle = 14 kExtBegin = 15 cdef extern from "tvm/runtime/c_runtime_api.h": @@ -78,7 +77,7 @@ ctypedef void* TVMStreamHandle ctypedef void* TVMRetValueHandle ctypedef void* TVMFunctionHandle ctypedef void* ObjectHandle -ctypedef void* NodeHandle + ctypedef struct TVMNDArrayContainer: DLTensor dl_tensor @@ -134,18 +133,6 @@ cdef extern from "tvm/runtime/c_runtime_api.h": int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index) -cdef extern from "tvm/c_dsl_api.h": - int TVMNodeFree(NodeHandle handle) - int TVMNodeTypeKey2Index(const char* type_key, - int* out_index) - int TVMNodeGetTypeIndex(NodeHandle handle, - int* out_index) - int TVMNodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* out_value, - int* out_type_code, - int* out_success) - cdef inline py_str(const char* x): if PY_MAJOR_VERSION < 3: return x diff --git a/python/tvm/_ffi/_cython/core.pyx b/python/tvm/_ffi/_cython/core.pyx index a9349338fc6a..cbf9d5859046 100644 --- a/python/tvm/_ffi/_cython/core.pyx +++ b/python/tvm/_ffi/_cython/core.pyx @@ -17,7 +17,7 @@ include "./base.pxi" include "./object.pxi" -include "./node.pxi" +# include "./node.pxi" include "./function.pxi" include "./ndarray.pxi" diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index ceacf7407170..a2360427b6c7 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -41,10 +41,9 @@ cdef int tvm_callback(TVMValue* args, for i in range(num_args): value = args[i] tcode = type_codes[i] - if (tcode == kNodeHandle or + if (tcode == kObjectHandle or tcode == kFuncHandle or tcode == kModuleHandle or - tcode == kObjectHandle or tcode > kExtBegin): CALL(TVMCbArgToReturn(&value, tcode)) @@ -98,9 +97,9 @@ cdef inline int make_arg(object arg, list temp_args) except -1: """Pack arguments into c args tvm call accept""" cdef unsigned long long ptr - if isinstance(arg, NodeBase): - value[0].v_handle = (arg).chandle - tcode[0] = kNodeHandle + if isinstance(arg, ObjectBase): + value[0].v_handle = (arg).chandle + tcode[0] = kObjectHandle elif isinstance(arg, NDArrayBase): value[0].v_handle = (arg).chandle tcode[0] = (kNDArrayContainer if @@ -152,12 +151,9 @@ cdef inline int make_arg(object arg, temp_args.append(tstr) elif isinstance(arg, (list, tuple, dict, NodeGeneric)): arg = convert_to_node(arg) - value[0].v_handle = (arg).chandle - tcode[0] = kNodeHandle - temp_args.append(arg) - elif isinstance(arg, _CLASS_OBJECT): value[0].v_handle = (arg).chandle tcode[0] = kObjectHandle + temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): value[0].v_handle = c_handle(arg.handle) tcode[0] = kModuleHandle @@ -188,9 +184,7 @@ cdef inline bytearray make_ret_bytes(void* chandle): cdef inline object make_ret(TVMValue value, int tcode): """convert result to return value.""" - if tcode == kNodeHandle: - return make_ret_node(value.v_handle) - elif tcode == kObjectHandle: + if tcode == kObjectHandle: return make_ret_object(value.v_handle) elif tcode == kNull: return None @@ -314,6 +308,7 @@ cdef class FunctionBase: _CLASS_FUNCTION = None _CLASS_MODULE = None _CLASS_OBJECT = None +_CLASS_NODE = None def _set_class_module(module_class): """Initialize the module.""" @@ -327,3 +322,7 @@ def _set_class_function(func_class): def _set_class_object(obj_class): global _CLASS_OBJECT _CLASS_OBJECT = obj_class + +def _set_class_node(node_class): + global _CLASS_NODE + _CLASS_NODE = node_class diff --git a/python/tvm/_ffi/_cython/node.pxi b/python/tvm/_ffi/_cython/node.pxi deleted file mode 100644 index 5e0c366e5600..000000000000 --- a/python/tvm/_ffi/_cython/node.pxi +++ /dev/null @@ -1,110 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from ... import _api_internal -from ..base import string_types -from ..node_generic import _set_class_node_base - -"""Maps node type to its constructor""" -NODE_TYPE = [] - -def _register_node(int index, object cls): - """register node class""" - while len(NODE_TYPE) <= index: - NODE_TYPE.append(None) - NODE_TYPE[index] = cls - - -cdef inline object make_ret_node(void* chandle): - global NODE_TYPE - cdef int tindex - cdef list node_type - cdef object cls - node_type = NODE_TYPE - CALL(TVMNodeGetTypeIndex(chandle, &tindex)) - if tindex < len(node_type): - cls = node_type[tindex] - if cls is not None: - obj = cls.__new__(cls) - else: - obj = NodeBase.__new__(NodeBase) - else: - obj = NodeBase.__new__(NodeBase) - (obj).chandle = chandle - return obj - - -cdef class NodeBase: - cdef void* chandle - - cdef _set_handle(self, handle): - cdef unsigned long long ptr - if handle is None: - self.chandle = NULL - else: - ptr = handle.value - self.chandle = (ptr) - - property handle: - def __get__(self): - if self.chandle == NULL: - return None - else: - return ctypes_handle(self.chandle) - - def __set__(self, value): - self._set_handle(value) - - def __dealloc__(self): - CALL(TVMNodeFree(self.chandle)) - - def __getattr__(self, name): - cdef TVMValue ret_val - cdef int ret_type_code, ret_succ - CALL(TVMNodeGetAttr(self.chandle, c_str(name), - &ret_val, &ret_type_code, &ret_succ)) - if ret_succ == 0: - raise AttributeError( - "'%s' object has no attribute '%s'" % (type(self), name)) - return make_ret(ret_val, ret_type_code) - - def __init_handle_by_constructor__(self, fconstructor, *args): - """Initialize the handle by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return handle is directly set into the Node object - instead of creating a new Node. - """ - # avoid error raised during construction. - self.chandle = NULL - cdef void* chandle - ConstructorCall( - (fconstructor).chandle, - kNodeHandle, args, &chandle) - self.chandle = chandle - -_set_class_node_base(NodeBase) diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 90be6a9c5b74..9561eab94ea2 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -16,6 +16,8 @@ # under the License. """Maps object type to its constructor""" +from ..node_generic import _set_class_node_base + OBJECT_TYPE = [] def _register_object(int index, object cls): @@ -27,6 +29,7 @@ def _register_object(int index, object cls): cdef inline object make_ret_object(void* chandle): global OBJECT_TYPE + global _CLASS_NODE cdef unsigned tindex cdef list object_type cdef object cls @@ -39,9 +42,11 @@ cdef inline object make_ret_object(void* chandle): if cls is not None: obj = cls.__new__(cls) else: - obj = ObjectBase.__new__(ObjectBase) + # default use node base class + # TODO(tqchen) change to object after Node unifies with Object + obj = _CLASS_NODE.__new__(_CLASS_NODE) else: - obj = ObjectBase.__new__(ObjectBase) + obj = _CLASS_NODE.__new__(_CLASS_NODE) (obj).chandle = chandle return obj @@ -94,3 +99,6 @@ cdef class ObjectBase: (fconstructor).chandle, kObjectHandle, args, &chandle) self.chandle = chandle + + +_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py index baca89d628b8..c6c151af9053 100644 --- a/python/tvm/_ffi/node.py +++ b/python/tvm/_ffi/node.py @@ -21,21 +21,8 @@ import ctypes import sys from .. import _api_internal +from .object import Object, register_object, _set_class_node from .node_generic import NodeGeneric, convert_to_node, const -from .base import _LIB, check_call, c_str, py_str, _FFI_MODE - -IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError -try: - # pylint: disable=wrong-import-position - if _FFI_MODE == "ctypes": - raise ImportError() - if sys.version_info >= (3, 0): - from ._cy3.core import _register_node, NodeBase as _NodeBase - else: - from ._cy2.core import _register_node, NodeBase as _NodeBase -except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.node import _register_node, NodeBase as _NodeBase def _new_object(cls): @@ -43,20 +30,22 @@ def _new_object(cls): return cls.__new__(cls) -class NodeBase(_NodeBase): +class NodeBase(Object): """NodeBase is the base class of all TVM language AST object.""" def __repr__(self): return _api_internal._format_str(self) def __dir__(self): - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - check_call(_LIB.TVMNodeListAttrNames( - self.handle, ctypes.byref(size), ctypes.byref(plist))) - names = [] - for i in range(size.value): - names.append(py_str(plist[i])) - return names + fnames = _api_internal._NodeListAttrNames(self) + size = fnames(-1) + return [fnames(i) for i in range(size)] + + def __getattr__(self, name): + try: + return _api_internal._NodeGetAttr(self, name) + except AttributeError: + raise AttributeError( + "%s has no attribute %s" % (str(type(self)), name)) def __hash__(self): return _api_internal._raw_ptr(self) @@ -95,24 +84,6 @@ def same_as(self, other): return self.__hash__() == other.__hash__() -def register_node(type_key=None): - """register node type - - Parameters - ---------- - type_key : str or cls - The type key of the node - """ - node_name = type_key if isinstance(type_key, str) else type_key.__name__ - - def register(cls): - """internal register function""" - tindex = ctypes.c_int() - ret = _LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex)) - if ret == 0: - _register_node(tindex.value, cls) - return cls - - if isinstance(type_key, str): - return register - return register(type_key) +# pylint: disable=invalid-name +register_node = register_object +_set_class_node(NodeBase) diff --git a/python/tvm/_ffi/object.py b/python/tvm/_ffi/object.py index be8b086a50f9..002fd27af0fd 100644 --- a/python/tvm/_ffi/object.py +++ b/python/tvm/_ffi/object.py @@ -20,25 +20,25 @@ import sys import ctypes -from .base import _FFI_MODE, check_call, _LIB, c_str +from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError try: - # pylint: disable=wrong-import-position + # pylint: disable=wrong-import-position,unused-import if _FFI_MODE == "ctypes": raise ImportError() if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_object + from ._cy3.core import _set_class_object, _set_class_node from ._cy3.core import ObjectBase as _ObjectBase from ._cy3.core import _register_object else: - from ._cy2.core import _set_class_object + from ._cy2.core import _set_class_object, _set_class_node from ._cy2.core import ObjectBase as _ObjectBase from ._cy2.core import _register_object except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.function import _set_class_object + # pylint: disable=wrong-import-position,unused-import + from ._ctypes.function import _set_class_object, _set_class_node from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.object import _register_object @@ -75,8 +75,15 @@ def register(cls): tindex = cls._type_index else: tidx = ctypes.c_uint() - check_call(_LIB.TVMObjectTypeKey2Index( - c_str(object_name), ctypes.byref(tidx))) + if not _RUNTIME_ONLY: + check_call(_LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx))) + else: + # directly skip unknown objects during runtime. + ret = _LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx)) + if ret != 0: + return cls tindex = tidx.value _register_object(tindex, cls) return cls diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 00e19459df76..2dbb67dfbf73 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -36,13 +36,12 @@ class TypeCode(object): TVM_TYPE = 5 TVM_CONTEXT = 6 ARRAY_HANDLE = 7 - NODE_HANDLE = 8 + OBJECT_HANDLE = 8 MODULE_HANDLE = 9 FUNC_HANDLE = 10 STR = 11 BYTES = 12 NDARRAY_CONTAINER = 13 - OBJECT_HANDLE = 14 EXT_BEGIN = 15 diff --git a/python/tvm/error.py b/python/tvm/error.py index b5a7ed2374b7..a6d4f701d2a6 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -49,6 +49,7 @@ def __init__(self, msg): register_error("ValueError", ValueError) register_error("TypeError", TypeError) +register_error("AttributeError", AttributeError) @register_error diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index b36715249f0a..ded5d0d13bd7 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -62,6 +62,10 @@ def compile(mod, target=None, target_host=None, params=None): compiler._compile(mod, target, target_host) return vm.Executable(compiler._get_exec()) +def enabled(): + """Whether vm profiler is enabled.""" + return hasattr(_vm, "_VMCompilerProfiler") + class VMCompilerProfiler(vm.VMCompiler): """Build Relay module to run on VM runtime.""" def __init__(self): diff --git a/python/tvm/relay/debug.py b/python/tvm/relay/debug.py index ee30f25d88c1..8887a7eb3c7c 100644 --- a/python/tvm/relay/debug.py +++ b/python/tvm/relay/debug.py @@ -17,12 +17,8 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" from __future__ import absolute_import -from .base import NodeBase, register_relay_node from ..api import register_func -@register_relay_node -class InterpreterState(NodeBase): - pass # pylint: disable=unused-argument def _debugger_init(expr, stack): diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index d9399492264b..848d5c00ab3f 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -71,7 +71,7 @@ macro_rules! TVMPODValue { Context(TVMContext), Handle(*mut c_void), ArrayHandle(TVMArrayHandle), - NodeHandle(*mut c_void), + ObjectHandle(*mut c_void), ModuleHandle(TVMModuleHandle), FuncHandle(TVMFunctionHandle), NDArrayContainer(*mut c_void), @@ -92,7 +92,7 @@ macro_rules! TVMPODValue { TVMTypeCode_kTVMContext => Context($value.v_ctx), TVMTypeCode_kHandle => Handle($value.v_handle), TVMTypeCode_kArrayHandle => ArrayHandle($value.v_handle as TVMArrayHandle), - TVMTypeCode_kNodeHandle => NodeHandle($value.v_handle), + TVMTypeCode_kObjectHandle => ObjectHandle($value.v_handle), TVMTypeCode_kModuleHandle => ModuleHandle($value.v_handle), TVMTypeCode_kFuncHandle => FuncHandle($value.v_handle), TVMTypeCode_kNDArrayContainer => NDArrayContainer($value.v_handle), @@ -124,7 +124,7 @@ macro_rules! TVMPODValue { TVMTypeCode_kArrayHandle, ) }, - NodeHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kNodeHandle), + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kObjectHandle), ModuleHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kModuleHandle), FuncHandle(val) => ( diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index 948711276304..01d0c58cfc5d 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -264,7 +264,7 @@ unsafe extern "C" fn tvm_callback( for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ffi::TVMTypeCode_kNodeHandle as c_int + if tcode == ffi::TVMTypeCode_kObjectHandle as c_int || tcode == ffi::TVMTypeCode_kFuncHandle as c_int || tcode == ffi::TVMTypeCode_kModuleHandle as c_int { diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index f31f02b1eaf4..c57e2afaa8eb 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -117,8 +117,7 @@ TVM_REGISTER_API("arith._CreateAnalyzer") }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - auto& sptr = args[1].node_sptr(); - if (sptr->is_type()) { + if (args[1].IsObjectRef()) { self->Bind(args[0], args[1].operator Range()); } else { self->Bind(args[0], args[1].operator Expr()); diff --git a/src/api/api_base.cc b/src/api/api_base.cc index 28ebb4d65005..c25c35f636e6 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,7 +30,7 @@ namespace tvm { TVM_REGISTER_API("_format_str") .set_body([](TVMArgs args, TVMRetValue *ret) { - CHECK(args[0].type_code() == kNodeHandle); + CHECK(args[0].type_code() == kObjectHandle); std::ostringstream os; os << args[0].operator NodeRef(); *ret = os.str(); @@ -38,9 +38,8 @@ TVM_REGISTER_API("_format_str") TVM_REGISTER_API("_raw_ptr") .set_body([](TVMArgs args, TVMRetValue *ret) { - CHECK(args[0].type_code() == kNodeHandle); - *ret = reinterpret_cast( - args[0].node_sptr().get()); + CHECK(args[0].type_code() == kObjectHandle); + *ret = reinterpret_cast(args[0].value().v_handle); }); TVM_REGISTER_API("_save_json") diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 73e26719cf15..f2ca67e6e2f9 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -33,7 +33,7 @@ namespace codegen { TVM_REGISTER_API("codegen._Build") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Build({args[0]}, args[1]); } else { *ret = Build(args[0], args[1]); diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index b8ee1441fe12..9312c5532302 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2016 by Contributors * Implementation of API functions related to IR build * \file api_ir.cc */ diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index aa0ce47b4a37..f3d6c5f6ab62 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -57,25 +57,26 @@ TVM_REGISTER_API("_str") TVM_REGISTER_API("_Array") .set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector > data; + std::vector data; for (int i = 0; i < args.size(); ++i) { if (args[i].type_code() != kNull) { - data.push_back(args[i].node_sptr()); + data.push_back(args[i].operator ObjectRef()); } else { - data.push_back(NodePtr(nullptr)); + data.push_back(ObjectRef(nullptr)); } } auto node = make_node(); node->data = std::move(data); - *ret = node; + *ret = runtime::ObjectRef(node); }); TVM_REGISTER_API("_ArrayGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { int64_t i = args[1]; - auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); CHECK_LT(static_cast(i), n->data.size()) << "out of bound of array"; *ret = n->data[static_cast(i)]; @@ -83,10 +84,11 @@ TVM_REGISTER_API("_ArrayGetItem") TVM_REGISTER_API("_ArraySize") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); *ret = static_cast( - static_cast(sptr.get())->data.size()); + static_cast(ptr)->data.size()); }); TVM_REGISTER_API("_Map") @@ -98,10 +100,10 @@ TVM_REGISTER_API("_Map") for (int i = 0; i < args.num_args; i += 2) { CHECK(args[i].type_code() == kStr) << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kNodeHandle) + CHECK(args[i + 1].type_code() == kObjectHandle) << "value of the map to be NodeRef"; data.emplace(std::make_pair(args[i].operator std::string(), - args[i + 1].node_sptr())); + args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); @@ -110,12 +112,12 @@ TVM_REGISTER_API("_Map") // Container node. MapNode::ContainerType data; for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kNodeHandle) + CHECK(args[i].type_code() == kObjectHandle) << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kNodeHandle) + CHECK(args[i + 1].type_code() == kObjectHandle) << "value of map to be NodeRef"; - data.emplace(std::make_pair(args[i].node_sptr(), - args[i + 1].node_sptr())); + data.emplace(std::make_pair(args[i].operator ObjectRef(), + args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); @@ -125,31 +127,33 @@ TVM_REGISTER_API("_Map") TVM_REGISTER_API("_MapSize") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); *ret = static_cast(n->data.size()); } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); *ret = static_cast(n->data.size()); } }); TVM_REGISTER_API("_MapGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK(args[0].type_code() == kNodeHandle); - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - CHECK(args[1].type_code() == kNodeHandle); - auto* n = static_cast(sptr.get()); - auto it = n->data.find(args[1].node_sptr()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + CHECK(args[1].type_code() == kObjectHandle); + auto* n = static_cast(ptr); + auto it = n->data.find(args[1].operator ObjectRef()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; *ret = (*it).second; } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); auto it = n->data.find(args[1].operator std::string()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; @@ -159,16 +163,17 @@ TVM_REGISTER_API("_MapGetItem") TVM_REGISTER_API("_MapCount") .set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK(args[0].type_code() == kNodeHandle); - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); - CHECK(args[1].type_code() == kNodeHandle); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); + CHECK_EQ(args[0].type_code(), kObjectHandle); *ret = static_cast( - n->data.count(args[1].node_sptr())); + n->data.count(args[1].operator ObjectRef())); } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); *ret = static_cast( n->data.count(args[1].operator std::string())); } @@ -176,9 +181,11 @@ TVM_REGISTER_API("_MapCount") TVM_REGISTER_API("_MapItems") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); auto rkvs = make_node(); for (const auto& kv : n->data) { rkvs->data.push_back(kv.first); @@ -186,10 +193,10 @@ TVM_REGISTER_API("_MapItems") } *ret = rkvs; } else { - auto* n = static_cast(sptr.get()); + auto* n = static_cast(ptr); auto rkvs = make_node(); for (const auto& kv : n->data) { - rkvs->data.push_back(ir::StringImm::make(kv.first).node_); + rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(kv.second); } *ret = rkvs; @@ -426,7 +433,7 @@ TVM_REGISTER_API("_ScheduleCacheRead") TVM_REGISTER_API("_ScheduleCacheWrite") .set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[1].IsNodeType()) { + if (args[1].IsObjectRef()) { *ret = args[0].operator Schedule() .cache_write(args[1].operator Tensor(), args[2]); } else { diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index d2352496c2b4..dd0415afd9eb 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -35,7 +35,7 @@ namespace ir { TVM_REGISTER_API("ir_pass.Simplify") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { if (args.size() > 1) { *ret = Simplify(args[0].operator Stmt(), args[1]); } else { @@ -52,7 +52,7 @@ TVM_REGISTER_API("ir_pass.Simplify") TVM_REGISTER_API("ir_pass.CanonicalSimplify") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { if (args.size() > 1) { *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]); } else { @@ -69,7 +69,7 @@ TVM_REGISTER_API("ir_pass.CanonicalSimplify") TVM_REGISTER_API("ir_pass.Substitute") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Substitute(args[0].operator Stmt(), args[1].operator Map()); } else { *ret = Substitute(args[0].operator Expr(), args[1].operator Map()); @@ -78,7 +78,7 @@ TVM_REGISTER_API("ir_pass.Substitute") TVM_REGISTER_API("ir_pass.Equal") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); } else { *ret = Equal(args[0].operator Expr(), args[1].operator Expr()); diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 177360bf2ebb..cf0e0f3c6b7a 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * Implementation of API functions related to schedule pass. * \file api_schedule.cc */ diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc index 89e999f73edb..64805c9e8aa0 100644 --- a/src/api/dsl_api.cc +++ b/src/api/dsl_api.cc @@ -18,36 +18,18 @@ */ /*! - * Copyright (c) 2016 by Contributors * Implementation of DSL API * \file dsl_api.cc */ -#include #include -#include #include #include +#include #include #include -#include -#include "../runtime/dsl_api.h" namespace tvm { namespace runtime { -/*! \brief entry to to easily hold returning information */ -struct TVMAPIThreadLocalEntry { - /*! \brief result holder for returning strings */ - std::vector ret_vec_str; - /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; - /*! \brief result holder for retruning string */ - std::string ret_str; -}; - -/*! \brief Thread local store that can be used to hold return values. */ -typedef dmlc::ThreadLocalStore TVMAPIThreadLocalStore; - -using TVMAPINode = NodePtr; struct APIAttrGetter : public AttrVisitor { std::string skey; @@ -138,93 +120,71 @@ struct APIAttrDir : public AttrVisitor { } }; -class DSLAPIImpl : public DSLAPI { - public: - void NodeFree(NodeHandle handle) const final { - delete static_cast(handle); - } - void NodeTypeKey2Index(const char* type_key, - int* out_index) const final { - *out_index = static_cast(Node::TypeKey2Index(type_key)); - } - void NodeGetTypeIndex(NodeHandle handle, - int* out_index) const final { - *out_index = static_cast( - (*static_cast(handle))->type_index()); - } - void NodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* ret_val, - int* ret_type_code, - int* ret_success) const final { - TVMRetValue rv; +struct NodeAPI { + static void GetAttr(TVMArgs args, TVMRetValue* ret) { + NodeRef ref = args[0]; + Node* tnode = const_cast(ref.get()); APIAttrGetter getter; - TVMAPINode* tnode = static_cast(handle); - getter.skey = key; - getter.ret = &rv; + getter.skey = args[1].operator std::string(); + getter.ret = ret; + + bool success; if (getter.skey == "type_key") { - ret_val->v_str = (*tnode)->type_key(); - *ret_type_code = kStr; - *ret_success = 1; - return; - } else if (!(*tnode)->is_type()) { - (*tnode)->VisitAttrs(&getter); - *ret_success = getter.found_ref_object || rv.type_code() != kNull; + *ret = tnode->GetTypeKey(); + success = true; + } else if (!tnode->IsInstance()) { + tnode->VisitAttrs(&getter); + success = getter.found_ref_object || ret->type_code() != kNull; } else { // specially handle dict attr - DictAttrsNode* dnode = static_cast(tnode->get()); - auto it = dnode->dict.find(key); + DictAttrsNode* dnode = static_cast(tnode); + auto it = dnode->dict.find(getter.skey); if (it != dnode->dict.end()) { - *ret_success = 1; - rv = (*it).second; + success = true; + *ret = (*it).second; } else { - *ret_success = 0; + success = false; } } - if (*ret_success) { - if (rv.type_code() == kStr || - rv.type_code() == kTVMType) { - TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get(); - e->ret_str = rv.operator std::string(); - *ret_type_code = kStr; - ret_val->v_str = e->ret_str.c_str(); - } else { - rv.MoveToCHost(ret_val, ret_type_code); - } + if (!success) { + LOG(FATAL) << "AttributeError: " << tnode->GetTypeKey() + << " object has no attributed " << getter.skey; } } - void NodeListAttrNames(NodeHandle handle, - int *out_size, - const char*** out_array) const final { - TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - ret->ret_vec_str.clear(); - TVMAPINode* tnode = static_cast(handle); + + static void ListAttrNames(TVMArgs args, TVMRetValue* ret) { + NodeRef ref = args[0]; + Node* tnode = const_cast(ref.get()); + auto names = std::make_shared >(); APIAttrDir dir; - dir.names = &(ret->ret_vec_str); + dir.names = names.get(); - if (!(*tnode)->is_type()) { - (*tnode)->VisitAttrs(&dir); + if (!tnode->IsInstance()) { + tnode->VisitAttrs(&dir); } else { // specially handle dict attr - DictAttrsNode* dnode = static_cast(tnode->get()); + DictAttrsNode* dnode = static_cast(tnode); for (const auto& kv : dnode->dict) { - ret->ret_vec_str.push_back(kv.first); + names->push_back(kv.first); } } - ret->ret_vec_charp.clear(); - for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { - ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); - } - *out_array = dmlc::BeginPtr(ret->ret_vec_charp); - *out_size = static_cast(ret->ret_vec_str.size()); + + *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) { + int64_t i = args[0]; + if (i == -1) { + *rv = static_cast(names->size()); + } else { + *rv = (*names)[i]; + } + }); } }; -TVM_REGISTER_GLOBAL("dsl_api.singleton") -.set_body([](TVMArgs args, TVMRetValue* rv) { - static DSLAPIImpl impl; - void* ptr = &impl; - *rv = ptr; - }); +TVM_REGISTER_GLOBAL("_NodeGetAttr") +.set_body(NodeAPI::GetAttr); + +TVM_REGISTER_GLOBAL("_NodeListAttrNames") +.set_body(NodeAPI::ListAttrNames); + } // namespace runtime } // namespace tvm diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index acd964935c25..98e25742592d 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -36,9 +36,7 @@ Analyzer::Analyzer() int_set(this) { } -void Analyzer::Bind(const VarExpr& v, const Expr& expr) { - Var var(v.node_); - +void Analyzer::Bind(const VarExpr& var, const Expr& expr) { Expr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); @@ -49,9 +47,8 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) { this->canonical_simplify.Update(var, new_expr); } -void Analyzer::Bind(const VarExpr& v, const Range& range) { +void Analyzer::Bind(const VarExpr& var, const Range& range) { CHECK(range.defined()); - Var var(v.node_); if (is_one(range->extent)) { this->Bind(var, range->min); } else { diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index d80e4969d5c2..02e8079c9c7b 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -629,7 +629,7 @@ Mutate_(const Mul* op, const Expr& self) { } if (const auto* bconst = b.as()) { if (a.as()) { - SumExpr ret(std::move(a.node_)); + SumExpr ret = Downcast(std::move(a)); ret.CopyOnWrite()->MulToSelf(bconst->value); return std::move(ret); } else { @@ -931,7 +931,7 @@ Mutate_(const Mod* op, const Expr& self) { int64_t new_base = psum->base % cval; if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) { - SumExpr sum_expr(std::move(a.node_)); + SumExpr sum_expr = Downcast(a); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv); } @@ -992,7 +992,7 @@ Mutate_(const FloorMod* op, const Expr& self) { // Simplify the offset constant if necessary. // floormod(x - 5, 3) => floormod(x + 1, 3) int64_t new_base = floormod(psum->base, cval); - SumExpr sum_expr(std::move(a.node_)); + SumExpr sum_expr = Downcast(std::move(a)); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv); } else { diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index d5c012d302dc..168486ee0018 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -39,7 +39,7 @@ ConstIntBound::ConstIntBound( auto node = make_node(); node->min_value = min_value; node->max_value = max_value; - node_ = std::move(node); + data_ = std::move(node); } inline void PrintBoundValue(std::ostream& os, int64_t val) { diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 3c5f12a7379e..7da020efc42a 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -176,7 +176,7 @@ bool DetectClipBound( if (const Variable* v = n.as()) { if (bmap->count(v)) { if (flag == 0) { - var = Var(n.node_); + var = Downcast(n); flag = 1; } else if (flag == 1) { if (!var.same_as(n)) { diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 0e24714daf1f..313b34ded034 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -40,7 +40,7 @@ IntervalSet::IntervalSet(Expr min_value, Expr max_value) { auto node = make_node(); node->min_value = std::move(min_value); node->max_value = std::move(max_value); - node_ = std::move(node); + data_ = std::move(node); } IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) { @@ -506,7 +506,7 @@ class IntervalSetEvaluator : } IntervalSet VisitExprDefault_(const Node* op) final { - DLOG(WARNING) << "cannot evaluate set type " << op->type_key(); + DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); } diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index 04e166ae52c0..cda9d585ace1 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -87,7 +87,7 @@ Stmt IRMutatorWithAnalyzer:: Mutate_(const AttrStmt* op, const Stmt& s) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); diff --git a/src/arithmetic/ir_visitor_with_analyzer.h b/src/arithmetic/ir_visitor_with_analyzer.h index 71eea50e4c72..918f2e89501f 100644 --- a/src/arithmetic/ir_visitor_with_analyzer.h +++ b/src/arithmetic/ir_visitor_with_analyzer.h @@ -47,7 +47,7 @@ class IRVisitorWithAnalyzer final : public IRVisitor { void Visit_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 08454dd0ef5a..9e363e7cf99a 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -41,7 +41,7 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) { node->coeff = coeff; node->base = base; // finish construction. - node_ = std::move(node); + data_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 3f1c32243a23..66340e9c9021 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -34,6 +34,7 @@ namespace tvm { TVM_REGISTER_NODE_TYPE(TargetNode); +TVM_REGISTER_NODE_TYPE(GenericFuncNode); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const TargetNode *op, IRPrinter *p) { @@ -51,9 +52,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) */ Target CreateTarget(const std::string& target_name, const std::vector& options) { - auto target = Target(make_node()); - auto t = static_cast(target.node_.get()); - + auto t = make_node(); t->target_name = target_name; std::string libs_flag = "-libs="; @@ -137,7 +136,7 @@ Target CreateTarget(const std::string& target_name, return target::stackvm(); } - return target; + return Target(t); } TVM_REGISTER_API("_TargetCreate") @@ -674,7 +673,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); struct GenericFunc::Manager { - std::unordered_map > fmap; + std::unordered_map fmap; // mutex std::mutex mutex; @@ -694,10 +693,11 @@ GenericFunc GenericFunc::Get(const std::string& name) { if (it == m->fmap.end()) { auto f = make_node(); f->name_ = name; - m->fmap[name] = f; - return GenericFunc(f); + auto gf = GenericFunc(f); + m->fmap[name] = gf; + return gf; } else { - return GenericFunc(it->second); + return it->second; } } @@ -707,12 +707,12 @@ void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) auto it = m->fmap.find(name); CHECK(it == m->fmap.end()) << "GenericFunc already registered " << name; func->name_ = name; - m->fmap[name] = func.node_; + m->fmap[name] = func; } GenericFunc& GenericFunc::set_default(const PackedFunc value, - bool allow_override) { - auto node = static_cast(node_.get()); + bool allow_override) { + auto node = static_cast(operator->()); if (!allow_override) { CHECK(node->generic_func_ == nullptr) << "Generic function already registered for " << node->name_; @@ -736,7 +736,7 @@ GenericFunc& GenericFunc::register_func(const std::vector& tags, } void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { - auto node = static_cast(node_.get()); + auto node = static_cast(get()); auto target = Target::Current(true); PackedFunc func; diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index ecf62ab0cfac..ab203f2aa28a 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -806,7 +806,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) { void CodeGenC::VisitStmt_(const AttrStmt* op) { if (op->attr_key == ir::attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_idmap_.count(iv->var.get())) { BindThreadIndex(iv); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index d009290bb2fe..de54e242ff40 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -1173,7 +1173,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv); diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 7caf3a258b6f..6a3b0571c9ab 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -606,7 +606,7 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) { void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 54616adc214e..778b6b1a7811 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -300,7 +300,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmt* op) { PrintStmt(op->body); indent_ -= tab_; } else if (op->attr_key == ir::attr::realize_scope) { - auto v = FunctionRef(op->node.node_); + auto v = Downcast(op->node); alloc_storage_scope_[v] = op->value.as()->value; PrintStmt(op->body); } else { @@ -408,7 +408,7 @@ void CodeGenHybrid::PrintIndent() { std::string CodeGenHybrid::GetVarID(const Variable *v) { if (binds_.count(v)) return binds_[v]; - auto key = std::make_pair(v->GetNodePtr().get(), 0); + auto key = std::make_pair(static_cast(v), 0); if (id_map_.count(key)) { return id_map_[key]; } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 498838fc908f..866756996f8d 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file codegen_hybrid.h * \brief Common utilities to generated C style code. */ diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 995dfb392e87..b9391e4895b9 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -44,17 +44,17 @@ class AttrFunctor; #define ATTR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitAttr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitAttr_(static_cast(n.get()), \ std::forward(args)...); \ }); \ // A functor for common attribute information. template -class AttrFunctor { +class AttrFunctor { private: - using TSelf = AttrFunctor; - using FType = tvm::IRFunctor; + using TSelf = AttrFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -65,7 +65,7 @@ class AttrFunctor { * \param args Additional arguments. * \return The result of the call */ - virtual R VisitAttr(const NodeRef& n, Args... args) { + virtual R VisitAttr(const ObjectRef& n, Args... args) { static FType vtable = InitVTable(); if (vtable.can_dispatch(n)) { return vtable(n, this, std::forward(args)...); @@ -73,7 +73,7 @@ class AttrFunctor { return VisitAttrDefault_(n.get(), std::forward(args)...); } } - virtual R VisitAttrDefault_(const Node* node, Args... args) = 0; + virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -143,60 +143,60 @@ class AttrFunctor { }; class AttrsEqualHandler : - protected AttrFunctor { + protected AttrFunctor { public: /*! * \brief Check if lhs equals rhs * \param lhs The left operand. * \param rhs The right operand. */ - bool Equal(const NodeRef& lhs, const NodeRef& rhs); + bool Equal(const ObjectRef& lhs, const ObjectRef& rhs); protected: - bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final; - bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final; - bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloorDiv* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloorMod* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final; + bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::IntImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::UIntImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloatImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::StringImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Add* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Sub* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Mul* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Div* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Mod* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloorDiv* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloorMod* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Min* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Max* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::GE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::GT* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::LT* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::LE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::EQ* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::NE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::And* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Or* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Not* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Cast* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Call* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Select* lhs, const ObjectRef& other) final; }; class AttrsHashHandler : - protected AttrFunctor { + protected AttrFunctor { public: /*! * \brief Get hash value of node * \param node The node to be hashed. */ - size_t Hash(const NodeRef& node) { + size_t Hash(const ObjectRef& node) { if (!node.defined()) return 0; return this->VisitAttr(node); } protected: - size_t VisitAttrDefault_(const Node* lhs) final; + size_t VisitAttrDefault_(const Object* lhs) final; size_t VisitAttr_(const ir::IntImm* lhs) final; size_t VisitAttr_(const ir::UIntImm* lhs) final; size_t VisitAttr_(const ir::FloatImm* lhs) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index c5b14ac577ec..a299e17996e0 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -40,7 +40,7 @@ void DictAttrsNode::InitByPackedArgs( for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; runtime::TVMArgValue val = args[i + 1]; - if (val.type_code() == kNodeHandle) { + if (val.type_code() == kObjectHandle) { dict.Set(key, val.operator NodeRef()); } else if (val.type_code() == kStr) { dict.Set(key, Expr(val.operator std::string())); @@ -72,14 +72,14 @@ TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); using namespace ir; // Equal handler. -bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) { +bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; return this->VisitAttr(lhs, rhs); } -bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) { - if (lhs->derived_from()) { +bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& other) { + if (lhs->IsInstance()) { AttrsEqual equal; equal.handler_ = this; return static_cast(lhs)->ContentEqual( @@ -88,58 +88,58 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) return lhs == other.get(); } -bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (rhs->data.size() != lhs->data.size()) return false; for (size_t i = 0; i < lhs->data.size(); ++i) { - if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false; + if (!Equal(lhs->data[i], rhs->data[i])) return false; } } return true; } -bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (rhs->data.size() != lhs->data.size()) return false; for (const auto& kv : lhs->data) { auto it = rhs->data.find(kv.first); if (it == rhs->data.end()) return false; - if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false; + if (!Equal(kv.second, it->second)) return false; } } return true; } #define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \ - bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \ + bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \ if (const auto* rhs = other.as()) { \ if (!Equal(lhs->a, rhs->a)) return false; \ if (!Equal(lhs->b, rhs->b)) return false; \ @@ -167,7 +167,7 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(NE); TVM_DEFINE_ATTRS_BINOP_EQUAL(And); TVM_DEFINE_ATTRS_BINOP_EQUAL(Or); -bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return Equal(lhs->a, rhs->a); } else { @@ -175,7 +175,7 @@ bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (lhs->type != rhs->type) return false; return Equal(lhs->value, rhs->value); @@ -184,7 +184,7 @@ bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->name == rhs->name && @@ -196,7 +196,7 @@ bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const ObjectRef& other) { if (const auto* rhs = other.as