diff --git a/.gitignore b/.gitignore index 068cb87484a0..1fcb2dc2d3fc 100644 --- a/.gitignore +++ b/.gitignore @@ -2,9 +2,10 @@ __pycache__/ *.py[cod] *$py.class - +*.S # C extensions *.so +*.ll # Distribution / packaging .Python diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 808f485387f9..ff3db4367a30 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 808f485387f9a03f78fa9f1159f387d0d91b7a28 +Subproject commit ff3db4367a30f542aafb83b4af45e685b80102d0 diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index b5dc51b9e7ef..4a363cba5c79 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -67,7 +67,17 @@ namespace { namespace tvm { namespace runtime { RPCEnv::RPCEnv() { +#ifndef _WIN32 + char cwd[PATH_MAX]; + if (char *rc = getcwd(cwd, sizeof(cwd))) { + base_ = std::string(cwd) + "/rpc"; + } else { + base_ = "./rpc"; + } +#else base_ = "./rpc"; +#endif + mkdir(base_.c_str(), 0777); TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { static RPCEnv env; diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index ea4ab00c113b..2c8bdfae0168 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -33,7 +33,7 @@ #include #include "../../src/support/socket.h" -#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" #include "rpc_env.h" #include "rpc_server.h" @@ -66,6 +66,22 @@ static pid_t waitPidEintr(int* status) { } #endif +#ifdef __ANDROID__ +static std::string getNextString(std::stringstream* iss) { + std::string str = iss->str(); + size_t start = iss->tellg(); + size_t len = str.size(); + // Skip leading spaces. + while (start < len && isspace(str[start])) start++; + + size_t end = start; + while (end < len && !isspace(str[end])) end++; + + iss->seekg(end); + return str.substr(start, end-start); +} +#endif + /*! * \brief RPCServer RPC Server class. * \param host The hostname of the server, Default=0.0.0.0 @@ -86,7 +102,7 @@ class RPCServer { tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), custom_addr_(std::move(custom_addr)) { - + } /*! @@ -98,7 +114,7 @@ class RPCServer { tracker_sock_.Close(); listen_sock_.Close(); } catch(...) { - + } } @@ -144,7 +160,7 @@ class RPCServer { } int timeout = GetTimeOutFromOpts(opts); -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) // step 3: serving if (timeout != 0) { const pid_t timer_pid = fork(); @@ -164,9 +180,9 @@ class RPCServer { int status = 0; const pid_t finished_first = waitPidEintr(&status); if (finished_first == timer_pid) { - kill(worker_pid, SIGKILL); + kill(worker_pid, SIGTERM); } else if (finished_first == worker_pid) { - kill(timer_pid, SIGKILL); + kill(timer_pid, SIGTERM); } else { LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; } @@ -197,7 +213,7 @@ class RPCServer { try { SpawnRPCChild(conn.sockfd, seconds(timeout)); } catch (const std::exception&) { - + } auto dur = high_resolution_clock::now() - start_time; @@ -217,10 +233,10 @@ class RPCServer { * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ - void AcceptConnection(TrackerClient* tracker, + void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock, - support::SockAddr* addr, - std::string* opts, + support::SockAddr* addr, + std::string* opts, int ping_period = 2) { std::set old_keyset; std::string matchkey; @@ -260,7 +276,12 @@ class RPCServer { std::stringstream ssin(remote_key); std::string arg0; +#ifndef __ANDROID__ ssin >> arg0; +#else + arg0 = getNextString(&ssin); +#endif + if (arg0 != expect_header) { code = kRPCMismatch; CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); @@ -274,7 +295,11 @@ class RPCServer { CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); LOG(INFO) << "Connection success " << addr->AsString(); +#ifndef __ANDROID__ ssin >> *opts; +#else + *opts = getNextString(&ssin); +#endif *conn_sock = conn; return; } @@ -301,8 +326,9 @@ class RPCServer { int GetTimeOutFromOpts(const std::string& opts) const { const std::string option = "-timeout="; - if (opts.find(option) == 0) { - const std::string cmd = opts.substr(opts.find_last_of(option) + 1); + size_t pos = opts.rfind(option); + if (pos != std::string::npos) { + const std::string cmd = opts.substr(pos + option.size()); CHECK(support::IsNumber(cmd)) << "Timeout is not valid"; return std::stoi(cmd); } @@ -330,7 +356,7 @@ void ServerLoopFromChild(SOCKET socket) { tvm::support::TCPSocket sock(socket); const auto env = RPCEnv(); RPCServerLoop(int(sock.sockfd)); - + sock.Close(); env.CleanUp(); } @@ -357,7 +383,7 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc._ServerCreate") +TVM_REGISTER_GLOBAL("rpc.ServerCreate") .set_body([](TVMArgs args, TVMRetValue* rv) { RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); }); diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index dfd576f4c195..112f7d214a27 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -31,7 +31,7 @@ #include #include -#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/support/socket.h" namespace tvm { diff --git a/docs/api/python/contrib.rst b/docs/api/python/contrib.rst index b482d30515d4..8ac4e1ff7d3a 100644 --- a/docs/api/python/contrib.rst +++ b/docs/api/python/contrib.rst @@ -48,9 +48,9 @@ tvm.contrib.dlpack .. automodule:: tvm.contrib.dlpack :members: -tvm.contrib.emscripten -~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: tvm.contrib.emscripten +tvm.contrib.emcc +~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.emcc :members: tvm.contrib.miopen diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index c08c0d6d347b..340da7f36e32 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -107,6 +107,7 @@ class ConstIntBound : public ObjectRef { */ class ConstIntBoundAnalyzer { public: + using BoundMapType = std::unordered_map; /*! * \brief analyze the expr * \param expr The expression of interest. @@ -120,8 +121,7 @@ class ConstIntBoundAnalyzer { * \param bound The lookup table to store the intermediate results * \return the result of the analysis. */ - TVM_DLL ConstIntBound operator()(const PrimExpr& expr, - std::unordered_map* bound); + TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound); /*! * \brief Update constant int bound information of var. diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index b0776dee661f..d113860ddbce 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -28,7 +28,7 @@ #include #include #include - +#include #include #include #include @@ -131,21 +131,21 @@ class IRModuleNode : public Object { * \param name The variable name. * \returns true if contains, otherise false. */ - TVM_DLL bool ContainGlobalVar(const std::string& name) const; + TVM_DLL bool ContainGlobalVar(const String& name) const; /*! * \brief Check if the global_type_var_map_ contains a global type variable. * \param name The variable name. * \returns true if contains, otherise false. */ - TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const; + TVM_DLL bool ContainGlobalTypeVar(const String& name) const; /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const; + TVM_DLL GlobalVar GetGlobalVar(const String& str) const; /*! * \brief Collect all global vars defined in this module. @@ -158,7 +158,7 @@ class IRModuleNode : public Object { * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const; + TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const; /*! * \brief Collect all global type vars defined in this module. @@ -172,7 +172,7 @@ class IRModuleNode : public Object { * \param cons name of the constructor * \returns Constructor of ADT, error if not found */ - TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const; + TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const; /*! * \brief Look up a global function by its variable. @@ -186,7 +186,7 @@ class IRModuleNode : public Object { * \param name The name of the function. * \returns The function named by the argument. */ - TVM_DLL BaseFunc Lookup(const std::string& name) const; + TVM_DLL BaseFunc Lookup(const String& name) const; /*! * \brief Look up a global type definition by its variable. @@ -200,7 +200,7 @@ class IRModuleNode : public Object { * \param var The name of the global type definition. * \return The type definition. */ - TVM_DLL TypeData LookupTypeDef(const std::string& var) const; + TVM_DLL TypeData LookupTypeDef(const String& var) const; /*! * \brief Look up a constructor by its tag. @@ -225,18 +225,18 @@ class IRModuleNode : public Object { * relative it will be resovled against the current * working directory. */ - TVM_DLL void Import(const std::string& path); + TVM_DLL void Import(const String& path); /*! * \brief Import Relay code from the file at path, relative to the standard library. * \param path The path of the Relay code to import. */ - TVM_DLL void ImportFromStd(const std::string& path); + TVM_DLL void ImportFromStd(const String& path); /*! * \brief The set of imported files. */ - TVM_DLL std::unordered_set Imports() const; + TVM_DLL std::unordered_set Imports() const; static constexpr const char* _type_key = "IRModule"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -265,7 +265,7 @@ class IRModuleNode : public Object { /*! \brief The files previously imported, required to ensure importing is idempotent for each module. */ - std::unordered_set import_set_; + std::unordered_set import_set_; friend class IRModule; }; @@ -283,7 +283,7 @@ class IRModule : public ObjectRef { */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, - std::unordered_set import_set = {}); + std::unordered_set import_set = {}); /*! \brief default constructor */ IRModule() {} /*! @@ -329,7 +329,7 @@ class IRModule : public ObjectRef { * \param source_path The path to the source file. * \return A Relay module. */ - TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path); + TVM_DLL static IRModule FromText(const String& text, const String& source_path); /*! \brief Declare the container type. */ using ContainerType = IRModuleNode; @@ -346,7 +346,7 @@ class IRModule : public ObjectRef { * Use AsText if you want to store the text. * \sa AsText. */ -TVM_DLL std::string PrettyPrint(const ObjectRef& node); +TVM_DLL String PrettyPrint(const ObjectRef& node); /*! * \brief Render the node as a string in the text format. @@ -362,8 +362,8 @@ TVM_DLL std::string PrettyPrint(const ObjectRef& node); * \sa PrettyPrint. * \return The text representation. */ -TVM_DLL std::string AsText(const ObjectRef& node, +TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); + runtime::TypedPackedFunc annotate = nullptr); } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 920ecfbf9b13..79bcdc6c0573 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -550,6 +550,54 @@ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); */ TVM_DLL int TVMObjectFree(TVMObjectHandle obj); +/*! + * \brief Allocate a data space on device. + * \param ctx The device context to perform operation. + * \param nbytes The number of bytes in memory. + * \param alignment The alignment of the memory. + * \param type_hint The type of elements. Only needed by certain backends such + * as nbytes & alignment are sufficient for most backends. + * \param out_data The allocated device pointer. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx, + size_t nbytes, + size_t alignment, + DLDataType type_hint, + void** out_data); + +/*! + * \brief Free a data space on device. + * \param ctx The device context to perform operation. + * \param ptr The data space. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMDeviceFreeDataSpace(TVMContext ctx, void* ptr); + +/*! + * \brief Copy data from one place to another. + * \param from The source array. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param num_bytes The size of the memory in bytes + * \param ctx_from The source context + * \param ctx_to The target context + * \param type_hint The type of elements, only neded by certain backends. + * can be useful for cross device endian converison. + * \param stream Optional stream object. + * \return 0 when success, -1 when failure happens. + */ +TVM_DLL int TVMDeviceCopyDataFromTo(const void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t num_bytes, + TVMContext ctx_from, + TVMContext ctx_to, + DLDataType type_hint, + TVMStreamHandle stream); + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index f2ddc84e9f98..12069182354b 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -157,9 +157,9 @@ class TVM_DLL DeviceAPI { * \param event_dst The destination stream to synchronize. */ virtual void SyncStreamFromTo(TVMContext ctx, - TVMStreamHandle event_src, - TVMStreamHandle event_dst); - /*! + TVMStreamHandle event_src, + TVMStreamHandle event_dst); + /*! * \brief Allocate temporal workspace for backend execution. * * \note We have the following assumption about backend temporal @@ -176,8 +176,8 @@ class TVM_DLL DeviceAPI { * as OpenGL, as nbytes is sufficient for most backends. */ virtual void* AllocWorkspace(TVMContext ctx, - size_t nbytes, - DLDataType type_hint = {}); + size_t nbytes, + DLDataType type_hint = {}); /*! * \brief Free temporal workspace in backend execution. * diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index cf6d5fab0e19..dfc21fcf233f 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -45,6 +45,14 @@ #define TVM_RUNTIME_HEADER_ONLY 0 #endif +// Always inline macro only use in template +// expansion cases where we know inline is important. +#ifdef _MSC_VER +#define TVM_ALWAYS_INLINE __forceinline inline +#else +#define TVM_ALWAYS_INLINE inline __attribute__((always_inline)) +#endif + namespace tvm { namespace runtime { @@ -273,7 +281,7 @@ class TypedPackedFunc { * \param args The arguments * \returns The return value. */ - inline R operator()(Args ...args) const; + TVM_ALWAYS_INLINE R operator()(Args ...args) const; /*! * \brief convert to PackedFunc * \return the internal PackedFunc @@ -728,7 +736,11 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(PackedFunc f) { - this->SwitchToClass(kTVMPackedFuncHandle, f); + if (f == nullptr) { + this->SwitchToPOD(kTVMNullptr); + } else { + this->SwitchToClass(kTVMPackedFuncHandle, f); + } return *this; } template @@ -1076,11 +1088,15 @@ struct func_signature_helper { template struct func_signature_helper { using FType = R(Args...); + static_assert(!std::is_reference::value, + "TypedPackedFunc return reference"); }; template struct func_signature_helper { using FType = R(Args...); + static_assert(!std::is_reference::value, + "TypedPackedFunc return reference"); }; /*! @@ -1096,12 +1112,16 @@ struct function_signature { template struct function_signature { using FType = R(Args...); + static_assert(!std::is_reference::value, + "TypedPackedFunc return reference"); }; // handle case of function ptr. template struct function_signature { using FType = R(Args...); + static_assert(!std::is_reference::value, + "TypedPackedFunc return reference"); }; } // namespace detail @@ -1114,66 +1134,71 @@ class TVMArgsSetter { template::value>::type> - void operator()(size_t i, T value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, T value) const { values_[i].v_int64 = static_cast(value); type_codes_[i] = kDLInt; } - void operator()(size_t i, uint64_t value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast(value); CHECK_LE(value, static_cast(std::numeric_limits::max())); type_codes_[i] = kDLInt; } - void operator()(size_t i, double value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, double value) const { values_[i].v_float64 = value; type_codes_[i] = kDLFloat; } - void operator()(size_t i, std::nullptr_t value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const { values_[i].v_handle = value; type_codes_[i] = kTVMNullptr; } - void operator()(size_t i, const TVMArgValue& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const { values_[i] = value.value_; type_codes_[i] = value.type_code_; } - void operator()(size_t i, void* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const { values_[i].v_handle = value; type_codes_[i] = kTVMOpaqueHandle; } - void operator()(size_t i, DLTensor* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const { values_[i].v_handle = value; type_codes_[i] = kTVMDLTensorHandle; } - void operator()(size_t i, TVMContext value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, TVMContext value) const { values_[i].v_ctx = value; type_codes_[i] = kTVMContext; } - void operator()(size_t i, DLDataType value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const { values_[i].v_type = value; type_codes_[i] = kTVMDataType; } - void operator()(size_t i, DataType dtype) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const { operator()(i, dtype.operator DLDataType()); } - void operator()(size_t i, const char* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const { values_[i].v_str = value; type_codes_[i] = kTVMStr; } // setters for container types - void operator()(size_t i, const std::string& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const { values_[i].v_str = value.c_str(); type_codes_[i] = kTVMStr; } - void operator()(size_t i, const TVMByteArray& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const { values_[i].v_handle = const_cast(&value); type_codes_[i] = kTVMBytes; } - void operator()(size_t i, const PackedFunc& value) const { - values_[i].v_handle = const_cast(&value); - type_codes_[i] = kTVMPackedFuncHandle; + TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const { + if (value != nullptr) { + values_[i].v_handle = const_cast(&value); + type_codes_[i] = kTVMPackedFuncHandle; + } else { + values_[i].v_handle = nullptr; + type_codes_[i] = kTVMNullptr; + } } template - void operator()(size_t i, const TypedPackedFunc& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc& value) const { operator()(i, value.packed()); } void operator()(size_t i, const TVMRetValue& value) const { @@ -1191,7 +1216,7 @@ class TVMArgsSetter { typename = typename std::enable_if< std::is_base_of::value> ::type> - void operator()(size_t i, const TObjectRef& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const { this->SetObject(i, value); } @@ -1200,7 +1225,7 @@ class TVMArgsSetter { std::is_base_of::type>::value> ::type> - void operator()(size_t i, TObjectRef&& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const { this->SetObject(i, std::forward(value)); } @@ -1230,10 +1255,10 @@ namespace detail { template struct unpack_call_dispatcher { template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { + TVM_ALWAYS_INLINE static void run(const F& f, + const TVMArgs& args_pack, + TVMRetValue* rv, + Args&&... unpacked_args) { // construct a movable argument value // which allows potential move of argument to the input of F. unpack_call_dispatcher @@ -1247,27 +1272,33 @@ struct unpack_call_dispatcher { template struct unpack_call_dispatcher { template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { - *rv = R(f(std::forward(unpacked_args)...)); + TVM_ALWAYS_INLINE static void run(const F& f, + const TVMArgs& args_pack, + TVMRetValue* rv, + Args&&... unpacked_args) { + using RetType = decltype(f(std::forward(unpacked_args)...)); + if (std::is_same::value) { + *rv = f(std::forward(unpacked_args)...); + } else { + *rv = R(f(std::forward(unpacked_args)...)); + } } }; template struct unpack_call_dispatcher { template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { + TVM_ALWAYS_INLINE static void run(const F& f, + const TVMArgs& args_pack, + TVMRetValue* rv, + Args&&... unpacked_args) { f(std::forward(unpacked_args)...); } }; template -inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { +TVM_ALWAYS_INLINE void unpack_call( + const F& f, const TVMArgs& args, TVMRetValue* rv) { CHECK_EQ(nargs, args.size()) << "Expect " << nargs << " arguments but get " << args.size(); unpack_call_dispatcher::run(f, args, rv); @@ -1280,22 +1311,23 @@ struct unpack_call_by_signature { template struct unpack_call_by_signature { template - static void run(const F& f, - const TVMArgs& args, - TVMRetValue* rv) { + TVM_ALWAYS_INLINE static void run( + const F& f, + const TVMArgs& args, + TVMRetValue* rv) { unpack_call(f, args, rv); } }; template -inline R call_packed(const PackedFunc& pf, Args&& ...args) { +TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&& ...args) { return R(pf(std::forward(args)...)); } template struct typed_packed_call_dispatcher { template - static inline R run(const PackedFunc& pf, Args&& ...args) { + TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&& ...args) { return pf(std::forward(args)...); } }; @@ -1303,7 +1335,7 @@ struct typed_packed_call_dispatcher { template<> struct typed_packed_call_dispatcher { template - static inline void run(const PackedFunc& pf, Args&& ...args) { + TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&& ...args) { pf(std::forward(args)...); } }; @@ -1334,7 +1366,7 @@ inline void TypedPackedFunc::AssignTypedLambda(FType flambda) { } template -inline R TypedPackedFunc::operator()(Args... args) const { +TVM_ALWAYS_INLINE R TypedPackedFunc::operator()(Args... args) const { return detail::typed_packed_call_dispatcher ::run(packed_, std::forward(args)...); } diff --git a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java index c31c67f283af..61ff966eaf38 100644 --- a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java +++ b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java @@ -38,53 +38,14 @@ public class GraphRuntime { * @return Runtime graph module that can be used to execute the graph. */ public static GraphModule create(String graphJson, Module libmod, TVMContext ctx) { - Module graphModule = null; - if (ctx.deviceType >= RPC.RPC_SESS_MASK) { - if (!(ctx instanceof TVMRemoteContext)) { - throw new IllegalArgumentException( - "Looks like you are using remote context with no RPCSession bind." - + "Use session.context instead."); - } - RPCSession rpcSession = ((TVMRemoteContext) ctx).rpcSession; - // check arguments - if (!"rpc".equals(libmod.typeKey())) { - throw new IllegalArgumentException("libmod.typeKey != rpc"); - } - final int sessIndex = (int) ((Function) reflectionStaticCall( - RPC.class, "getApi", "_SessTableIndex")) - .pushArg(libmod).invoke().asLong(); - if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) { - throw new IllegalArgumentException(String.format( - "libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d", - sessIndex, reflectionGetField(rpcSession, "tblIndex"))); - } - - Function rpcModuleHandle = (Function) reflectionStaticCall( - RPC.class, "getApi","_ModuleHandle"); - if (rpcModuleHandle == null) { - throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle." - + "Did you compile tvm_runtime with the correct version?"); - } - - Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create"); - if (fcreate == null) { - throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create." - + "Did you compile tvm_runtime with correct version?"); - } - - TVMValue hmod = rpcModuleHandle.pushArg(libmod).invoke(); - graphModule = fcreate.call(graphJson, hmod, - ctx.deviceType % RPC.RPC_SESS_MASK, ctx.deviceId).asModule(); - } else { - Function fcreate = Function.getFunction("tvm.graph_runtime.create"); - if (fcreate == null) { - throw new RuntimeException("Cannot find global function tvm.graph_runtime.create." - + "Did you compile tvm_runtime with correct version?"); - } - graphModule = fcreate.pushArg(graphJson) - .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId) - .invoke().asModule(); + Function fcreate = Function.getFunction("tvm.graph_runtime.create"); + if (fcreate == null) { + throw new RuntimeException("Cannot find global function tvm.graph_runtime.create." + + "Did you compile tvm_runtime with correct version?"); } + Module graphModule = fcreate.pushArg(graphJson) + .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId) + .invoke().asModule(); return new GraphModule(graphModule, ctx); } diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java index 5178ac900a36..69321c3b51c8 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java @@ -29,7 +29,7 @@ public class Client { * @return The connected session. */ public static RPCSession connect(String url, int port, String key) { - Function doConnect = RPC.getApi("_Connect"); + Function doConnect = RPC.getApi("Connect"); if (doConnect == null) { throw new RuntimeException("Please compile with USE_RPC=1"); } diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java b/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java index 29a457f39a40..1f3191fb2e8c 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java @@ -46,7 +46,7 @@ public NativeServerLoop(final Function fsend, final Function frecv) { try { tempDir = serverEnv(); System.err.println("starting server loop..."); - RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); + RPC.getApi("ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); System.err.println("done server loop..."); } catch (IOException e) { e.printStackTrace(); diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java index 92b328488b40..b9f621473cf4 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java @@ -39,7 +39,7 @@ public class RPCSession { RPCSession(Module sess) { session = sess; - tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong(); + tblIndex = (int) RPC.getApi("SessTableIndex").pushArg(session).invoke().asLong(); } /** @@ -237,7 +237,7 @@ public byte[] download(String path) { * @return The remote module containing remote function. */ public Module loadModule(String path) { - return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); + return RPC.getApi("LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); } diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index dc2dc1944f30..b17174a7c6bf 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -141,7 +141,13 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, TVMContext): values[i].v_int64 = _ctx_to_int64(arg) type_codes[i] = TypeCode.TVM_CONTEXT - elif isinstance(arg, bytearray): + elif isinstance(arg, (bytearray, bytes)): + # from_buffer only taeks in bytearray. + if isinstance(arg, bytes): + byte_arr = bytearray(arg) + temp_args.append(byte_arr) + arg = byte_arr + arr = TVMByteArray() arr.data = ctypes.cast( (ctypes.c_byte * len(arg)).from_buffer(arg), diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 1f68df1885db..45bcf64a616d 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -142,7 +142,13 @@ cdef inline int make_arg(object arg, value[0].v_ctx = (( ctypes.addressof(arg)))[0] tcode[0] = kTVMContext - elif isinstance(arg, bytearray): + elif isinstance(arg, (bytes, bytearray)): + # from_buffer only taeks in bytearray. + if isinstance(arg, bytes): + byte_arr = bytearray(arg) + temp_args.append(byte_arr) + arg = byte_arr + arr = TVMByteArray() arr.data = ctypes.cast( (ctypes.c_byte * len(arg)).from_buffer(arg), diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 8d3ce19f9444..8674e31c3b84 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -48,7 +48,6 @@ def _load_lib(): """Load libary by searching possible path.""" lib_path = libinfo.find_lib_path() lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) - # DMatrix functions lib.TVMGetLastError.restype = ctypes.c_char_p return lib, os.path.basename(lib_path[0]) diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 0d1a4e214791..de8f7b565c09 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -88,6 +88,9 @@ def find_lib_path(name=None, search_path=None, optional=False): dll_path.append(install_lib_dir) + if os.path.isdir(source_dir): + dll_path.append(os.path.join(source_dir, "web", "dist", "wasm")) + dll_path = [os.path.realpath(x) for x in dll_path] if search_path is not None: if isinstance(search_path, list): @@ -154,6 +157,7 @@ def find_include_path(name=None, search_path=None, optional=False): ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) source_dir = os.path.join(ffi_dir, "..", "..", "..") install_include_dir = os.path.join(ffi_dir, "..", "..", "..", "..") + third_party_dir = os.path.join(source_dir, "3rdparty") header_path = [] diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index ae37923a1dcf..8ad47acfe989 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -90,7 +90,8 @@ def get_target_triple(): def cross_compiler(compile_func, options=None, output_format=None, - get_target_triple=None): + get_target_triple=None, + add_files=None): """Create a cross compiler function by specializing compile_func with options. This function can be used to construct compile functions that @@ -111,6 +112,10 @@ def cross_compiler(compile_func, get_target_triple: Optional[Callable] Function that can target triple according to dumpmachine option of compiler. + add_files: Optional[List[str]] + List of paths to additional object, source, library files + to pass as part of the compilation. + Returns ------- fcompile : Callable[[str, str, Optional[str]], None] @@ -133,6 +138,7 @@ def cross_compiler(compile_func, """ base_options = [] if options is None else options kwargs = {} + add_files = [] if add_files is None else add_files # handle case where compile_func is the name of the cc if isinstance(compile_func, str): @@ -144,7 +150,7 @@ def _fcompile(outputs, objects, options=None): all_options = base_options if options is not None: all_options += options - compile_func(outputs, objects, options=all_options, **kwargs) + compile_func(outputs, objects + add_files, options=all_options, **kwargs) if not output_format and hasattr(compile_func, "output_format"): output_format = compile_func.output_format diff --git a/python/tvm/contrib/emscripten.py b/python/tvm/contrib/emcc.py similarity index 65% rename from python/tvm/contrib/emscripten.py rename to python/tvm/contrib/emcc.py index 7f31273451f7..6df205a030bc 100644 --- a/python/tvm/contrib/emscripten.py +++ b/python/tvm/contrib/emcc.py @@ -16,18 +16,16 @@ # under the License. """Util to invoke emscripten compilers in the system.""" # pylint: disable=invalid-name -from __future__ import absolute_import as _abs - import subprocess -from .._ffi.base import py_str -from .._ffi.libinfo import find_lib_path +from tvm._ffi.base import py_str +from tvm._ffi.libinfo import find_lib_path + -def create_js(output, - objects, - options=None, - side_module=False, - cc="emcc"): - """Create emscripten javascript library. +def create_tvmjs_wasm(output, + objects, + options=None, + cc="emcc"): + """Create wasm that is supposed to run with the tvmjs. Parameters ---------- @@ -44,25 +42,27 @@ def create_js(output, The compile string. """ cmd = [cc] - cmd += ["-Oz"] - if not side_module: - cmd += ["-s", "RESERVED_FUNCTION_POINTERS=2"] - cmd += ["-s", "NO_EXIT_RUNTIME=1"] - extra_methods = ['cwrap', 'getValue', 'setValue', 'addFunction'] - cfg = "[" + (','.join("\'%s\'" % x for x in extra_methods)) + "]" - cmd += ["-s", "EXTRA_EXPORTED_RUNTIME_METHODS=" + cfg] - else: - cmd += ["-s", "SIDE_MODULE=1"] - cmd += ["-o", output] + cmd += ["-O3"] + + cmd += ["-std=c++14"] + cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"] + cmd += ["-s", "STANDALONE_WASM=1"] + cmd += ["-s", "ALLOW_MEMORY_GROWTH=1"] + + objects = [objects] if isinstance(objects, str) else objects + with_runtime = False for obj in objects: - if obj.find("libtvm_web_runtime.bc") != -1: + if obj.find("wasm_runtime.bc") != -1: with_runtime = True - if not with_runtime and not side_module: - objects += [find_lib_path("libtvm_web_runtime.bc")[0]] + if not with_runtime: + objects += [find_lib_path("wasm_runtime.bc")[0]] + objects += [find_lib_path("tvmjs_support.bc")[0]] + + cmd += ["-o", output] cmd += objects if options: @@ -79,4 +79,4 @@ def create_js(output, msg += py_str(out) raise RuntimeError(msg) -create_js.object_format = "bc" +create_tvmjs_wasm.object_format = "bc" diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 73235f71c77b..740d1c3f19f3 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -18,9 +18,10 @@ import numpy as np import tvm._ffi -from .._ffi.base import string_types -from .._ffi.runtime_ctypes import TVMContext -from ..rpc import base as rpc_base +from tvm.rpc import _ffi_api as _rpc_ffi_api +from tvm.rpc import base as rpc_base +from tvm._ffi.base import string_types +from tvm._ffi.runtime_ctypes import TVMContext def create(graph_json_str, libmod, ctx): @@ -99,7 +100,7 @@ def get_device_ctx(libmod, ctx): device_type = cur_ctx.device_type if device_type >= rpc_base.RPC_SESS_MASK: assert libmod.type_key == "rpc" - assert rpc_base._SessTableIndex( + assert _rpc_ffi_api.SessTableIndex( libmod) == cur_ctx._rpc_sess._tbl_index num_rpc_ctx += 1 device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK diff --git a/python/tvm/error.py b/python/tvm/error.py index 4c3e6060c25a..b3502f6b0ead 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -57,6 +57,11 @@ def __init__(self, msg): register_error("KeyError", KeyError) +@register_error +class RPCError(RuntimeError): + """Error thrown by the remote server handling the RPC call.""" + + @register_error class OpError(TVMError): """Base class of all operator errors in frontends.""" diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index 4cf341335ea7..59da8fa5a3bf 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -29,12 +29,11 @@ def find_example_resource(): """Find resource examples.""" curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - base_path = os.path.join(curr_path, "../../../") - index_page = os.path.join(base_path, "web/example_rpc.html") + base_path = os.path.abspath(os.path.join(curr_path, "..", "..", "..")) + index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html") js_files = [ - os.path.join(base_path, "web/tvm_runtime.js"), - os.path.join(base_path, "build/libtvm_web_runtime.js"), - os.path.join(base_path, "build/libtvm_web_runtime.js.mem") + os.path.join(base_path, "web/dist/tvmjs.bundle.js"), + os.path.join(base_path, "web/dist/wasm/tvmjs_runtime.wasi.js") ] for fname in [index_page] + js_files: if not os.path.exists(fname): @@ -69,7 +68,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--host', type=str, default="0.0.0.0", + parser.add_argument('--host', type=str, default="localhost", help='the hostname of the server') parser.add_argument('--port', type=int, default=9090, help='The port of the RPC') diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 2ef94505b413..f4bcb6bf947a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -557,6 +557,31 @@ def _impl_v2(cls, inputs, attr, params): }, )(inputs, attr, params) + @classmethod + def _impl_v11(cls, inputs, attr, params): + pad_width = [] + pads = infer_value_simulated(inputs[1], params).asnumpy() + if len(inputs) == 3: + value = infer_value_simulated(inputs[2], params).asnumpy().item() + else: + value = 0 + attr["pad_value"] = value + dims = int(len(pads) / 2) + for i in range(dims): + pad_width.append((pads[i], pads[i+dims])) + attr['pad_width'] = pad_width + pad_mode = attr.get('mode', b'constant').decode('utf-8') + if pad_mode in ['constant', 'edge', 'reflect']: + attr['pad_mode'] = pad_mode + attr.pop('mode', None) + else: + raise tvm.error.OpAttributeInvalid( + 'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') + + return AttrCvt('pad')(inputs[:1], attr, params) + + + class ParametricSoftPlus(OnnxOpConverter): """ Operator converter for ParametricSoftPlus. @@ -576,7 +601,12 @@ class Prelu(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs)) - return _op.nn.prelu(inputs[0], inputs[1]) + alpha_shape = infer_shape(inputs[1]) + if len(alpha_shape) != 1: + alpha = _op.reshape(inputs[1], (-1,)) + else: + alpha = inputs[1] + return _op.nn.prelu(inputs[0], alpha) class Reciprocal(OnnxOpConverter): @@ -616,7 +646,7 @@ def _impl_v1(cls, inputs, attr, params): def _impl_v5(cls, inputs, attr, params): if get_name(inputs[1]) in params: # pop shape out of parameters since it wont be needed later. - shape = tuple(params.pop(inputs[1].name_hint).asnumpy()) + shape = tuple(params.pop(inputs[1].name_hint).asnumpy().astype("int32")) out = _op.reshape(inputs[0], shape) else: data, shape = inputs @@ -782,7 +812,10 @@ def _impl_v9(cls, inputs, attr, params): if not scales: #Here we are going to higher OPSET version. assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs)) - scales = params[inputs[1].name_hint].asnumpy() + if get_name(inputs[1]) in params: + scales = params[inputs[1].name_hint].asnumpy() + else: + scales = infer_value_simulated(inputs[1], params).asnumpy() inputs = inputs[:1] assert scales[0] == 1.0 and scales[1] == 1.0 input_shape = infer_shape(inputs[0]) @@ -1068,6 +1101,11 @@ class ReduceProd(Reduce): """ name = 'prod' +class ReduceLogSumExp(Reduce): + """ Operator converter for ReduceLogSumExp. + """ + name = 'logsumexp' + class ArgMax(OnnxOpConverter): """ Operator converter for ArgMax. """ @@ -1477,6 +1515,8 @@ def _impl_v9(cls, inputs, attr, params): raise ValueError("Expect 1 input only") output = AttrCvt(op_name='argwhere')(inputs, attr, params) + # ONNX NonZero always outputs int64 + output = _op.cast(output, "int64") return _op.transpose(output, axes=(1, 0)) class TopK(OnnxOpConverter): @@ -1630,8 +1670,7 @@ def _get_convert_map(opset): 'ReduceSum': ReduceSum.get_converter(opset), 'ReduceMean': ReduceMean.get_converter(opset), 'ReduceProd': ReduceProd.get_converter(opset), - # 'ReduceProd' - # 'ReduceLogSumExp' + 'ReduceLogSumExp': ReduceLogSumExp.get_converter(opset), #defs/sorting 'ArgMax': ArgMax.get_converter(opset), diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 703ef9c8b6b0..5a645c67cf61 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -31,6 +31,7 @@ from ... import nd as _nd from .common import ExprTable from .common import infer_shape as _infer_shape +from .tflite_flexbuffer import FlexBufferDecoder __all__ = ['from_tflite'] @@ -64,6 +65,7 @@ def __init__(self, model, subgraph, exp_tab): self.convert_map = { 'ABS': self.convert_abs, 'ADD': self.convert_add, + 'ADD_N': self.convert_add_n, 'AVERAGE_POOL_2D': self.convert_average_pool2d, 'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd, 'CAST': self.convert_cast, @@ -99,7 +101,7 @@ def __init__(self, model, subgraph, exp_tab): 'LOGISTIC': self.convert_logistic, 'MAX_POOL_2D': self.convert_max_pool2d, 'MAXIMUM': self.convert_maximum, - 'MEAN': self._convert_reduce_mean, + 'MEAN': self.convert_reduce_mean, 'MINIMUM': self.convert_minimum, 'MIRROR_PAD': self.convert_mirror_pad, 'MUL': self.convert_mul, @@ -109,16 +111,17 @@ def __init__(self, model, subgraph, exp_tab): 'PAD': self.convert_pad, 'POW': self.convert_pow, 'PRELU': self.convert_prelu, - 'REDUCE_ANY': self._convert_reduce_any, - 'REDUCE_MAX': self._convert_reduce_max, - 'REDUCE_MIN': self._convert_reduce_min, - 'REDUCE_PROD': self._convert_reduce_prod, + 'REDUCE_ANY': self.convert_reduce_any, + 'REDUCE_MAX': self.convert_reduce_max, + 'REDUCE_MIN': self.convert_reduce_min, + 'REDUCE_PROD': self.convert_reduce_prod, 'RELU':self.convert_relu, 'RESHAPE': self.convert_reshape, 'RESIZE_BILINEAR': self.convert_resize_bilinear, 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, + 'SELECT': self.convert_select, 'SIN': self.convert_sin, 'SLICE': self.convert_slice, 'SOFTMAX': self.convert_softmax, @@ -132,7 +135,7 @@ def __init__(self, model, subgraph, exp_tab): 'SQUEEZE': self.convert_squeeze, 'STRIDED_SLICE': self.convert_strided_slice, 'SUB': self.convert_sub, - 'SUM': self._convert_reduce_sum, + 'SUM': self.convert_reduce_sum, 'TAN': self.convert_tan, 'TANH':self.convert_tanh, 'TILE': self.convert_tile, @@ -140,6 +143,7 @@ def __init__(self, model, subgraph, exp_tab): 'TRANSPOSE_CONV': self.convert_transpose_conv, 'TRANSPOSE': self.convert_transpose, 'UNPACK': self.convert_unpack, + 'WHERE': self.convert_select, 'ZEROS_LIKE': self.convert_zeros_like, } @@ -320,6 +324,45 @@ def dequantize(self, expr, tensor): input_zero_point=tensor.qnn_params['zero_point']) return dequantized + + def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, + scale, zero_point, dtype): + """Convert TFLite fused activation function. The expr is an input quantized tensor with + scale and zero point """ + try: + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + + # Quantize a float value to an quantized integer value + quantize = lambda x: float(int(round(x / scale)) + zero_point) + + # Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not + # beyond the dtype range. + qmin = float(tvm.tir.op.min_value(dtype).value) + qmax = float(tvm.tir.op.max_value(dtype).value) + + # The input expr is a quantized tensor with its scale and zero point. We calculate the + # suitable clip off points based on these scale and zero point. + if fused_activation_fn == ActivationFunctionType.NONE: + return expr + if fused_activation_fn == ActivationFunctionType.RELU6: + return _op.clip(expr, + a_min=max(qmin, quantize(0)), + a_max=min(qmax, quantize(6.0))) + if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + return _op.clip(expr, + a_min=max(qmin, quantize(-1.0)), + a_max=min(qmax, quantize(1.0))) + if fused_activation_fn == ActivationFunctionType.RELU: + return _op.clip(expr, + a_min=max(qmin, quantize(0.0)), + a_max=qmax) + + fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] + raise tvm.error.OpNotImplemented( + 'Quantized activation {} is not supported yet.'.format(fused_activation_fn_str)) + def convert_conv2d(self, op): """Convert TFLite conv2d""" return self.convert_conv(op, "conv2d") @@ -428,7 +471,6 @@ def convert_l2_normalization(self, op): try: from tflite.BuiltinOptions import BuiltinOptions from tflite.L2NormOptions import L2NormOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -453,17 +495,15 @@ def convert_l2_normalization(self, op): if self.is_quantized(op): raise tvm.error.OpNotImplemented( 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') + # TFL uses only the default epsilon value out = _op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1]) # if we have fused activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'TFLite quantized L2_NORMALIZATION operator\ - with fused activation function is not supported yet.') + if output_tensor.qnn_params: + raise tvm.error.OpNotImplemented( + 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') + out = self.convert_fused_activation_function(out, fused_activation_fn) return out @@ -608,7 +648,6 @@ def convert_concatenation(self, op): try: from tflite.ConcatenationOptions import ConcatenationOptions from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -640,14 +679,20 @@ def convert_concatenation(self, op): output_zero_point=output_tensor.qnn_params['zero_point'], axis=concatenation_axis) - # if we have activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.concatenate')) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def _convert_unary_elemwise(self, relay_op, op): @@ -790,7 +835,6 @@ def _convert_elemwise(self, relay_op, op): from tflite.MulOptions import MulOptions from tflite.DivOptions import DivOptions from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -798,28 +842,9 @@ def _convert_elemwise(self, relay_op, op): assert len(input_tensors) == 2, "input tensors length should be 2" lhs_tensor = input_tensors[0] - if self.has_expr(lhs_tensor.tensor_idx): - # In most cases, we can assume that TOCO fuses elemwise operators - # with constants - it means both will be tensors. - lhs_expr = self.get_expr(lhs_tensor.tensor_idx) - else: - # However, in some corner cases, the elemwise operator is not fused, - # we can receive as constant. - lhs_type_str = self.get_tensor_type_str(lhs_tensor.tensor.Type()) - lhs_expr = self.exp_tab.new_const(self.get_tensor_value(lhs_tensor), - dtype=lhs_type_str) - rhs_tensor = input_tensors[1] - if self.has_expr(rhs_tensor.tensor_idx): - # In most cases, we can assume that TOCO fuses elemwise operators - # with constants - it means both will be tensors. - rhs_expr = self.get_expr(rhs_tensor.tensor_idx) - else: - # However, in some corner cases, the elemwise operator is not fused, - # we can receive as constant. - rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) - rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), - dtype=rhs_type_str) + lhs_expr = self.get_tensor_expr(lhs_tensor) + rhs_expr = self.get_tensor_expr(rhs_tensor) output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" @@ -855,13 +880,20 @@ def _convert_elemwise(self, relay_op, op): op_options = op.BuiltinOptions() options.Init(op_options.Bytes, op_options.Pos) fused_activation_fn = options.FusedActivationFunction() - # if we have activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if output_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - 'Elemwise operators with fused activation are not supported yet.') - out = self.convert_fused_activation_function(out, fused_activation_fn) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out def convert_add(self, op): @@ -871,6 +903,20 @@ def convert_add(self, op): return self._convert_elemwise(_qnn.op.add, op) return self._convert_elemwise(_op.add, op) + def convert_add_n(self, op): + """Convert TFLite ADD_N""" + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + + input_tensors = self.get_input_tensors(op) + assert not input_tensors[0].qnn_params, "TFLite does not support quantized ADD_N." + lhs_expr = self.get_tensor_expr(input_tensors[0]) + for rhs_tensor in input_tensors[1:]: + assert not rhs_tensor.qnn_params, "TFLite does not support quantized ADD_N" + rhs_expr = self.get_tensor_expr(rhs_tensor) + lhs_expr = _op.add(lhs_expr, rhs_expr) + return lhs_expr + def convert_sub(self, op): """Convert TFLite SUB""" # Check if the input tensor is quantized, call QNN op @@ -1241,7 +1287,7 @@ def convert_fill(self, op): return out def _convert_reduce(self, relay_op, op): - """Generic method to Convert TFLite MEAN operators""" + """Generic method to Convert TFLite REDUCE operators""" try: from tflite.BuiltinOptions import BuiltinOptions from tflite.ReducerOptions import ReducerOptions @@ -1285,22 +1331,22 @@ def _convert_reduce(self, relay_op, op): return out - def _convert_reduce_min(self, op): + def convert_reduce_min(self, op): return self._convert_reduce(_op.reduce.min, op) - def _convert_reduce_max(self, op): + def convert_reduce_max(self, op): return self._convert_reduce(_op.reduce.max, op) - def _convert_reduce_mean(self, op): + def convert_reduce_mean(self, op): return self._convert_reduce(_op.reduce.mean, op) - def _convert_reduce_prod(self, op): + def convert_reduce_prod(self, op): return self._convert_reduce(_op.reduce.prod, op) - def _convert_reduce_sum(self, op): + def convert_reduce_sum(self, op): return self._convert_reduce(_op.reduce.sum, op) - def _convert_reduce_any(self, op): + def convert_reduce_any(self, op): return self._convert_reduce(_op.reduce.any, op) def convert_fully_connected(self, op): @@ -1309,7 +1355,6 @@ def convert_fully_connected(self, op): from tflite.FullyConnectedOptions import FullyConnectedOptions from tflite.BuiltinOptions import BuiltinOptions from tflite.TensorType import TensorType - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -1329,16 +1374,28 @@ def convert_fully_connected(self, op): input_tensor_shape = input_tensor.tensor.ShapeAsNumpy() weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy() - # reshape input tensor from N H W C to N H*W*C - input_size_per_batch = 1 - for s in range(1, len(input_tensor_shape)): - input_size_per_batch *= input_tensor_shape[s] - assert input_size_per_batch == weight_tensor_shape[1], \ - "input size and weight size are mismatched" - target_shape = tuple((input_tensor_shape[0], input_size_per_batch)) + # Weight should have only 2 dimensions(TFLite convention) + assert len(weight_tensor_shape) == 2, "Weight should be only 2-dim" + + # Input shape: [i_batch_size, ..., n_inputs] + # Filter shape: [n_inputs, n_units] + # + # As we will transform Fully_Connected Input to Dense Op inputs as below + # Dense expected Input shape: [batch_size, n_units] + # Dense expected Weight shape: [out_dim, n_units] + # Dense output shape: [batch_size, out_dim] + # So it is evident that input shape: [batch_size = input_size / n_units, n_units] + input_size = 1 + for _, shape in enumerate(input_tensor_shape): + input_size *= shape + + # First get the batch size + batch_size = int(input_size / weight_tensor_shape[1]) + target_shape = tuple((batch_size, weight_tensor_shape[1])) in_expr = self.get_expr(input_tensor_idx) in_expr = _op.reshape(in_expr, target_shape) + #TODO: Change the output shape calculation based on keep_dim option assert op.BuiltinOptionsType() == BuiltinOptions.FullyConnectedOptions op_options = op.BuiltinOptions() fully_connected_options = FullyConnectedOptions() @@ -1350,8 +1407,11 @@ def convert_fully_connected(self, op): assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) - weight_value = self.get_tensor_value(weight_tensor) - weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) + if self.has_expr(weight_tensor.tensor_idx): + weight_expr = self.get_expr(weight_tensor.tensor_idx) + else: + weight_value = self.get_tensor_value(weight_tensor) + weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) weight_shape = _infer_shape(weight_expr) if input_tensor.qnn_params: @@ -1376,15 +1436,6 @@ def convert_fully_connected(self, op): dtype=bias_tensor_type_str) out = _op.nn.bias_add(out, bias_expr) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.dense')) - # Finally if the dense is quantized. Add a requantize at the end. if output_tensor.qnn_params: data_scale = input_tensor.qnn_params['scale'] @@ -1394,6 +1445,8 @@ def convert_fully_connected(self, op): new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') + + # Requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, input_zero_point=new_input_zero_point, @@ -1401,6 +1454,19 @@ def convert_fully_connected(self, op): output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + # Call activation function + output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=output_scale_val, + zero_point=output_zero_point_val, + dtype=output_tensor_type_str) + + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_squeeze(self, op): @@ -1435,7 +1501,9 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") - assert fused_activation_fn != ActivationFunctionType.NONE + + if fused_activation_fn == ActivationFunctionType.NONE: + return in_expr if fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(in_expr, a_min=0, a_max=6) if fused_activation_fn == ActivationFunctionType.RELU: @@ -1446,13 +1514,12 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): return _op.tanh(in_expr) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( - 'Operator {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + 'Fused activation {} is not supported yet.'.format(fused_activation_fn_str)) def convert_conv(self, op, conv_type): """convolution implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType from tflite.TensorType import TensorType from tflite.Conv2DOptions import Conv2DOptions from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions @@ -1583,17 +1650,9 @@ def convert_conv(self, op, conv_type): channel_axis = 3 out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.conv2d')) - - # Finally if the conv is quantized. Add a requantize at the end. + # Handle fused activation. if output_tensor.qnn_params: + # Calculate the intermediate scale and zero point of the int32 output. data_scale = input_tensor.qnn_params['scale'] weight_scale = weight_tensor.qnn_params['scale'] data_scale_val = get_scalar_from_constant(data_scale) @@ -1601,6 +1660,8 @@ def convert_conv(self, op, conv_type): new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') + + # Finally requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, input_zero_point=new_input_zero_point, @@ -1608,6 +1669,18 @@ def convert_conv(self, op, conv_type): output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + # Call activation function + output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=output_scale_val, + zero_point=output_zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_split(self, op): @@ -1697,6 +1770,18 @@ def convert_slice(self, op): return out + def convert_select(self, op): + """Convert TFLite SELECT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be == 3" + cond = self.get_tensor_expr(input_tensors[0]) + x = self.get_tensor_expr(input_tensors[1]) + y = self.get_tensor_expr(input_tensors[2]) + + out = _op.where(cond, x, y) + + return out + def convert_transpose(self, op): """transpose implementation.""" input_tensors = self.get_input_tensors(op) @@ -1771,7 +1856,6 @@ def convert_pool2d(self, op, pool_type): """pool2d implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType from tflite.Pool2DOptions import Pool2DOptions from tflite.Padding import Padding except ImportError: @@ -1846,13 +1930,19 @@ def convert_pool2d(self, op, pool_type): raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool')) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if input_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.pool2d')) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_pad(self, op): @@ -2241,28 +2331,15 @@ def convert_transpose_conv(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" - _option_names = [ - "w_scale", - "max_detections", - "_output_quantized", - "detections_per_class", - "x_scale", - "nms_score_threshold", - "num_classes", - "max_classes_per_detection", - "use_regular_nms", - "y_scale", - "h_scale", - "_support_output_type_float_in_quantized_op", - "nms_iou_threshold" - ] - - custom_options = get_custom_options(op, _option_names) - if custom_options["use_regular_nms"]: - raise tvm.error.OpAttributeUnImplemented( - "use_regular_nms=True is not yet supported for operator {}." - .format("TFLite_Detection_PostProcess") - ) + flexbuffer = op.CustomOptionsAsNumpy().tobytes() + custom_options = FlexBufferDecoder(flexbuffer).decode() + + if "use_regular_nms" in custom_options: + if custom_options["use_regular_nms"]: + raise tvm.error.OpAttributeUnImplemented( + "use_regular_nms=True is not yet supported for operator {}." + .format("TFLite_Detection_PostProcess") + ) inputs = self.get_input_tensors(op) assert len(inputs) == 3, "inputs length should be 3" @@ -2357,6 +2434,20 @@ def get_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + def get_tensor_expr(self, tensor): + """ Returns constant expr for constant else a tensor expr""" + if self.has_expr(tensor.tensor_idx): + # In most cases, we can assume that TOCO fuses elemwise operators + # with constants - it means both will be tensors. + expr = self.get_expr(tensor.tensor_idx) + else: + # However, in some corner cases, the elemwise operator is not fused, + # we can receive as constant. + type_str = self.get_tensor_type_str(tensor.tensor.Type()) + expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str) + + return expr + def get_scalar_from_constant(expr): """ Returns scalar value from Relay constant scalar. """ @@ -2433,91 +2524,6 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") -def get_custom_options(op, option_names): - """Get the options of a custom operator. - - This implements partial flexbuffer deserialization to be able - to read custom options. It is not intended to be a general - purpose flexbuffer deserializer and as such only supports a - limited number of types and assumes the data is a flat map. - - Parameters - ---------- - op: - A custom TFlite operator. - option_names: list - A complete list of the custom option names. - - Returns - ------- - options: dict - A dictionary of the custom options. - - """ - import struct - from enum import IntEnum - - class _FlexBufferType(IntEnum): - """Flexbuffer type schema from flexbuffers.h""" - FBT_NULL = 0 - FBT_INT = 1 - FBT_UINT = 2 - FBT_FLOAT = 3 - # Types above stored inline, types below store an offset. - FBT_KEY = 4 - FBT_STRING = 5 - FBT_INDIRECT_INT = 6 - FBT_INDIRECT_UINT = 7 - FBT_INDIRECT_FLOAT = 8 - FBT_MAP = 9 - FBT_VECTOR = 10 # Untyped. - FBT_VECTOR_INT = 11 # Typed any size (stores no type table). - FBT_VECTOR_UINT = 12 - FBT_VECTOR_FLOAT = 13 - FBT_VECTOR_KEY = 14 - FBT_VECTOR_STRING = 15 - FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field). - FBT_VECTOR_UINT2 = 17 - FBT_VECTOR_FLOAT2 = 18 - FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field). - FBT_VECTOR_UINT3 = 20 - FBT_VECTOR_FLOAT3 = 21 - FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field). - FBT_VECTOR_UINT4 = 23 - FBT_VECTOR_FLOAT4 = 24 - FBT_BLOB = 25 - FBT_BOOL = 26 - FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type - - buffer = op.CustomOptionsAsNumpy().tobytes() - value_vector_offset = buffer[-3] - buffer = buffer[:-3] - num_bytes = 4 # Assume all values are stored in 32 bit width - value_vector_size = struct.unpack( - "> 2) - value_offset = -value_vector_offset + i*num_bytes - value_bytes = buffer[value_offset:value_offset+num_bytes] - if flex_type == _FlexBufferType.FBT_BOOL: - value = bool(value_bytes[0]) - if flex_type == _FlexBufferType.FBT_INT: - value = struct.unpack("> 2) + value_bytes = self.buffer[end + i * byte_width: end + (i + 1) * byte_width] + if value_type == FlexBufferType.FBT_BOOL: + value = bool(value_bytes[0]) + elif value_type == FlexBufferType.FBT_INT: + value = struct.unpack("> 2) + byte_width = 1 << BitWidth(root_packed_type & 3) + + if root_type == FlexBufferType.FBT_MAP: + return self.decode_map(root_end, byte_width, root_byte_width) + raise NotImplementedError("Flexbuffer Decoding is partially imlpemented.") diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index d3226012e887..988c94928d33 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -18,7 +18,7 @@ # pylint: disable=redefined-builtin from . import _make -from .tensor import sqrt +from .tensor import sqrt, log, exp from .transform import squeeze from ..expr import Tuple, TupleWrapper @@ -475,3 +475,40 @@ def prod(data, axis=None, keepdims=False, exclude=False): """ axis = [axis] if isinstance(axis, int) else axis return _make.prod(data, axis, keepdims, exclude) + + +def logsumexp(data, axis=None, keepdims=False): + """Compute the log of the sum of exponentials of input elements over given axes. + + This function is more numerically stable than log(sum(exp(input))). + It avoids overflows caused by taking the exp of large inputs and underflows + caused by taking the log of small inputs. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a standard deviation operation is performed. + The default, axis=None, will compute the log of the sum of exponentials of all elements + in the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + + Returns + ------- + result : relay.Expr + The computed result. + """ + + axis = [axis] if isinstance(axis, int) else axis + max_x = max(data, axis, True) + exp_x = exp(data - max_x) + sum_x = sum(exp_x, axis, True) + out_x = log(sum_x) + max_x + if not keepdims: + out_x = squeeze(out_x, axis) + return out_x diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 9189b5edff83..83e4e40b53b9 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -136,8 +136,32 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.cuda.conv2d_nhwc), wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), name="conv2d_nhwc.cuda") - N, _, _, _ = get_const_tuple(data.shape) - _, _, CI, CO = get_const_tuple(kernel.shape) + N, H, W, _ = get_const_tuple(data.shape) + KH, KW, CI, CO = get_const_tuple(kernel.shape) + # Winograd shape related judgment + judge_winograd_tensorcore, judge_winograd_shape = winograd_judge(N, H, W, KH, KW, + CI, CO, padding, + stride_h, stride_w, + dilation_h, dilation_w, + pre_flag=False) + if judge_winograd_shape: + if target.target_name == "cuda" and \ + nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \ + judge_winograd_tensorcore: + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_tensorcore), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore), + name="conv2d_nhwc_winograd_tensorcore.cuda", + plevel=5) + else: + strategy.add_implementation( + wrap_compute_conv2d( + topi.cuda.conv2d_nhwc_winograd_direct), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_direct), + name="conv2d_nhwc_winograd_direct.cuda", + plevel=5) if target.target_name == "cuda": if nvcc.have_tensorcore(tvm.gpu(0).compute_version): if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ @@ -220,6 +244,9 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") layout = attrs.data_layout + data, kernel = inputs + stride_h, stride_w = attrs.get_int_tuple("strides") + padding = attrs.get_int_tuple("padding") assert dilation == (1, 1), "Do not support dilate now" assert groups == 1, "Do not supoort arbitrary group number" strategy = _op.OpStrategy() @@ -229,6 +256,30 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty wrap_topi_schedule( topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform), name="conv2d_nchw_winograd_without_weight_transform.cuda") + elif layout == "NHWC": + N, H, W, _ = get_const_tuple(data.shape) + alpha, _, CI, CO = get_const_tuple(kernel.shape) + dilation_h, dilation_w = dilation + judge_winograd_tensorcore, _ = winograd_judge(N, H, W, alpha, alpha, CI, CO, + padding, stride_h, stride_w, + dilation_h, dilation_w, + pre_flag=True) + if target.target_name == "cuda" and \ + nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \ + judge_winograd_tensorcore: + strategy.add_implementation( + wrap_compute_conv2d( + topi.cuda.conv2d_nhwc_winograd_tensorcore_without_weight_transform), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore_without_weight_transform), + name="conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") + else: + strategy.add_implementation( + wrap_compute_conv2d( + topi.cuda.conv2d_nhwc_winograd_direct_without_weight_transform), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_direct_without_weight_transform), + name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda") else: raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}". format(layout)) @@ -516,3 +567,26 @@ def proposal_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_proposal), name="proposal.cuda") return strategy + +def winograd_judge(N, H, W, KH, KW, CI, CO, padding, stride_h, + stride_w, dilation_h, dilation_w, pre_flag): + """Winograd judgement about tensorcore and shape""" + if H % 8 == 0: + tile_size = 4 + else: + tile_size = 2 + if pre_flag: + alpha = KH + KH = KW = alpha + 1 - tile_size + pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (KH, KW)) + OH = (H + pt + pb - KH) // stride_h + 1 + OW = (W + pl + pr - KW) // stride_w + 1 + nH, nW = (OH + tile_size - 1) // tile_size, (OW + tile_size - 1) // tile_size + P = N * nH * nW + judge_winograd_tensorcore = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ + (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ + (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0) + judge_winograd_shape = 2 < KH < 8 and 2 < KW < 8 and KH == KW and \ + stride_h == 1 and stride_w == 1 and \ + dilation_h == 1 and dilation_w == 1 + return judge_winograd_tensorcore, judge_winograd_shape diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 5c1baef4db94..5a3106d1e787 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,7 +18,7 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs -from tvm.relay.expr import Tuple +from tvm.relay.expr import Tuple, TupleWrapper from tvm.relay.op.nn.util import get_pad_tuple2d from . import _make @@ -156,7 +156,7 @@ def concatenate(data, Parameters ---------- - data : Union(List[relay.Expr], Tuple[relay.Expr]) + data : Union(List[relay.Expr], Tuple[relay.Expr], TupleWrapper[relay.Expr]) The list of quantized tensors. input_scales : List[relay.Expr] @@ -180,15 +180,16 @@ def concatenate(data, The concatenated quantized tensor. """ - data = list(data) - if not data: - raise ValueError("relay.concatenate requires data to be non-empty.") + if isinstance(data, (list, tuple)): + data = Tuple(data) + elif isinstance(data, TupleWrapper): + data = data.tuple_value if not isinstance(axis, int): raise ValueError("For now, we only support integer axis") input_scales = list(input_scales) input_zero_points = list(input_zero_points) - return _make.concatenate(Tuple(data), + return _make.concatenate(data, Tuple(input_scales), Tuple(input_zero_points), output_scale, diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 1a231eb1aaed..dc7937c0b346 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -183,11 +183,16 @@ def get_workload_official(model_url, model_sub_path): model_path = download_testdata(model_url, model_tar_name, module=['tf', 'official']) dir_path = os.path.dirname(model_path) - import tarfile if model_path.endswith("tgz") or model_path.endswith("gz"): + import tarfile tar = tarfile.open(model_path) tar.extractall(path=dir_path) tar.close() + elif model_path.endswith("zip"): + import zipfile + zip_object = zipfile.ZipFile(model_path) + zip_object.extractall(path=dir_path) + zip_object.close() else: raise RuntimeError('Could not decompress the file: ' + model_path) return os.path.join(dir_path, model_sub_path) diff --git a/python/tvm/rpc/__init__.py b/python/tvm/rpc/__init__.py index 5f959eb44745..b64ba33d9e09 100644 --- a/python/tvm/rpc/__init__.py +++ b/python/tvm/rpc/__init__.py @@ -26,4 +26,6 @@ """ from .server import Server -from .client import RPCSession, LocalSession, TrackerSession, connect, connect_tracker +from .client import connect, connect_tracker +from .client import RPCSession, LocalSession, PopenSession, TrackerSession +from .minrpc import with_minrpc diff --git a/python/tvm/rpc/_ffi_api.py b/python/tvm/rpc/_ffi_api.py new file mode 100644 index 000000000000..1a7cc739b5c1 --- /dev/null +++ b/python/tvm/rpc/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.rpc""" +import tvm._ffi + + +tvm._ffi._init_api("rpc", __name__) diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index bc81534a12d9..f0e33f8503f2 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -17,8 +17,6 @@ """Base definitions for RPC.""" # pylint: disable=invalid-name -from __future__ import absolute_import - import socket import time import json @@ -26,7 +24,6 @@ import struct import random import logging -import tvm._ffi from .._ffi.base import py_str @@ -176,7 +173,3 @@ def connect_with_retry(addr, timeout=60, retry_period=5): logger.warning("Cannot connect to tracker %s, retry in %g secs...", str(addr), retry_period) time.sleep(retry_period) - - -# Still use tvm.rpc for the foreign functions -tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base") diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index ed57e0d4276d..9997673f52c2 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -15,19 +15,20 @@ # specific language governing permissions and limitations # under the License. """RPC client tools""" -from __future__ import absolute_import - import os +import stat import socket import struct import time + import tvm._ffi from tvm.contrib import util from tvm._ffi.base import TVMError from tvm.runtime import ndarray as nd -from tvm.runtime import load_module as _load_module from . import base +from . import server +from . import _ffi_api class RPCSession(object): @@ -38,9 +39,23 @@ class RPCSession(object): # pylint: disable=invalid-name def __init__(self, sess): self._sess = sess - self._tbl_index = base._SessTableIndex(sess) + self._tbl_index = _ffi_api.SessTableIndex(sess) self._remote_funcs = {} + def system_lib(self): + """Get system-wide library module. + + Returns + ------- + module : runtime.Module + The system-wide library module. + + See Also + -------- + tvm.runtime.system_lib + """ + return self.get_function("runtime.SystemLib")() + def get_function(self, name): """Get function from the session. @@ -145,7 +160,7 @@ def load_module(self, path): m : Module The remote module containing remote function. """ - return base._LoadRemoteModule(self._sess, path) + return _ffi_api.LoadRemoteModule(self._sess, path) def cpu(self, dev_id=0): """Construct CPU device.""" @@ -183,28 +198,41 @@ class LocalSession(RPCSession): need to be ran both locally and remotely. """ def __init__(self): - # pylint: disable=super-init-not-called - self.context = nd.context - self.get_function = tvm._ffi.get_global_func - self._temp = util.tempdir() + self._temp = server._server_env([]) + RPCSession.__init__(self, _ffi_api.LocalSession()) - def upload(self, data, target=None): - if isinstance(data, bytearray): - if not target: - raise ValueError("target must present when file is a bytearray") - blob = data - else: - blob = bytearray(open(data, "rb").read()) - if not target: - target = os.path.basename(data) - with open(self._temp.relpath(target), "wb") as f: - f.write(blob) - def download(self, path): - return bytearray(open(self._temp.relpath(path), "rb").read()) +@tvm._ffi.register_func("rpc.PopenSession") +def _popen_session(binary): + temp = util.tempdir() - def load_module(self, path): - return _load_module(self._temp.relpath(path)) + if isinstance(binary, (bytes, bytearray)): + path_exec = temp.relpath("server.minrpc") + with open(path_exec, "wb") as outfile: + outfile.write(binary) + os.chmod(path_exec, stat.S_IXUSR | stat.S_IRUSR) + path_exec = os.path.abspath(path_exec) + else: + path_exec = os.path.abspath(binary) + if not os.path.isfile(path_exec): + raise RuntimeError(f"{path_exec} does not exist.") + if not os.access(path_exec, os.X_OK): + raise RuntimeError(f"{path_exec} is not executable.") + + sess = _ffi_api.CreatePipeClient(path_exec) + return sess + + +class PopenSession(RPCSession): + """RPCSession interface backed by popen. + + Parameters + ---------- + binary : List[Union[str, bytes]] + The binary to be executed. + """ + def __init__(self, binary): + RPCSession.__init__(self, _popen_session(binary)) class TrackerSession(object): @@ -378,7 +406,7 @@ def request_and_run(self, key, max_retry, str(last_err))) -def connect(url, port, key="", session_timeout=0): +def connect(url, port, key="", session_timeout=0, session_constructor_args=None): """Connect to RPC Server Parameters @@ -397,15 +425,43 @@ def connect(url, port, key="", session_timeout=0): the connection when duration is longer than this value. When duration is zero, it means the request must always be kept alive. + session_constructor_args: List + List of additional arguments to passed as the remote session constructor. + The first element of the list is always a string specifying the name of + the session constructor, the following args are the positional args to that function. + Returns ------- sess : RPCSession The connected session. + + Examples + -------- + Normal usage + .. code-block:: python + + client = rpc.connect(server_url, server_port, server_key) + + Session_constructor can be used to customize the session in the remote + The following code connects to a remote internal server via a proxy + by constructing another RPCClientSession on the proxy machine and use that + as the serving session of the proxy endpoint. + + .. code-block:: python + + client_via_proxy = rpc.connect( + proxy_server_url, proxy_server_port, proxy_server_key, + session_constructor_args=[ + "rpc.Connect", internal_url, internal_port, internal_key]) + """ try: if session_timeout: key += " -timeout=%s" % str(session_timeout) - sess = base._Connect(url, port, key) + session_constructor_args = session_constructor_args if session_constructor_args else [] + if not isinstance(session_constructor_args, (list, tuple)): + raise TypeError("Expect the session constructor to be a list or tuple") + sess = _ffi_api.Connect(url, port, key, *session_constructor_args) except NameError: raise RuntimeError("Please compile with USE_RPC=1") return RPCSession(sess) diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py new file mode 100644 index 000000000000..760c5362f11d --- /dev/null +++ b/python/tvm/rpc/minrpc.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utils to path.""" +import os +from tvm._ffi import libinfo +from tvm.contrib import cc + + +def find_minrpc_server_libpath(server="posix_popen_server"): + """Get the path of minrpc server libary. + + Parameters + ---------- + server : str + The kind of built in minrpc server. + + Returns + ------- + path : str + The path to the min server library. + """ + curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + source_dir = os.path.abspath(os.path.join(curr_dir, "..", "..", "..")) + + path = os.path.join( + source_dir, "src", "runtime", "rpc", "minrpc", ("%s.cc" % server)) + + candidates = [path] + if not os.path.isfile(path): + raise RuntimeError("Cannot find minserver %s, in candidates %s" % (server, candidates)) + return path + + +def with_minrpc(compile_func, + server="posix_popen_server", + runtime="libtvm"): + """Attach the compiler function with minrpc related options. + + Parameters + ---------- + compile_func : Union[str, Callable[[str, str, Optional[str]], None]] + The compilation function to decorate. + + server : str + The server type. + + runtime : str + The runtime library. + + Returns + ------- + fcompile : function + The return compilation. + """ + server_path = find_minrpc_server_libpath(server) + runtime_path = libinfo.find_lib_path( + [runtime, runtime + ".so", runtime + ".dylib"])[0] + + runtime_dir = os.path.abspath(os.path.dirname(runtime_path)) + options = ["-std=c++14"] + # Make sure the rpath to the libtvm is set so we can do local tests. + # Note that however, this approach won't work on remote. + # Always recommend to to link statically. + options += ["-Wl,-rpath=" + runtime_dir] + options += ["-I" + path for path in libinfo.find_include_path()] + fcompile = cc.cross_compiler( + compile_func, + options=options, + add_files=[server_path, runtime_path]) + fcompile.__name__ = "with_minrpc" + fcompile.need_system_lib = True + return fcompile diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index c3a3647948ee..03746dad6d62 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -42,6 +42,7 @@ raise ImportError( "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg) +from . import _ffi_api from . import base from .base import TrackerCode from .server import _server_env @@ -549,7 +550,7 @@ def _fsend(data): data = bytes(data) conn.write_message(data, binary=True) return len(data) - on_message = base._CreateEventDrivenServer( + on_message = _ffi_api.CreateEventDrivenServer( _fsend, "WebSocketProxyServer", "%toinit") return on_message diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 03749c1c17e4..15a3c7de789d 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -43,6 +43,7 @@ from tvm._ffi.libinfo import find_lib_path from tvm.runtime.module import load_module as _load_module from tvm.contrib import util +from . import _ffi_api from . import base from . base import TrackerCode @@ -56,7 +57,7 @@ def _server_env(load_library, work_path=None): temp = util.tempdir() # pylint: disable=unused-variable - @tvm._ffi.register_func("tvm.rpc.server.workpath") + @tvm._ffi.register_func("tvm.rpc.server.workpath", override=True) def get_workpath(path): return temp.relpath(path) @@ -81,7 +82,7 @@ def _serve_loop(sock, addr, load_library, work_path=None): """Server loop""" sockfd = sock.fileno() temp = _server_env(load_library, work_path) - base._ServerLoop(sockfd) + _ffi_api.ServerLoop(sockfd) if not work_path: temp.remove() logger.info("Finish serving %s", addr) @@ -330,7 +331,7 @@ def __init__(self, utvm_dev_config_args=None, ): try: - if base._ServerLoop is None: + if _ffi_api.ServerLoop is None: raise RuntimeError("Please compile with USE_RPC=1") except NameError: raise RuntimeError("Please compile with USE_RPC=1") diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 716f87f33fc1..b580e3f6dc6d 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -244,6 +244,7 @@ def _dso_exportable(self): def export_library(self, file_name, fcompile=None, + addons=None, **kwargs): """Export the module and its imported device code one library. @@ -283,7 +284,7 @@ def export_library(self, modules = self._collect_dso_modules() temp = _util.tempdir() - files = [] + files = addons if addons else [] is_system_lib = False has_c_module = False llvm_target_triple = None @@ -313,6 +314,9 @@ def export_library(self, if llvm_target_triple is None and hasattr(fcompile, "get_target_triple"): llvm_target_triple = fcompile.get_target_triple() + if getattr(fcompile, "need_system_lib", False) and not is_system_lib: + raise ValueError("%s need --system-lib option" % str(fcompile)) + if self.imported_modules: if enabled("llvm") and llvm_target_triple: path_obj = temp.relpath("devc.o") diff --git a/rust/runtime/src/workspace.rs b/rust/runtime/src/workspace.rs index 8344dfbb1adf..65ad25324cae 100644 --- a/rust/runtime/src/workspace.rs +++ b/rust/runtime/src/workspace.rs @@ -64,7 +64,7 @@ impl WorkspacePool { .iter() .fold(None, |cur_ws_idx: Option, &idx| { let ws_size = self.workspaces[idx].size(); - if !ws_size >= size { + if ws_size < size { return cur_ws_idx; } cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { @@ -92,9 +92,8 @@ impl WorkspacePool { break; } } - if let Some(ws_idx) = ws_idx { - self.free.push(ws_idx); - } + let ws_idx = ws_idx.ok_or_else(|| format_err!("Invalid pointer"))?; + self.free.push(ws_idx); Ok(()) } } @@ -135,6 +134,5 @@ pub extern "C" fn TVMBackendFreeWorkspace( Ok(()) => 0, Err(_) => -1, }) as c_int - }); - 0 + }) } diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 2bb0189890a6..a10db7a92020 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -870,6 +870,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { lhs->upper_factor != SplitExprNode::kPosInf) { auto updated = ToSplitExpr(this->VisitExpr(ModImpl( lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); + updated.CopyOnWrite()->scale = lhs->scale; // re-apply the lower_factor if (lhs->lower_factor != 1) { return SplitDivConst(updated, lhs->lower_factor, div_mode); diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index bb7c3dde17e0..4437225ffce4 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -147,17 +147,16 @@ class ConstIntBoundAnalyzer::Impl : } } if (bound_) { - const PrimExprNode* op = expr.as(); - auto val = bound_->find(op); + auto val = bound_->find(expr); if (val != bound_->end()) { - auto everything = Everything(op->dtype); + auto everything = Everything(expr->dtype); CHECK( (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || (val->second->min_value == everything.min_value && val->second->max_value == everything.max_value)) << "Detected bound for " << expr << "conflicts with memorization"; } - (*bound_)[op] = ConstIntBound(res.min_value, res.max_value); + (*bound_)[expr] = ConstIntBound(res.min_value, res.max_value); } return res; } @@ -369,7 +368,7 @@ class ConstIntBoundAnalyzer::Impl : // additional bound info std::vector additional_info_; // look up table for memorization - std::unordered_map* bound_{nullptr}; + BoundMapType* bound_{nullptr}; // constants: the limit value means umlimited // NOTE: kNegInf/kPosInf are used to represent infinity. static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; @@ -563,7 +562,7 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) { } ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, - std::unordered_map* bound) { + BoundMapType* bound) { impl_->bound_ = bound; Entry ret = impl_->VisitExpr(expr); impl_->bound_ = nullptr; diff --git a/src/ir/module.cc b/src/ir/module.cc index 6262150556c7..1be58f3caded 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -40,7 +40,7 @@ namespace tvm { IRModule::IRModule(tvm::Map functions, tvm::Map type_definitions, - std::unordered_set import_set) { + std::unordered_set import_set) { auto n = make_object(); n->functions = std::move(functions); n->type_definitions = std::move(type_definitions); @@ -111,15 +111,15 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { reduce_temp(); } -bool IRModuleNode::ContainGlobalVar(const std::string& name) const { +bool IRModuleNode::ContainGlobalVar(const String& name) const { return global_var_map_.find(name) != global_var_map_.end(); } -bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const { +bool IRModuleNode::ContainGlobalTypeVar(const String& name) const { return global_type_var_map_.find(name) != global_type_var_map_.end(); } -GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const { +GlobalVar IRModuleNode::GetGlobalVar(const String& name) const { auto it = global_var_map_.find(name); if (it == global_var_map_.end()) { std::ostringstream msg; @@ -146,7 +146,7 @@ tvm::Array IRModuleNode::GetGlobalVars() const { return tvm::Array(global_vars); } -GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { +GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const { CHECK(global_type_var_map_.defined()); auto it = global_type_var_map_.find(name); CHECK(it != global_type_var_map_.end()) @@ -154,7 +154,7 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { return (*it).second; } -Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const { +Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const { TypeData typeDef = this->LookupTypeDef(adt); for (Constructor c : typeDef->constructors) { if (cons.compare(c->name_hint) == 0) { @@ -315,7 +315,7 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { return (*it).second; } -BaseFunc IRModuleNode::Lookup(const std::string& name) const { +BaseFunc IRModuleNode::Lookup(const String& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } @@ -327,7 +327,7 @@ TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { return (*it).second; } -TypeData IRModuleNode::LookupTypeDef(const std::string& name) const { +TypeData IRModuleNode::LookupTypeDef(const String& name) const { GlobalTypeVar id = this->GetGlobalTypeVar(name); return this->LookupTypeDef(id); } @@ -379,7 +379,7 @@ IRModule IRModule::FromExpr( return mod; } -void IRModuleNode::Import(const std::string& path) { +void IRModuleNode::Import(const String& path) { if (this->import_set_.count(path) == 0) { this->import_set_.insert(path); DLOG(INFO) << "Importing: " << path; @@ -392,18 +392,18 @@ void IRModuleNode::Import(const std::string& path) { } } -void IRModuleNode::ImportFromStd(const std::string& path) { +void IRModuleNode::ImportFromStd(const String& path) { auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; std::string std_path = (*f)(); - return this->Import(std_path + "/" + path); + this->Import(std_path + "/" + path.operator std::string()); } -std::unordered_set IRModuleNode::Imports() const { +std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } -IRModule IRModule::FromText(const std::string& text, const std::string& source_path) { +IRModule IRModule::FromText(const String& text, const String& source_path) { auto* f = tvm::runtime::Registry::Get("relay.fromtext"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; IRModule mod = (*f)(text, source_path); @@ -467,7 +467,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Lookup") }); TVM_REGISTER_GLOBAL("ir.Module_Lookup_str") -.set_body_typed([](IRModule mod, std::string var) { +.set_body_typed([](IRModule mod, String var) { return mod->Lookup(var); }); @@ -477,7 +477,7 @@ TVM_REGISTER_GLOBAL("ir.Module_LookupDef") }); TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str") -.set_body_typed([](IRModule mod, std::string var) { +.set_body_typed([](IRModule mod, String var) { return mod->LookupTypeDef(var); }); @@ -499,12 +499,12 @@ TVM_REGISTER_GLOBAL("ir.Module_Update") }); TVM_REGISTER_GLOBAL("ir.Module_Import") -.set_body_typed([](IRModule mod, std::string path) { +.set_body_typed([](IRModule mod, String path) { mod->Import(path); }); TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd") -.set_body_typed([](IRModule mod, std::string path) { +.set_body_typed([](IRModule mod, String path) { mod->ImportFromStd(path); });; diff --git a/src/ir/op.cc b/src/ir/op.cc index b024165c1a4c..bd8a6e22f70e 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -148,7 +148,10 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames") return ret; }); -TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get); +TVM_REGISTER_GLOBAL("relay.op._GetOp") +.set_body_typed([](std::string name) -> Op { + return Op::Get(name); +}); TVM_REGISTER_GLOBAL("relay.op._OpGetAttr") .set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index bda997a59d4d..2e675c8ed8f4 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -918,22 +918,28 @@ static const char* kSemVer = "v0.0.4"; // - Implements AsText // - relay_text_printer.cc (specific printing logics for relay) // - tir_text_printer.cc (specific printing logics for TIR) -std::string PrettyPrint(const ObjectRef& node) { +String PrettyPrint(const ObjectRef& node) { Doc doc; doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node); return doc.str(); } -std::string AsText(const ObjectRef& node, +String AsText(const ObjectRef& node, bool show_meta_data, - runtime::TypedPackedFunc annotate) { + runtime::TypedPackedFunc annotate) { Doc doc; doc << kSemVer << Doc::NewLine(); - doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); + runtime::TypedPackedFunc ftyped = nullptr; + if (annotate != nullptr) { + ftyped = runtime::TypedPackedFunc( + [&annotate](const ObjectRef& expr) -> std::string { + return annotate(expr); + }); + } + doc << relay::RelayTextPrinter(show_meta_data, ftyped).PrintFinal(node); return doc.str(); } - TVM_REGISTER_GLOBAL("ir.PrettyPrint") .set_body_typed(PrettyPrint); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 3c9b07712f0c..33a9235bd495 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -523,7 +523,6 @@ centered at that value (zero padding is added where necessary). .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .add_type_rel("Identity", IdentityRel); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 7aa8bf1863a1..6d78ba8d8e8b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -829,9 +829,13 @@ bool TakeRel(const Array& types, // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + return false; + } const auto* indices = types[1].as(); - CHECK(indices != nullptr); + if (indices == nullptr) { + return false; + } CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; const auto param = attrs.as(); CHECK(param != nullptr); @@ -2325,7 +2329,12 @@ bool LayoutTransformRel(const Array& types, const Attrs& attrs, const TypeReporter& reporter) { const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + CHECK(types[0].as()) + << "LayoutTransform: expect input data type to be TensorType but get " + << types[0]; + return false; + } const LayoutTransformAttrs* params = attrs.as(); Layout src_layout(params->src_layout); @@ -2333,7 +2342,6 @@ bool LayoutTransformRel(const Array& types, CHECK(src_layout.defined() && dst_layout.defined()) << "cannot convert from/to undefined layout"; - auto layout_converter = tir::BijectiveLayout(src_layout, dst_layout); CHECK(layout_converter.defined()) << "cannot convert from " << params->src_layout << " to " << params->dst_layout; diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 152c7693fd81..2d397aba333e 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -380,7 +380,9 @@ bool ShapeOfRel(const Array& types, const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); - CHECK(tt != nullptr); + if (tt == nullptr) { + return false; + } const auto* param = attrs.as(); CHECK(param != nullptr); auto rank_shape = RankShape(tt->shape); diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 650dcb962d44..338e7a1ff6ad 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -149,8 +149,16 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, // If the output qnn params do not match the input qnn params, we can call requantize on the input // expr first, followed by a concatenate on the requantized input exprs. - auto tuple_data = data.as(); - CHECK(tuple_data != nullptr); + Array tuple_exprs; + if (data->IsInstance()) { + tuple_exprs = data.as()->fields; + } else if (data->IsInstance()) { // if the data is a CallNode, use TupleGetItems + auto call = Downcast(data); + for (size_t i = 0; i < tuple_type->fields.size(); i++) { + tuple_exprs.push_back(TupleGetItem(call, i)); + } + } + CHECK(!tuple_exprs.empty()); auto tuple_input_scales = input_scales.as(); CHECK(tuple_input_scales != nullptr); @@ -160,7 +168,7 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, int idx = 0; Array requantized_exprs; - for (auto quantized_expr : tuple_data->fields) { + for (auto quantized_expr : tuple_exprs) { // Get the input scale for the idx quantized input tensor. auto input_scale = tuple_input_scales->fields[idx]; diff --git a/src/relay/quantize/quantize.cc b/src/relay/quantize/quantize.cc index 631d8c0fdf58..431e18b95356 100644 --- a/src/relay/quantize/quantize.cc +++ b/src/relay/quantize/quantize.cc @@ -135,7 +135,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig") -.set_body_typed(QConfig::Current); +.set_body_typed([]() -> QConfig { + return QConfig::Current(); +}); TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope") .set_body_typed(QConfig::EnterQConfigScope); diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index fb1f74da2103..32b3381eeb2e 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -460,6 +460,7 @@ int TVMFuncCall(TVMFunctionHandle func, TVMValue* ret_val, int* ret_type_code) { API_BEGIN(); + TVMRetValue rv; (*static_cast(func)).CallPacked( TVMArgs(args, arg_type_codes, num_args), &rv); @@ -585,6 +586,42 @@ int TVMCbArgToReturn(TVMValue* value, int* code) { API_END(); } + +int TVMDeviceAllocDataSpace(DLContext ctx, + size_t nbytes, + size_t alignment, + DLDataType type_hint, + void** out_data) { + API_BEGIN(); + out_data[0] = DeviceAPIManager::Get(ctx)->AllocDataSpace( + ctx, nbytes, alignment, type_hint); + API_END(); +} + +int TVMDeviceFreeDataSpace(DLContext ctx, void* ptr) { + API_BEGIN(); + DeviceAPIManager::Get(ctx)->FreeDataSpace(ctx, ptr); + API_END(); +} + +int TVMDeviceCopyDataFromTo(const void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t num_bytes, + TVMContext ctx_from, + TVMContext ctx_to, + DLDataType type_hint, + TVMStreamHandle stream) { + API_BEGIN(); + TVMContext ctx = ctx_from.device_type != kDLCPU ? ctx_from : ctx_to; + DeviceAPIManager::Get(ctx)->CopyDataFromTo( + from, from_offset, + to, to_offset, + num_bytes, ctx_from, ctx_to, type_hint, stream); + API_END(); +} + // set device api TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) .set_body([](TVMArgs args, TVMRetValue *ret) { diff --git a/src/runtime/module.cc b/src/runtime/module.cc index d2ed7ff9e2b7..813a79d43c06 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -36,7 +36,7 @@ void ModuleNode::Import(Module other) { if (!std::strcmp(this->type_key(), "rpc")) { static const PackedFunc* fimport_ = nullptr; if (fimport_ == nullptr) { - fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule"); + fimport_ = runtime::Registry::Get("rpc.ImportRemoteModule"); CHECK(fimport_ != nullptr); } (*fimport_)(GetRef(this), other); diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 4717d89e33c1..855a342a7e97 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -37,7 +37,7 @@ struct Registry::Manager { // map storing the functions. // We delibrately used raw pointer // This is because PackedFunc can contain callbacks into the host languge(python) - // and the resource can become invalid because of indeterminstic order of destruction. + // and the resource can become invalid because of indeterminstic order of destruction and forking. // The resources will only be recycled during program exit. std::unordered_map fmap; // mutex @@ -60,20 +60,18 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*) return *this; } -Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*) +Registry& Registry::Register(const std::string& name, bool can_override) { // NOLINT(*) Manager* m = Manager::Global(); std::lock_guard lock(m->mutex); - auto it = m->fmap.find(name); - if (it == m->fmap.end()) { - Registry* r = new Registry(); - r->name_ = name; - m->fmap[name] = r; - return *r; - } else { - CHECK(override) - << "Global PackedFunc " << name << " is already registered"; - return *it->second; + if (m->fmap.count(name)) { + CHECK(can_override) + << "Global PackedFunc " << name << " is already registered"; } + + Registry* r = new Registry(); + r->name_ = name; + m->fmap[name] = r; + return *r; } bool Registry::Remove(const std::string& name) { diff --git a/src/runtime/rpc/minrpc/minrpc_server.h b/src/runtime/rpc/minrpc/minrpc_server.h new file mode 100644 index 000000000000..a84042e2ef73 --- /dev/null +++ b/src/runtime/rpc/minrpc/minrpc_server.h @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file minrpc_server.h + * \brief Minimum RPC server implementation, + * redirects all the calls to C runtime API. + * + * \note This file do not depend on c++ std or c std, + * and only depends on TVM's C runtime API. + */ +#ifndef TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ +#define TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ + +#include +#include +#include "../rpc_protocol.h" +#include "../../../support/arena.h" + +/*! \brief Whether or not to enable glog style DLOG */ +#ifndef TVM_MINRPC_ENABLE_LOGGING +#define TVM_MINRPC_ENABLE_LOGGING 0 +#endif + +#ifndef MINRPC_CHECK +#define MINRPC_CHECK(cond) \ + if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError); +#endif + +#if TVM_MINRPC_ENABLE_LOGGING +#include +#endif + + +namespace tvm { +namespace runtime { + +/*! + * \brief A minimum RPC server that only depends on the tvm C runtime.. + * + * All the dependencies are provided by the io arguments. + * + * \tparam TIOHandler IO provider to provide io handling. + * An IOHandler needs to provide the following functions: + * - PosixWrite, PosixRead, Close: posix style, read, write, close API. + * - Exit: exit with status code. + */ +template +class MinRPCServer { + public: + /*! + * \brief Constructor. + * \param io The IO handler. + */ + explicit MinRPCServer(TIOHandler io) + : io_(io), arena_(PageAllocator(io)) {} + + /*! \brief Run the server loop until shutdown signal is received. */ + void ServerLoop() { + RPCCode code; + uint64_t packet_len; + + while (true) { + arena_.RecycleAll(); + allow_clean_shutdown_ = true; + + this->Read(&packet_len); + if (packet_len == 0) continue; + this->Read(&code); + + allow_clean_shutdown_ = false; + + if (code >= RPCCode::kSyscallCodeStart) { + this->HandleSyscallFunc(code); + } else { + switch (code) { + case RPCCode::kCallFunc: { + HandleNormalCallFunc(); + break; + } + case RPCCode::kInitServer: { + HandleInitServer(); + break; + } + case RPCCode::kCopyFromRemote: { + HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + HandleCopyToRemote(); + break; + } + case RPCCode::kShutdown: { + this->Shutdown(); + return; + } + default: { + this->ThrowError(RPCServerStatus::kUnknownRPCCode); + break; + } + } + } + } + } + + void Shutdown() { + arena_.FreeAll(); + io_.Close(); + } + + void HandleNormalCallFunc() { + uint64_t call_handle; + TVMValue* values; + int* tcodes; + int num_args; + TVMValue ret_value[3]; + int ret_tcode[3]; + + this->Read(&call_handle); + RecvPackedSeq(&values, &tcodes, &num_args); + + int call_ecode = TVMFuncCall( + reinterpret_cast(call_handle), + values, tcodes, num_args, + &(ret_value[1]), &(ret_tcode[1])); + + if (call_ecode == 0) { + // Return value encoding as in LocalSession + int rv_tcode = ret_tcode[1]; + ret_tcode[0] = kDLInt; + ret_value[0].v_int64 = rv_tcode; + if (rv_tcode == kTVMNDArrayHandle) { + ret_tcode[1] = kTVMDLTensorHandle; + ret_value[2].v_handle = ret_value[1].v_handle; + ret_tcode[2] = kTVMOpaqueHandle; + this->ReturnPackedSeq(ret_value, ret_tcode, 3); + } else if (rv_tcode == kTVMPackedFuncHandle || + rv_tcode == kTVMModuleHandle) { + ret_tcode[1] = kTVMOpaqueHandle; + this->ReturnPackedSeq(ret_value, ret_tcode, 2); + } else { + this->ReturnPackedSeq(ret_value, ret_tcode, 2); + } + } else { + this->ReturnLastTVMError(); + } + } + + void HandleCopyFromRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + + uint8_t* data_ptr; + int call_ecode = 0; + if (ctx.device_type == kDLCPU) { + data_ptr = reinterpret_cast(handle) + offset; + } else { + data_ptr = this->ArenaAlloc(num_bytes); + call_ecode = TVMDeviceCopyDataFromTo( + reinterpret_cast(handle), offset, + data_ptr, 0, num_bytes, + ctx, DLContext{kDLCPU, 0}, + type_hint, nullptr); + // need sync to make sure that the copy is completed. + if (call_ecode == 0) { + call_ecode = TVMSynchronize( + ctx.device_type, ctx.device_id, nullptr); + } + } + + if (call_ecode == 0) { + RPCCode code = RPCCode::kCopyAck; + uint64_t packet_nbytes = sizeof(code) + num_bytes; + + this->Write(packet_nbytes); + this->Write(code); + this->WriteArray(data_ptr, num_bytes); + } else { + this->ReturnLastTVMError(); + } + } + + void HandleCopyToRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + int call_ecode = 0; + + if (ctx.device_type == kDLCPU) { + uint8_t* dptr = reinterpret_cast(handle) + offset; + this->ReadArray(dptr, num_bytes); + } else { + uint8_t* temp_data = this->ArenaAlloc(num_bytes); + this->ReadArray(temp_data, num_bytes); + + call_ecode = TVMDeviceCopyDataFromTo( + temp_data, 0, + reinterpret_cast(handle), offset, + num_bytes, + DLContext{kDLCPU, 0}, ctx, + type_hint, nullptr); + // need sync to make sure that the copy is completed. + if (call_ecode == 0) { + call_ecode = TVMSynchronize( + ctx.device_type, ctx.device_id, nullptr); + } + } + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void HandleSyscallFunc(RPCCode code) { + TVMValue* values; + int* tcodes; + int num_args; + RecvPackedSeq(&values, &tcodes, &num_args); + switch (code) { + case RPCCode::kFreeHandle: { + this->SyscallFreeHandle(values, tcodes, num_args); + break; + } + case RPCCode::kGetGlobalFunc: { + this->SyscallGetGlobalFunc(values, tcodes, num_args); + break; + } + case RPCCode::kDevSetDevice: { + this->ReturnException("SetDevice not supported"); + break; + } + case RPCCode::kDevGetAttr: { + this->ReturnException("GetAttr not supported"); + break; + } + case RPCCode::kDevAllocData: { + this->SyscallDevAllocData(values, tcodes, num_args); + break; + } + case RPCCode::kDevFreeData: { + this->SyscallDevFreeData(values, tcodes, num_args); + break; + } + case RPCCode::kDevStreamSync: { + this->SyscallDevStreamSync(values, tcodes, num_args); + break; + } + case RPCCode::kCopyAmongRemote: { + this->SyscallCopyAmongRemote(values, tcodes, num_args); + break; + } + default: { + this->ReturnException("Syscall not recognized"); + break; + } + } + } + + void HandleInitServer() { + uint64_t len; + this->Read(&len); + char* proto_ver = this->ArenaAlloc(len + 1); + this->ReadArray(proto_ver, len); + + TVMValue* values; + int* tcodes; + int num_args; + RecvPackedSeq(&values, &tcodes, &num_args); + MINRPC_CHECK(num_args == 0); + this->ReturnVoid(); + } + + void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[1] == kDLInt); + + void* handle = values[0].v_handle; + int64_t type_code = values[1].v_int64; + int call_ecode; + + if (type_code == kTVMNDArrayHandle) { + call_ecode = TVMArrayFree(static_cast(handle)); + } else if (type_code == kTVMPackedFuncHandle) { + call_ecode = TVMFuncFree(handle); + } else { + MINRPC_CHECK(type_code == kTVMModuleHandle); + call_ecode = TVMModFree(handle); + } + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 1); + MINRPC_CHECK(tcodes[0] == kTVMStr); + + void* handle; + int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallCopyAmongRemote(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 9); + // from, from_offset + MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[1] == kDLInt); + // to, to_offset + MINRPC_CHECK(tcodes[2] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[3] == kDLInt); + // size + MINRPC_CHECK(tcodes[4] == kDLInt); + // ctx_from, ctx_to + MINRPC_CHECK(tcodes[5] == kTVMContext); + MINRPC_CHECK(tcodes[6] == kTVMContext); + // type_hint, stream + MINRPC_CHECK(tcodes[7] == kTVMDataType); + MINRPC_CHECK(tcodes[8] == kTVMOpaqueHandle); + + void* from = values[0].v_handle; + int64_t from_offset = values[1].v_int64; + void* to = values[2].v_handle; + int64_t to_offset = values[3].v_int64; + int64_t size = values[4].v_int64; + TVMContext ctx_from = values[5].v_ctx; + TVMContext ctx_to = values[6].v_ctx; + DLDataType type_hint = values[7].v_type; + TVMStreamHandle stream = values[8].v_handle; + + int call_ecode = TVMDeviceCopyDataFromTo( + from, from_offset, + to, to_offset, size, + ctx_from, ctx_to, type_hint, stream); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevAllocData(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 4); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kDLInt); + MINRPC_CHECK(tcodes[2] == kDLInt); + MINRPC_CHECK(tcodes[3] == kTVMDataType); + + TVMContext ctx = values[0].v_ctx; + int64_t nbytes = values[1].v_int64; + int64_t alignment = values[2].v_int64; + DLDataType type_hint = values[3].v_type; + + void* handle; + int call_ecode = TVMDeviceAllocDataSpace( + ctx, nbytes, alignment, type_hint, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevFreeData(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + TVMContext ctx = values[0].v_ctx; + void* handle = values[1].v_handle; + + int call_ecode = TVMDeviceFreeDataSpace(ctx, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + TVMContext ctx = values[0].v_ctx; + void* handle = values[1].v_handle; + + int call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + io_.Exit(static_cast(code)); + } + + template + T* ArenaAlloc(int count) { + static_assert(std::is_pod::value, "need to be trival"); + return arena_.template allocate_(count); + } + + template + void Read(T* data) { + static_assert(std::is_pod::value, "need to be trival"); + this->ReadRawBytes(data, sizeof(T)); + } + + template + void ReadArray(T* data, size_t count) { + static_assert(std::is_pod::value, "need to be trival"); + return this->ReadRawBytes(data, sizeof(T) * count); + } + + template + void Write(const T& data) { + static_assert(std::is_pod::value, "need to be trival"); + return this->WriteRawBytes(&data, sizeof(T)); + } + + template + void WriteArray(T* data, size_t count) { + static_assert(std::is_pod::value, "need to be trival"); + return this->WriteRawBytes(data, sizeof(T) * count); + } + + private: + // Internal allocator that redirects alloc to TVM's C API. + class PageAllocator { + public: + using ArenaPageHeader = tvm::support::ArenaPageHeader; + + explicit PageAllocator(TIOHandler io) + : io_(io) {} + + ArenaPageHeader* allocate(size_t min_size) { + size_t npages = ((min_size + kPageSize - 1) / kPageSize); + void* data; + + if (TVMDeviceAllocDataSpace( + DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign, + DLDataType{kDLInt, 1, 1}, &data) != 0) { + io_.Exit(static_cast(RPCServerStatus::kAllocError)); + } + + ArenaPageHeader* header = static_cast(data); + header->size = npages * kPageSize; + header->offset = sizeof(ArenaPageHeader); + return header; + } + + void deallocate(ArenaPageHeader* page) { + if (TVMDeviceFreeDataSpace(DLContext{kDLCPU, 0}, page) != 0) { + io_.Exit(static_cast(RPCServerStatus::kAllocError)); + } + } + + static const constexpr int kPageSize = 2 << 10; + static const constexpr int kPageAlign = 8; + + private: + TIOHandler io_; + }; + + void RecvPackedSeq(TVMValue** out_values, + int** out_tcodes, + int* out_num_args) { + RPCReference::RecvPackedSeq( + out_values, out_tcodes, out_num_args, this); + } + + void ReturnVoid() { + int32_t num_args = 1; + int32_t tcode = kTVMNullptr; + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = + sizeof(code) + sizeof(num_args) + sizeof(tcode); + + this->Write(packet_nbytes); + this->Write(code); + this->Write(num_args); + this->Write(tcode); + } + + void ReturnHandle(void* handle) { + int32_t num_args = 1; + int32_t tcode = kTVMOpaqueHandle; + RPCCode code = RPCCode::kReturn; + uint64_t encode_handle = reinterpret_cast(handle); + + uint64_t packet_nbytes = + sizeof(code) + sizeof(num_args) + + sizeof(tcode) + sizeof(encode_handle); + + this->Write(packet_nbytes); + this->Write(code); + this->Write(num_args); + this->Write(tcode); + this->Write(encode_handle); + } + + void ReturnException(const char* msg) { + RPCReference::ReturnException(msg, this); + } + + void ReturnPackedSeq(const TVMValue* arg_values, + const int* type_codes, + int num_args) { + RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this); + } + + void ReturnLastTVMError() { + this->ReturnException(TVMGetLastError()); + } + + void ReadRawBytes(void* data, size_t size) { + uint8_t* buf = reinterpret_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_.PosixRead(buf, size - ndone); + if (ret == 0) { + if (allow_clean_shutdown_) { + this->Shutdown(); + io_.Exit(0); + } else { + this->ThrowError(RPCServerStatus::kReadError); + } + } + if (ret == -1) { + this->ThrowError(RPCServerStatus::kReadError); + } + ndone += ret; + buf += ret; + } + } + + void WriteRawBytes(const void* data, size_t size) { + const uint8_t *buf = reinterpret_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_.PosixWrite(buf, size - ndone); + if (ret == 0 || ret == -1) { + this->ThrowError(RPCServerStatus::kWriteError); + } + buf += ret; + ndone += ret; + } + } + + /*! \brief IO handler. */ + TIOHandler io_; + /*! \brief internal arena. */ + support::GenericArena arena_; + /*! \brief Whether we are in a state that allows clean shutdown. */ + bool allow_clean_shutdown_{true}; + static_assert(DMLC_LITTLE_ENDIAN, "MinRPC only works on little endian."); +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ diff --git a/src/runtime/rpc/minrpc/posix_popen_server.cc b/src/runtime/rpc/minrpc/posix_popen_server.cc new file mode 100644 index 000000000000..fdc57112f0b9 --- /dev/null +++ b/src/runtime/rpc/minrpc/posix_popen_server.cc @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Disable constructor to bring minimum dep on c++ABI. +#define TVM_ARENA_HAS_DESTRUCTOR 0 + +#include +#include +#include "minrpc_server.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief IOHandler based on posix API. + */ +class PosixIOHandler { + public: + explicit PosixIOHandler(int read_fd = 0, int write_fd = 1) + : read_fd_(read_fd), write_fd_(write_fd) { + } + + ssize_t PosixRead(void* data, size_t size) { + return read(read_fd_, data, size); + } + + ssize_t PosixWrite(const void* data, size_t size) { + return write(write_fd_, data, size); + } + + void Exit(int code) { + exit(code); + } + + void Close() { + if (read_fd_ != 0) close(read_fd_); + if (write_fd_ != 0) close(write_fd_); + } + + private: + int read_fd_{0}; + int write_fd_{1}; +}; + +/*! \brief Type for the posix version of min rpc server. */ +using PosixMinRPCServer = MinRPCServer; + +} // namespace runtime +} // namespace tvm + +int main(int argc, char* argv[]) { + if (argc != 3) return -1; + // pass the descriptor via arguments. + tvm::runtime::PosixIOHandler handler(atoi(argv[1]), atoi(argv[2])); + tvm::runtime::PosixMinRPCServer server(handler); + server.ServerLoop(); + return 0; +} diff --git a/src/runtime/rpc/rpc_channel.cc b/src/runtime/rpc/rpc_channel.cc new file mode 100644 index 000000000000..f8dc6e636324 --- /dev/null +++ b/src/runtime/rpc/rpc_channel.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_channel.cc + */ +#include +#include "rpc_channel.h" + +namespace tvm { +namespace runtime { + +size_t CallbackChannel::Send(const void* data, size_t size) { + TVMByteArray bytes; + bytes.data = static_cast(data); + bytes.size = size; + int64_t n = fsend_(bytes); + if (n == -1) { + LOG(FATAL) << "CallbackChannel::Send"; + } + return static_cast(n); +} + +size_t CallbackChannel::Recv(void* data, size_t size) { + TVMRetValue ret = frecv_(size); + + if (ret.type_code() != kTVMBytes) { + LOG(FATAL) << "CallbackChannel::Recv"; + } + std::string* bytes = ret.ptr(); + memcpy(static_cast(data), bytes->c_str(), bytes->length()); + return bytes->length(); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_channel.h b/src/runtime/rpc/rpc_channel.h new file mode 100644 index 000000000000..be34a8b50440 --- /dev/null +++ b/src/runtime/rpc/rpc_channel.h @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_channel.h + * \brief Communication endpoints to connect local and remote RPC sessions. + */ +#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_H_ +#define TVM_RUNTIME_RPC_RPC_CHANNEL_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Abstract channel interface used to create RPCEndpoint. + */ +class RPCChannel { + public: + /*! \brief virtual destructor */ + virtual ~RPCChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + virtual size_t Send(const void* data, size_t size) = 0; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + virtual size_t Recv(void* data, size_t size) = 0; +}; + +/*! + * \brief RPC channel which callback + * frontend (Python/Java/etc.)'s send & recv function + */ +class CallbackChannel final : public RPCChannel { + public: + /*! + * \brief Constructor. + * + * \param fsend The send function, takes in a TVMByteArray and returns the + * number of bytes sent in that array. Returns -1 if error happens. + * \param frecv The recv function, takes an expected maximum size, and return + * a byte array with the actual amount of data received. + */ + explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) + : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} + + ~CallbackChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + size_t Send(const void* data, size_t size) final; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + size_t Recv(void* data, size_t size) final; + + private: + PackedFunc fsend_; + PackedFunc frecv_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_CHANNEL_H_ diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 9fd45acd14bf..93af4e268ba6 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include "rpc_session.h" namespace tvm { @@ -31,20 +32,24 @@ namespace runtime { class RPCDeviceAPI final : public DeviceAPI { public: void SetDevice(TVMContext ctx) final { - GetSess(ctx)->CallRemote( - RPCCode::kDevSetDevice, ctx); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->SetDevice(remote_ctx); } + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { - *rv = GetSess(ctx)->CallRemote( - RPCCode::kDevGetAttr, ctx, static_cast(kind)); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->GetAttr(remote_ctx, kind, rv); } + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { auto sess = GetSess(ctx); - void *data = sess->CallRemote( - RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + auto remote_ctx = RemoveSessMask(ctx); + void *data = sess->GetDeviceAPI(remote_ctx)->AllocDataSpace( + remote_ctx, nbytes, alignment, type_hint); + RemoteSpace* space = new RemoteSpace(); space->data = data; space->sess = std::move(sess); @@ -52,9 +57,10 @@ class RPCDeviceAPI final : public DeviceAPI { } void FreeDataSpace(TVMContext ctx, void* ptr) final { RemoteSpace* space = static_cast(ptr); + auto remote_ctx = RemoveSessMask(ctx); try { - GetSess(ctx)->CallRemote( - RPCCode::kDevFreeData, ctx, space->data); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace( + remote_ctx, space->data); } catch (const dmlc::Error& e) { // fault tolerance to remote close. } @@ -75,29 +81,35 @@ class RPCDeviceAPI final : public DeviceAPI { to_dev_type > kRPCSessMask) { CHECK(ctx_from.device_type == ctx_to.device_type) << "Cannot copy across two different remote session"; - GetSess(ctx_from)->CallRemote( - RPCCode::kCopyAmongRemote, - static_cast(from)->data, from_offset, - static_cast(to)->data, to_offset, - size, ctx_from, ctx_to, type_hint, stream); + auto remote_ctx_from = RemoveSessMask(ctx_from); + auto remote_ctx_to = RemoveSessMask(ctx_to); + auto remote_ctx = remote_ctx_from; + if (remote_ctx.device_type == kDLCPU) remote_ctx = remote_ctx_to; + GetSess(ctx_from)->GetDeviceAPI(remote_ctx) + ->CopyDataFromTo(static_cast(from)->data, from_offset, + static_cast(to)->data, to_offset, + size, remote_ctx_from, remote_ctx_to, type_hint, stream); } else if (from_dev_type > kRPCSessMask && to_dev_type == kDLCPU) { + auto remote_ctx_from = RemoveSessMask(ctx_from); GetSess(ctx_from)->CopyFromRemote( static_cast(from)->data, from_offset, - to, to_offset, size, ctx_from, type_hint); + to, to_offset, size, remote_ctx_from, type_hint); } else if (from_dev_type == kDLCPU && to_dev_type > kRPCSessMask) { + auto remote_ctx_to = RemoveSessMask(ctx_to); GetSess(ctx_to)->CopyToRemote( - (void*)from, from_offset, // NOLINT(*) + const_cast(from), from_offset, static_cast(to)->data, to_offset, - size, ctx_to, type_hint); + size, remote_ctx_to, type_hint); } else { LOG(FATAL) << "expect copy from/to remote or between remote"; } } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - GetSess(ctx)->CallRemote( - RPCCode::kDevStreamSync, ctx, stream); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->StreamSync(remote_ctx, stream); } private: @@ -107,6 +119,11 @@ class RPCDeviceAPI final : public DeviceAPI { int tbl_index = dev_type / kRPCSessMask - 1; return RPCSession::Get(tbl_index); } + + static TVMContext RemoveSessMask(TVMContext ctx) { + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + return ctx; + } }; TVM_REGISTER_GLOBAL("device_api.rpc") diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc new file mode 100644 index 000000000000..8a7f11cbdea6 --- /dev/null +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -0,0 +1,1064 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_session.cc + * \brief RPC session for remote function call. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rpc_endpoint.h" +#include "rpc_local_session.h" +#include "../object_internal.h" +#include "../../support/ring_buffer.h" +#include "../../support/arena.h" + +namespace tvm { +namespace runtime { + +/*! + * Event-driven state-machine based handlers for RPCEndpoint. + * + * Key functions: + * + * - SendPackedSeq: send the arguments over to the peer + * - HandleNextEvent: handle the next request from the peer(RPCCode followed by per code protocol). + */ +class RPCEndpoint::EventHandler : public dmlc::Stream { + public: + EventHandler(support::RingBuffer* reader, + support::RingBuffer* writer, + std::string name, + std::string* remote_key) + : reader_(reader), + writer_(writer), + name_(name), + remote_key_(remote_key) { + this->Clear(); + + if (*remote_key == "%toinit") { + state_ = kInitHeader; + remote_key_->resize(0); + pending_request_bytes_ = sizeof(int32_t); + } + } + + /*! + * \brief Bytes needed to fulfill current request + */ + size_t BytesNeeded() const { + if (reader_->bytes_available() < pending_request_bytes_) { + return pending_request_bytes_ - reader_->bytes_available(); + } else { + return 0; + } + } + + /*! + * \brief Request number of bytes from the reader. + * \param nbytes The number of bytes + */ + void RequestBytes(size_t nbytes) { + pending_request_bytes_ += nbytes; + reader_->Reserve(pending_request_bytes_); + } + + /*! \return Whether we are ready to handle next request. */ + bool Ready() const { + return reader_->bytes_available() >= pending_request_bytes_; + } + + /*! \return Whether we can perform a clean shutdown */ + bool CanCleanShutdown() const { + return state_ == kRecvPacketNumBytes; + } + + /*! \brief Finish the copy ack stage. */ + void FinishCopyAck() { + this->SwitchToState(kRecvPacketNumBytes); + } + + /*! + * \brief Enter the io loop until the next event. + * \param client_mode Whether we are in the client. + * \param setreturn The function to set the return value encoding. + * \return The function to set return values when there is a return event. + */ + RPCCode HandleNextEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) { + std::swap(client_mode_, client_mode); + + while (this->Ready()) { + switch (state_) { + case kInitHeader: HandleInitHeader(); break; + case kRecvPacketNumBytes: { + uint64_t packet_nbytes; + CHECK(this->Read(&packet_nbytes)); + if (packet_nbytes != 0) { + this->SwitchToState(kProcessPacket); + this->RequestBytes(packet_nbytes); + } else { + this->SwitchToState(kRecvPacketNumBytes); + } + break; + } + case kProcessPacket: { + this->HandleProcessPacket(setreturn); + break; + } + case kReturnReceived: { + this->SwitchToState(kRecvPacketNumBytes); + std::swap(client_mode_, client_mode); + return RPCCode::kReturn; + } + case kCopyAckReceived: { + std::swap(client_mode_, client_mode); + return RPCCode::kCopyAck; + } + case kShutdownReceived: { + std::swap(client_mode_, client_mode); + return RPCCode::kShutdown; + } + } + } + std::swap(client_mode_, client_mode); + return RPCCode::kNone; + } + + /*! \brief Clear all the states in the Handler.*/ + void Clear() { + state_ = kRecvPacketNumBytes; + pending_request_bytes_ = sizeof(uint64_t); + } + + /*! + * \brief Validate that the arguments can be sent through RPC. + * \param arg_values The argument values. + * \param type_codes The type codes. + */ + void ValidateArguments(const TVMValue* arg_values, + const int* type_codes, + int num_args) { + TVMArgs args(arg_values, type_codes, num_args); + for (int i = 0; i < num_args; ++i) { + int tcode = type_codes[i]; + if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) { + LOG(FATAL) << "ValueError: Cannot pass argument " << i + << ", type " << args[i].AsObjectRef()->GetTypeKey() + << " is not supported by RPC"; + } else if (tcode == kTVMContext) { + DLContext ctx = args[i]; + CHECK_LT(static_cast(ctx.device_type), kRPCSessMask) + << "InternalError: cannot pass RPC context in the channel"; + } + } + } + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code); + } + + uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, + const int* type_codes, + int num_args, + bool client_mode) { + return RPCReference::PackedSeqGetNumBytes( + arg_values, type_codes, num_args, client_mode, this); + } + + void SendPackedSeq(const TVMValue* arg_values, + const int* type_codes, + int num_args, + bool client_mode) { + RPCReference::SendPackedSeq( + arg_values, type_codes, num_args, client_mode, this); + } + + // Endian aware IO handling + using Stream::Read; + using Stream::Write; + using Stream::ReadArray; + using Stream::WriteArray; + + bool Read(RPCCode* code) { + int32_t cdata; + if (!this->Read(&cdata)) return false; + *code = static_cast(cdata); + return true; + } + void Write(RPCCode code) { + int32_t cdata = static_cast(code); + this->Write(cdata); + } + + template + T* ArenaAlloc(int count) { + static_assert(std::is_pod::value, "need to be trival"); + return arena_.template allocate_(count); + } + + protected: + enum State { + kInitHeader, + kRecvPacketNumBytes, + kProcessPacket, + kReturnReceived, + kCopyAckReceived, + kShutdownReceived + }; + // Current state; + State state_; + // Initialize remote header + bool init_header_step_{0}; + // Whether current handler is client or server mode. + bool client_mode_{false}; + // Internal arena + support::Arena arena_; + + // State switcher + void SwitchToState(State state) { + // invariant + if (state != kCopyAckReceived) { + CHECK_EQ(pending_request_bytes_, 0U) + << "state=" << state; + } + state_ = state; + CHECK(state != kInitHeader) + << "cannot switch to init header"; + if (state == kRecvPacketNumBytes) { + this->RequestBytes(sizeof(uint64_t)); + // recycle arena for the next session. + arena_.RecycleAll(); + } + } + + // handler for initial header read + void HandleInitHeader() { + if (init_header_step_ == 0) { + int32_t len; + this->Read(&len); + remote_key_->resize(len); + init_header_step_ = 1; + this->RequestBytes(len); + return; + } else { + CHECK_EQ(init_header_step_, 1); + this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); + this->SwitchToState(kRecvPacketNumBytes); + } + } + + // Handler for read code. + void HandleProcessPacket(RPCSession::FEncodeReturn setreturn) { + RPCCode code = RPCCode::kNone; + this->Read(&code); + + if (code >= RPCCode::kSyscallCodeStart) { + this->HandleSyscall(code); + } else { + switch (code) { + case RPCCode::kInitServer: { + this->HandleInitServer(); + break; + } + case RPCCode::kCallFunc: { + this->HandleNormalCallFunc(); + break; + } + case RPCCode::kCopyFromRemote: { + this->HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + this->HandleCopyToRemote(); + break; + } + case RPCCode::kException: + case RPCCode::kReturn: { + this->HandleReturn(code, setreturn); + break; + } + case RPCCode::kCopyAck: { + this->SwitchToState(kCopyAckReceived); + break; + } + case RPCCode::kShutdown: { + this->SwitchToState(kShutdownReceived); + break; + } + default: LOG(FATAL) << "Unknown event " << static_cast(code); + } + } + } + + /*! + * \brief Recive incoming packed seq from the stream. + * \return The received argments. + * \note The TVMArgs is available until we switchstate. + */ + TVMArgs RecvPackedSeq() { + TVMValue* values; + int* tcodes; + int num_args; + RPCReference::RecvPackedSeq(&values, &tcodes, &num_args, this); + return TVMArgs(values, tcodes, num_args); + } + + /*! + * \brief Return exception to the remote. + * \param err_msg The error message. + */ + void ReturnException(const char* err_msg) { + RPCReference::ReturnException(err_msg, this); + } + + /*! + * \brief Return nullptr to the remote. + * \param err_msg The error message. + */ + void ReturnVoid() { + RPCReference::ReturnVoid(this); + } + + /*! + * \brief Return a packed sequence to the remote. + * \param args The arguments. + */ + void ReturnPackedSeq(TVMArgs args) { + RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.size(), this); + } + + /*! + * \brief Handle the case when return/exception value is received. + * \param code The RPC code. + * \param setreturn The function to encode return. + */ + void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) { + TVMArgs args = RecvPackedSeq(); + + if (code == RPCCode::kException) { + // switch to the state before sending exception. + this->SwitchToState(kRecvPacketNumBytes); + std::string msg = args[0]; + LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg; + } + + CHECK(setreturn != nullptr) << "fsetreturn not available"; + setreturn(args); + + this->SwitchToState(kReturnReceived); + } + + void HandleSyscall(RPCCode code); + + void HandleCopyFromRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; + + char* data_ptr; + auto* sess = GetServingSession(); + + // When session is local, we can directly treat handle + // as the cpu pointer without allocating a temp space. + if (ctx.device_type == kDLCPU && + sess->IsLocalSession() && + DMLC_IO_NO_ENDIAN_SWAP) { + data_ptr = reinterpret_cast(handle) + offset; + } else { + try { + data_ptr = this->ArenaAlloc(num_bytes); + sess->CopyFromRemote( + reinterpret_cast(handle), offset, + data_ptr, 0, + num_bytes, ctx, type_hint); + // endian aware handling + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes); + } + } catch (const std::runtime_error &e) { + this->ReturnException(e.what()); + this->SwitchToState(kRecvPacketNumBytes); + return; + } + } + RPCCode code = RPCCode::kCopyAck; + uint64_t packet_nbytes = sizeof(code) + num_bytes; + + // Return Copy Ack + this->Write(packet_nbytes); + this->Write(code); + this->WriteArray(data_ptr, num_bytes); + + this->SwitchToState(kRecvPacketNumBytes); + } + + void HandleCopyToRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + + size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; + auto* sess = GetServingSession(); + + // When session is local, we can directly treat handle + // as the cpu pointer without allocating a temp space. + if (ctx.device_type == kDLCPU && sess->IsLocalSession()) { + char* dptr = reinterpret_cast(handle) + offset; + this->ReadArray(dptr, num_bytes); + + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes); + } + } else { + char* temp_data = this->ArenaAlloc(num_bytes); + this->ReadArray(temp_data, num_bytes); + + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(temp_data, elem_bytes, num_bytes / elem_bytes); + } + + try { + sess->CopyToRemote( + temp_data, 0, + reinterpret_cast(handle), offset, + num_bytes, ctx, type_hint); + } catch (const std::runtime_error &e) { + this->ReturnException(e.what()); + this->SwitchToState(kRecvPacketNumBytes); + return; + } + } + + this->ReturnVoid(); + this->SwitchToState(kRecvPacketNumBytes); + } + + // Handle for packed call. + void HandleNormalCallFunc() { + uint64_t call_handle; + + this->Read(&call_handle); + TVMArgs args = RecvPackedSeq(); + + try { + GetServingSession()->CallFunc( + reinterpret_cast(call_handle), + args.values, args.type_codes, args.size(), + [this](TVMArgs ret) { this->ReturnPackedSeq(ret); }); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + } + + this->SwitchToState(kRecvPacketNumBytes); + } + + void HandleInitServer() { + std::string client_protocol_ver; + + uint64_t len; + this->Read(&len); + client_protocol_ver.resize(len); + this->Read(dmlc::BeginPtr(client_protocol_ver), len); + + TVMArgs args = RecvPackedSeq(); + + try { + CHECK(serving_session_ == nullptr) + << "Server has already been initialized"; + + std::string server_protocol_ver = kRPCProtocolVer; + CHECK_EQ(client_protocol_ver, server_protocol_ver) + << "Server[" << name_ << "]: Client protocol version mismatch with the server " + << " server protocol=" << server_protocol_ver + << ", client protocol=" << client_protocol_ver; + + if (args.size() == 0) { + serving_session_ = std::make_shared(); + } else { + std::string constructor_name = args[0]; + auto* fconstructor = Registry::Get(constructor_name); + CHECK(fconstructor != nullptr) + << " Cannot find session constructor " << constructor_name; + TVMRetValue con_ret; + + try { + fconstructor->CallPacked( + TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), &con_ret); + } catch (const dmlc::Error& e) { + LOG(FATAL) << "Server[" << name_ << "]:" + << " Error caught from session constructor " << constructor_name + << ":\n" << e.what(); + } + + CHECK_EQ(con_ret.type_code(), kTVMModuleHandle) + << "Server[" << name_ << "]:" + << " Constructor " << constructor_name + << " need to return an RPCModule"; + Module mod = con_ret; + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") + << "Constructor " << constructor_name << " to return an RPCModule"; + serving_session_ = RPCModuleGetSession(mod); + } + + this->ReturnVoid(); + } catch (const std::runtime_error &e) { + this->ReturnException(e.what()); + } + + this->SwitchToState(kRecvPacketNumBytes); + } + + // Handler for special syscalls that have a specific RPCCode. + template + void SysCallHandler(F f) { + TVMArgs args = RecvPackedSeq(); + try { + TVMRetValue rv; + f(GetServingSession(), args, &rv); + TVMValue ret_value; + int ret_tcode; + TVMArgsSetter setter(&ret_value, &ret_tcode); + setter(0, rv); + + this->ReturnPackedSeq(TVMArgs(&ret_value, &ret_tcode, 1)); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + } + this->SwitchToState(kRecvPacketNumBytes); + } + + private: + RPCSession* GetServingSession() const { + CHECK(serving_session_ != nullptr) + << "Need to call InitRemoteSession first before any further actions"; + return serving_session_.get(); + } + // Utility functions + // Internal read function, update pending_request_bytes_ + size_t Read(void* data, size_t size) final { + CHECK_LE(size, pending_request_bytes_); + reader_->Read(data, size); + pending_request_bytes_ -= size; + return size; + } + // wriite the data to the channel. + void Write(const void* data, size_t size) final { + writer_->Write(data, size); + } + // Number of pending bytes requests + size_t pending_request_bytes_{0}; + // The ring buffer to read data from. + support::RingBuffer* reader_; + // The ringr buffer to write reply to. + support::RingBuffer* writer_; + // The session used to serve the RPC requests. + std::shared_ptr serving_session_; + // Name of endpoint. + std::string name_; + // remote key + std::string* remote_key_; +}; + +RPCCode RPCEndpoint::HandleUntilReturnEvent( + bool client_mode, RPCSession::FEncodeReturn setreturn) { + RPCCode code = RPCCode::kCallFunc; + while (code != RPCCode::kReturn && + code != RPCCode::kShutdown && + code != RPCCode::kCopyAck) { + while (writer_.bytes_available() != 0) { + writer_.ReadWithCallback([this](const void *data, size_t size) { + return channel_->Send(data, size); + }, writer_.bytes_available()); + } + size_t bytes_needed = handler_->BytesNeeded(); + if (bytes_needed != 0) { + size_t n = reader_.WriteWithCallback([this](void* data, size_t size) { + return channel_->Recv(data, size); + }, bytes_needed); + if (n == 0) { + if (handler_->CanCleanShutdown()) { + return RPCCode::kShutdown; + } else { + LOG(FATAL) << "Channel closes before we get neded bytes"; + } + } + } + code = handler_->HandleNextEvent(client_mode, setreturn); + } + return code; +} + +void RPCEndpoint::Init() { + // Event handler + handler_ = std::make_shared( + &reader_, &writer_, name_, &remote_key_); + // Quick function to for syscall remote. + syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) { + std::lock_guard lock(mutex_); + RPCCode code = static_cast(all_args[0].operator int()); + TVMArgs args(all_args.values + 1, all_args.type_codes +1, all_args.num_args -1); + + uint64_t packet_nbytes = + sizeof(code) + + handler_->PackedSeqGetNumBytes( + args.values, args.type_codes, args.num_args, true); + + // All packet begins with packet nbytes + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); + + code = HandleUntilReturnEvent(true, [rv](TVMArgs args) { + CHECK_EQ(args.size(), 1); + *rv = args[0]; + }); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); + }); +} + +std::shared_ptr RPCEndpoint::Create( + std::unique_ptr channel, + std::string name, + std::string remote_key) { + std::shared_ptr endpt = std::make_shared(); + endpt->channel_ = std::move(channel); + endpt->name_ = std::move(name); + endpt->remote_key_ = std::move(remote_key); + endpt->Init(); + return endpt; +} + +RPCEndpoint::~RPCEndpoint() { + this->Shutdown(); +} + +void RPCEndpoint::Shutdown() { + if (channel_ != nullptr) { + RPCCode code = RPCCode::kShutdown; + uint64_t packet_nbytes = sizeof(code); + + handler_->Write(packet_nbytes); + handler_->Write(code); + + // flush all writing buffer to output channel. + try { + while (writer_.bytes_available() != 0) { + size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { + return channel_->Send(data, size); + }, writer_.bytes_available()); + if (n == 0) break; + } + } catch (const dmlc::Error& e) { + } + channel_.reset(nullptr); + } +} + +void RPCEndpoint::ServerLoop() { + if (const auto* f = Registry::Get("tvm.rpc.server.start")) { + (*f)(); + } + TVMRetValue rv; + CHECK(HandleUntilReturnEvent(false, [](TVMArgs) {}) == RPCCode::kShutdown); + if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) { + (*f)(); + } + channel_.reset(nullptr); +} + +int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) { + RPCCode code = RPCCode::kNone; + if (in_bytes.length() != 0) { + reader_.Write(in_bytes.c_str(), in_bytes.length()); + code = handler_->HandleNextEvent(false, [](TVMArgs) {}); + } + if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { + writer_.ReadWithCallback([this](const void *data, size_t size) { + return channel_->Send(data, size); + }, writer_.bytes_available()); + } + CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); + if (code == RPCCode::kShutdown) return 0; + if (writer_.bytes_available() != 0) return 2; + return 1; +} + +void RPCEndpoint::InitRemoteSession(TVMArgs args) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kInitServer; + std::string protocol_ver = kRPCProtocolVer; + uint64_t length = protocol_ver.length(); + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(length) + + length + + handler_->PackedSeqGetNumBytes( + args.values, args.type_codes, args.num_args, true); + + // All packet begins with packet nbytes + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(length); + handler_->WriteArray(protocol_ver.data(), length); + handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); + + code = HandleUntilReturnEvent(true, [](TVMArgs args) {}); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); +} + +// Get remote function with name +void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + RPCSession::FEncodeReturn encode_return) { + std::lock_guard lock(mutex_); + + handler_->ValidateArguments(arg_values, arg_type_codes, num_args); + RPCCode code = RPCCode::kCallFunc; + uint64_t handle = reinterpret_cast(h); + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(handle) + + handler_->PackedSeqGetNumBytes( + arg_values, arg_type_codes, num_args, true); + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->SendPackedSeq( + arg_values, arg_type_codes, num_args, true); + + code = HandleUntilReturnEvent(true, encode_return); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); +} + +void RPCEndpoint::CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t data_size, + TVMContext ctx_to, + DLDataType type_hint) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCopyToRemote; + uint64_t handle = reinterpret_cast(to); + uint64_t offset = static_cast(to_offset); + uint64_t size = static_cast(data_size); + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(handle) + + sizeof(offset) + + sizeof(size) + + sizeof(ctx_to) + + sizeof(type_hint) + + data_size; + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->Write(offset); + handler_->Write(size); + handler_->Write(ctx_to); + handler_->Write(type_hint); + handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); + + CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kReturn); +} + +void RPCEndpoint::CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t data_size, + TVMContext ctx_from, + DLDataType type_hint) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCopyFromRemote; + uint64_t handle = reinterpret_cast(from); + uint64_t offset = static_cast(from_offset); + uint64_t size = static_cast(data_size); + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(handle) + + sizeof(offset) + + sizeof(size) + + sizeof(ctx_from) + + sizeof(type_hint); + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->Write(offset); + handler_->Write(size); + handler_->Write(ctx_from); + handler_->Write(type_hint); + + TVMRetValue rv; + CHECK(HandleUntilReturnEvent(true, [](TVMArgs){}) == RPCCode::kCopyAck); + handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); + handler_->FinishCopyAck(); +} + +// SysCallEventHandler functions +void RPCGetGlobalFunc(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + std::string name = args[0]; + *rv = handler->GetFunction(name); +} + +void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + void* handle = args[0]; + int type_code = args[1]; + handler->FreeHandle(handle, type_code); +} + +void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + handler->GetDeviceAPI(ctx)->SetDevice(ctx); +} + +void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + DeviceAttrKind kind = static_cast(args[1].operator int()); + if (kind == kExist) { + DeviceAPI* api = handler->GetDeviceAPI(ctx, true); + if (api != nullptr) { + api->GetAttr(ctx, kind, rv); + } else { + *rv = 0; + } + } else { + handler->GetDeviceAPI(ctx)->GetAttr( + ctx, static_cast(kind), rv); + } +} + +void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + uint64_t nbytes = args[1]; + uint64_t alignment = args[2]; + DLDataType type_hint = args[3]; + void* data = handler->GetDeviceAPI(ctx)->AllocDataSpace( + ctx, nbytes, alignment, type_hint); + *rv = data; +} + +void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + void* ptr = args[1]; + handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr); +} + +void RPCDevStreamSync(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + TVMContext ctx = args[0]; + TVMStreamHandle handle = args[1]; + handler->GetDeviceAPI(ctx)->StreamSync(ctx, handle); +} + +void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue *rv) { + void* from = args[0]; + uint64_t from_offset = args[1]; + void* to = args[2]; + uint64_t to_offset = args[3]; + uint64_t size = args[4]; + TVMContext ctx_from = args[5]; + TVMContext ctx_to = args[6]; + DLDataType type_hint = args[7]; + TVMStreamHandle stream = args[8]; + TVMContext ctx = ctx_from; + + if (ctx.device_type == kDLCPU) { + ctx = ctx_to; + } else { + CHECK(ctx_to.device_type == kDLCPU || + ctx_to.device_type == ctx_from.device_type) + << "Can not copy across different ctx types directly"; + } + handler->GetDeviceAPI(ctx)->CopyDataFromTo( + from, from_offset, + to, to_offset, + size, ctx_from, ctx_to, type_hint, stream); +} + +void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { + // Event handler sit at clean state at this point. + switch (code) { + // system functions + case RPCCode::kFreeHandle: SysCallHandler(RPCFreeHandle); break; + case RPCCode::kGetGlobalFunc: SysCallHandler(RPCGetGlobalFunc); break; + case RPCCode::kDevSetDevice: SysCallHandler(RPCDevSetDevice); break; + case RPCCode::kDevGetAttr: SysCallHandler(RPCDevGetAttr); break; + case RPCCode::kDevAllocData: SysCallHandler(RPCDevAllocData); break; + case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break; + case RPCCode::kDevStreamSync: SysCallHandler(RPCDevStreamSync); break; + case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break; + default: LOG(FATAL) << "Unknown event " << static_cast(code); + } + + CHECK_EQ(state_, kRecvPacketNumBytes); +} + +/*! + * \brief RPC client session that proxies all calls to an endpoint. + */ +class RPCClientSession : public RPCSession, + public DeviceAPI { + public: + /*! + * \brief param endpoint The client endpoint of the session. + */ + explicit RPCClientSession(std::shared_ptr endpoint) + : endpoint_(endpoint) {} + + // function overrides + PackedFuncHandle GetFunction(const std::string& name) final { + return endpoint_->SysCallRemote(RPCCode::kGetGlobalFunc, name); + } + + void CallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& fencode_return) final { + endpoint_->CallFunc( + func, arg_values, arg_type_codes, num_args, fencode_return); + } + + void CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint) final { + endpoint_->CopyToRemote( + from, from_offset, to, to_offset, nbytes, ctx_to, type_hint); + } + + void CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint) final { + endpoint_->CopyFromRemote( + from, from_offset, to, to_offset, nbytes, ctx_from, type_hint); + } + + void FreeHandle(void* handle, int type_code) final { + endpoint_->SysCallRemote(RPCCode::kFreeHandle, handle, type_code); + } + + + void SetDevice(TVMContext ctx) final { + endpoint_->SysCallRemote(RPCCode::kDevSetDevice, ctx); + } + + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { + if (ctx.device_type == kDLCPU && kind == kExist) { + // cpu always exists. + *rv = 1; + } else { + *rv = endpoint_->SysCallRemote(RPCCode::kDevGetAttr, ctx, static_cast(kind)); + } + } + + void* AllocDataSpace(TVMContext ctx, + size_t nbytes, + size_t alignment, + DLDataType type_hint) final { + return endpoint_->SysCallRemote( + RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + } + + void FreeDataSpace(TVMContext ctx, void* ptr) final { + endpoint_->SysCallRemote(RPCCode::kDevFreeData, ctx, ptr); + } + + void CopyDataFromTo(const void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t size, + TVMContext ctx_from, + TVMContext ctx_to, + DLDataType type_hint, + TVMStreamHandle stream) final { + endpoint_->SysCallRemote( + RPCCode::kCopyAmongRemote, + const_cast(from), from_offset, + to, to_offset, + size, + ctx_from, ctx_to, + type_hint, stream); + } + + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevStreamSync, ctx, stream); + } + + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing) final { + return this; + } + + bool IsLocalSession() const final { + return false; + } + + private: + std::shared_ptr endpoint_; +}; + +std::shared_ptr +CreateClientSession(std::shared_ptr endpoint) { + return std::make_shared(endpoint); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h new file mode 100644 index 000000000000..9a6afcdc1ca4 --- /dev/null +++ b/src/runtime/rpc/rpc_endpoint.h @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_endpoint.h + * \brief Communication endpoints to connect local and remote RPC sessions. + */ +#ifndef TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ +#define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ + +#include +#include +#include +#include +#include +#include "rpc_session.h" +#include "rpc_channel.h" +#include "rpc_protocol.h" +#include "../../support/ring_buffer.h" + +namespace tvm { +namespace runtime { + +// Magic header for RPC data plane +const int kRPCMagic = 0xff271; +// magic header for RPC tracker(control plane) +const int kRPCTrackerMagic = 0x2f271; +// sucess response +const int kRPCSuccess = kRPCMagic + 0; +// cannot found matched key in server +const int kRPCMismatch = kRPCMagic + 2; + +/*! \brief Enumeration code for the RPC tracker */ +enum class TrackerCode : int { + kFail = -1, + kSuccess = 0, + kPing = 1, + kStop = 2, + kPut = 3, + kRequest = 4, + kUpdateInfo = 5, + kSummary = 6, + kGetPendingMatchKeys = 7 +}; + + +/*! + * \brief Communication endpoints to connect local and remote RPC sessions. + * An endpoint can either be a client or a server. + */ +class RPCEndpoint { + public: + /*! \brief virtual destructor */ + ~RPCEndpoint(); + /*! + * \brief The server loop that server runs to handle RPC calls. + */ + void ServerLoop(); + /*! + * \brief Message handling function for an async IO event driven server. + * + * Called when the server receives a message or an IO event update. + * Event driven handler will never call recv on the channel + * and always relies on the ServerIOEventHandler to receive the data. + * + * \param in_bytes The incoming bytes. + * \param event_flag 1: read_available, 2: write_avaiable. + * \return State flag. + * 1: continue running, no need to write, + * 2: need to write + * 0: shutdown + */ + int ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag); + + /*! + * \brief Initalize the session on the remote that will be used to back all the RPC requests. + * + * If no session constructor arguments is passed, LocalSession will be used in the remote. + * Otherwise the remote serving session will be constructed using the arguments + * specified in the session_constructor_args. + * + * The construction rule can be summarized as follows: + * + * \code + * + * auto args = session_constructor_args; + * int n = args.size(); + * if (n != 0) { + * std::string constructor = args[0]; + * server.serving_session_ = GetGlobalFunc(constructor)( + * args[1], args[2] ... args[n - 1]) + * } else { + * server.serving_session_ = LocalSession(); + * } + * \endcode + * + * \param session_constructor_args Optional sequence of the remote sesssion constructor. + */ + void InitRemoteSession(TVMArgs session_constructor_args); + + /*! + * \brief Call into remote function + * \param handle The function handle + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * \param fencode_return The function to receive return value encodings. + */ + void CallFunc(RPCSession::PackedFuncHandle handle, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + RPCSession::FEncodeReturn encode_return); + /*! + * \brief Copy bytes into remote array content. + * \param from The source host data. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param ctx_to The target context. + * \param type_hint Hint of content data type. + */ + void CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint); + /*! + * \brief Copy bytes from remote array content. + * \param from The source host data. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param ctx_from The source context. + * \param type_hint Hint of content data type. + */ + void CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint); + + /*! + * \brief Call a remote defined system function with arguments. + * \param fcode The function code. + * \param args The arguments + * \return The returned remote value. + */ + template + inline TVMRetValue SysCallRemote(RPCCode fcode, Args&& ...args); + /*! + * \brief Create a RPC session with given channel. + * \param channel The communication channel. + * \param name The local name of the session, used for debug + * \param remote_key The remote key of the session + * if remote_key equals "%toinit", we need to re-intialize + * it by event handler. + */ + static std::shared_ptr Create( + std::unique_ptr channel, + std::string name, + std::string remote_key); + + private: + class EventHandler; + // Handle events until receives a return + // Also flushes channels so that the function advances. + RPCCode HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn); + // Initalization + void Init(); + // Shutdown + void Shutdown(); + // Internal channel. + std::unique_ptr channel_; + // Internal mutex + std::mutex mutex_; + // Internal ring buffer. + support::RingBuffer reader_, writer_; + // Event handler. + std::shared_ptr handler_; + // syscall remote with specified function code. + PackedFunc syscall_remote_; + // The name of the session. + std::string name_; + // The remote key + std::string remote_key_; +}; + +/*! + * \brief Create an RPC client session from an RPC client endpoint. + * \param endpoint The endpoint. + * \return The created session. + */ +std::shared_ptr +CreateClientSession(std::shared_ptr endpoint); + +// implementation of inline functions +template +inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&& ...args) { + return syscall_remote_(static_cast(code), std::forward(args)...); +} +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index 29adb0fed108..284dca5cce6b 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,11 +19,12 @@ /*! * \file rpc_event_impl.cc - * \brief Event based RPC server implementation. + * \brief Event driven RPC server implementation. */ #include #include -#include "rpc_session.h" +#include "rpc_endpoint.h" +#include "rpc_local_session.h" namespace tvm { namespace runtime { @@ -35,16 +36,17 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend, LOG(FATAL) << "Do not allow explicit receive"; return 0; }); + std::unique_ptr ch(new CallbackChannel(fsend, frecv)); - std::shared_ptr sess = - RPCSession::Create(std::move(ch), name, remote_key); + std::shared_ptr sess = + RPCEndpoint::Create(std::move(ch), name, remote_key); return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - int ret = sess->ServerEventHandler(args[0], args[1]); + int ret = sess->ServerAsyncIOEventHandler(args[0], args[1]); *rv = ret; }); } -TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer") +TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer") .set_body_typed(CreateEventDrivenServer); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc new file mode 100644 index 000000000000..351a9896d899 --- /dev/null +++ b/src/runtime/rpc/rpc_local_session.cc @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file local_session.cc + * \brief Local session that directs requests to local API. + */ +#include +#include +#include +#include "rpc_local_session.h" + +namespace tvm { +namespace runtime { + +RPCSession::PackedFuncHandle +LocalSession::GetFunction(const std::string& name) { + PackedFunc pf = this->GetFunctionInternal(name); + // return raw handl because the remote need to explicitly manage it. + if (pf != nullptr) return new PackedFunc(pf); + return nullptr; +} + +void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& encode_return) { + auto* pf = static_cast(func); + TVMRetValue rv; + + pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); + int rv_tcode = rv.type_code(); + + // return value encoding. + TVMValue ret_value_pack[3]; + int ret_tcode_pack[3]; + TVMArgsSetter set_arg(ret_value_pack, ret_tcode_pack); + // first location always encode type code. + set_arg(0, rv_tcode); + + if (rv_tcode == kTVMNDArrayHandle) { + // We follow a special protocol to return NDArray to client side + // The first pack value is the NDArray handle as DLTensor + // The second pack value is a customized deleter that deletes the NDArray. + rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); + ret_tcode_pack[1] = kTVMDLTensorHandle; + ret_value_pack[2].v_handle = ret_value_pack[1].v_handle; + ret_tcode_pack[2] = kTVMOpaqueHandle; + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3)); + } else if (rv_tcode == kTVMPackedFuncHandle || + rv_tcode == kTVMModuleHandle) { + // MoveToCHost means rv no longer manages the object. + // return handle instead. + rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); + ret_tcode_pack[1] = kTVMOpaqueHandle; + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } else if (rv_tcode == kTVMBytes) { + TVMByteArray byte_arr; + auto* sptr = rv.ptr(); + byte_arr.data = sptr->data(); + byte_arr.size = sptr->length(); + set_arg(1, byte_arr); + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } else { + set_arg(1, rv); + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } +} + +void LocalSession::CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + this->GetDeviceAPI(ctx_to)->CopyDataFromTo( + from, from_offset, + to, to_offset, + nbytes, cpu_ctx, ctx_to, type_hint, nullptr); + // Copy can happen asynchrously + // synchronize to make sure that copy is completed + this->GetDeviceAPI(ctx_to)->StreamSync(ctx_to, nullptr); +} + +void LocalSession::CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + + this->GetDeviceAPI(ctx_from)->CopyDataFromTo( + from, from_offset, + to, to_offset, + nbytes, ctx_from, cpu_ctx, type_hint, nullptr); + // Copy can happen asynchrously + // synchronize to make sure that copy is completed + this->GetDeviceAPI(ctx_from)->StreamSync(ctx_from, nullptr); +} + +void LocalSession::FreeHandle(void* handle, int type_code) { + TVMValue value; + value.v_handle = handle; + // will trigger deleter once the rv goes out of the scope. + TVMRetValue rv = TVMRetValue::MoveFromCHost(value, type_code); +} + +DeviceAPI* LocalSession::GetDeviceAPI(TVMContext ctx, bool allow_missing) { + return DeviceAPI::Get(ctx, allow_missing); +} + +PackedFunc LocalSession::GetFunctionInternal(const std::string& name) { + auto* fp = tvm::runtime::Registry::Get(name); + if (fp != nullptr) { + return *fp; + } else { + return nullptr; + } +} + +TVM_REGISTER_GLOBAL("rpc.LocalSession") +.set_body_typed([]() { + return CreateRPCSessionModule(std::make_shared()); +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h new file mode 100644 index 000000000000..3b6e7d8ea6f0 --- /dev/null +++ b/src/runtime/rpc/rpc_local_session.h @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_local_session.h + * \brief Local session that directs all request to the local runtime API. + */ +#ifndef TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ +#define TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ + +#include +#include +#include +#include +#include "rpc_session.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief A local session that directly use the handle repr of the + * local tvm runtime objects on the same process. + */ +class LocalSession : public RPCSession { + public: + // function overrides + PackedFuncHandle GetFunction(const std::string& name) final; + + void CallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& fencode_return) final; + + void CopyToRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_to, + DLDataType type_hint) final; + + void CopyFromRemote(void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t nbytes, + TVMContext ctx_from, + DLDataType type_hint) final; + + void FreeHandle(void* handle, int type_code) final; + + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) final; + + bool IsLocalSession() const final { + return true; + } + + protected: + /*! + * \brief Internal implementation of GetFunction. + * \param name The name of the function. + * \return The corresponding PackedFunc. + */ + virtual PackedFunc GetFunctionInternal(const std::string& name); +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 0e48e6fb2708..106230457fc5 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -18,63 +18,117 @@ */ /*! - * \file rpc_device_api.cc - * \brief RPC module. + * \file rpc_module.cc + * \brief RPC runtime module. */ #include +#include #include #include +#include "rpc_endpoint.h" #include "rpc_session.h" namespace tvm { namespace runtime { -// Wrapped remote function to packed func. -class RPCWrappedFunc { +/*! + * \brief A wrapped remote function as a PackedFunc. + */ +class RPCWrappedFunc : public Object { public: RPCWrappedFunc(void* handle, std::shared_ptr sess) : handle_(handle), sess_(sess) { - fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - WrapRemote(sess, args, rv); - }); } - void operator()(TVMArgs args, TVMRetValue *rv) const { - sess_->CallFunc(handle_, args, rv, UnwrapRemote, &fwrap_); + void operator()(TVMArgs args, TVMRetValue* rv) const { + std::vector values(args.values, args.values + args.size()); + std::vector type_codes(args.type_codes, args.type_codes + args.size()); + std::vector> temp_dltensors; + + // scan and check whether we need rewrite these arguments + // to their remote variant. + for (int i = 0; i < args.size(); ++i) { + int tcode = type_codes[i]; + + switch (tcode) { + case kTVMDLTensorHandle: + case kTVMNDArrayHandle: { + // Pass NDArray as DLTensor, NDArray and DLTensor + // are compatible to each other, just need to change the index. + type_codes[i] = kTVMDLTensorHandle; + // translate to a remote view of DLTensor + auto dptr = std::make_unique( + *static_cast(values[i].v_handle)); + dptr->ctx = RemoveSessMask(dptr->ctx); + dptr->data = static_cast(dptr->data)->data; + values[i].v_handle = dptr.get(); + temp_dltensors.emplace_back(std::move(dptr)); + break; + } + case kTVMContext: { + values[i].v_ctx = RemoveSessMask(values[i].v_ctx); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: { + values[i].v_handle = UnwrapRemoteValueToHandle( + TVMArgValue(values[i], tcode)); + break; + } + } + } + auto set_return = [this, rv](TVMArgs args) { + this->WrapRemoteReturnToValue(args, rv); + }; + sess_->CallFunc(handle_, values.data(), type_codes.data(), + args.size(), set_return); } + ~RPCWrappedFunc() { try { - sess_->CallRemote(RPCCode::kFreeFunc, handle_); + sess_->FreeHandle(handle_, kTVMPackedFuncHandle); } catch (const dmlc::Error& e) { // fault tolerance to remote close } } - static void WrapRemote(std::shared_ptr sess, - TVMArgs args, - TVMRetValue* rv); + private: + // remote function handle + void* handle_{nullptr}; + // pointer to the session. + std::shared_ptr sess_; - static void* UnwrapRemote(int rpc_sess_table_index, - const TVMArgValue& arg); + // unwrap a remote value to the underlying handle. + void* UnwrapRemoteValueToHandle(const TVMArgValue& arg) const; + // wrap a remote return via Set + void WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const; + + // remove a remote session mask + TVMContext RemoveSessMask(TVMContext ctx) const { + int dev_type = ctx.device_type; + CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1) + << "Can not pass in local context or context with a different remote session"; + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + return ctx; + } // deleter of RPC remote array static void RemoteNDArrayDeleter(Object* obj) { auto* ptr = static_cast(obj); RemoteSpace* space = static_cast(ptr->dl_tensor.data); - space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx); + space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle); delete space; delete ptr; } + // wrap return value as remote NDArray. - static NDArray WrapRemoteNDArray(std::shared_ptr sess, - DLTensor* tensor, - void* nd_handle) { + NDArray WrapRemoteNDArray(DLTensor* tensor, void* nd_handle) const { NDArray::Container* data = new NDArray::Container(); data->manager_ctx = nd_handle; data->SetDeleter(RemoteNDArrayDeleter); RemoteSpace* space = new RemoteSpace(); - space->sess = sess; + space->sess = sess_; space->data = tensor->data; data->dl_tensor.data = space; NDArray ret(GetObjectPtr(data)); @@ -89,18 +143,13 @@ class RPCWrappedFunc { data->dl_tensor.ctx.device_id = tensor->ctx.device_id; data->dl_tensor.ctx.device_type = static_cast( static_cast(tensor->ctx.device_type) + - kRPCSessMask * (sess->table_index() + 1)); + kRPCSessMask * (sess_->table_index() + 1)); // check strides. CHECK(tensor->strides == nullptr); // setup byteoffset data->dl_tensor.byte_offset = tensor->byte_offset; return ret; } - - private: - PackedFunc fwrap_; - void* handle_{nullptr}; - std::shared_ptr sess_; }; // RPC that represents a remote module session. @@ -109,10 +158,11 @@ class RPCModuleNode final : public ModuleNode { RPCModuleNode(void* module_handle, std::shared_ptr sess) : module_handle_(module_handle), sess_(sess) { } + ~RPCModuleNode() { if (module_handle_ != nullptr) { try { - sess_->CallRemote(RPCCode::kModuleFree, module_handle_); + sess_->FreeHandle(module_handle_, kTVMModuleHandle); } catch (const dmlc::Error& e) { // fault tolerance to remote close } @@ -127,31 +177,56 @@ class RPCModuleNode final : public ModuleNode { PackedFunc GetFunction( const std::string& name, const ObjectPtr& sptr_to_self) final { - RPCFuncHandle handle = GetFuncHandle(name); - return WrapRemote(handle); + if (module_handle_ == nullptr) { + return WrapRemoteFunc(sess_->GetFunction(name)); + } else { + InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction"); + return remote_mod_get_function_(GetRef(this), name, false); + } } std::string GetSource(const std::string& format) final { - if (module_handle_ != nullptr) { - std::string ret = sess_->CallRemote( - RPCCode::kModuleGetSource, module_handle_, format); - } + LOG(FATAL) << "GetSource for rpc Module is not supported"; return ""; } - std::shared_ptr& sess() { - return sess_; - } - PackedFunc GetTimeEvaluator(const std::string& name, TVMContext ctx, int number, int repeat, int min_repeat_ms) { - RPCFuncHandle handle = GetFuncHandle(name); - if (handle == nullptr) return PackedFunc(); - handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat, min_repeat_ms); - return WrapRemote(handle); + InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator"); + // Remove session mask because we pass ctx by parts. + int dev_type = ctx.device_type; + CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1) + << "ValueError: Need to pass the matched remote context to RPCModule.GetTimeEvaluator"; + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + + if (module_handle_ != nullptr) { + return remote_get_time_evaluator_( + GetRef(this), name, + static_cast(ctx.device_type), ctx.device_id, + number, repeat, min_repeat_ms); + } else { + return remote_get_time_evaluator_( + Optional(nullptr), name, + static_cast(ctx.device_type), ctx.device_id, + number, repeat, min_repeat_ms); + } + } + + Module LoadModule(std::string name) { + InitRemoteFunc(&remote_load_module_, "tvm.rpc.server.load_module"); + return remote_load_module_(name); + } + + void ImportModule(Module other) { + InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule"); + remote_import_module_(GetRef(this), other); + } + + const std::shared_ptr& sess() { + return sess_; } void* module_handle() const { @@ -159,7 +234,15 @@ class RPCModuleNode final : public ModuleNode { } private: - PackedFunc WrapRemote(RPCFuncHandle handle) { + template + void InitRemoteFunc(FType* func, const std::string& name) { + if (*func != nullptr) return; + RPCSession::PackedFuncHandle handle = sess_->GetFunction(name); + CHECK(handle != nullptr) << "Cannot found remote function " << name; + *func = WrapRemoteFunc(handle); + } + + PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) { if (handle == nullptr) return PackedFunc(); auto wf = std::make_shared(handle, sess_); return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { @@ -167,33 +250,30 @@ class RPCModuleNode final : public ModuleNode { }); } - RPCFuncHandle GetFuncHandle(const std::string& name) { - RPCFuncHandle handle = nullptr; - if (module_handle_ == nullptr) { - handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name); - } else { - handle = sess_->CallRemote( - RPCCode::kModuleGetFunc, module_handle_, name); - } - return handle; - } // The module handle void* module_handle_{nullptr}; // The local channel std::shared_ptr sess_; - // Wrap function to wrap remote module/function. - PackedFunc fwrap_; + // remote function to get time evaluator + TypedPackedFunc, std::string, int, int, int, int, int)> + remote_get_time_evaluator_; + // remote function getter for modules. + TypedPackedFunc remote_mod_get_function_; + // remote function getter for load module + TypedPackedFunc remote_load_module_; + // remote function getter for load module + TypedPackedFunc remote_import_module_; }; -void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index, - const TVMArgValue& arg) { + +void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const { if (arg.type_code() == kTVMModuleHandle) { Module mod = arg; std::string tkey = mod->type_key(); CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); - CHECK_EQ(rmod->sess()->table_index(), rpc_sess_table_index) + CHECK(rmod->sess() == sess_) << "ValueError: Cannot pass in module into a different remote session"; return rmod->module_handle(); } else { @@ -204,93 +284,173 @@ void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index, } } -void RPCWrappedFunc::WrapRemote(std::shared_ptr sess, - TVMArgs args, - TVMRetValue *rv) { - void* handle = args.values[0].v_handle; - int tcode = args.type_codes[0]; +void RPCWrappedFunc::WrapRemoteReturnToValue( + TVMArgs args, + TVMRetValue *rv) const { + int tcode = args[0]; - if (handle == nullptr) return; + if (tcode == kTVMNullptr) return; if (tcode == kTVMPackedFuncHandle) { - auto wf = std::make_shared(handle, sess); + CHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto wf = std::make_shared(handle, sess_); *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { - return wf->operator()(args, rv); - }); + return wf->operator()(args, rv); + }); } else if (tcode == kTVMModuleHandle) { - auto n = make_object(handle, sess); + CHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto n = make_object(handle, sess_); *rv = Module(n); } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) { - CHECK_EQ(args.size(), 2); - DLTensor* tensor = args[0]; - void* nd_handle = args[1]; - *rv = WrapRemoteNDArray(sess, tensor, nd_handle); + CHECK_EQ(args.size(), 3); + DLTensor* tensor = args[1]; + void* nd_handle = args[2]; + *rv = WrapRemoteNDArray(tensor, nd_handle); } else { - LOG(FATAL) << "Cannot wrap tcode=" << tcode; + CHECK_EQ(args.size(), 2); + *rv = args[1]; } } -Module CreateRPCModule(std::shared_ptr sess) { +Module CreateRPCSessionModule(std::shared_ptr sess) { auto n = make_object(nullptr, sess); + RPCSession::InsertToSessionTable(sess); return Module(n); } +std::shared_ptr RPCModuleGetSession(Module mod) { + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") + << "ValueError: Cannot pass a non-RPC module to remote"; + auto* rmod = static_cast(mod.operator->()); + return rmod->sess(); +} + +PackedFunc WrapTimeEvaluator(PackedFunc pf, + TVMContext ctx, + int number, + int repeat, + int min_repeat_ms) { + CHECK(pf != nullptr); + + if (static_cast(ctx.device_type) == static_cast(kDLMicroDev)) { + auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator"); + CHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled"; + return (*get_micro_time_evaluator)(pf, ctx, number, repeat); + } + + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) + mutable { + TVMRetValue temp; + std::ostringstream os; + // skip first time call, to activate lazy compilation components. + pf.CallPacked(args, &temp); + + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + + for (int i = 0; i < repeat; ++i) { + std::chrono::time_point< + std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; + double duration_ms = 0.0; + + do { + if (duration_ms > 0.0) { + number = static_cast( + std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random + } + + tbegin = std::chrono::high_resolution_clock::now(); + // start timing + for (int i = 0; i < number; ++i) { + pf.CallPacked(args, &temp); + } + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + tend = std::chrono::high_resolution_clock::now(); + + duration_ms = std::chrono::duration_cast > + (tend - tbegin).count() * 1000; + } while (duration_ms < min_repeat_ms); + + double speed = std::chrono::duration_cast >( + tend - tbegin).count() / number; + os.write(reinterpret_cast(&speed), sizeof(speed)); + } + + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + return PackedFunc(ftimer); +} + + TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; +.set_body_typed([](Optional opt_mod, + std::string name, + int device_type, + int device_id, + int number, + int repeat, + int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + if (opt_mod.defined()) { + Module m = opt_mod.value(); std::string tkey = m->type_key(); - TVMContext ctx; - ctx.device_type = static_cast(args[2].operator int()); - ctx.device_id = args[3]; if (tkey == "rpc") { - *rv = static_cast(m.operator->()) - ->GetTimeEvaluator(args[1], ctx, args[4], args[5], args[6]); + return static_cast(m.operator->()) + ->GetTimeEvaluator(name, ctx, number, repeat, min_repeat_ms); } else { - *rv = WrapTimeEvaluator( - m.GetFunction(args[1], false), ctx, args[4], args[5], args[6]); + return WrapTimeEvaluator( + m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); } - }); + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapTimeEvaluator( + *pf, ctx, number, repeat, min_repeat_ms); + } +}); -TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - auto& sess = static_cast(m.operator->())->sess(); - void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]); - auto n = make_object(mhandle, sess); - *rv = Module(n); - }); +// server function registration. +TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule") +.set_body_typed([](Module parent, Module child) { + parent->Import(child); +}); -TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module parent = args[0]; - Module child = args[1]; - CHECK(!std::strcmp(parent->type_key(), "rpc") && - !std::strcmp(child->type_key(), "rpc")); - auto* pmod = static_cast(parent.operator->()); - auto* cmod = static_cast(child.operator->()); - CHECK(pmod->sess().get() == cmod->sess().get()) - << "Import of remote module need to belong to same session."; - pmod->sess()->CallRemote(RPCCode::kModuleImport, - pmod->module_handle(), - cmod->module_handle()); - }); - -TVM_REGISTER_GLOBAL("rpc._ModuleHandle") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->module_handle(); - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") +.set_body_typed([](Module parent, std::string name, bool query_imports) { + return parent->GetFunction(name, query_imports); +}); + +// functions to access an RPC module. +TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule") +.set_body_typed([](Module sess, std::string name) { + std::string tkey = sess->type_key(); + CHECK_EQ(tkey, "rpc"); + return static_cast(sess.operator->())->LoadModule(name); +}); -TVM_REGISTER_GLOBAL("rpc._SessTableIndex") +TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule") +.set_body_typed([](Module parent, Module child) { + std::string tkey = parent->type_key(); + CHECK_EQ(tkey, "rpc"); + static_cast(parent.operator->())->ImportModule(child); +}); + +TVM_REGISTER_GLOBAL("rpc.SessTableIndex") .set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->sess()->table_index(); - }); + Module m = args[0]; + std::string tkey = m->type_key(); + CHECK_EQ(tkey, "rpc"); + *rv = static_cast(m.operator->())->sess()->table_index(); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc new file mode 100644 index 000000000000..376b8b5dc61e --- /dev/null +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_pipe_impl.cc + * \brief Pipe-based RPC channel. + */ +// Linux only for now, as linux is the most common usecase. +#if defined(__linux__) || defined(__ANDROID__) + +#include +#include +#include +#include + +#include +#include +#include + +#include "rpc_endpoint.h" +#include "rpc_local_session.h" +#include "../../support/pipe.h" + +namespace tvm { +namespace runtime { + +class PipeChannel final : public RPCChannel { + public: + explicit PipeChannel(int readfd, int writefd, pid_t child_pid) + : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) { + } + + ~PipeChannel() { + Close(); + } + + size_t Send(const void* data, size_t size) final { + ssize_t n = write(writefd_, data, size); + if (n == -1) { + LOG(FATAL) << "Pipe write error"; + } + return static_cast(n); + } + + size_t Recv(void* data, size_t size) final { + ssize_t n = read(readfd_, data, size); + if (n == -1) { + LOG(FATAL) << "Pipe read error"; + } + return static_cast(n); + } + + void Close() { + close(readfd_); + close(writefd_); + kill(child_pid_, SIGKILL); + } + + private: + int readfd_; + int writefd_; + pid_t child_pid_; +}; + + +Module CreatePipeClient(std::vector cmd) { + int parent2child[2]; + int child2parent[2]; + CHECK_EQ(pipe(parent2child), 0); + CHECK_EQ(pipe(child2parent), 0); + + int parent_read = child2parent[0]; + int parent_write = parent2child[1]; + int child_read = parent2child[0]; + int child_write = child2parent[1]; + + pid_t pid = fork(); + if (pid == 0) { + // child process + close(parent_read); + close(parent_write); + std::string sread_pipe = std::to_string(child_read); + std::string swrite_pipe = std::to_string(child_write); + std::vector argv; + for (auto& str : cmd) { + argv.push_back(dmlc::BeginPtr(str)); + } + argv.push_back(dmlc::BeginPtr(sread_pipe)); + argv.push_back(dmlc::BeginPtr(swrite_pipe)); + argv.push_back(nullptr); + execvp(argv[0], &argv[0]); + } + // parent process + close(child_read); + close(child_write); + + auto endpt = RPCEndpoint::Create( + std::unique_ptr( + new PipeChannel(parent_read, parent_write, pid)), + "pipe", "pipe"); + endpt->InitRemoteSession(TVMArgs(nullptr, nullptr, 0)); + return CreateRPCSessionModule(CreateClientSession(endpt)); +} + +TVM_REGISTER_GLOBAL("rpc.CreatePipeClient") +.set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector cmd; + for (int i = 0; i < args.size(); ++i) { + cmd.push_back(args[i].operator std::string()); + } + *rv = CreatePipeClient(cmd); +}); + + +} // namespace runtime +} // namespace tvm +#endif diff --git a/src/runtime/rpc/rpc_protocol.h b/src/runtime/rpc/rpc_protocol.h new file mode 100644 index 000000000000..6221bfbe1e82 --- /dev/null +++ b/src/runtime/rpc/rpc_protocol.h @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_procotol.h + * \brief Common header defining the communication code used in the RPC protocol. + */ +#ifndef TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ +#define TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ + +namespace tvm { +namespace runtime { + +/*! \brief The current RPC procotol version. */ +constexpr const char* kRPCProtocolVer = "0.7.0"; + +/*! \brief The RPC code */ +enum class RPCCode : int { + kNone, + kShutdown, + kInitServer, + kCallFunc, + kReturn, + kException, + kCopyFromRemote, + kCopyToRemote, + kCopyAck, + // The following are syscall code that can send over CallRemote + kSyscallCodeStart, + kGetGlobalFunc = kSyscallCodeStart, + kFreeHandle, + kDevSetDevice, + kDevGetAttr, + kDevAllocData, + kDevFreeData, + kDevStreamSync, + kCopyAmongRemote, +}; + +/*! + * \brief List of potential error status during rpc communication. + */ +enum class RPCServerStatus : int { + kSuccess = 0, + kInvalidTypeCodeObject, + kInvalidTypeCodeNDArray, + kInvalidDLTensorFieldStride, + kInvalidDLTensorFieldByteOffset, + kUnknownTypeCode, + kUnknownRPCCode, + kRPCCodeNotSupported, + kUnknownRPCSyscall, + kCheckError, + kReadError, + kWriteError, + kAllocError +}; + +/*! + * \brief Convert RPC server status to string. + * \param status The status. + * \return The corresponding string. + */ +inline const char* RPCServerStatusToString(RPCServerStatus status) { + switch (status) { + case RPCServerStatus::kSuccess: return "kSuccess"; + case RPCServerStatus::kInvalidTypeCodeObject: return "kInvalidTypeCodeObject"; + case RPCServerStatus::kInvalidTypeCodeNDArray: return "kInvalidTypeCodeNDArray"; + case RPCServerStatus::kInvalidDLTensorFieldStride: return "kInvalidDLTensorFieldStride"; + case RPCServerStatus::kInvalidDLTensorFieldByteOffset: { + return "kInvalidDLTensorFieldByteOffset"; + } + case RPCServerStatus::kUnknownTypeCode: return "kUnknownTypeCode"; + case RPCServerStatus::kUnknownRPCCode: return "kUnknownRPCCode"; + case RPCServerStatus::kRPCCodeNotSupported: return "RPCCodeNotSupported"; + case RPCServerStatus::kUnknownRPCSyscall: return "kUnknownRPCSyscall"; + case RPCServerStatus::kCheckError: return "kCheckError"; + case RPCServerStatus::kReadError: return "kReadError"; + case RPCServerStatus::kWriteError: return "kWriteError"; + case RPCServerStatus::kAllocError: return "kAllocError"; + default: return ""; + } +} + +/*! + * \brief Reference implementation of the communication protocol. + * + * \note The implementation is intentionally written via template + * so it can be used in a dependency free setting. + * + * \sa src/runtime/rpc/device/min_rpc_server.h + */ +struct RPCReference { + /*! + * \brief Auxiliary class to get the packed sequence. + * \tparam TChannel The channel to throw errror. + */ + template + struct PackedSeqNumBytesGetter { + public: + explicit PackedSeqNumBytesGetter(TChannel* channel) + : channel_(channel) {} + + template + void Write(const T& value) { + num_bytes_ += sizeof(T); + } + + template + void WriteArray(const T* value, size_t num) { + num_bytes_ += sizeof(T) * num; + } + + void ThrowError(RPCServerStatus status) { + channel_->ThrowError(status); + } + + uint64_t num_bytes() const { + return num_bytes_; + } + + private: + TChannel* channel_; + uint64_t num_bytes_{0}; + }; + + /*! + * \return the length of the str. + * \param str the string. + * \return The length. + */ + static uint64_t StrLength(const char* str) { + uint64_t len = 0; + while (str[len] != '\0') ++len; + return len; + } + + /*! + * \brief Get the total nbytes to be sent in the packed sequence. + * + * \param arg_values The values to be sent over. + * \param type_codes The type codes to be sent over. + * \param num_args Number of argument. + * \param client_mode Whether it is a client to server call. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + * \return The total number of bytes. + */ + template + static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, + const int* type_codes, + int num_args, + bool client_mode, + TChannel* channel) { + PackedSeqNumBytesGetter getter(channel); + SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter); + return getter.num_bytes(); + } + + /*! + * \brief Send packed argument sequnce to the other peer. + * + * This function serves as the foundational communication primitive between peers. + * + * TVMValue sequence encoding protocol(according to the type): + * + * - int/float/uint/bytes/str: Serialize all content. + * - DLTensor: send meta-data, send data handle as opaque handle(via uint64_t) + * - OpaqueHandle: send as uint64_t + * - ModuleHandle, PackedFuncHandle: send as uint64_t, + * The support to Module/PackedFuncHandle are reserved for arguments + * in the CallFunc from a client to server only. + * Note that we cannot simply take these argument out(as the handle) + * refers to a value on the remote(instead of local). + * + * \param arg_values The values to be sent over. + * \param type_codes The type codes to be sent over. + * \param num_args Number of argument. + * \param client_mode Whether it is a client to server call. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void SendPackedSeq(const TVMValue* arg_values, + const int* type_codes, + int num_args, + bool client_mode, + TChannel* channel) { + channel->Write(num_args); + channel->WriteArray(type_codes, num_args); + + // Argument packing. + for (int i = 0; i < num_args; ++i) { + int tcode = type_codes[i]; + TVMValue value = arg_values[i]; + switch (tcode) { + case kDLInt: + case kDLUInt: + case kDLFloat: { + channel->template Write(value.v_int64); + break; + } + case kTVMDataType: { + channel->Write(value.v_type); + // padding + int32_t padding = 0; + channel->template Write(padding); + break; + } + case kTVMContext: { + channel->Write(value.v_ctx); + break; + } + + case kTVMPackedFuncHandle: + case kTVMModuleHandle: { + if (!client_mode) { + channel->ThrowError(RPCServerStatus::kInvalidTypeCodeObject); + } + // always send handle in 64 bit. + uint64_t handle = reinterpret_cast(value.v_handle); + channel->Write(handle); + break; + } + case kTVMOpaqueHandle: { + // always send handle in 64 bit. + uint64_t handle = reinterpret_cast(value.v_handle); + channel->Write(handle); + break; + } + case kTVMNDArrayHandle: { + channel->ThrowError(RPCServerStatus::kInvalidTypeCodeNDArray); + break; + } + case kTVMDLTensorHandle: { + DLTensor* arr = static_cast(value.v_handle); + TVMContext ctx; + uint64_t data; + // When we return NDArray, we directly return + // the space and the context + // The client will be further wrapping + ctx = arr->ctx; + data = reinterpret_cast(arr->data); + channel->Write(data); + channel->Write(ctx); + channel->Write(arr->ndim); + channel->Write(arr->dtype); + channel->WriteArray(arr->shape, arr->ndim); + if (arr->strides != nullptr) { + channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride); + } + if (arr->byte_offset != 0) { + channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldByteOffset); + } + break; + } + case kTVMNullptr: break; + case kTVMStr: { + const char* s = value.v_str; + uint64_t len = StrLength(s); + channel->Write(len); + channel->WriteArray(s, len); + break; + } + case kTVMBytes: { + TVMByteArray* bytes = static_cast(arg_values[i].v_handle); + uint64_t len = bytes->size; + channel->Write(len); + channel->WriteArray(bytes->data, len); + break; + } + default: { + channel->ThrowError(RPCServerStatus::kUnknownTypeCode); + break; + } + } + } + } + + /*! + * \brief Receive packed seq from the channel. + * + * \param out_arg_values The values to be received. + * \param out_tcodes The type codes to be received. + * \param out_num_args Number of argument. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + * \note The temporary space are populated via an arena inside channel. + */ + template + static void RecvPackedSeq(TVMValue** out_values, + int** out_tcodes, + int* out_num_args, + TChannel* channel) { + // receive number of args + int num_args; + channel->Read(&num_args); + *out_num_args = num_args; + + if (num_args == 0) { + *out_values = nullptr; + *out_tcodes = nullptr; + return; + } + + TVMValue* values = channel->template ArenaAlloc(num_args); + int* tcodes = channel->template ArenaAlloc(num_args); + *out_values = values; + *out_tcodes = tcodes; + + // receive type code. + channel->ReadArray(tcodes, num_args); + + // receive arguments + for (int i = 0; i < num_args; ++i) { + auto& value = values[i]; + switch (tcodes[i]) { + case kDLInt: + case kDLUInt: + case kDLFloat: { + channel->template Read(&(value.v_int64)); + break; + } + case kTVMDataType: { + channel->Read(&(value.v_type)); + int32_t padding = 0; + channel->template Read(&padding); + break; + } + case kTVMContext: { + channel->Read(&(value.v_ctx)); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: + case kTVMOpaqueHandle: { + // always send handle in 64 bit. + uint64_t handle; + channel->Read(&handle); + value.v_handle = reinterpret_cast(handle); + break; + } + case kTVMNullptr: { + value.v_handle = nullptr; + break; + } + case kTVMStr: { + uint64_t len; + channel->Read(&len); + char* str = channel->template ArenaAlloc(len + 1); + str[len] = '\0'; + channel->ReadArray(str, len); + value.v_str = str; + break; + } + case kTVMBytes: { + uint64_t len; + channel->Read(&len); + TVMByteArray* arr = channel->template ArenaAlloc(1); + char* data = channel->template ArenaAlloc(len); + arr->size = len; + arr->data = data; + channel->ReadArray(data, len); + value.v_handle = arr; + break; + } + case kTVMDLTensorHandle: { + uint64_t handle; + channel->Read(&handle); + DLTensor* arr = channel->template ArenaAlloc(1); + DLTensor& tensor = *arr; + tensor.data = reinterpret_cast(handle); + channel->Read(&(tensor.ctx)); + channel->Read(&(tensor.ndim)); + channel->Read(&(tensor.dtype)); + tensor.shape = channel->template ArenaAlloc(tensor.ndim); + channel->ReadArray(tensor.shape, tensor.ndim); + tensor.strides = nullptr; + tensor.byte_offset = 0; + value.v_handle = arr; + break; + } + default: { + channel->ThrowError(RPCServerStatus::kUnknownTypeCode); + break; + } + } + } + } + + /*! + * \brief Return an exception packet. + * + * \param msg The error message. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnException(const char* msg, TChannel* channel) { + RPCCode code = RPCCode::kException; + int32_t num_args = 1; + int32_t tcode = kTVMStr; + uint64_t len = StrLength(msg); + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(num_args) + + sizeof(tcode) + + sizeof(len) + + len; + + channel->Write(packet_nbytes); + channel->Write(code); + channel->Write(num_args); + channel->Write(tcode); + channel->Write(len); + channel->WriteArray(msg, len); + } + + /*! + * \brief Return a normal packed sequence packet. + * + * \param msg The error message. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnPackedSeq(const TVMValue* arg_values, + const int* type_codes, + int num_args, + TChannel* channel) { + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = + sizeof(code) + + PackedSeqGetNumBytes( + arg_values, type_codes, num_args, false, channel); + + channel->Write(packet_nbytes); + channel->Write(code); + SendPackedSeq( + arg_values, type_codes, num_args, false, channel); + } + + /*! + * \brief Return a null(void) packet. + * + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnVoid(TChannel* channel) { + int32_t num_args = 1; + int32_t tcode = kTVMNullptr; + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = + sizeof(code) + + sizeof(num_args) + + sizeof(tcode); + + channel->Write(packet_nbytes); + channel->Write(code); + channel->Write(num_args); + channel->Write(tcode); + } +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index f6a7fb60b5f4..612ca418e812 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,8 @@ namespace tvm { namespace runtime { std::string RPCGetPath(const std::string& name) { - static const PackedFunc* f = + // do live lookup everytime as workpath can change. + const PackedFunc* f = runtime::Registry::Get("tvm.rpc.server.workpath"); CHECK(f != nullptr) << "require tvm.rpc.server.workpath"; return (*f)(name); diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index ae293abfacdd..dd0afa0145d2 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -21,817 +21,16 @@ * \file rpc_session.cc * \brief RPC session for remote function call. */ -#include #include #include -#include -#include -#include +#include #include -#include -#include -#include -#include -#include -#include #include "rpc_session.h" -#include "../object_internal.h" -#include "../../support/ring_buffer.h" -#include "../../support/socket.h" -#include "../micro/micro_session.h" namespace tvm { namespace runtime { -// Temp buffer for data array -struct RPCByteArrayBuffer { - TVMByteArray arr; - std::string data; -}; -// Temp buffer for data array -struct RPCDataArrayBuffer { - DLTensor tensor; - std::vector shape; -}; -/*! - * \brief Temporal argument buffer. - */ -struct RPCArgBuffer { - // The argument values - std::vector value; - // The type codes. - std::vector tcode; - // Temporal resources. - std::vector > temp_bytes; - // Temporal array - std::vector > temp_array; - // convert buffer as TVMArgs - TVMArgs AsTVMArgs() const { - return TVMArgs(value.data(), tcode.data(), static_cast(value.size())); - } -}; - -// Event handler for RPC events. -class RPCSession::EventHandler : public dmlc::Stream { - public: - EventHandler(support::RingBuffer* reader, - support::RingBuffer* writer, - int rpc_sess_table_index, - std::string name, - std::string* remote_key) - : reader_(reader), - writer_(writer), - rpc_sess_table_index_(rpc_sess_table_index), - name_(name), - remote_key_(remote_key) { - this->Clear(); - if (*remote_key == "%toinit") { - state_ = kInitHeader; - remote_key_->resize(0); - pending_request_bytes_ = sizeof(int32_t); - } - } - // Bytes needed to fulfill current request - size_t BytesNeeded() { - if (reader_->bytes_available() < pending_request_bytes_) { - return pending_request_bytes_ - reader_->bytes_available(); - } else { - return 0; - } - } - // Request number of bytes from reader. - void RequestBytes(size_t nbytes) { - pending_request_bytes_ += nbytes; - reader_->Reserve(pending_request_bytes_); - } - // Whether we are ready to handle next request. - bool Ready() { - return reader_->bytes_available() >= pending_request_bytes_; - } - bool CanCleanShutdown() const { - return state_ == kRecvCode; - } - void FinishCopyAck() { - this->SwitchToState(kRecvCode); - } - RPCCode HandleNextEvent(TVMRetValue* rv, - bool client_mode, - const PackedFunc* fwrap) { - std::swap(client_mode_, client_mode); - while (this->Ready()) { - switch (state_) { - case kInitHeader: HandleInitHeader(); break; - case kRecvCode: HandleRecvCode(); break; - case kRecvCallHandle: { - CHECK(this->Read(&call_handle_)); - this->SwitchToState(kRecvPackedSeqNumArgs); - break; - } - case kRecvPackedSeqNumArgs: { - CHECK(this->Read(&num_packed_args_)); - arg_buf_.reset(new RPCArgBuffer()); - arg_buf_->value.resize(num_packed_args_); - arg_buf_->tcode.resize(num_packed_args_); - this->SwitchToState(kRecvPackedSeqTypeCode); - break; - } - case kRecvPackedSeqTypeCode: { - if (num_packed_args_ != 0) { - this->ReadArray(arg_buf_->tcode.data(), num_packed_args_); - } - arg_index_ = 0; - arg_recv_stage_ = 0; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kRecvPackedSeqArg: { - this->HandleRecvPackedSeqArg(); - break; - } - case kDoCopyFromRemote: { - this->HandleCopyFromRemote(); - break; - } - case kDoCopyToRemote: { - this->HandleCopyToRemote(); - break; - } - case kReturnReceived: { - CHECK_GE(arg_buf_->value.size(), 1U); - - TVMArgValue argv = arg_buf_->AsTVMArgs()[0]; - if (argv.type_code() == kTVMPackedFuncHandle || - argv.type_code() == kTVMModuleHandle || - argv.type_code() == kTVMDLTensorHandle) { - CHECK(fwrap != nullptr) << "function/module wrapper not available"; - fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv); - } else { - CHECK_EQ(arg_buf_->value.size(), 1U); - *rv = argv; - } - arg_buf_.reset(); - this->SwitchToState(kRecvCode); - std::swap(client_mode_, client_mode); - return RPCCode::kReturn; - } - case kCopyAckReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kCopyAck; - } - case kShutdownReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kShutdown; - } - } - } - std::swap(client_mode_, client_mode); - return RPCCode::kNone; - } - // Reset and clear all states. - void Clear() { - state_ = kRecvCode; - pending_request_bytes_ = sizeof(RPCCode); - arg_recv_stage_ = 0; - arg_buf_.reset(); - } - // strip session on mask - TVMContext StripSessMask(TVMContext ctx) { - int dev_type = ctx.device_type; - CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1) - << "Can not pass in local context or context with a different remote session"; - ctx.device_type = static_cast(dev_type % kRPCSessMask); - return ctx; - } - // Send Packed sequence to writer. - // - // client_mode: whether we are in client mode. - // - // funwrap: auxiliary function to unwrap remote Object - // when it is provided, we need to unwrap objects. - // - // return_ndarray is a special flag to handle returning of ndarray - // In this case, we return the shape, context and data of the array, - // as well as a customized PackedFunc that handles deletion of - // the array in the remote. - void SendPackedSeq(const TVMValue* arg_values, - const int* type_codes, - int num_args, - bool client_mode, - FUnwrapRemoteObject funwrap = nullptr, - bool return_ndarray = false) { - std::swap(client_mode_, client_mode); - - this->Write(num_args); - for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - if (tcode == kTVMNDArrayHandle) tcode = kTVMDLTensorHandle; - this->Write(tcode); - } - - // Argument packing. - for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - TVMValue value = arg_values[i]; - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - this->Write(value.v_int64); - break; - } - case kTVMDataType: { - this->Write(value.v_type); - // padding - int32_t padding = 0; - this->Write(padding); - break; - } - case kTVMContext: { - value.v_ctx = StripSessMask(value.v_ctx); - this->Write(value.v_ctx); - break; - } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: { - // always send handle in 64 bit. - uint64_t handle; - // allow pass module as argument to remote. - if (funwrap != nullptr) { - void* remote_handle = (*funwrap)( - rpc_sess_table_index_, - runtime::TVMArgValue(value, tcode)); - handle = reinterpret_cast(remote_handle); - } else { - CHECK(!client_mode_) - << "Cannot directly pass remote object as argument"; - handle = reinterpret_cast(value.v_handle); - } - this->Write(handle); - break; - } - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle = reinterpret_cast(value.v_handle); - this->Write(handle); - break; - } - case kTVMNDArrayHandle: - case kTVMDLTensorHandle: { - DLTensor* arr = static_cast(value.v_handle); - TVMContext ctx; - uint64_t data; - if (!return_ndarray) { - // in the client mode - // ctx contains the remote table index - // the space is wrapped by an RemoteSpace - // that holds reference to the session. - ctx = StripSessMask(arr->ctx); - data = reinterpret_cast( - static_cast(arr->data)->data); - } else { - // When we return NDArray, we directly return - // the space and the context - // The client will be further wrapping - ctx = arr->ctx; - data = reinterpret_cast(arr->data); - } - this->Write(data); - this->Write(ctx); - this->Write(arr->ndim); - this->Write(arr->dtype); - this->WriteArray(arr->shape, arr->ndim); - CHECK(arr->strides == nullptr) - << "Do not support strided remote array"; - CHECK_EQ(arr->byte_offset, 0) - << "Do not support send byte offset"; - break; - } - case kTVMNullptr: break; - case kTVMStr: { - const char* s = value.v_str; - uint64_t len = strlen(s); - this->Write(len); - this->WriteArray(s, len); - break; - } - case kTVMBytes: { - TVMByteArray* bytes = static_cast(arg_values[i].v_handle); - uint64_t len = bytes->size; - this->Write(len); - this->WriteArray(bytes->data, len); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } - std::swap(client_mode_, client_mode); - } - - // Endian aware IO handling - using Stream::Read; - using Stream::Write; - using Stream::ReadArray; - using Stream::WriteArray; - - inline bool Read(RPCCode* code) { - int cdata; - if (!this->Read(&cdata)) return false; - *code = static_cast(cdata); - return true; - } - inline void Write(RPCCode code) { - int cdata = static_cast(code); - this->Write(cdata); - } - - protected: - enum State { - kInitHeader, - kRecvCode, - kRecvCallHandle, - kRecvPackedSeqNumArgs, - kRecvPackedSeqTypeCode, - kRecvPackedSeqArg, - kDoCopyFromRemote, - kDoCopyToRemote, - kReturnReceived, - kCopyAckReceived, - kShutdownReceived - }; - // Current state; - State state_; - // The RPCCode to be read. - RPCCode code_; - // Handle for the remote function call. - uint64_t call_handle_; - // Initialize remote header - bool init_header_step_{0}; - // Number of packed arguments. - int num_packed_args_; - // Current argument index. - int arg_index_; - // The stage of each argument receiver. - int arg_recv_stage_; - // Whether current handler is client or server mode. - bool client_mode_{false}; - // Argument buffer - std::unique_ptr arg_buf_; - // Temp byte buffer. - std::unique_ptr temp_bytes_; - // Temp array buffer. - std::unique_ptr temp_array_; - // Internal temporal data space. - std::string temp_data_; - // Temp variables for copy request state. - TVMContext copy_ctx_; - DLDataType copy_dtype_; - uint64_t copy_handle_, copy_offset_, copy_size_; - // State switcher - void SwitchToState(State state) { - // invariant - CHECK_EQ(pending_request_bytes_, 0U) - << "state=" << state; - state_ = state; - switch (state) { - case kInitHeader: { - LOG(FATAL) << "cannot switch to init header"; - break; - } - case kRecvCode: { - this->RequestBytes(sizeof(RPCCode)); - break; - } - case kRecvCallHandle: { - this->RequestBytes(sizeof(call_handle_)); - break; - } - case kRecvPackedSeqNumArgs: { - this->RequestBytes(sizeof(num_packed_args_)); - break; - } - case kRecvPackedSeqTypeCode: { - this->RequestBytes(sizeof(int) * num_packed_args_); - break; - } - case kRecvPackedSeqArg: { - CHECK_LE(arg_index_, num_packed_args_); - if (arg_index_ == num_packed_args_) { - // The function can change state_ again. - HandlePackedCall(); - } else { - RequestRecvPackedSeqArg(); - } - break; - } - case kDoCopyFromRemote: { - this->RequestBytes(sizeof(uint64_t) * 3); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - case kDoCopyToRemote: { - this->RequestBytes(sizeof(uint64_t) * 3); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - case kCopyAckReceived: - case kReturnReceived: - case kShutdownReceived: { - break; - } - } - } - // Requets bytes needed for next computation. - void RequestRecvPackedSeqArg() { - CHECK_EQ(arg_recv_stage_, 0); - int tcode = arg_buf_->tcode[arg_index_]; - static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant"); - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: - case kTVMDataType: - case kTVMOpaqueHandle: - case kTVMStr: - case kTVMBytes: - case kTVMModuleHandle: - case kTVMContext: { - this->RequestBytes(sizeof(TVMValue)); break; - } - case kTVMPackedFuncHandle: { - CHECK(client_mode_) - << "Only client can receive remote functions"; - this->RequestBytes(sizeof(TVMValue)); break; - } - case kTVMNullptr: break; - case kTVMDLTensorHandle: { - this->RequestBytes(sizeof(uint64_t)); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(int)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } - // Handler for packed sequence argument receive. - void HandleRecvPackedSeqArg() { - CHECK_LT(arg_index_, num_packed_args_); - int tcode = arg_buf_->tcode[arg_index_]; - TVMValue& value = arg_buf_->value[arg_index_]; - if (arg_recv_stage_ == 0) { - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - this->Read(&(value.v_int64)); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMDataType: { - this->Read(&(value.v_type)); - int32_t padding = 0; - this->Read(&padding); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMContext: { - this->Read(&(value.v_ctx)); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle; - this->Read(&handle); - value.v_handle = reinterpret_cast(handle); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMNullptr: { - value.v_handle = nullptr; - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMStr: - case kTVMBytes: { - uint64_t len; - this->Read(&len); - temp_bytes_.reset( new RPCByteArrayBuffer()); - temp_bytes_->data.resize(len); - arg_recv_stage_ = 1; - this->RequestBytes(len); - break; - } - case kTVMDLTensorHandle: { - temp_array_.reset(new RPCDataArrayBuffer()); - uint64_t handle; - this->Read(&handle); - DLTensor& tensor = temp_array_->tensor; - tensor.data = reinterpret_cast(handle); - this->Read(&(tensor.ctx)); - this->Read(&(tensor.ndim)); - this->Read(&(tensor.dtype)); - temp_array_->shape.resize(tensor.ndim); - tensor.shape = temp_array_->shape.data(); - arg_recv_stage_ = 1; - tensor.strides = nullptr; - tensor.byte_offset = 0; - this->RequestBytes(sizeof(int64_t) * tensor.ndim); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } else { - CHECK_EQ(arg_recv_stage_, 1); - if (tcode == kTVMStr || tcode == kTVMBytes) { - if (temp_bytes_->data.size() != 0) { - this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size()); - } - if (tcode == kTVMStr) { - value.v_str = temp_bytes_->data.c_str(); - } else { - temp_bytes_->arr.size = static_cast(temp_bytes_->data.size()); - temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data); - value.v_handle = &(temp_bytes_->arr); - } - arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_)); - } else { - CHECK_EQ(tcode, kTVMDLTensorHandle); - DLTensor& tensor = temp_array_->tensor; - this->ReadArray(tensor.shape, tensor.ndim); - value.v_handle = &tensor; - arg_buf_->temp_array.emplace_back(std::move(temp_array_)); - } - ++arg_index_; - arg_recv_stage_ = 0; - this->SwitchToState(kRecvPackedSeqArg); - } - } - // handler for initial header read - void HandleInitHeader() { - if (init_header_step_ == 0) { - int32_t len; - this->Read(&len); - remote_key_->resize(len); - init_header_step_ = 1; - this->RequestBytes(len); - return; - } else { - CHECK_EQ(init_header_step_, 1); - this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); - this->SwitchToState(kRecvCode); - } - } - // Handler for read code. - void HandleRecvCode() { - this->Read(&code_); - if (code_ > RPCCode::kSystemFuncStart) { - SwitchToState(kRecvPackedSeqNumArgs); - return; - } - // invariant. - CHECK_EQ(arg_recv_stage_, 0); - switch (code_) { - case RPCCode::kCallFunc: { - SwitchToState(kRecvCallHandle); - break; - } - case RPCCode::kException: - case RPCCode::kReturn: { - SwitchToState(kRecvPackedSeqNumArgs); - break; - } - case RPCCode::kCopyFromRemote: { - SwitchToState(kDoCopyFromRemote); - break; - } - case RPCCode::kCopyToRemote: { - SwitchToState(kDoCopyToRemote); - break; - } - case RPCCode::kShutdown: { - SwitchToState(kShutdownReceived); - break; - } - case RPCCode::kCopyAck: { - SwitchToState(kCopyAckReceived); - break; - } - default: LOG(FATAL) << "Unknown event " << static_cast(code_); - } - } - - void HandleCopyFromRemote() { - uint64_t handle, offset, num_bytes; - TVMContext ctx; - DLDataType type_hint; - this->Read(&handle); - this->Read(&offset); - this->Read(&num_bytes); - this->Read(&ctx); - this->Read(&type_hint); - size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; - - if (ctx.device_type == kDLCPU) { - RPCCode code = RPCCode::kCopyAck; - this->Write(code); - char* dptr = reinterpret_cast(handle) + offset; - if (!DMLC_IO_NO_ENDIAN_SWAP) { - temp_data_.resize(0); - temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes); - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); - this->WriteArray(temp_data_.data(), num_bytes); - } else { - this->WriteArray(dptr, num_bytes); - } - } else { - temp_data_.resize(num_bytes + 1); - try { - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - DeviceAPI::Get(ctx)->CopyDataFromTo( - reinterpret_cast(handle), offset, - dmlc::BeginPtr(temp_data_), 0, - num_bytes, ctx, cpu_ctx, type_hint, nullptr); - RPCCode code = RPCCode::kCopyAck; - this->Write(code); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); - } - this->WriteArray(&temp_data_[0], num_bytes); - } catch (const std::runtime_error &e) { - RPCCode code = RPCCode::kException; - this->Write(code); - TVMValue ret_value; - ret_value.v_str = e.what(); - int ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } - this->SwitchToState(kRecvCode); - } - - void HandleCopyToRemote() { - // use static variable to persist state. - // This only works if next stage is immediately after this. - if (arg_recv_stage_ == 0) { - CHECK(this->Read(©_handle_)); - CHECK(this->Read(©_offset_)); - CHECK(this->Read(©_size_)); - CHECK(this->Read(©_ctx_)); - CHECK(this->Read(©_dtype_)); - arg_recv_stage_ = 1; - CHECK_EQ(pending_request_bytes_, 0U); - this->RequestBytes(copy_size_); - } else { - CHECK_EQ(arg_recv_stage_, 1); - TVMValue ret_value; - ret_value.v_handle = nullptr; - int ret_tcode = kTVMNullptr; - RPCCode code = RPCCode::kReturn; - std::string errmsg; - - size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8; - if (copy_ctx_.device_type == kDLCPU) { - char* dptr = reinterpret_cast(copy_handle_) + copy_offset_; - this->ReadArray(dptr, copy_size_); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes); - } - } else { - temp_data_.resize(copy_size_ + 1); - this->ReadArray(&temp_data_[0], copy_size_); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes); - } - try { - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - DeviceAPI::Get(copy_ctx_)->CopyDataFromTo( - temp_data_.data(), 0, - reinterpret_cast(copy_handle_), copy_offset_, - copy_size_, cpu_ctx, copy_ctx_, copy_dtype_, nullptr); - } catch (const std::runtime_error &e) { - code = RPCCode::kException; - errmsg = e.what(); - ret_value.v_str = errmsg.c_str(); - ret_tcode = kTVMStr; - } - } - this->Write(code); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - arg_recv_stage_ = 0; - this->SwitchToState(kRecvCode); - } - } - // Handle for packed call. - void HandlePackedCall(); - - template - void CallHandler(F f) { - TVMRetValue rv; - TVMValue ret_value; - int ret_tcode; - try { - // Need to move out, in case f itself need to call RecvPackedSeq - // Which will override argbuf again. - std::unique_ptr args = std::move(arg_buf_); - f(args->AsTVMArgs(), &rv); - RPCCode code = RPCCode::kReturn; - this->Write(code); - if (rv.type_code() == kTVMStr) { - ret_value.v_str = rv.ptr()->c_str(); - ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMBytes) { - std::string* bytes = rv.ptr(); - TVMByteArray arr; - arr.data = bytes->c_str(); - arr.size = bytes->length(); - ret_value.v_handle = &arr; - ret_tcode = kTVMBytes; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMPackedFuncHandle || - rv.type_code() == kTVMModuleHandle) { - // always send handle in 64 bit. - CHECK(!client_mode_) - << "Only server can send function and module handle back."; - rv.MoveToCHost(&ret_value, &ret_tcode); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMNDArrayHandle) { - // always send handle in 64 bit. - CHECK(!client_mode_) - << "Only server can send NDArray back"; - // We follow a special protocol to return NDArray to client side - // The first pack value is the NDArray handle as DLTensor - // The second pack value is a customized deleter that deletes the NDArray. - TVMValue ret_value_pack[2]; - int ret_tcode_pack[2]; - rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]); - ret_value_pack[1].v_handle = ret_value_pack[0].v_handle; - ret_tcode_pack[1] = kTVMOpaqueHandle; - SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true); - } else { - ret_value = rv.value(); - ret_tcode = rv.type_code(); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } catch (const std::runtime_error& e) { - RPCCode code = RPCCode::kException; - this->Write(code); - ret_value.v_str = e.what(); - ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } - - private: - // Utility functions - // Internal read function, update pending_request_bytes_ - size_t Read(void* data, size_t size) final { - CHECK_LE(size, pending_request_bytes_); - reader_->Read(data, size); - pending_request_bytes_ -= size; - return size; - } - void Write(const void* data, size_t size) final { - writer_->Write(data, size); - } - // Number of pending bytes requests - size_t pending_request_bytes_; - // The ring buffer to read data from. - support::RingBuffer* reader_; - // The ringr buffer to write reply to. - support::RingBuffer* writer_; - // Session table index. - int rpc_sess_table_index_; - // Name of session. - std::string name_; - // remote key - std::string* remote_key_; -}; - -struct RPCSessTable { +class RPCSessTable { public: static constexpr int kMaxRPCSession = 32; // Get global singleton @@ -864,465 +63,13 @@ struct RPCSessTable { std::array, kMaxRPCSession> tbl_; }; -RPCCode RPCSession::HandleUntilReturnEvent( - TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap) { - RPCCode code = RPCCode::kCallFunc; - while (code != RPCCode::kReturn && - code != RPCCode::kShutdown && - code != RPCCode::kCopyAck) { - while (writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - } - size_t bytes_needed = handler_->BytesNeeded(); - if (bytes_needed != 0) { - size_t n = reader_.WriteWithCallback([this](void* data, size_t size) { - return channel_->Recv(data, size); - }, bytes_needed); - if (n == 0) { - if (handler_->CanCleanShutdown()) { - return RPCCode::kShutdown; - } else { - LOG(FATAL) << "Channel closes before we get neded bytes"; - } - } - } - code = handler_->HandleNextEvent(rv, client_mode, fwrap); - } - return code; -} - -void RPCSession::Init() { - // Event handler - handler_ = std::make_shared( - &reader_, &writer_, table_index_, name_, &remote_key_); - // Quick function to call remote. - call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); - RPCCode code = HandleUntilReturnEvent(rv, true, nullptr); - CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); - }); -} - -std::shared_ptr RPCSession::Create( - std::unique_ptr channel, - std::string name, - std::string remote_key) { - std::shared_ptr sess = std::make_shared(); - sess->channel_ = std::move(channel); - sess->name_ = std::move(name); - sess->remote_key_ = std::move(remote_key); - sess->table_index_ = RPCSessTable::Global()->Insert(sess); - sess->Init(); - return sess; -} - std::shared_ptr RPCSession::Get(int table_index) { return RPCSessTable::Global()->Get(table_index); } -RPCSession::~RPCSession() { - this->Shutdown(); -} - -void RPCSession::Shutdown() { - if (channel_ != nullptr) { - RPCCode code = RPCCode::kShutdown; - handler_->Write(code); - // flush all writing buffer to output channel. - try { - while (writer_.bytes_available() != 0) { - size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - if (n == 0) break; - } - } catch (const dmlc::Error& e) { - } - channel_.reset(nullptr); - } -} - -void RPCSession::ServerLoop() { - std::lock_guard lock(mutex_); - if (const auto* f = Registry::Get("tvm.rpc.server.start")) { - (*f)(); - } - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown); - if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) { - (*f)(); - } - channel_.reset(nullptr); -} - -int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) { - std::lock_guard lock(mutex_); - RPCCode code = RPCCode::kNone; - if (bytes.length() != 0) { - reader_.Write(bytes.c_str(), bytes.length()); - TVMRetValue rv; - code = handler_->HandleNextEvent(&rv, false, nullptr); - } - if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - } - CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); - if (code == RPCCode::kShutdown) return 0; - if (writer_.bytes_available() != 0) return 2; - return 1; -} - -// Get remote function with name -void RPCSession::CallFunc(void* h, - TVMArgs args, - TVMRetValue* rv, - FUnwrapRemoteObject funwrap, - const PackedFunc* fwrap) { - std::lock_guard lock(mutex_); - - RPCCode code = RPCCode::kCallFunc; - handler_->Write(code); - uint64_t handle = reinterpret_cast(h); - handler_->Write(handle); - handler_->SendPackedSeq( - args.values, args.type_codes, args.num_args, true, funwrap); - code = HandleUntilReturnEvent(rv, true, fwrap); - CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); -} - -void RPCSession::CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_to, - DLDataType type_hint) { - std::lock_guard lock(mutex_); - ctx_to = handler_->StripSessMask(ctx_to); - RPCCode code = RPCCode::kCopyToRemote; - handler_->Write(code); - uint64_t handle = reinterpret_cast(to); - handler_->Write(handle); - uint64_t offset = static_cast(to_offset); - handler_->Write(offset); - uint64_t size = static_cast(data_size); - handler_->Write(size); - handler_->Write(ctx_to); - handler_->Write(type_hint); - handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn); -} - -void RPCSession::CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_from, - DLDataType type_hint) { - std::lock_guard lock(mutex_); - ctx_from = handler_->StripSessMask(ctx_from); - RPCCode code = RPCCode::kCopyFromRemote; - handler_->Write(code); - uint64_t handle = reinterpret_cast(from); - handler_->Write(handle); - uint64_t offset = static_cast(from_offset); - handler_->Write(offset); - uint64_t size = static_cast(data_size); - handler_->Write(size); - handler_->Write(ctx_from); - handler_->Write(type_hint); - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck); - reader_.Reserve(data_size); - handler_->RequestBytes(data_size); - while (!handler_->Ready()) { - size_t bytes_needed = handler_->BytesNeeded(); - reader_.WriteWithCallback([this](void* data, size_t size) { - size_t n = channel_->Recv(data, size); - CHECK_NE(n, 0U) << "Channel closes before we get neded bytes"; - return n; - }, bytes_needed); - } - handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); - handler_->FinishCopyAck(); -} - -RPCFuncHandle RPCSession::GetTimeEvaluator( - RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat, int min_repeat_ms) { - return this->CallRemote( - RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat, min_repeat_ms); -} - -// Event handler functions -void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) { - std::string name = args[0]; - auto *fp = tvm::runtime::Registry::Get(name); - if (fp != nullptr) { - *rv = static_cast(new tvm::runtime::PackedFunc(*fp)); - } else { - *rv = nullptr; - } -} - -void RPCFreeFunc(TVMArgs args, TVMRetValue *rv) { - void* handle = args[0]; - delete static_cast(handle); -} - -void RPCDevSetDevice(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - DeviceAPI::Get(ctx)->SetDevice(ctx); -} - -void RPCDevGetAttr(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - DeviceAttrKind kind = static_cast(args[1].operator int()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPI::Get(ctx, true); - if (api != nullptr) { - api->GetAttr(ctx, kind, rv); - } else { - *rv = 0; - } - } else { - DeviceAPI::Get(ctx)->GetAttr( - ctx, static_cast(kind), rv); - } -} - -void RPCDevAllocData(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - uint64_t nbytes = args[1]; - uint64_t alignment = args[2]; - DLDataType type_hint = args[3]; - void* data = DeviceAPI::Get(ctx)->AllocDataSpace( - ctx, nbytes, alignment, type_hint); - *rv = data; -} - -void RPCDevFreeData(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - void* ptr = args[1]; - DeviceAPI::Get(ctx)->FreeDataSpace(ctx, ptr); -} - -void RPCDevStreamSync(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - TVMStreamHandle handle = args[1]; - DeviceAPI::Get(ctx)->StreamSync(ctx, handle); -} - -void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) { - void* from = args[0]; - uint64_t from_offset = args[1]; - void* to = args[2]; - uint64_t to_offset = args[3]; - uint64_t size = args[4]; - TVMContext ctx_from = args[5]; - TVMContext ctx_to = args[6]; - DLDataType type_hint = args[7]; - TVMStreamHandle stream = args[8]; - TVMContext ctx = ctx_from; - if (ctx.device_type == kDLCPU) { - ctx = ctx_to; - } else { - CHECK(ctx_to.device_type == kDLCPU || - ctx_to.device_type == ctx_from.device_type) - << "Can not copy across different ctx types directly"; - } - DeviceAPI::Get(ctx)->CopyDataFromTo( - from, from_offset, - to, to_offset, - size, ctx_from, ctx_to, type_hint, stream); -} - -void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { - static const PackedFunc* fsys_load_ = nullptr; - if (fsys_load_ == nullptr) { - fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module"); - CHECK(fsys_load_ != nullptr); - } - std::string file_name = args[0]; - TVMRetValue ret = (*fsys_load_)(file_name); - // pass via void* - TVMValue value; - int rcode; - ret.MoveToCHost(&value, &rcode); - CHECK_EQ(rcode, kTVMModuleHandle); - *rv = static_cast(value.v_handle); -} - -void RPCModuleImport(TVMArgs args, TVMRetValue *rv) { - void* pmod = args[0]; - void* cmod = args[1]; - ObjectInternal::GetModuleNode(pmod)->Import( - GetRef(ObjectInternal::GetModuleNode(cmod))); -} - -void RPCModuleFree(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - ObjectInternal::ObjectFree(mhandle); -} - -void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction( - args[1], false); - if (pf != nullptr) { - *rv = static_cast(new PackedFunc(pf)); - } else { - *rv = nullptr; - } -} - -void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - std::string fmt = args[1]; - *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt); -} - -void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { - void* handle = args[0]; - static_cast( - reinterpret_cast(handle))->DecRef(); -} - -void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { - PackedFunc *pf = static_cast(args[0].operator void*()); - void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3], args[4])); - delete pf; - *rv = fhandle; -} - -void RPCSession::EventHandler::HandlePackedCall() { - CHECK_EQ(pending_request_bytes_, 0U); - if (code_ == RPCCode::kReturn) { - state_ = kReturnReceived; return; - } - // reset state to clean init state - state_ = kRecvCode; - this->RequestBytes(sizeof(RPCCode)); - // Event handler sit at clean state at this point. - switch (code_) { - case RPCCode::kCallFunc: { - PackedFunc* pf = reinterpret_cast(call_handle_); - CallHandler([pf](TVMArgs args, TVMRetValue* rv) { - pf->CallPacked(args, rv); - }); - break; - } - case RPCCode::kException: { - CHECK_EQ(arg_buf_->value.size(), 1U); - CHECK_EQ(arg_buf_->tcode[0], kTVMStr); - std::ostringstream os; - os << "Except caught from RPC call: " << arg_buf_->value[0].v_str; - arg_buf_.reset(); - throw dmlc::Error(os.str()); - break; - } - // system functions - case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break; - case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break; - case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break; - case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break; - case RPCCode::kDevGetAttr: CallHandler(RPCDevGetAttr); break; - case RPCCode::kDevAllocData: CallHandler(RPCDevAllocData); break; - case RPCCode::kDevFreeData: CallHandler(RPCDevFreeData); break; - case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break; - case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break; - case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break; - case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break; - case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break; - case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break; - case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break; - case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break; - default: LOG(FATAL) << "Unknown event " << static_cast(code_); - } - CHECK_EQ(state_, kRecvCode); -} - -PackedFunc WrapTimeEvaluator(PackedFunc pf, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms) { - if (static_cast(ctx.device_type) == static_cast(kDLMicroDev)) { - auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator"); - CHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled"; - return (*get_micro_time_evaluator)(pf, ctx, number, repeat); - } - - auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) mutable { - TVMRetValue temp; - std::ostringstream os; - // skip first time call, to activate lazy compilation components. - pf.CallPacked(args, &temp); - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - - for (int i = 0; i < repeat; ++i) { - std::chrono::time_point< - std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; - double duration_ms = 0.0; - - do { - if (duration_ms > 0.0) { - number = static_cast( - std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random - } - - tbegin = std::chrono::high_resolution_clock::now(); - // start timing - for (int i = 0; i < number; ++i) { - pf.CallPacked(args, &temp); - } - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - tend = std::chrono::high_resolution_clock::now(); - - duration_ms = std::chrono::duration_cast > - (tend - tbegin).count() * 1000; - } while (duration_ms < min_repeat_ms); - - double speed = std::chrono::duration_cast >( - tend - tbegin).count() / number; - os.write(reinterpret_cast(&speed), sizeof(speed)); - } - std::string blob = os.str(); - TVMByteArray arr; - arr.size = blob.length(); - arr.data = blob.data(); - // return the time. - *rv = arr; - }; - return PackedFunc(ftimer); -} - -size_t CallbackChannel::Send(const void* data, size_t size) { - TVMByteArray bytes; - bytes.data = static_cast(data); - bytes.size = size; - int64_t n = fsend_(bytes); - if (n == -1) { - support::Socket::Error("CallbackChannel::Send"); - } - return static_cast(n); -} - -size_t CallbackChannel::Recv(void* data, size_t size) { - TVMRetValue ret = frecv_(size); - - if (ret.type_code() != kTVMBytes) { - support::Socket::Error("CallbackChannel::Recv"); - } - std::string* bytes = ret.ptr(); - memcpy(static_cast(data), bytes->c_str(), bytes->length()); - return bytes->length(); +void RPCSession::InsertToSessionTable(std::shared_ptr sess) { + CHECK_EQ(sess->table_index_, 0); + sess->table_index_ = RPCSessTable::Global()->Insert(sess); } } // namespace runtime diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index db63be4be74d..e7e4433b1867 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -24,230 +24,178 @@ #ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_SESSION_H_ + #include #include -#include -#include +#include #include -#include -#include "../../support/ring_buffer.h" +#include namespace tvm { namespace runtime { -// Magic header for RPC data plane -const int kRPCMagic = 0xff271; -// magic header for RPC tracker(control plane) -const int kRPCTrackerMagic = 0x2f271; -// sucess response -const int kRPCSuccess = kRPCMagic + 0; -// cannot found matched key in server -const int kRPCMismatch = kRPCMagic + 2; - -/*! \brief Enumeration code for the RPC tracker */ -enum class TrackerCode : int { - kFail = -1, - kSuccess = 0, - kPing = 1, - kStop = 2, - kPut = 3, - kRequest = 4, - kUpdateInfo = 5, - kSummary = 6, - kGetPendingMatchKeys = 7 -}; -/*! \brief The remote functio handle */ -using RPCFuncHandle = void*; - -struct RPCArgBuffer; - -/*! \brief The RPC code */ -enum class RPCCode : int { - kNone, - kCallFunc, - kReturn, - kException, - kShutdown, - kCopyFromRemote, - kCopyToRemote, - kCopyAck, - // The following are code that can send over CallRemote - kSystemFuncStart, - kGetGlobalFunc, - kGetTimeEvaluator, - kFreeFunc, - kDevSetDevice, - kDevGetAttr, - kDevAllocData, - kDevFreeData, - kDevStreamSync, - kCopyAmongRemote, - kModuleLoad, - kModuleImport, - kModuleFree, - kModuleGetFunc, - kModuleGetSource, - kNDArrayFree -}; - /*! - * \brief Function that unwraps a remote object to its handle. - * \param rpc_sess_table_index RPC session table index for validation. - * \param obj Handle to the object argument. - * \return The corresponding handle. - */ -typedef void* (*FUnwrapRemoteObject)( - int rpc_sess_table_index, - const TVMArgValue& obj); - -/*! - * \brief Abstract channel interface used to create RPCSession. + * \brief The interface of all remote RPC sessions. + * + * It contains all the necessary interface to implement + * remote call and resource management. + * + * The interface is designed to allow easy proxy-chaining + * by forward requests to another RPCSession. */ -class RPCChannel { +class RPCSession { public: - /*! \brief virtual destructor */ - virtual ~RPCChannel() {} - /*! - * \brief Send data over to the channel. - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes sent. - */ - virtual size_t Send(const void* data, size_t size) = 0; + /*! \brief PackedFunc Handle in the remote. */ + using PackedFuncHandle = void*; + + /*! \brief Module handle in the remote. */ + using ModuleHandle = void*; + + /*! \brief NDArray handle in the remote. */ + using NDArrayHandle = void*; + /*! - * \brief Recv data from channel. + * \brief Callback to send an encoded return values via encode_args. + * + * \param encode_args The arguments that we can encode the return values into. + * \param ret_tcode The actual remote type code of the return value. * - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes received. + * Encoding convention (as list of arguments): + * - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention. + * - PackedFunc/Module: [tcode: int, handle: void*] + * - NDArray: [tcode: int, meta: DLTensor*, nd_handle: void*] + * DLTensor* contains the meta-data as well as handle into the remote data. + * nd_handle can be used for deletion. */ - virtual size_t Recv(void* data, size_t size) = 0; -}; + using FEncodeReturn = std::function; + + /*! \brief Destructor.*/ + virtual ~RPCSession() {} -// Bidirectional Communication Session of PackedRPC -class RPCSession { - public: - /*! \brief virtual destructor */ - ~RPCSession(); /*! - * \brief The server loop that server runs to handle RPC calls. + * \brief Get function in the session. + * \param name The name of the function. + * \return The function handle. */ - void ServerLoop(); + virtual PackedFuncHandle GetFunction(const std::string& name) = 0; + /*! - * \brief Message handling function for event driven server. - * Called when the server receives a message. - * Event driven handler will never call recv on the channel - * and always relies on the ServerEventHandler. - * to receive the data. + * \brief Call into a remote Packed function. * - * \param in_bytes The incoming bytes. - * \param event_flag 1: read_available, 2: write_avaiable. - * \return State flag. - * 1: continue running, no need to write, - * 2: need to write - * 0: shutdown - */ - int ServerEventHandler(const std::string& in_bytes, - int event_flag); - /*! - * \brief Call into remote function - * \param handle The function handle - * \param args The arguments - * \param rv The return value. - * \param funpwrap Function that takes a remote object and returns the raw handle. - * \param fwrap Wrapper function to turn Function/Module handle into real return. + * Calling convention: + * + * - type_code is follows the PackedFunc convention. + * - int/float/string/bytes follows the PackedFunc convention, all data are local. + * - PackedFunc/Module and future remote objects: pass remote handle instead. + * - NDArray/DLTensor: pass a DLTensor pointer, the data field of DLTensor + * points to a remote data handle returned by the Device API. + * The meta-data of the DLTensor sits on local. + * + * The caller populates the arguments and manages these arguments. + * + * The callee can change the content of arg_values and arg_type_codes + * if they want to do inplace modify and forward. + * + * The callee need to store the return value into ret_value. + * - PackedFunc/Module are stored as void* + * - NDArray is stored as local NDArray, whose data field is a remote handle. + * Notably the NDArray's deleter won't delete remote handle. + * It is up to the user of the RPCSession to such wrapping. + * - In short, remote handles are "moved" as return values + * and the callee needs to explicitly manage them by calling + * the deleter functions when they are no longer needed. + * + * \param func The function handle. + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * \param fencode_return The function to set the return value, + * if not called, return value is null. */ - void CallFunc(RPCFuncHandle handle, - TVMArgs args, - TVMRetValue* rv, - FUnwrapRemoteObject funwrap, - const PackedFunc* fwrap); + virtual void CallFunc(PackedFuncHandle func, + const TVMValue* arg_values, + const int* arg_type_codes, + int num_args, + const FEncodeReturn& fencode_return) = 0; + /*! * \brief Copy bytes into remote array content. - * \param from The source host data. - * \param from_offset The byte offeset in the from. - * \param to The target array. - * \param to_offset The byte offset in the to. + * \param local_from The source host data. + * \param local_from_offset The byte offeset in the from. + * \param remote_to The target array. + * \param remote_to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. - * \param ctx_to The target context. + * \param remote_ctx_to The target context. * \param type_hint Hint of content data type. */ - void CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_to, - DLDataType type_hint); + virtual void CopyToRemote(void* local_from, + size_t local_from_offset, + void* remote_to, + size_t remote_to_offset, + size_t nbytes, + TVMContext remote_ctx_to, + DLDataType type_hint) = 0; /*! * \brief Copy bytes from remote array content. - * \param from The source host data. - * \param from_offset The byte offeset in the from. + * \param remote_from The source host data. + * \param remote_from_offset The byte offeset in the from. * \param to The target array. * \param to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. - * \param ctx_from The source context. + * \param remote_ctx_from The source context in the remote. * \param type_hint Hint of content data type. */ - void CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_from, - DLDataType type_hint); + virtual void CopyFromRemote(void* remote_from, + size_t remote_from_offset, + void* local_to, + size_t local_to_offset, + size_t nbytes, + TVMContext remote_ctx_from, + DLDataType type_hint) = 0; + + /*! + * \brief Free a remote function. + * \param handle The remote handle, can be NDArray/PackedFunc/Module + * \param type_code The type code of the underlying type. + */ + virtual void FreeHandle(void* handle, int type_code) = 0; + /*! - * \brief Get a remote timer function on ctx. - * This function consumes fhandle, caller should not call Free on fhandle. + * \brief Get device API that represents the remote + * actions that can be taken on the remote. * - * \param fhandle The function handle. - * \param ctx The ctx to run measurement on. - * \param number The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. - * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. - * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, - the `number` parameter will be automatically increased. - * \return A remote timer function + * The caller can then call into the Alloc/Free functions + * to allocate free spaces and taking the pointer as the handle. + * + * The device API is guaranteed to be alive during the + * lifetime of the Session. + * + * \param ctx The remote context. + * \param allow_missing Whether can we return nullptr if it is not available. + * + * \return The device API. */ - RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms); + virtual DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) = 0; + /*! - * \brief Call a remote defined system function with arguments. - * \param fcode The function code. - * \param args The arguments - * \return The returned remote value. + * \brief Whether the session is a local session and we can directly + * the data handle returned by the session and treat it as pointer + * to the local memory. + * + * This information is useful for RPC server to directly copy into the + * local memory without creating a temporary buffer. + * + * \return Whether it is a local session. */ - template - inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args); + virtual bool IsLocalSession() const = 0; + /*! * \return The session table index of the session. */ int table_index() const { return table_index_; } - /*! - * \brief Create a RPC session with given channel. - * \param channel The communication channel. - * \param name The local name of the session, used for debug - * \param remote_key The remote key of the session - * if remote_key equals "%toinit", we need to re-intialize - * it by event handler. - */ - static std::shared_ptr Create( - std::unique_ptr channel, - std::string name, - std::string remote_key); + /*! * \brief Try get session from the global session table by table index. * \param table_index The table index of the session. @@ -256,62 +204,25 @@ class RPCSession { static std::shared_ptr Get(int table_index); private: - class EventHandler; - // Handle events until receives a return - // Also flushes channels so that the function advances. - RPCCode HandleUntilReturnEvent( - TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap); - // Initalization - void Init(); - // Shutdown - void Shutdown(); - // Internal channel. - std::unique_ptr channel_; - // Internal mutex - std::recursive_mutex mutex_; - // Internal ring buffer. - support::RingBuffer reader_, writer_; - // Event handler. - std::shared_ptr handler_; - // call remote with specified function code. - PackedFunc call_remote_; - // The index of this session in RPC session table. + /*! \brief index of this session in RPC session table */ int table_index_{0}; - // The name of the session. - std::string name_; - // The remote key - std::string remote_key_; + /*! \brief Insert the current session to the session table.*/ + static void InsertToSessionTable(std::shared_ptr sess); + // friend declaration + friend Module CreateRPCSessionModule(std::shared_ptr sess); }; /*! - * \brief RPC channel which callback - * frontend (Python/Java/etc.)'s send & recv function + * \brief Remote space handle cell used by the RPC runtime API. + * + * When we allocate space using a rpc context, the data pointer + * points to an allocated RemoteSpace. */ -class CallbackChannel final : public RPCChannel { - public: - explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) - : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} - - ~CallbackChannel() {} - /*! - * \brief Send data over to the channel. - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes sent. - */ - size_t Send(const void* data, size_t size) final; - /*! - * \brief Recv data from channel. - * - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes received. - */ - size_t Recv(void* data, size_t size) final; - - private: - PackedFunc fsend_; - PackedFunc frecv_; +struct RemoteSpace { + /*! \brief The remote data handle. */ + void* data; + /*! \brief Reference to the underlying RPC session. */ + std::shared_ptr sess; }; /*! @@ -319,18 +230,18 @@ class CallbackChannel final : public RPCChannel { * \param f The function argument. * \param ctx The context. * \param number The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. + * We call these runs as one `repeat` of measurement. * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. + * In total, the function will be invoked (1 + number x repeat) times, + * where the first one is warm up and will be discarded. + * The returned result contains `repeat` costs, + * each of which is an average of `number` costs. * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, - the `number` parameter will be automatically increased. + * By default, one `repeat` contains `number` runs. If this parameter is set, + * the parameters `number` will be dynamically adjusted to meet the + * minimum duration requirement of one `repeat`. + * i.e., When the run time of one `repeat` falls below this time, + * the `number` parameter will be automatically increased. * \return f_timer A timer function. */ PackedFunc WrapTimeEvaluator(PackedFunc f, @@ -344,21 +255,15 @@ PackedFunc WrapTimeEvaluator(PackedFunc f, * \param sess The RPC session of the global module. * \return The created module. */ -Module CreateRPCModule(std::shared_ptr sess); +Module CreateRPCSessionModule(std::shared_ptr sess); -// Remote space pointer. -struct RemoteSpace { - void* data; - std::shared_ptr sess; -}; +/*! + * \brief Get the session module from a RPC session Module. + * \param mod The input module(must be an RPCModule). + * \return The internal RPCSession. + */ +std::shared_ptr RPCModuleGetSession(Module mod); -// implementation of inline functions -template -inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) { - std::lock_guard lock(mutex_); - writer_.Write(&code, sizeof(code)); - return call_remote_(std::forward(args)...); -} } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_RPC_RPC_SESSION_H_ diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 642fbb8ec7f2..f3a30dd6c485 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -22,8 +22,11 @@ * \brief Socket based RPC implementation. */ #include +#include #include +#include "rpc_endpoint.h" #include "rpc_session.h" +#include "rpc_local_session.h" #include "../../support/socket.h" namespace tvm { @@ -61,8 +64,8 @@ class SockChannel final : public RPCChannel { support::TCPSocket sock_; }; -std::shared_ptr -RPCConnect(std::string url, int port, std::string key) { +std::shared_ptr +RPCConnect(std::string url, int port, std::string key, TVMArgs init_seq) { support::TCPSocket sock; support::SockAddr addr(url.c_str(), port); sock.Create(addr.ss_family()); @@ -96,42 +99,56 @@ RPCConnect(std::string url, int port, std::string key) { remote_key.resize(keylen); CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); } - return RPCSession::Create( + auto endpt = RPCEndpoint::Create( std::unique_ptr(new SockChannel(sock)), key, remote_key); + endpt->InitRemoteSession(init_seq); + return endpt; } -Module RPCClientConnect(std::string url, int port, std::string key) { - return CreateRPCModule(RPCConnect(url, port, "client:" + key)); +Module RPCClientConnect(std::string url, + int port, + std::string key, + TVMArgs init_seq) { + auto endpt = RPCConnect(url, port, "client:" + key, init_seq); + return CreateRPCSessionModule(CreateClientSession(endpt)); } // TVM_DLL needed for MSVC TVM_DLL void RPCServerLoop(int sockfd) { support::TCPSocket sock( static_cast(sockfd)); - RPCSession::Create( + RPCEndpoint::Create( std::unique_ptr(new SockChannel(sock)), "SockServerLoop", "")->ServerLoop(); } -void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) { - RPCSession::Create(std::unique_ptr( - new CallbackChannel(fsend, frecv)), +void RPCServerLoop(PackedFunc fsend, + PackedFunc frecv) { + RPCEndpoint::Create( + std::unique_ptr(new CallbackChannel(fsend, frecv)), "SockServerLoop", "")->ServerLoop(); } -TVM_REGISTER_GLOBAL("rpc._Connect") -.set_body_typed(RPCClientConnect); +TVM_REGISTER_GLOBAL("rpc.Connect") +.set_body([](TVMArgs args, TVMRetValue *rv) { + std::string url = args[0]; + int port = args[1]; + std::string key = args[2]; + *rv = RPCClientConnect( + url, port, key, + TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); +}); -TVM_REGISTER_GLOBAL("rpc._ServerLoop") +TVM_REGISTER_GLOBAL("rpc.ServerLoop") .set_body([](TVMArgs args, TVMRetValue* rv) { - if (args.size() == 1) { - RPCServerLoop(args[0]); - } else { - CHECK_EQ(args.size(), 2); - RPCServerLoop( - args[0].operator tvm::runtime::PackedFunc(), - args[1].operator tvm::runtime::PackedFunc()); - } - }); + if (args[0].type_code() == kDLInt) { + RPCServerLoop(args[0]); + } else { + RPCServerLoop( + args[0].operator tvm::runtime::PackedFunc(), + args[1].operator tvm::runtime::PackedFunc()); + } +}); + } // namespace runtime } // namespace tvm diff --git a/src/support/arena.h b/src/support/arena.h index 744ff4f12188..b06227680808 100644 --- a/src/support/arena.h +++ b/src/support/arena.h @@ -26,42 +26,114 @@ #ifndef TVM_SUPPORT_ARENA_H_ #define TVM_SUPPORT_ARENA_H_ +#ifndef TVM_ARENA_HAS_DESTRUCTOR +#define TVM_ARENA_HAS_DESTRUCTOR 1 +#endif + +#include #include #include + namespace tvm { namespace support { -const constexpr int kArenaPageSize = 16 << 10; +/*! + * \brief An arena page header. + */ +struct ArenaPageHeader { + /*! \brief points to the next page. */ + ArenaPageHeader* next; + /*! + * \brief Total size of the page. + */ + size_t size; + /*! \brief memory allocator offset inside page. */ + size_t offset; +}; + +/*! + * \brief Simple page allocator that uses new and delete. + */ +class SimplePageAllocator { + public: + /*! + * \brief Allocate a new page. + * \param min_size Minimum size of the page. + * \return The allocated page. + * \note This function can return a bigger page to meet the min_size requirement. + */ + ArenaPageHeader* allocate(size_t min_size) { + size_t npages = ((min_size + kPageSize - 1) / kPageSize); + ArenaPageHeader* header = reinterpret_cast(new Page[npages]); + header->size = npages * kPageSize; + header->offset = sizeof(ArenaPageHeader); + return header; + } + /*! + * \brief De-allocate an allocate page. + * \param page The page to be de-allocated. + */ + void deallocate(ArenaPageHeader* page) { + delete [] reinterpret_cast(page); + } + + static const constexpr int kPageSize = 16 << 10; + static const constexpr int kPageAlign = 1024; + + private: + // page size 16 KB + // The page data type; + using Page = std::aligned_storage::type; +}; /*! * \brief Arena allocator that allocates memory from continuous * chunk and frees them all only during destruction. */ -class Arena { +template +class GenericArena { public: - Arena() { + explicit GenericArena(PageAllocator alloc = PageAllocator()) + : alloc_(alloc) { // eagerly allocate the first page. - head_ = reinterpret_cast(new Page()); + head_ = tail_ = alloc_.allocate(1); head_->next = nullptr; - head_->ptr = sizeof(PageHeader); } - ~Arena() { - // delete all the allocated pages. - while (head_ != nullptr) { - Page* page = reinterpret_cast(head_); - head_ = head_->next; - delete page; - } + +#if TVM_ARENA_HAS_DESTRUCTOR + ~GenericArena() { + this->FreeAll(); + } +#endif + + /*! \brief Free all pages. */ + void FreeAll() { + FreePageList(&head_); + FreePageList(&free_list_); + } + /*! \brief Recycle all the pages in the arena */ + void RecycleAll() { + // put all the current list to the free list. + tail_->next = free_list_; + // allocate the first in the free list to head + free_list_ = head_->next; + head_->next = nullptr; + // Reset the head. + head_->offset = sizeof(ArenaPageHeader); + tail_ = head_; } /*! * \brief Allocate a space from Arena for type T * \param T the data type to be allocated + * \param count Numberof elements * \note The space of T is not initialized. */ template - T* allocate_() { - return static_cast(Alloc(sizeof(T), alignof(T))); + T* allocate_(int count = 1) { + static_assert(PageAllocator::kPageAlign % alignof(T) == 0, + "To large alignment"); + return static_cast(Alloc(sizeof(T) * count, alignof(T))); } /*! * \brief Create a new instance of type T. @@ -82,25 +154,21 @@ class Arena { } private: - // page size 16 KB - // The page data type; - using Page = std::aligned_storage::type; - /*! \brief Page header */ - struct PageHeader { - /*! \brief points to the next page */ - PageHeader* next; - /*! \brief memory allocator ptr inside page */ - size_t ptr; - }; - /* \brief The page header */ - PageHeader* head_{nullptr}; + /*! \brief internal page allocator. */ + PageAllocator alloc_; + /* \brief The the head of the allocated list. */ + ArenaPageHeader* head_{nullptr}; + /*! \brief The tail of the allocated list. */ + ArenaPageHeader* tail_{nullptr}; + /* \brief List of free pages. */ + ArenaPageHeader* free_list_{nullptr}; /*! * \brief Align ptr by upper bound. - * \param ptr The pointer value. + * \param offset The offset value. * \param align The alignment requirement. */ - size_t UpperAlign(size_t ptr, size_t align) { - return ptr + (align - (ptr % align)) % align; + size_t UpperAlign(size_t offset, size_t align) { + return offset + (align - (offset % align)) % align; } /*! * \brief Internal aligned alloc function. @@ -108,22 +176,41 @@ class Arena { * \param align The alignment requirement. */ void* Alloc(size_t size, size_t align) { - size_t ptr = UpperAlign(head_->ptr, align); - if (ptr + size <= kArenaPageSize) { - head_->ptr = ptr + size; - return reinterpret_cast(head_) + ptr; + size_t offset = UpperAlign(head_->offset, align); + if (offset + size <= head_->size) { + head_->offset = offset + size; + return reinterpret_cast(head_) + offset; } else { - PageHeader* new_head = reinterpret_cast(new Page()); + ArenaPageHeader* new_head; + offset = UpperAlign(sizeof(ArenaPageHeader), align); + if (free_list_ != nullptr && offset + size <= free_list_-> size) { + new_head = free_list_; + free_list_ = free_list_->next; + } else { + new_head = alloc_.allocate(offset + size); + } new_head->next = head_; - ptr = UpperAlign(sizeof(PageHeader), align); - CHECK_LE(ptr + size, kArenaPageSize); - new_head->ptr = ptr + size; + new_head->offset = offset + size; head_ = new_head; - return reinterpret_cast(head_) + ptr; + return reinterpret_cast(head_) + offset; + } + } + /*! + * \brief Free all the pages in the list. + * \param ptr The head ptr. + */ + void FreePageList(ArenaPageHeader** ptr) { + // delete all the allocated pages. + while (ptr[0] != nullptr) { + ArenaPageHeader* temp = ptr[0]; + ptr[0] = ptr[0]->next; + alloc_.deallocate(temp); } } }; +using Arena = GenericArena; + /*! * \brief Link list node * \tparam T the content data type diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index e6e3b04ec7a9..d3227adb1f9d 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -49,8 +49,12 @@ class RingBuffer { return ring_.size(); } /*! - * Reserve capacity to be at least n. - * Will only increase capacity if n is bigger than current capacity. + * Reserve capacity to be at least n. + * Will only increase capacity if n is bigger than current capacity. + * + * The effect of Reserve only lasts before the next call to Reserve. + * Other functions in the ring buffer can also call into the reserve. + * * \param n The size of capacity. */ void Reserve(size_t n) { @@ -63,19 +67,27 @@ class RingBuffer { size_t ncopy = head_ptr_ + bytes_available_ - old_size; memcpy(&ring_[0] + old_size, &ring_[0], ncopy); } - } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) { - // shrink too large temporary buffer to avoid out of memory on some embedded devices + } else if (ring_.size() > n * 8 && + ring_.size() > kInitCapacity) { + // shrink too large temporary buffer to + // avoid out of memory on some embedded devices + if (bytes_available_ != 0) { + // move existing bytes to the head. size_t old_bytes = bytes_available_; - std::vector tmp(old_bytes); - Read(&tmp[0], old_bytes); - ring_.resize(kInitCapacity); - ring_.shrink_to_fit(); memcpy(&ring_[0], &tmp[0], old_bytes); - head_ptr_ = 0; bytes_available_ = old_bytes; + } + // shrink the ring. + size_t new_size = kInitCapacity; + new_size = std::max(new_size, n); + new_size = std::max(new_size, bytes_available_); + + ring_.resize(new_size); + ring_.shrink_to_fit(); + head_ptr_ = 0; } } @@ -137,7 +149,7 @@ class RingBuffer { bytes_available_ += size; } /*! - * \brief Writen data into the buffer by give it a non-blocking callback function. + * \brief Written data into the buffer by give it a non-blocking callback function. * * \param frecv A receive function handle * \param max_nbytes Maximum number of bytes can write. @@ -168,9 +180,9 @@ class RingBuffer { private: // buffer head size_t head_ptr_{0}; - // number of bytes in the buffer. + // number of bytes occupied in the buffer. size_t bytes_available_{0}; - // The internald ata ring. + // The internal data ring. std::vector ring_; }; } // namespace support diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 4cf5ccdd081c..796e39b01d58 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -76,11 +76,10 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitExpr(const PrimExpr& e) { if (e.dtype().is_int()) { int bits = max_bits_; - const PrimExprNode* op = e.as(); - if (bound_.find(op) == bound_.end()) { + if (bound_.find(e) == bound_.end()) { analyzer_.const_int_bound(e, &bound_); } - ConstIntBound bound = bound_[op]; + ConstIntBound bound = bound_[e]; int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; int64_t lbound = Downcast(min_value(DataType::Int(target_bits_)))->value; if (e.dtype().bits() <= target_bits_ || @@ -187,7 +186,7 @@ class DataTypeVisitor final : public StmtExprVisitor { // the extent of vars to be rewritten std::unordered_map vextent_; // the memorized bound generated by ConstIntBoundAnalyzer - std::unordered_map bound_; + arith::ConstIntBoundAnalyzer::BoundMapType bound_; }; class DataTypeRewriter : public StmtExprMutator { diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 04d6c94e10e8..da3a456dafb6 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -36,6 +36,7 @@ "scala", "java", "go", + "ts", "sh", "py", "pyi", @@ -81,6 +82,7 @@ # List of file names allowed ALLOW_FILE_NAME = { ".gitignore", + ".eslintignore", ".gitattributes", "README", "Makefile", @@ -107,8 +109,7 @@ "rust/runtime/tests/test_wasm32/.cargo/config", "apps/sgx/.cargo/config", # html for demo purposes - "tests/webgl/test_static_webgl_library.html", - "web/example_rpc.html", + "web/apps/browser/rpc_server.html", # images are normally not allowed # discuss with committers before add more images "apps/android_rpc/app/src/main/res/mipmap-hdpi/ic_launcher.png", diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index 5421d22a08aa..0714850287f3 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -28,6 +28,8 @@ core.cpp build _static _build +node_modules +dist .*~ \#..*\# \.#.* @@ -40,6 +42,7 @@ RelayVisitor.py # Specific files package-list MANIFEST +.eslintignore .gitignore .gitattributes .gitmodules diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1185a5c6bf5d..8bc9f458153f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -859,21 +859,22 @@ def _test_upsample_bilinear_opset9(): in_shape = (1, 1, 3, 3) out_shape = (1, 1, 3*scale, 3*scale) y = helper.make_node("Upsample", ['in', 'scales'], ['out'], mode='linear') - scales = [1.0, 1.0, 2.0, 2.0] + scales = [1, 1, 2, 2] in_array = np.random.uniform(size=in_shape).astype(np.float32) out_array = topi.testing.bilinear_resize_python( in_array, (3*scale, 3*scale), "NCHW") - ref_array = np.array(scales) ref_node = helper.make_node('Constant', inputs=[], - outputs=['scales'], + outputs=['const'], value=onnx.helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, - dims=ref_array.shape, - vals=ref_array.flatten().astype(float))) + dims=scales, + vals=np.random.random(scales).flatten().astype(float))) - graph = helper.make_graph([ref_node, y], + shape_node = helper.make_node("Shape", ['const'], ['scales']) + + graph = helper.make_graph([ref_node, shape_node, y], 'upsample_bilinear_opset9_test', inputs=[helper.make_tensor_value_info( "in", TensorProto.FLOAT, list(in_shape))], @@ -1278,7 +1279,63 @@ def verify_pad(indata, pads, mode='constant', value=0.0): # tvm result for target, ctx in ctx_list(): tvm_out = get_tvm_output( - model, indata, target, ctx, outdata.shape, 'float32') + model, indata, target, ctx, outdata.shape, 'float32', opset=2) + tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + + +def verify_pad_v11(indata, pads, mode='constant', value=0.0): + indata = np.array(indata).astype(np.float32) + # numpy expect result + len_dim = len(pads) // 2 + np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)] + pads = np.array(pads) + # onnx graph + if mode in ['edge', 'reflect']: + inputs = [indata, pads] + outdata = np.pad(indata, pad_width=np_pads, mode=mode) + node = helper.make_node( + 'Pad', + inputs=['input', 'pads'], + outputs=['output'], + mode=mode + ) + graph = helper.make_graph([node], + 'pad_test', + inputs=[helper.make_tensor_value_info("input", + TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("pads", + TensorProto.INT64,(len(pads),))], + initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads)], + outputs=[helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(outdata.shape))]) + else: + inputs = [indata, pads, np.array([value])] + outdata = np.pad(indata, pad_width=np_pads, + mode='constant', constant_values=value) + node = helper.make_node( + 'Pad', + inputs=['input', 'pads', 'constant_value'], + outputs=['output'], + mode='constant' + ) + graph = helper.make_graph([node], + 'pad_test', + inputs=[helper.make_tensor_value_info("input", + TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("pads", + TensorProto.INT64,(len(pads),)), + helper.make_tensor_value_info("constant_value", + TensorProto.INT64,(1,)), + ], + initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads), + helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value])], + outputs=[helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(outdata.shape))]) + model = helper.make_model(graph, producer_name='pad_test') + # tvm result + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, inputs, target, ctx, outdata.shape, 'float32', opset=11) tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) @@ -1294,6 +1351,17 @@ def test_pad(): verify_pad(np.random.randn(1, 3, 4, 5).astype( np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') + verify_pad_v11(np.random.randn(2, 2).astype( + np.float32), [0, 1, 0, 0], 'constant', 0.0) + verify_pad_v11(np.random.randn(2, 3).astype( + np.float32), [1, 0, 0, 1], 'constant', 0.0) + verify_pad_v11(np.random.randn(3, 2).astype( + np.float32), [0, 0, 1, 0], 'constant', 5.0) + verify_pad_v11(np.random.randn(1, 3, 4, 5).astype( + np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge') + verify_pad_v11(np.random.randn(1, 3, 4, 5).astype( + np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') + def verify_reduce_x(name, indata, axis, keepdims): indata = np.array(indata).astype(np.float32) @@ -1306,6 +1374,15 @@ def verify_reduce_x(name, indata, axis, keepdims): outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1) elif name == 'ReduceMean': outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1) + elif name == 'ReduceLogSumExp': + def _np_log_sum_exp(x, axis, keepdims=False): + max_x = np.max(x, axis=axis, keepdims=True) + x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + x = x + max_x + if not keepdims: + x = np.squeeze(x, axis=axis) + return x + outdata = _np_log_sum_exp(indata, axis=axis, keepdims=keepdims == 1) else: raise Exception('unsupport op: {}'.format(name)) if len(np.asarray(outdata).shape) == 0: @@ -1379,6 +1456,34 @@ def test_reduce_mean(): axis=(1,), keepdims=1) +def test_reduce_logsumexp(): + + for keepdims in [True, False]: + verify_reduce_x("ReduceLogSumExp", + np.random.randn(3, 2, 2).astype(np.float32), + axis=None, keepdims=keepdims) + + verify_reduce_x("ReduceLogSumExp", + np.random.randn(3, 2, 3).astype(np.float32), + axis=None, keepdims=keepdims) + + verify_reduce_x("ReduceLogSumExp", + np.random.randn(3, 3, 3).astype(np.float32), + axis=(1,), keepdims=keepdims) + + verify_reduce_x("ReduceLogSumExp", + np.random.randn(3, 3, 3, 1).astype(np.float32), + axis=(1, 2), keepdims=keepdims) + + verify_reduce_x("ReduceLogSumExp", + np.random.randn(3, 3, 3, 1).astype(np.float32), + axis=(1), keepdims=keepdims) + + verify_reduce_x("ReduceLogSumExp", + np.random.randn(1, 3, 4, 1).astype(np.float32), + axis=(1), keepdims=keepdims) + + def verify_split(indata, outdatas, split, axis=0): indata = np.array(indata).astype(np.float32) outdatas = [np.array(o).astype(np.float32) for o in outdatas] @@ -1532,6 +1637,34 @@ def selu_x(x, alpha, gamma): {'alpha': 0.25, 'gamma': 0.3}) +def test_prelu(): + def verify_prelu(x_shape, a_shape): + node = helper.make_node('PRelu', + inputs=['X', 'slope'], + outputs=['Y']) + + graph = helper.make_graph([node], + "prelu_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("slope", TensorProto.FLOAT, list(a_shape))], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(x_shape))]) + + model = helper.make_model(graph, producer_name='prelu_test') + + indata = np.random.uniform(-10, 10, x_shape).astype(np.float32) + slopedata = np.random.uniform(-10, 10, a_shape).astype(np.float32) + onnx_out = get_onnxruntime_output(model, [indata, slopedata]) + + for target, ctx in [('llvm', tvm.cpu())]: + tvm_out = get_tvm_output(model, [indata, slopedata], target, ctx, list(x_shape), + output_dtype='float32') + tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) + + verify_prelu([3,4,5,6], [1, 4, 1, 1]) + verify_prelu([1,8,5,6], [1, 8, 1, 1]) + verify_prelu([2,12,16,16], [1, 12, 1, 1]) + + def test_ThresholdedRelu(): def ThresholdedRelu_x(x, alpha): out_np = np.clip(x, alpha, np.inf) @@ -2528,6 +2661,7 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ test_reduce_min() test_reduce_sum() test_reduce_mean() + test_reduce_logsumexp() test_pad() test_split() test_binary_ops() @@ -2535,6 +2669,7 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ test_leaky_relu() test_elu() test_selu() + test_prelu() test_ThresholdedRelu() test_ScaledTanh() test_ParametricSoftplus() diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 283d87d5078a..da89a139c113 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -73,6 +73,16 @@ def get_real_image(im_height, im_width): data = np.reshape(x, (1, im_height, im_width, 3)) return data +def get_real_image_object_detection(im_height, im_width): + repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/' + img_name = 'street_small.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + image = Image.open(img_path).resize((im_height, im_width)) + x = np.array(image).astype('uint8') + data = np.reshape(x, (1, im_height, im_width, 3)) + return data + def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', out_names=None): """ Generic function to compile on relay and execute on tvm """ @@ -98,6 +108,7 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) + with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) @@ -419,6 +430,29 @@ def test_forward_cast(): _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8) _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64) +####################################################################### +# Batch Mat Mul +# ---- +def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False): + with tf.Graph().as_default(): + A = array_ops.placeholder(shape=A_shape, dtype=dtype, name='A') + B = array_ops.placeholder(shape=B_shape, dtype=dtype, name='B') + result = math_ops.matmul(A, B, adjoint_a=adjoint_a, + adjoint_b=adjoint_b, name='batchmatmul') + + A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) + B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) + compare_tflite_with_tvm([A_np, B_np], [A.name, B.name], [A, B], [result]) + + +def test_forward_batch_matmul(): + """ BATCH_MAT_MUL """ + _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32') + _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True) + _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', True, False) + _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True) + _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'float32') + ####################################################################### # Tile # ---- @@ -1176,6 +1210,43 @@ def test_all_elemwise(): _test_forward_elemwise(_test_floor_divide) _test_forward_elemwise(_test_floor_mod) + +####################################################################### +# AddN +# ---- + + +def _test_forward_add_n(inputs): + tf.reset_default_graph() + with tf.Graph().as_default(): + temp = [] + for each in inputs: + temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype)) + output = tf.add_n(temp) + compare_tflite_with_tvm([each for each in inputs], [each.name for each in temp], + [each for each in temp], [output]) + + +def test_forward_add_n(): + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + x = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) + y = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) + z = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) + m, n, o = x.astype(np.float32), y.astype(np.float32), z.astype(np.float32) + in0 = x + in1 = [x, y] + in2 = (x, y, z) + in3 = m + in4 = [m, n] + in5 = (m, n, o) + _test_forward_add_n(in0) + _test_forward_add_n(in1) + _test_forward_add_n(in2) + _test_forward_add_n(in3) + _test_forward_add_n(in4) + _test_forward_add_n(in5) + + ####################################################################### # Logical operators # ----------------- @@ -1378,6 +1449,27 @@ def test_all_reduce(): ####################################################################### +# Select, Where +# ------------- + +def test_forward_select(): + with tf.Graph().as_default(): + with tf.Session() as sess: + input1 = tf.placeholder( + tf.int32, shape=[1, 4, 4, 3], name='input1') + input2 = tf.placeholder( + tf.int32, shape=[1, 4, 4, 3], name='input2') + mask = input1 > input2 + out = tf.where(mask, input1 + 1, input2 * 2) + in_data1 = np.random.uniform( + 0, 10, size=(1, 4, 4, 3)).astype("int32") + in_data2 = np.random.uniform( + 0, 10, size=(1, 4, 4, 3)).astype("int32") + + compare_tflite_with_tvm([in_data1, in_data2], [ + 'input1:0', 'input2:0'], [input1, input2], [out]) + + # Squeeze # ------- @@ -1741,23 +1833,30 @@ def test_detection_postprocess(): tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions]) tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions], ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4) - # check valid count is the same + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same assert tvm_output[3] == tflite_output[3] - # check all the output shapes are the same - assert tvm_output[0].shape == tflite_output[0].shape - assert tvm_output[1].shape == tflite_output[1].shape - assert tvm_output[2].shape == tflite_output[2].shape valid_count = tvm_output[3][0] - # only check the valid detections are the same - # tvm has a different convention to tflite for invalid detections, it uses all -1s whereas - # tflite appears to put in nonsense data instead - tvm_boxes = tvm_output[0][0][:valid_count] - tvm_classes = tvm_output[1][0][:valid_count] - tvm_scores = tvm_output[2][0][:valid_count] - # check the output data is correct - tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) - tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5) - tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5) + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # Check bounding box co-ords + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), + rtol=1e-5, atol=1e-5) + + # Check the class + # Stricter check to ensure class remains same + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), + np.squeeze(tflite_output[1][0][i])) + + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) ####################################################################### @@ -1942,6 +2041,100 @@ def test_forward_qnn_mobilenet_v3_net(): tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) +####################################################################### +# Quantized SSD Mobilenet +# ----------------------- + +def test_forward_qnn_coco_ssd_mobilenet_v1(): + """Test the quantized Coco SSD Mobilenet V1 TF Lite model.""" + pytest.skip("LLVM bug - getExtendedVectorNumElements - " + + "https://discuss.tvm.ai/t/segfault-in-llvm/3567. The workaround is to use a " + + "specific target, for example, llvm -mpcu=core-avx2") + + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", + "detect.tflite") + + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + data = get_real_image_object_detection(300, 300) + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same + assert tvm_output[3] == tflite_output[3] + valid_count = tvm_output[3][0] + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # We compare the bounding boxes whose prediction score is above 60%. This is typical in end + # to end application where a low prediction score is discarded. This is also needed because + # multiple low score bounding boxes can have same score and TFlite and TVM can have + # different orderings for same score bounding boxes. Another reason for minor differences in + # low score bounding boxes is the difference between TVM and TFLite for requantize operator. + if tvm_output[2][0][i] > 0.6: + # Check bounding box co-ords. The tolerances have to be adjusted, from 1e-5 to 1e-2, + # because of differences between for requantiize operator in TFLite and TVM. + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), + np.squeeze(tflite_output[0][0][i]), + rtol=1e-2, atol=1e-2) + + # Check the class + # Stricter check to ensure class remains same + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), + np.squeeze(tflite_output[1][0][i])) + + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), + np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) + + +####################################################################### +# SSD Mobilenet +# ------------- + +def test_forward_coco_ssd_mobilenet_v1(): + """Test the FP32 Coco SSD Mobilenet V1 TF Lite model.""" + tflite_model_file = tf_testing.get_workload_official( + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tgz", + "ssd_mobilenet_v1_coco_2018_01_28.tflite") + + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + np.random.seed(0) + data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32') + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same + assert tvm_output[3] == tflite_output[3] + valid_count = tvm_output[3][0] + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # Check bounding box co-ords + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), + rtol=1e-5, atol=1e-5) + # Check the class + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i])) + + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) + ####################################################################### # MediaPipe # ------------- @@ -1961,6 +2154,7 @@ def test_forward_mediapipe_hand_landmark(): tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]), rtol=1e-5, atol=1e-5) + ####################################################################### # Main # ---- @@ -1980,6 +2174,9 @@ def test_forward_mediapipe_hand_landmark(): # Cast test_forward_cast() + # BatchMatMul + test_forward_batch_matmul() + # Tile test_forward_tile() @@ -1997,6 +2194,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_stridedslice() test_forward_depthtospace() test_forward_spacetodepth() + test_forward_select() # NN test_forward_convolution() @@ -2014,6 +2212,7 @@ def test_forward_mediapipe_hand_landmark(): # Elemwise test_all_elemwise() + test_forward_add_n() # Unary elemwise test_all_unary_elemwise() @@ -2038,6 +2237,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_mobilenet_v3() test_forward_inception_v3_net() test_forward_inception_v4_net() + test_forward_coco_ssd_mobilenet_v1() test_forward_mediapipe_hand_landmark() # End to End quantized @@ -2047,3 +2247,4 @@ def test_forward_mediapipe_hand_landmark(): #This also fails with a segmentation fault in my run #with Tflite 1.15.2 test_forward_qnn_mobilenet_v3_net() + test_forward_qnn_coco_ssd_mobilenet_v1() diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index bbe2c69d6294..947a4bfd0b3b 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -165,7 +165,10 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") dtype = "bool" if ref_func in [np.all, np.any] else dtype x = relay.var("x", relay.TensorType(data, dtype)) - z = test_func(x, axis, keepdims, exclude) + if test_func == relay.logsumexp: + z = test_func(x, axis, keepdims) + else: + z = test_func(x, axis, keepdims, exclude) zz = run_infer_type(z) if axis: assert "axis=" in z.astext() @@ -215,6 +218,14 @@ def _wrapper(data, axis=None, keepdims=False): return func(data, axis=axis).reshape(out_shape) return _wrapper + def _np_log_sum_exp(x, axis, keepdims=False): + max_x = np.max(x, axis=axis, keepdims=True) + x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + x = x + max_x + if not keepdims: + x = np.squeeze(x, axis=axis) + return x + d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") for func in [[relay.sum, np.sum], [relay.max, np.max], @@ -225,6 +236,7 @@ def _wrapper(data, axis=None, keepdims=False): [relay.prod, np.prod], [relay.all, np.all], [relay.any, np.any], + [relay.logsumexp, _np_log_sum_exp], [relay.argmin, _with_keepdims(np.argmin)], [relay.argmax, _with_keepdims(np.argmax)]]: verify_reduce(func, (d1, d2, d3, d4), None, False, False, ()) diff --git a/tests/python/relay/test_op_qnn_concatenate.py b/tests/python/relay/test_op_qnn_concatenate.py index 03ab9eeb1321..fb60e9805206 100644 --- a/tests/python/relay/test_op_qnn_concatenate.py +++ b/tests/python/relay/test_op_qnn_concatenate.py @@ -144,7 +144,32 @@ def test_same_i_qnn_params(): op_res = intrp.evaluate(func)(x_data, y_data) np.testing.assert_equal(op_res.asnumpy(), golden_output) +def test_call_input(): + # This tests the case where the input to concatenate is not explicitly a + # tuple node but is instead a call node. + x_data = np.ones(shape=(64,)).astype('uint8') + + x = relay.var("x", shape=(64,), dtype='uint8') + x_scale = relay.const(1, 'float32') + y_scale = relay.const(1, 'float32') + x_zero_point = relay.const(0, 'int32') + y_zero_point = relay.const(0, 'int32') + + tup = relay.split(x, 2, axis=0) + z = relay.qnn.op.concatenate(tup, + input_scales=(x_scale, y_scale), + input_zero_points=(x_zero_point, y_zero_point), + output_scale=y_scale, + output_zero_point=relay.const(0, 'int32'), + axis=0) + func = relay.Function([x], z) + + intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm") + op_res = intrp.evaluate(func)(x_data) + np.testing.assert_equal(op_res.asnumpy(), x_data) + if __name__ == '__main__': + test_call_input() test_same_io_qnn_params() test_different_io_qnn_params() test_few_same_io_qnn_params() diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 9b18f72cb6e7..bc0420f26d9b 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -153,6 +153,56 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_alter_layout_lrn(): + """Test alternating the layout of a conv2d. + The layout of broadcast operators and the weight should be changed accordingly. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias") + weight = relay.var("weight") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.max_pool2d(y, pool_size=(2, 2)) + y = relay.nn.lrn(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + new_attrs['kernel_layout'] = 'OIHW16i' + return relay.nn.conv2d(data, weight, **new_attrs) + + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + + y = relay.layout_transform(x, "NCHW", "NCHW16c") + w = relay.layout_transform(weight, "OIHW", "OIHW16i") + y = relay.nn.conv2d(y, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + kernel_layout="OIHW16i", + data_layout="NCHW16c") + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW16c") + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.nn.lrn(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = before() + a = run_opt_pass(a, [transform.CanonicalizeOps(), + transform.AlterOpLayout()]) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_alter_layout_dual_path(): """ @@ -1027,6 +1077,7 @@ def expected(): test_alter_return_none() test_alter_layout() test_alter_layout_dual_path() + test_alter_layout_lrn() test_alter_layout_resnet() test_alter_layout_broadcast_op() test_alter_layout_broadcast_scalar_op() diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 0dcf1fb5344c..179152273c00 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -126,7 +126,7 @@ def test_floormod_simplify(): x, y = te.var("x"), te.var("y") ck.verify(flm(flm((x*4) + y - 466036, 24528) - 24512, 16), flm((x*4) + y + 12, 16)) - + ck.verify(flm(flm((x*4), 16), 8), flm(x, 2) * 4) def test_canonical_mixed(): diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index b61e6bb9fa01..17321bdeb293 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -18,10 +18,12 @@ from tvm import te import tvm.testing import os +import stat import logging import time import multiprocessing +import pytest import numpy as np from tvm import rpc from tvm.contrib import util @@ -77,11 +79,9 @@ def remotethrow(name): f1 = client.get_function("rpc.test.addone") assert f1(10) == 11 f3 = client.get_function("rpc.test.except") - try: + + with pytest.raises(tvm.error.RPCError): f3("abc") - assert False - except tvm.error.TVMError as e: - assert "abc" in str(e) f2 = client.get_function("rpc.test.strcat") assert f2("abc", 11) == "abc:11" @@ -101,6 +101,58 @@ def remote_array_func(y): fremote = remote.get_function("rpc.test.remote_array_func") fremote(r_cpu) + +def test_rpc_large_array(): + # testcase of large array creation + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + ctx = remote.cpu(0) + a_np = np.ones((5041, 720)).astype('float32') + b_np = np.ones((720, 192)).astype('float32') + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + np.testing.assert_equal(a.asnumpy(), a_np) + np.testing.assert_equal(b.asnumpy(), b_np) + + +def test_rpc_echo(): + def check(remote): + fecho = remote.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + with pytest.raises(RuntimeError): + raise_err = remote.get_function( + "testing.test_raise_error_callback")("RuntimeError") + raise_err() + + remote.cpu().sync() + with pytest.raises(AttributeError): + f3 = remote.system_lib()["notexist"] + + + temp = rpc.server._server_env([]) + server = rpc.Server("localhost") + client = rpc.connect(server.host, server.port) + check(rpc.LocalSession()) + + check(client) + # Test minrpc server. + temp = util.tempdir() + minrpc_exec = temp.relpath("minrpc") + tvm.rpc.with_minrpc("g++")(minrpc_exec, []) + check(rpc.PopenSession(minrpc_exec)) + # minrpc on the remote + server = rpc.Server("localhost") + client = rpc.connect( + server.host, server.port, + session_constructor_args=["rpc.PopenSession", + open(minrpc_exec, "rb").read()]) + check(client) + + def test_rpc_file_exchange(): if not tvm.runtime.enabled("rpc"): return @@ -114,14 +166,20 @@ def test_rpc_file_exchange(): def test_rpc_remote_module(): if not tvm.runtime.enabled("rpc"): return - server = rpc.Server("localhost") - client = rpc.connect(server.host, server.port) # graph - n = tvm.runtime.convert(1024) + n = tvm.runtime.convert(102) A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) + server0 = rpc.Server("localhost", key="x0") + server1 = rpc.Server("localhost", key="x1") + + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor_args=[ + "rpc.Connect", server1.host, server1.port, "x1"]) + def check_remote(remote): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") @@ -133,13 +191,45 @@ def check_remote(remote): f.export_library(path_dso) remote.upload(path_dso) f1 = remote.load_module("dev_lib.so") - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + def check_minrpc(): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + if tvm.get_global_func("rpc.PopenSession", allow_missing=True) is None: + return + # export to minrpc + temp = util.tempdir() + f = tvm.build(s, [A, B], "llvm --system-lib", name="myadd") + path_minrpc = temp.relpath("dev_lib.minrpc") + f.export_library(path_minrpc, rpc.with_minrpc("g++")) + + with pytest.raises(RuntimeError): + rpc.PopenSession("filenotexist") + + # statrt the minrpc session. + remote = tvm.rpc.PopenSession(path_minrpc) + ctx = remote.cpu(0) + f1 = remote.system_lib() + + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) + time_f = f1.time_evaluator("myadd", remote.cpu(0), number=1) + cost = time_f(a, b).mean + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + + # change to not executable + os.chmod(path_minrpc, stat.S_IRUSR) + with pytest.raises(RuntimeError): + rpc.PopenSession(path_minrpc) + + def check_remote_link_cl(remote): """Test function to run remote code such as cl @@ -174,8 +264,8 @@ def check_remote_link_cl(remote): fhost = remote.load_module("myadd.o") fdev = remote.load_module("myadd.cl") fhost.import_module(fdev) - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) fhost(a, b) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) # Option 2: export library as a tar ball then handled by remote compiler @@ -183,13 +273,15 @@ def check_remote_link_cl(remote): f.export_library(path_tar) remote.upload(path_tar) fhost = remote.load_module("myadd.tar") - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) fhost(a, b) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote(client) check_remote(rpc.LocalSession()) + check_remote(client) + check_minrpc() + def test_rpc_return_func(): @@ -204,6 +296,37 @@ def addone(x): assert fadd(12) == 22 +def test_rpc_session_constructor_args(): + # start server + server0 = rpc.Server("localhost", key="x0") + server1 = rpc.Server("localhost", key="x1") + + def check_multi_hop(): + # use server0 as proxy to connect to server1 + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor_args=[ + "rpc.Connect", server1.host, server1.port, "x1"]) + + fecho = client.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + nd = tvm.nd.array([1,2,3], ctx=client.cpu(0)) + assert(nd.asnumpy()[1] == 2) + + def check_error_handling(): + with pytest.raises(tvm.error.RPCError): + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor_args=["rpc.NonExistingConstructor"]) + + check_multi_hop() + check_error_handling() + + def test_rpc_return_ndarray(): # Use closure to check the ref counter correctness nd = tvm.nd.array(np.zeros(10).astype("float32")) @@ -221,6 +344,7 @@ def my_module(name): # start server server = rpc.Server("localhost", key="x1") client = rpc.connect(server.host, server.port, key="x1") + m = client.get_function("rpc.test.remote_return_nd") get_arr = m("get_arr") ref_count = m("ref_count") @@ -315,6 +439,7 @@ def target(host, port, device_key, timeout): time.sleep(0.5) summary = client.summary() + assert summary['queue_info'][device_key]['free'] == 0 assert summary['queue_info'][device_key]['pending'] == 1 @@ -334,6 +459,8 @@ def target(host, port, device_key, timeout): if __name__ == "__main__": logging.basicConfig(level=logging.INFO) + test_rpc_echo() + test_rpc_session_constructor_args() test_rpc_return_ndarray() test_rpc_return_func() test_bigendian_rpc() @@ -344,3 +471,4 @@ def target(host, port, device_key, timeout): test_local_func() test_rpc_tracker_register() test_rpc_tracker_request() + test_rpc_large_array() diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 819961dc6ebd..41006f41f754 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -43,9 +43,6 @@ cd .. make doc rm -f docs/doxygen/html/*.map docs/doxygen/html/*.md5 -# JS doc -jsdoc -c web/.jsdoc_conf.json web/tvm_runtime.js web/README.md - # Java doc make javadoc @@ -54,7 +51,6 @@ rm -rf _docs mv docs/_build/html _docs rm -f _docs/.buildinfo mv docs/doxygen/html _docs/doxygen -mv out _docs/jsdoc mv jvm/core/target/site/apidocs _docs/javadoc echo "Start creating the docs tarball.." diff --git a/tests/web/test_packed_func.js b/tests/web/test_packed_func.js deleted file mode 100644 index d239f7346e74..000000000000 --- a/tests/web/test_packed_func.js +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -function testGetGlobal() { - var targs = [10, 10.0, "hello"] - tvm.registerFunc("my_packed_func", function () { - tvm.assert(Array.from(arguments).toString() == targs, "assert fail"); - return 10 - }); - var f = tvm.getGlobalFunc("my_packed_func") - tvm.assert(tvm.isPackedFunc(f)); - y = f.apply(null, targs); - tvm.assert(y == 10); - f.release(); -} - - -function testReturnFunc() { - function addy(y) { - function add(x) { - return x + y; - } - return add; - } - var myf = tvm.convertFunc(addy); - var f = myf(10); - tvm.assert(tvm.isPackedFunc(f)); - tvm.assert(f(11) == 21); - myf.release(); - f.release(); -} - -function testByteArray() { - var a = new Uint8Array(3); - a[0] = 1; - a[1] = 2; - function myfunc(ss){ - tvm.assert(ss instanceof Uint8Array); - tvm.assert(ss.toString() == a); - } - f = tvm.convertFunc(myfunc); - f(a); - f.release(); -} - -testGetGlobal(); -testReturnFunc(); -testByteArray(); diff --git a/tests/webgl/README.md b/tests/webgl/README.md deleted file mode 100644 index 5303cc059740..000000000000 --- a/tests/webgl/README.md +++ /dev/null @@ -1,24 +0,0 @@ - - - - - - - - - - - - - - - - - -## Test cases for the WebGL backend - -Any test case with name `test_local_...` tests the C++ OpenGL backend on the -local OS, which can be executed automatically. - -Any test case with name `test_remote_...` tests the WebGL backend within the -browser, which must be run manually. See instruction within the test. diff --git a/tests/webgl/test_local_gemm.py b/tests/webgl/test_local_gemm.py deleted file mode 100644 index 6bd22bf0057b..000000000000 --- a/tests/webgl/test_local_gemm.py +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te -import numpy as np - -def test_local_gemm(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - nn = 1024 - n = te.var('n') - n = tvm.runtime.convert(nn) - m = n - l = n - A = te.placeholder((n, l), name='A', dtype='int32') - B = te.placeholder((m, l), name='B', dtype='int32') - k = te.reduce_axis((0, l), name='k') - C = te.compute((n, m), lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k), - name='CC') - - s = te.create_schedule(C.op) - s[C].opengl() - print(tvm.lower(s, [A, B, C], simple_mode=True)) - - f = tvm.build(s, [A, B, C], "opengl", name="gemm") - print("------opengl code------") - print(f.imported_modules[0].get_source(fmt="gl")) - - ctx = tvm.opengl() - n, m, l = nn, nn, nn - a_np = np.random.uniform(low=0, high=10, size=(n, l)).astype(A.dtype) - b_np = np.random.uniform(low=0, high=10, size=(m, l)).astype(B.dtype) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) - f(a, b, c) - - tvm.testing.assert_allclose(c.asnumpy(), np.dot(a_np, b_np.T)) - -if __name__ == "__main__": - test_local_gemm() diff --git a/tests/webgl/test_local_save_load.py b/tests/webgl/test_local_save_load.py deleted file mode 100644 index cca68020c0c2..000000000000 --- a/tests/webgl/test_local_save_load.py +++ /dev/null @@ -1,53 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import numpy as np -import tvm -from tvm import te -from tvm import rpc -from tvm.contrib import util, emscripten - -def test_local_save_load(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - n = te.var("n") - A = te.placeholder((n,), name='A', dtype='int32') - B = te.placeholder((n,), name='B', dtype='int32') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - s[C].opengl() - - f = tvm.build(s, [A, B, C], "opengl", target_host="llvm", name="myadd") - - ctx = tvm.opengl(0) - n = 10 - a = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(B.dtype), ctx) - c = tvm.nd.array(np.zeros((n), dtype=C.dtype), ctx) - f(a, b, c) - - temp = util.tempdir() - path_so = temp.relpath("myadd.so") - f.export_library(path_so) - f1 = tvm.runtime.load_module(path_so) - f1(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - -if __name__ == "__main__": - test_local_save_load() diff --git a/tests/webgl/test_local_topi_conv2d_nchw.py b/tests/webgl/test_local_topi_conv2d_nchw.py deleted file mode 100644 index 0d9b7776096a..000000000000 --- a/tests/webgl/test_local_topi_conv2d_nchw.py +++ /dev/null @@ -1,99 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Example code to do convolution. -Copied from topi/tests/python/test_topi_conv2d_nchw.py. -Should be removed once we fix OpenGL testing on Jenkins.""" -import os -import numpy as np -import tvm -from tvm import te -import topi -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple - -def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): - in_height = in_width = in_size - - A = te.placeholder((batch, in_channel, in_height, in_width), name='A') - W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W') - B = topi.nn.conv2d_nchw(A, W, stride, padding) - C = topi.nn.relu(B) - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s1 = topi.generic.schedule_conv2d_nchw([B]) - s2 = topi.generic.schedule_conv2d_nchw([C]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - with tvm.target.build_config(auto_unroll_max_step=1400, - unroll_explicit=(device != "cuda")): - func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func1(a, w, b) - func2(a, w, c) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - - -def test_conv2d_nchw(): - # ResNet18 worklaods - verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) - verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) - verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) - verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) - verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) - verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) - verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) - verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) - verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) - # Vgg16 workloads - verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1) - # Super resolution workloads - verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2) - verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1) - verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1) - -if __name__ == "__main__": - test_conv2d_nchw() diff --git a/tests/webgl/test_local_topi_dense.py b/tests/webgl/test_local_topi_dense.py deleted file mode 100644 index 60dfe1ff690f..000000000000 --- a/tests/webgl/test_local_topi_dense.py +++ /dev/null @@ -1,76 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for dense operator -Copied from topi/tests/python/test_topi_dense.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" -import numpy as np -import tvm -from tvm import te -import topi -from topi.util import get_const_tuple -from tvm.contrib.pickle_memoize import memoize - - -def verify_dense(batch, in_dim, out_dim, use_bias=True): - A = te.placeholder((batch, in_dim), name='A') - B = te.placeholder((out_dim, in_dim), name='B') - C = te.placeholder((out_dim,), name='C') - D = topi.nn.dense(A, B, C if use_bias else None) - D = topi.nn.relu(D) - dtype = A.dtype - - # use memoize to pickle the test data for next time use - @memoize("topi.tests.test_topi_dense") - def get_ref_data(): - a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) - b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) - c_np = np.random.uniform(size=(out_dim,)).astype(dtype) - if use_bias: - d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) - else: - d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) - return (a_np, b_np, c_np, d_np) - # get the test data - a_np, b_np, c_np, d_np = get_ref_data() - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_dense(D) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(c_np, ctx) - d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx) - f = tvm.build(s, [A, B, C, D], device, name="dense") - f(a, b, c, d) - tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_dense(): - verify_dense(1, 1024, 1000, use_bias=True) - verify_dense(1, 1024, 1000, use_bias=False) - - -if __name__ == "__main__": - test_dense() diff --git a/tests/webgl/test_local_topi_pooling.py b/tests/webgl/test_local_topi_pooling.py deleted file mode 100644 index 3adae7bba51c..000000000000 --- a/tests/webgl/test_local_topi_pooling.py +++ /dev/null @@ -1,132 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for pooling -Copied from topi/tests/python/test_topi_pooling.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" -import numpy as np -import tvm -from tvm import te -import topi -import math -from topi.util import get_const_tuple - -def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): - iw = ih - kw = kh - sw = sh - ph, pw = padding - A = te.placeholder((n, ic, ih, iw), name='A') - B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, - pool_type=pool_type, ceil_mode=ceil_mode) - B = topi.nn.relu(B) - dtype = A.dtype - - bshape = get_const_tuple(B.shape) - ashape = get_const_tuple(A.shape) - if ceil_mode: - assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1) - else: - assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1) - - - a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) - pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype) - no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) - pad_np[np.ix_(*no_zero)] = a_np - _, oc, oh, ow = get_const_tuple(B.shape) - b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) - - if pool_type == 'avg': - for i in range(oh): - for j in range(ow): - b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) - elif pool_type =='max': - for i in range(oh): - for j in range(ow): - b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) - b_np = np.maximum(b_np, 0.0) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_pool(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) - print(tvm.lower(s, [A, B], simple_mode=True)) - - f = tvm.build(s, [A, B], device) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_pool(): - verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False) - verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False) - verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True) - - - -def verify_global_pool(n, c, h, w, pool_type): - A = te.placeholder((n, c, h, w), name='A') - B = topi.nn.global_pool(A, pool_type=pool_type) - B = topi.nn.relu(B) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - if pool_type == 'avg': - b_np = np.mean(a_np, axis=(2,3), keepdims=True) - elif pool_type =='max': - b_np = np.max(a_np, axis=(2,3), keepdims=True) - b_np = np.maximum(b_np, 0.0) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_global_pool(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - f = tvm.build(s, [A, B], device) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_global_pool(): - verify_global_pool(1, 1024, 7, 7, 'avg') - verify_global_pool(4, 1024, 7, 7, 'avg') - verify_global_pool(1, 1024, 7, 7, 'max') - verify_global_pool(4, 1024, 7, 7, 'max') - - -if __name__ == "__main__": - test_pool() - test_global_pool() diff --git a/tests/webgl/test_local_topi_softmax.py b/tests/webgl/test_local_topi_softmax.py deleted file mode 100644 index c0ddbf21419a..000000000000 --- a/tests/webgl/test_local_topi_softmax.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for softmax -Copied from topi/tests/python/test_topi_softmax.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" - -import os -import numpy as np -import tvm -from tvm import te -import topi -import logging -from topi.util import get_const_tuple - -def verify_softmax(m, n): - A = te.placeholder((m, n), name='A') - B = topi.nn.softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = topi.testing.softmax_python(a_np) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_softmax(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - foo = tvm.build(s, [A, B], device, name="softmax") - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ["opengl"]: - check_device(device) - -def test_softmax(): - verify_softmax(32, 10) - verify_softmax(3, 4) - - -def verify_log_softmax(m, n): - A = te.placeholder((m, n), name='A') - B = topi.nn.log_softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = topi.testing.log_softmax_python(a_np) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_softmax(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - foo = tvm.build(s, [A, B], device, name="log_softmax") - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ["opengl"]: - check_device(device) - - -def test_log_softmax(): - verify_log_softmax(32, 10) - verify_log_softmax(3, 4) - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - test_softmax() - test_log_softmax() diff --git a/tests/webgl/test_remote_save_load.py b/tests/webgl/test_remote_save_load.py deleted file mode 100644 index 34bbb3fa0f00..000000000000 --- a/tests/webgl/test_remote_save_load.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -The following instruction is based on web/README.md. - -Setup an RPC server: -$ python -m tvm.exec.rpc_proxy --example-rpc=1 - -Go to http://localhost:9190 in browser. - -Click "Connect To Proxy". - -Run this test script: -$ python tests/webgl/test_remote_save_load.py -""" - -import numpy as np -import tvm -from tvm import te -from tvm import rpc -from tvm.contrib import util, emscripten - -proxy_host = "localhost" -proxy_port = 9090 - -def try_remote_save_load(): - if not tvm.runtime.enabled("rpc"): - return - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - # Build the module. - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.placeholder((n,), name='B') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - s[C].opengl() - target_host = "llvm -target=asmjs-unknown-emscripten -system-lib" - f = tvm.build(s, [A, B, C], "opengl", target_host=target_host, name="myadd") - - remote = rpc.connect(proxy_host, proxy_port, key="js") - - temp = util.tempdir() - ctx = remote.opengl(0) - path_obj = temp.relpath("myadd.bc") - path_dso = temp.relpath("myadd.js") - path_gl = temp.relpath("myadd.gl") - path_json = temp.relpath("myadd.tvm_meta.json") - - f.save(path_obj) - emscripten.create_js(path_dso, path_obj, side_module=True) - f.imported_modules[0].save(path_gl) - - remote.upload(path_dso, "myadd.dso") - remote.upload(path_gl) - remote.upload(path_json) - - remote.download("myadd.dso") - remote.download("myadd.gl") - remote.download("myadd.tvm_meta.json") - - print('Loading myadd.dso') - fhost = remote.load_module("myadd.dso") - - print('Loading myadd.gl') - fdev = remote.load_module("myadd.gl") - - print('import_module') - fhost.import_module(fdev) - - print('running...') - a = tvm.nd.array(np.random.uniform(size=16).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(16, dtype=A.dtype), ctx) - c = tvm.nd.array(np.zeros(16, dtype=C.dtype), ctx) - fhost(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - -if __name__ == "__main__": - try_remote_save_load() diff --git a/tests/webgl/test_static_webgl_library.html b/tests/webgl/test_static_webgl_library.html deleted file mode 100644 index f9268c65edf3..000000000000 --- a/tests/webgl/test_static_webgl_library.html +++ /dev/null @@ -1,72 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - TVM RPC Test Page - - - -

TVM Test Page

-
- - - - - - - - \ No newline at end of file diff --git a/tests/webgl/test_static_webgl_library.py b/tests/webgl/test_static_webgl_library.py deleted file mode 100644 index 929da4ca294c..000000000000 --- a/tests/webgl/test_static_webgl_library.py +++ /dev/null @@ -1,66 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Create a static WebGL library and run it in the browser.""" - -from __future__ import absolute_import, print_function - -import os, shutil, SimpleHTTPServer, SocketServer -import tvm -from tvm import te -from tvm.contrib import emscripten, util -import numpy as np - -def try_static_webgl_library(): - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - - # Change to lib/ which contains "libtvm_runtime.bc". - os.chdir(os.path.join(curr_path, "../../lib")) - - # Create OpenGL module. - n = te.var("n") - A = te.placeholder((n,), name='A', dtype="float") - B = te.compute((n,), lambda *i: A[i], name="B") - - s = te.create_schedule(B.op) - s[B].opengl() - - target_host = "llvm -target=asmjs-unknown-emscripten -system-lib" - f = tvm.build(s, [A, B], name="identity", target="opengl", - target_host=target_host) - - # Create a JS library that contains both the module and the tvm runtime. - path_dso = "identity_static.js" - f.export_library(path_dso, emscripten.create_js, options=[ - "-s", "USE_GLFW=3", - "-s", "USE_WEBGL2=1", - "-lglfw", - ]) - - # Create "tvm_runtime.js" and "identity_static.html" in lib/ - shutil.copyfile(os.path.join(curr_path, "../../web/tvm_runtime.js"), - "tvm_runtime.js") - shutil.copyfile(os.path.join(curr_path, "test_static_webgl_library.html"), - "identity_static.html") - - port = 8080 - handler = SimpleHTTPServer.SimpleHTTPRequestHandler - httpd = SocketServer.TCPServer(("", port), handler) - print("Please open http://localhost:" + str(port) + "/identity_static.html") - httpd.serve_forever() - -if __name__ == "__main__": - try_static_webgl_library() diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 2b7a845cd9ec..8ccd80f38a91 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -25,6 +25,7 @@ from .conv2d_hwcn import * from .conv2d_int8 import * from .conv2d_winograd import * +from .conv2d_nhwc_winograd import * from .depthwise_conv2d import * from .group_conv2d_nchw import * from . import conv2d_alter_op diff --git a/topi/python/topi/cuda/conv2d_alter_op.py b/topi/python/topi/cuda/conv2d_alter_op.py index 8d9e86c192a0..c1e207cc2938 100644 --- a/topi/python/topi/cuda/conv2d_alter_op.py +++ b/topi/python/topi/cuda/conv2d_alter_op.py @@ -111,6 +111,42 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return relay.nn.contrib_conv2d_winograd_without_weight_transform( inputs[0], weight, **new_attrs) + if topi_tmpl in ('conv2d_nhwc_winograd_direct.cuda', 'conv2d_nhwc_winograd_tensorcore.cuda'): + if dilation != (1, 1): + logger.warning("Does not support weight pre-transform for dilated convolution.") + return None + + assert data_layout == "NHWC" and kernel_layout == "HWIO" + N, H, W, CI = get_const_tuple(data.shape) + KH, KW, _, CO = get_const_tuple(kernel.shape) + + # Pre-compute weight transformation in winograd + if H % 8 == 0: + tile_size = 4 + else: + tile_size = 2 + kernel_transform = relay.transpose(inputs[1], axes=[3, 2, 0, 1]) + weight = relay.nn.contrib_conv2d_winograd_weight_transform(kernel_transform, + tile_size=tile_size) + weight = relay.transpose(weight, axes=[0, 1, 3, 2]) + new_attrs['tile_size'] = tile_size + new_attrs['channels'] = CO + # Store the same config for the altered operator (workload) + new_data = data + new_weight = te.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO), + dtype=kernel.dtype) + if topi_tmpl == "conv2d_nhwc_winograd_direct.cuda": + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + "conv2d_nhwc_winograd_direct_without_weight_transform.cuda") + elif topi_tmpl == "conv2d_nhwc_winograd_tensorcore.cuda": + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + "conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs) + if topi_tmpl == "group_conv2d_NCHWc_int8.cuda": assert data_layout == "NCHW" and kernel_layout == "OIHW" N, CI, H, W = get_const_tuple(data.shape) diff --git a/topi/python/topi/cuda/conv2d_nhwc_winograd.py b/topi/python/topi/cuda/conv2d_nhwc_winograd.py new file mode 100644 index 000000000000..2f5b85eed620 --- /dev/null +++ b/topi/python/topi/cuda/conv2d_nhwc_winograd.py @@ -0,0 +1,639 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +# pylint: disable=too-many-arguments,too-many-locals +# pylint: disable=too-many-statements +"""Winograd template for cuda backend""" + +import tvm +from tvm import te +from tvm import autotvm +from .. import nn +from ..util import get_const_int, get_const_tuple, traverse_inline +from ..nn.winograd_util import winograd_transform_matrices +from .tensor_intrin import intrin_wmma_load_matrix_A +from .tensor_intrin import intrin_wmma_load_matrix_W +from .tensor_intrin import intrin_wmma_store_matrix +from .tensor_intrin import intrin_wmma_gemm + +def _infer_tile_size(data, kernel): + """Compute the tile size""" + N, H, W, CI = get_const_tuple(data.shape) + if H % 8 == 0: + return 4 + return 2 + + +def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack): + """Schedule for bgemm tensorcore""" + A = data_pack + B = kernel_pack + C = bgemm + _, _, P, out_dim = get_const_tuple(C.shape) + out_dtype = C.dtype + + # Explicit memory access + AS = s.cache_read(A, 'shared', [C]) + BS = s.cache_read(B, 'shared', [C]) + AF = s.cache_read(AS, 'wmma.matrix_a', [C]) + BF = s.cache_read(BS, 'wmma.matrix_b', [C]) + CF = s.cache_write(C, 'wmma.accumulator') + CS = s.cache_read(CF, 'shared', [C]) + + # Create tuning space + cfg.define_knob("block_row_warps", [1, 2, 4]) + cfg.define_knob("block_col_warps", [1, 2, 4]) + cfg.define_knob("warp_row_tiles", [1, 2, 4, 8]) + cfg.define_knob("warp_col_tiles", [1, 2, 4, 8]) + cfg.define_knob("chunk", [1, 2, 4, 8]) + cfg.define_knob("offset", [0, 1, 2, 4, 8]) + cfg.define_knob("offsetCS", [0, 1, 2, 4, 8]) + cfg.define_knob("vec", [1, 2, 4, 8]) + + # Ensure that the default parameters are applicable when autotvm is not in use + if (P % 16 == 0 and out_dim % 16 == 0): + cfg.define_knob("wmma_m", [16, 8, 32]) + elif (P % 32 == 0 and out_dim % 8 == 0): + cfg.define_knob("wmma_m", [32, 16, 8]) + elif (P % 8 == 0 and out_dim % 32 == 0): + cfg.define_knob("wmma_m", [8, 16, 32]) + + warp_size = 32 + wmma_k = 16 + block_row_warps = cfg["block_row_warps"].val + block_col_warps = cfg["block_col_warps"].val + warp_row_tiles = cfg["warp_row_tiles"].val + warp_col_tiles = cfg["warp_col_tiles"].val + chunk = cfg["chunk"].val + offsetAB = cfg["offset"].val + offsetCS = cfg["offsetCS"].val + wmma_m = cfg["wmma_m"].val + vec = cfg["vec"].val + + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + + # Define the stride of intrin functions + AS_align = chunk * wmma_k + offsetAB + BS_align = warp_col_tiles * block_col_warps * wmma_n + offsetAB + CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS + AS_stride = [AS_align, 1] + BS_stride = [BS_align, 1] + AF_stride = [wmma_k, 1] + BF_stride = [wmma_n * warp_col_tiles, 1] + CF_stride = [warp_col_tiles * wmma_n, 1] + CS_stride = [CS_align, 1] + block_x = te.thread_axis('blockIdx.x') + block_y = te.thread_axis('blockIdx.y') + block_z = te.thread_axis('blockIdx.z') + thread_x = te.thread_axis('threadIdx.x') + thread_y = te.thread_axis('threadIdx.y') + thread_z = te.thread_axis('threadIdx.z') + + # Schedule for computation + block_factor_b = wmma_m * warp_row_tiles * block_row_warps + block_factor_o = wmma_n * warp_col_tiles * block_col_warps + alpha_1, alpha_2, b, o = C.op.axis + block_k = s[C].fuse(alpha_1, alpha_2) + block_i, bc = s[C].split(b, factor=block_factor_b) + block_j, oc = s[C].split(o, factor=block_factor_o) + s[C].reorder(block_k, block_i, block_j, bc, oc) + t = s[C].fuse(bc, oc) + t, vi = s[C].split(t, factor=vec) + t, tx = s[C].split(t, factor=warp_size) + t, ty = s[C].split(t, factor=block_row_warps) + t, tz = s[C].split(t, factor=block_col_warps) + s[C].bind(block_k, block_z) + s[C].bind(block_i, block_x) + s[C].bind(block_j, block_y) + s[C].bind(tz, thread_z) + s[C].bind(ty, thread_y) + s[C].bind(tx, thread_x) + s[C].vectorize(vi) + + # Schedule for wmma store + s[CS].compute_at(s[C], block_j) + _, _, bb, oo = CS.op.axis + s[CS].storage_align(bb, CS_align - 1, CS_align) + bb, bbi = s[CS].split(bb, factor=wmma_m) + oo, ooi = s[CS].split(oo, factor=wmma_n) + bb, bbii = s[CS].split(bb, factor=warp_row_tiles) + oo, ooii = s[CS].split(oo, factor=warp_col_tiles) + s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi) + + # Schedule for wmma computation + s[CF].compute_at(s[CS], oo) + _, _, warp_i, warp_j = CF.op.axis + warp_i, _ii = s[CF].split(warp_i, factor=wmma_m) + warp_j, _jj = s[CF].split(warp_j, factor=wmma_n) + k, = CF.op.reduce_axis + k, _k = s[CF].split(k, factor=wmma_k) + ko, ki = s[CF].split(k, factor=chunk) + s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k) + + # Schedule for wmma_matrix_a load + s[AF].compute_at(s[CF], ki) + _, _, b, i = AF.op.axis + b, b_ii = s[AF].split(b, factor=wmma_m) + i, i_jj = s[AF].split(i, factor=wmma_k) + s[AF].reorder(b, i, b_ii, i_jj) + + # Schedule for wmma_matrix_b load + s[BF].compute_at(s[CF], ki) + _, _, i, o = BF.op.axis + o, o_ii = s[BF].split(o, factor=wmma_n) + i, i_ii = s[BF].split(i, factor=wmma_k) + s[BF].reorder(i, o, i_ii, o_ii) + + # Schedule for A's(B's) shared memory load + def shared_shedule(stage, strides): + s[stage].compute_at(s[CF], ko) + _, _, xo, yo = stage.op.axis + s[stage].storage_align(xo, strides - 1, strides) + t = s[stage].fuse(xo, yo) + t, vi = s[stage].split(t, factor=vec) + t, tx = s[stage].split(t, factor=warp_size) + t, ty = s[stage].split(t, factor=block_row_warps) + _, tz = s[stage].split(t, factor=block_col_warps) + s[stage].bind(ty, thread_y) + s[stage].bind(tz, thread_z) + s[stage].bind(tx, thread_x) + s[stage].vectorize(vi) + + shared_shedule(AS, AS_align) + shared_shedule(BS, BS_align) + + shape = (wmma_m, wmma_n, wmma_k) + in_dtype = 'float16' + AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype) + BL_gemm = te.placeholder((wmma_k, wmma_n), name='BL_gemm', dtype=in_dtype) + k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm') + CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj: + te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) * + BL_gemm[k_gemm, jj].astype(out_dtype), + axis=k_gemm), name='CL_compute') + + # Lower the computation loops down to TensorCore hardware intrinsics + # by mapping the tensorcore to tensor intrinsics + s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A(AF_stride, AS_stride, shape, "row_major", + (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16')) + s[BF].tensorize(i_ii, intrin_wmma_load_matrix_W(BF_stride, BS_stride, shape, "row_major", + (wmma_k, wmma_n), (wmma_k, wmma_n), 'float16')) + s[CF].tensorize(_ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, + BF_stride, CF_stride, shape)) + s[CS].tensorize(bbi, intrin_wmma_store_matrix(CS_stride, CF_stride, shape, out_dtype, + (wmma_m, wmma_n), (wmma_m, wmma_n))) + + +def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack): + """Schedule for bgemm direct""" + b1, b2, y, x = s[bgemm].op.axis + rc = s[bgemm].op.reduce_axis[0] + alpha = get_const_int(b1.dom.extent) + + # Create tuning space + cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4, + filter=lambda x: x.size[-3:] == [1, 1, 1]) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rc", rc, num_outputs=2) + cfg.define_knob("offset_bgemm", [0, 1, 2, 4, 8]) + cfg.define_knob("vector_bgemm", [1, 2, 4, 8]) + offset_bgemm = cfg["offset_bgemm"].val + vector_bgemm = cfg["vector_bgemm"].val + + C = bgemm + A0, B0 = kernel_pack, data_pack + + # Designate the memory hierarchy + OL = s.cache_write(C, 'local') + AA = s.cache_read(A0, 'shared', [OL]) + BB = s.cache_read(B0, 'shared', [OL]) + + # Tile and bind spatial axes + b = s[bgemm].fuse(b1, b2) + bgemm_scope, b = s[bgemm].split(b, nparts=1) + bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b) + by, vy, ty, yi = cfg["tile_y"].apply(s, C, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x) + s[C].bind(bz, te.thread_axis("blockIdx.z")) + s[C].bind(by, te.thread_axis("blockIdx.y")) + s[C].bind(bx, te.thread_axis("blockIdx.x")) + s[C].bind(vz, te.thread_axis("vthread")) + s[C].bind(vy, te.thread_axis("vthread")) + s[C].bind(vx, te.thread_axis("vthread")) + s[C].bind(tz, te.thread_axis("threadIdx.z")) + s[C].bind(ty, te.thread_axis("threadIdx.y")) + s[C].bind(tx, te.thread_axis("threadIdx.x")) + s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi) + + # Tile reduction axes + s[OL].compute_at(s[C], tx) + b1, b2, y, x = s[OL].op.axis + b = s[OL].fuse(b1, b2) + rc, = s[OL].op.reduce_axis + rco, rci = cfg['tile_rc'].apply(s, OL, rc) + s[OL].reorder(rco, b, y, x, rci) + + s[AA].compute_at(s[OL], rco) + _, _, k, n = s[AA].op.axis + AA_align = offset_bgemm + cfg["tile_x"].size[1] * cfg["tile_x"].size[2] * cfg["tile_x"].size[3] + s[AA].storage_align(k, AA_align - 1, AA_align) + + s[BB].compute_at(s[OL], rco) + _, _, m, k = s[BB].op.axis + BB_align = offset_bgemm + cfg["tile_rc"].size[1] + s[BB].storage_align(m, BB_align - 1, BB_align) + + # Schedule for A and B shared memory load + for load in [AA, BB]: + fused = s[load].fuse(*list(s[load].op.axis)) + fused, ti = s[load].split(fused, factor=vector_bgemm) + fused, tx = s[load].split(fused, cfg["tile_x"].size[2]) + fused, ty = s[load].split(fused, cfg["tile_y"].size[2]) + fused, tz = s[load].split(fused, cfg["tile_b"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(ti) + + +def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore, pre_computed): + """Compute declaration for winograd""" + tile_size = _infer_tile_size(data, kernel) + N, H, W, CI = get_const_tuple(data.shape) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides + + if not pre_computed: # Kernel tensor is raw tensor, do strict check + if dilation_h != 1 or dilation_w != 1: + kernel = nn.dilate(kernel, (dilation_h, dilation_w, 1, 1)) + KH, KW, CI, CO = get_const_tuple(kernel.shape) + alpha = KW + tile_size - 1 + assert HSTR == 1 and WSTR == 1 and KH == KW + else: + # Kernel tensor is pre-transfomred. This op is created by conv2d_alter_op. + # Dilation is not supported + alpha, _, CI, CO = get_const_tuple(kernel.shape) + KH = KW = alpha + 1 - tile_size + assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 + + pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) + data_pad = nn.pad(data, (0, pt, pl, 0), (0, pb, pr, 0), name="data_pad") + + r = KW + m = tile_size + H = (H + pt + pb - KH) // HSTR + 1 + W = (W + pl + pr - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + + # Determine whether the shape is available with tensorcore + shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ + (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ + (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0) + + if shape_judge and use_tensorcore: + trans_type = "float16" + else: + trans_type = data.dtype + + # Compute transform matrix + A, _, _ = winograd_transform_matrices(m, r, out_dtype) + _, B, G = winograd_transform_matrices(m, r, data.dtype) + + # Transform kernel + if not pre_computed: + # Check if we are currently tuning, if so we want to avoid counting + # prepacking in time costs. Just use a placeholder with the packed shape instead. + if autotvm.GLOBAL_SCOPE.in_tuning: + kernel_pack = te.placeholder((alpha, alpha, CI, CO), + dtype=kernel.dtype, + name='kernel_pack') + else: + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co: + te.sum((kernel[r_kh][r_kw][ci][co]) * + G[eps][r_kh] * G[nu][r_kw], + axis=[r_kh, r_kw]), name='kernel_pack') + else: + kernel_pack = kernel + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + + # Pack input tile + input_tile = te.compute((P, CI, alpha, alpha), lambda p, c, eps, nu: + data_pad[idxdiv(p, (nH * nW)), + idxmod(idxdiv(p, nW), nH) * m + eps, + idxmod(p, nW) * m + nu, + c], name='d') + + # Transform data + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: + te.sum(input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu], + axis=[r_a, r_b]), name='data_pack') + + # Convert data type of input feature maps and weights for tensorcore + Transdata = te.compute( + data_pack.shape, lambda eps, nu, p, ci: data_pack[eps, nu, p, ci].astype(trans_type)) + TransFilter = te.compute( + kernel_pack.shape, lambda eps, nu, ci, co: kernel_pack[eps, nu, ci, co].astype(trans_type)) + + # Do batch gemm + ci = te.reduce_axis((0, CI), name='ci') + bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co: + te.sum((Transdata[eps][nu][p][ci]).astype(out_dtype) * + (TransFilter[eps][nu][ci][co]).astype(out_dtype), + axis=[ci]), name='bgemm') + + # Inverse transform + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_a') + inverse = te.compute((P, CO, m, m), lambda p, co, vh, vw: + te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], + axis=[r_a, r_b]), name='inverse') + + # Output + output = te.compute((N, H, W, CO), lambda n, h, w, co: + inverse[n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), + co, + idxmod(h, m), + idxmod(w, m)], + name='output', tag='conv2d_nhwc_winograd') + cfg.add_flop(2 * N * CO * H * W * CI * KH * KW) + return output + + +def data_weight_transform(s, data_trans, input_tile, thread_num_trans, offset_trans, trans_tag): + """Schedule for data or kernel transform""" + kernel_align = thread_num_trans + offset_trans + indata_s = s.cache_read(input_tile, 'shared', [data_trans]) + data_l = s.cache_write(data_trans, 'local') + # Schedule for data or kernel transform + eps, nu, p, c = s[data_trans].op.axis + + block_x, thread_x = s[data_trans].split(c, thread_num_trans) + block_x = s[data_trans].fuse(p, block_x) + s[data_trans].reorder(block_x, thread_x, eps, nu) + s[data_trans].bind(thread_x, te.thread_axis("threadIdx.x")) + s[data_trans].bind(block_x, te.thread_axis("blockIdx.x")) + + s[data_l].compute_at(s[data_trans], thread_x) + eps_l, nu_l, p_l, c_l = s[data_l].op.axis + r_a, r_b = s[data_l].op.reduce_axis + block_x_l, thread_x_l = s[data_l].split(c_l, thread_num_trans) + block_x_l = s[data_l].fuse(p_l, block_x_l) + + s[data_l].reorder(block_x_l, thread_x_l, eps_l, nu_l, r_a, r_b) + + for axis in [eps_l, nu_l, r_a, r_b]: + s[data_l].unroll(axis) + + # Schedule for share memory load + s[indata_s].compute_at(s[data_l], block_x_l) + if trans_tag == "data": + p_is, c_is, eps_is, nu_is = s[indata_s].op.axis + data_align = get_const_int(eps_is.dom.extent) * \ + get_const_int(nu_is.dom.extent) + offset_trans + s[indata_s].storage_align(c_is, data_align - 1, data_align) + block_x_is, thread_x_is = s[indata_s].split(c_is, thread_num_trans) + s[indata_s].bind(thread_x_is, te.thread_axis("threadIdx.x")) + else: + eps_is, nu_is, ci_is, co_is = s[indata_s].op.axis + s[indata_s].storage_align(nu_is, kernel_align - 1, kernel_align) + block_x_is, thread_x_is = s[indata_s].split(co_is, thread_num_trans) + s[indata_s].reorder(ci_is, block_x_is, eps_is, nu_is, thread_x_is) + s[indata_s].bind(thread_x_is, te.thread_axis("threadIdx.x")) + + +def schedule_nhwc_winograd_cuda(cfg, s, output, use_tensorcore, pre_computed): + """Schedule winograd template""" + # Get stages + inverse = s[output].op.input_tensors[0] + bgemm, A = s[inverse].op.input_tensors + Transdata, TransFilter = s[bgemm].op.input_tensors + data_pack = s[Transdata].op.input_tensors[0] + kernel_pack = s[TransFilter].op.input_tensors[0] + s[Transdata].compute_inline() + s[TransFilter].compute_inline() + + input_tile, B = s[data_pack].op.input_tensors + pad_data = s[input_tile].op.input_tensors[0] + + # Define the stride of intrin functions + cfg.define_knob("thread_num_inverse", [1, 32, 64, 128, 256]) + cfg.define_knob("thread_num_data", [1, 32, 64, 128, 256]) + cfg.define_knob("thread_num_kernel", [1, 32, 64, 128, 256]) + cfg.define_knob("offset_inverse", [0, 2, 4]) + cfg.define_knob("offset_data", [0, 1, 2, 4]) + cfg.define_knob("offset_kernel", [0, 1, 2, 4]) + cfg.define_knob("inverse_in_vector", [1, 2, 4]) + + thread_num_data = cfg["thread_num_data"].val + thread_num_kernel = cfg["thread_num_kernel"].val + thread_num_inverse = cfg["thread_num_inverse"].val + offset_data = cfg["offset_data"].val + offset_kernel = cfg["offset_kernel"].val + offset_inverse = cfg["offset_inverse"].val + inverse_in_vector = cfg["inverse_in_vector"].val + + # Data transform + s[B].compute_inline() + data_weight_transform(s, data_pack, input_tile, thread_num_data, offset_data, trans_tag="data") + s[input_tile].compute_inline() + s[pad_data].compute_inline() + + # Kernel transform + if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning: + kernel, G = s[kernel_pack].op.input_tensors + s[G].compute_inline() + data_weight_transform(s, kernel_pack, kernel, thread_num_kernel, + offset_kernel, trans_tag="kernel") + else: + kernel = kernel_pack + + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + b1, b2, y, x = s[bgemm].op.axis + alpha = get_const_int(b1.dom.extent) + _, _, P, CI = get_const_tuple(Transdata.shape) + _, _, _, CO = get_const_tuple(TransFilter.shape) + + # Determine whether the shape is available with tensorcore + shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ + (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ + (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0) + + if shape_judge and use_tensorcore: + schedule_bgemm_tensorcore(cfg, s, bgemm, Transdata, TransFilter) + else: + schedule_bgemm_direct(cfg, s, bgemm, Transdata, TransFilter) + + # Schedule inverse, output and fusion + if output.op in s.outputs: + OL = None + else: + OL = output + s[OL].set_scope('local') + output = s.outputs[0] + + s[A].compute_inline() + inverse_s = s.cache_read(bgemm, 'shared', [inverse]) + + m = alpha - 3 + 1 + offset_inverse_in = offset_inverse + vector_width_inverse_in = inverse_in_vector + + # Schedule for output + n, h, w, co = s[output].op.axis + ho, wo, hi, wi = s[output].tile(h, w, m, m) + s[output].reorder(n, ho, wo, co, hi, wi) + fused = s[output].fuse(n, ho, wo) + + block_x_s, thread_x_s = s[output].split(co, thread_num_inverse) + block_x_s = s[output].fuse(fused, block_x_s) + s[output].reorder(block_x_s, thread_x_s, hi, wi) + + if OL is not None: + s[OL].compute_inline() + + # Schedule for inverse + s[inverse].compute_at(s[output], thread_x_s) + p_inv, co_inv, eps_inv, nu_inv = s[inverse].op.axis + block_x_inv, thread_x_inv = s[inverse].split(co_inv, thread_num_inverse) + r_a, r_b = s[inverse].op.reduce_axis + for axis in [eps_inv, nu_inv, r_a, r_b]: + s[inverse].unroll(axis) + + # Schedule for share memory load + s[inverse_s].compute_at(s[output], block_x_s) + eps_inv_s, nu_inv_s, p_inv_s, co_inv_s = s[inverse_s].op.axis + inverse_in_align = offset_inverse_in + thread_num_inverse + s[inverse_s].storage_align(p_inv_s, inverse_in_align - 1, inverse_in_align) + block_x_inv_s, thread_x_inv_s = s[inverse_s].split(co_inv_s, thread_num_inverse) + block_x_inv_s = s[inverse_s].fuse(p_inv_s, block_x_inv_s) + s[inverse_s].reorder(block_x_inv_s, eps_inv_s, nu_inv_s, thread_x_inv_s) + t = s[inverse_s].fuse(eps_inv_s, nu_inv_s, thread_x_inv_s) + t, ti = s[inverse_s].split(t, factor=vector_width_inverse_in) + t, tx = s[inverse_s].split(t, factor=thread_num_inverse) + s[inverse_s].bind(tx, te.thread_axis("threadIdx.x")) + s[inverse_s].vectorize(ti) + + s[output].bind(thread_x_s, te.thread_axis("threadIdx.x")) + s[output].bind(block_x_s, te.thread_axis("blockIdx.x")) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_direct.cuda") +def conv2d_nhwc_winograd_direct(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=False, pre_computed=False) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_direct.cuda") +def schedule_conv2d_nhwc_winograd_direct(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=False, + pre_computed=False) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_tensorcore.cuda") +def conv2d_nhwc_winograd_tensorcore(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=True, pre_computed=False) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_tensorcore.cuda") +def schedule_conv2d_nhwc_winograd_tensorcore(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=True, + pre_computed=False) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_direct_without_weight_transform.cuda") +def conv2d_nhwc_winograd_direct_without_weight_transform(cfg, data, kernel, strides, + padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=False, pre_computed=True) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_direct_without_weight_transform.cuda") +def schedule_conv2d_nhwc_winograd_direct_without_weight_transform(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=False, + pre_computed=True) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") +def conv2d_nhwc_winograd_tensorcore_without_weight_transform(cfg, data, kernel, strides, + padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=True, pre_computed=True) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") +def schedule_conv2d_nhwc_winograd_tensorcore_without_weight_transform(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=True, + pre_computed=True) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/tests/python/test_topi_conv2d_nhwc_winograd.py b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py new file mode 100644 index 000000000000..a7e55320d6dd --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-arguments +# pylint: disable=bad-whitespace +"""Example code to do convolution.""" + +import numpy as np +import tvm +import topi +import topi.testing +from tvm import te +from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import nvcc +from topi.nn.util import get_pad_tuple +from topi.util import get_const_tuple + + +_conv2d_nhwc_winograd_tensorcore = { + "cuda": (topi.cuda.conv2d_nhwc_winograd_tensorcore, + topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore) +} + +_conv2d_nhwc_winograd_direct = { + "cuda": (topi.cuda.conv2d_nhwc_winograd_direct, + topi.cuda.schedule_conv2d_nhwc_winograd_direct) +} + + +def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation=1, add_bias=False, add_relu=False, + devices='cuda', bgemm="direct"): + """Test the conv2d with winograd for nhwc layout""" + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + + in_height = in_width = in_size + + A = te.placeholder((batch, in_height, in_width, in_channel), name='A') + W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W') + bias = te.placeholder((1, 1, 1, num_filter), name='bias') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + if bgemm == "direct": + fcompute, fschedule = topi.testing.dispatch(device, + _conv2d_nhwc_winograd_direct) + elif bgemm == "tensorcore": + fcompute, fschedule = topi.testing.dispatch(device, + _conv2d_nhwc_winograd_tensorcore) + C = fcompute(A, W, stride, padding, dilation, 'float32') + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = fschedule([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, c) + + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=2e-3) + + check_device(devices) + + +def test_conv2d_nhwc_winograd_direct(): + """Test the conv2d with winograd for nhwc layout""" + # resnet 18 workloads + print("test_winograd_direct...") + verify_conv2d_nhwc(1, 64, 56, 64, 3, 1, 1, bgemm="direct") + verify_conv2d_nhwc(1, 128, 28, 128, 3, 1, 1) + verify_conv2d_nhwc(1, 256, 14, 256, 3, 1, 1) + verify_conv2d_nhwc(1, 512, 7, 512, 3, 1, 1) + verify_conv2d_nhwc(1, 48, 35, 64, 5, 1, 2) + + # weird workloads + verify_conv2d_nhwc(1, 1, 1, 1, 3, 1, 1) + verify_conv2d_nhwc(3, 3, 3, 3, 3, 1, 1) + verify_conv2d_nhwc(2, 13, 71, 59, 3, 1, 1) + + # Asymmetric padding + verify_conv2d_nhwc(1, 512, 7, 512, 3, 1, "SAME") + verify_conv2d_nhwc(2, 48, 56, 48, 3, 1, (1, 1), add_relu=True) + verify_conv2d_nhwc(2, 48, 56, 48, 3, 1, "SAME", add_relu=True, add_bias=True) + verify_conv2d_nhwc(1, 48, 35, 48, 5, 1, "VALID") + +def test_conv2d_nhwc_winograd_tensorcore(): + """Test the conv2d with winograd for nhwc layout""" + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + return + verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1, bgemm="tensorcore") + verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1, bgemm="tensorcore") + verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1, bgemm="tensorcore") + + verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, (1, 1), add_relu=True, bgemm="tensorcore") + verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, "SAME", add_relu=True, bgemm="tensorcore") + + +if __name__ == "__main__": + test_conv2d_nhwc_winograd_direct() + test_conv2d_nhwc_winograd_tensorcore() diff --git a/web/.eslintignore b/web/.eslintignore new file mode 100644 index 000000000000..1521c8b7652b --- /dev/null +++ b/web/.eslintignore @@ -0,0 +1 @@ +dist diff --git a/web/.gitignore b/web/.gitignore new file mode 100644 index 000000000000..a3135cf24b9d --- /dev/null +++ b/web/.gitignore @@ -0,0 +1,6 @@ +.vscode +*~ +out +node_modules +package-lock.json +build diff --git a/web/.jsdoc_conf.json b/web/.jsdoc_conf.json deleted file mode 100644 index 33783b3bbb21..000000000000 --- a/web/.jsdoc_conf.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "templates": { - "default": { - "includeDate": false - } - } -} diff --git a/web/Makefile b/web/Makefile new file mode 100644 index 000000000000..be7fa193c04c --- /dev/null +++ b/web/Makefile @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +TVM_ROOT=$(shell cd ..; pwd) + +INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ + -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include + +.PHONY: clean all + +all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js + +EMCC = emcc + +EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++14 -Wno-ignored-attributes \ + -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1 -s ERROR_ON_UNDEFINED_SYMBOLS=0 + +EMCC_LDFLAGS = --pre-js emcc/preload.js + +dist/wasm/%.bc: emcc/%.cc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -c -MM -MT dist/wasm/$*.bc $< >dist/wasm/$*.d + $(EMCC) $(EMCC_CFLAGS) -c -o dist/wasm/$*.bc $< + + +dist/wasm/tvmjs_runtime.wasm: dist/wasm/wasm_runtime.bc dist/wasm/tvmjs_support.bc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -o dist/wasm/tvmjs_runtime.js $+ $(EMCC_LDFLAGS) + + +dist/wasm/tvmjs_runtime.wasi.js: dist/wasm/tvmjs_runtime.wasm emcc/decorate_as_wasi.py + python3 emcc/decorate_as_wasi.py dist/wasm/tvmjs_runtime.js $@ + +clean: + @rm -rf dist/wasm + +-include dist/wasm/*.d diff --git a/web/README.md b/web/README.md index 5dfd6917934b..66a64a3d3d37 100644 --- a/web/README.md +++ b/web/README.md @@ -15,163 +15,70 @@ -# TVM WebAssembly and Javascript Backend +# TVM WebAssembly Runtime -This folder contains TVM WebAssembly and Javascript backend through Emscripten. +This folder contains TVM WebAssembly Runtime. ## Installation -While the LLVM main branch support webassembly as a target. We still need a good runtime with libc and other -system library support. Emscripten toolchain offers that nicely. The general idea is to build TVM against -the fastcomp LLVM backend in the Emscripten project and allow us to generate ```asmjs-unknown-emscripten``` -as a backend target. + +The LLVM main branch support webassembly as a target, we can directly +build TVM with LLVM mainline to generate wasm modules. +Note that, however, we still need emscripten to compile the runtime and provide system library support. + +Note that so far we requires everything to be in the source and setup PYTHONPATH(instead of use setup.py install). ### Setup Emscripten -Checkout [Emscripten Portable SDK Downloads](https://kripken.github.io/emscripten-site/docs/getting_started/downloads.html) -to download emsdk-portable and unzip it on a local folder. Follow the installation guide from emscripten document. -```bash -./emsdk update -./emsdk install latest -./emsdk activate latest -``` +We use emscripten to compile our runtime wasm library as well as a WASI variant that we can deploy +to the browser environment. -Because we need to compile against the LLVM backend of emscripten, we will need the source and llvm library. -Which can be installed via following command. +Follow [Emscripten](https://emscripten.org/) to download emsdk and install emcc on your local environment. -```bash -./emsdk install clang-incoming-64bit -./emsdk activate clang-incoming-64bit -``` +### Build TVM Wasm Runtime -### Setup Environment Variable +After the emcc is setup correctly. We can build tvm's wasm runtime by typing `make` in the web folder. -In normal setting, we can setup the necessary environment variable with the following command. ```bash -source /path-to-emsdk-portable/emsdk_env.sh +make ``` -However, this will put emscripten's clang and llvm path ahead of the current system path. -What you can do is to set the path manually, by putting emscripten's path after the PATH like the following ones. -You can get the detailed path by type ```./emsdk activate``` -```bash -export PATH=${PATH}:/emsdk-related-path-here +This command will create the follow files: +- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.contrib.emcc` will link into. +- `dist/wasm/tvmjs_runtime.wasm` a standalone wasm runtime for testing purposes. +- `dist/wasm/tvmjs_runtime.wasi.js` a WASI compatible library generated by emscripten that can be fed into runtime. -``` -### Build TVM with Fastcomp LLVM +### Build TVM Wasm JS Frontend -To build TVM with Emscripten's Fastcomp LLVM, we can modify the LLVM_CONFIG in ```config.mk``` -to point to fastcomp's llvm-config and build TVM normally. +Type the following command in the web folder. ```bash -LLVM_CONFIG = /path/to/emsdk-portable/clang/fastcomp/build_incoming_64/bin/llvm-config +npm run bundle ``` -### Build TVM Web Runtime +This command will create the tvmjs library that we can use to interface with the wasm runtime. -The above command gives us the TVM compiling environment. Now we need to build runtime, -to do so, make sure we set the environment correctly as in previous section and type -```bash -make web -``` +## Use TVM to Generate Wasm Library and Run it -This will create ```build/libtvm_web_runtime.bc``` and ```build/libtvm_web_runtime.js```. - -## Use TVM to Generate Javascript Library - -The general idea is to use TVM as normally and set target to be ```llvm -target=asmjs-unknown-emscripten -system-lib```. - -The following code snippet from [tests/web/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/tests/web/prepare_test_libs.py) demonstrate -the compilation process. - -```python -import tvm -from tvm import te -from tvm.contrib import emscripten -import os -def prepare_test_libs(base_path): - target = "llvm -target=asmjs-unknown-emscripten -system-lib" - if not tvm.runtime.enabled(target): - raise RuntimeError("Target %s is not enbaled" % target) - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') - s = te.create_schedule(B.op) - fadd1 = tvm.build(s, [A, B], target, name="add_one") - obj_path = os.path.join(base_path, "test_add_one.bc") - fadd1.save(obj_path) - emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path) - -if __name__ == "__main__": - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../build")) -``` +Check code snippet in -In this workflow, we use TVM to generate a ```.bc``` file and statically link -that with the ```build/libtvm_web_runtime.bc```(emscripten.create_js will help you do that). -The result js library is a library that contains both TVM runtime and the compiled function. - - -## Run the Generated Library - -The following code snippet from [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/tests/web/test_module_load.js) demonstrate -how to run the compiled library. - -```js -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/test_module.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -// Load system library, the compiled function is registered in sysLib. -var sysLib = tvm.systemLib(); - -function randomArray(length, max) { - return Array.apply(null, Array(length)).map(function() { - return Math.random() * max; - }); -} - -function testAddOne() { - // grab pre-loaded function - var faddOne = sysLib.getFunction("add_one"); - var assert = require('assert'); - tvm.assert(tvm.isPackedFunc(faddOne)); - var n = 124; - var A = tvm.empty(n).copyFrom(randomArray(n, 1)); - var B = tvm.empty(n); - // call the function. - faddOne(A, B); - AA = A.asArray(); // retrieve values in js array - BB = B.asArray(); // retrieve values in js array - // verify - for (var i = 0; i < BB.length; ++i) { - assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); - } - faddOne.release(); -} - -testAddOne(); -sysLib.release(); -console.log("Finish verifying test_module_load"); -``` +- [tests/python/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/web/tests/pythob/prepare_test_libs.py) + shows how to create a wasm library that links with tvm runtime. + - Note that all wasm libraries have to created using the `--system-lib` option + - emcc.create_wasm will automatically link the runtime library `dist/wasm/libtvm_runtime.bc` +- [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/web/tests/node/test_module_load.js) demonstrate + how to run the generated library through tvmjs API. -Current example supports static linking, which is the preferred way to get more efficiency -in javascript backend. -## Proxy based RPC +## Run Wasm Remotely through WebSocket RPC. -We can now use javascript end to start an RPC server and connect to it from python side, +We can now use js side to start an RPC server and connect to it from python side, making the testing flow easier. -The following is an example to reproduce this. This requires everything to be in the git source and setup PYTHONPATH(instead of use setup.py install) -- run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. -- Open broswer, goto the server webpage click Connect to proxy. - - Alternatively run "node web/example_rpc_node.js" -- run "python tests/web/websock_rpc_test.py" to run the rpc client. - -The general idea is to use Emscripten's dynamic linking to dynamically load modules. +The following is an example to reproduce this. +- run `python -m tvm.exec.rpc_proxy --example-rpc=1` to start proxy. +- Start the WebSocket RPC + - Browswer version: open https://localhost:8888, click connect to proxy + - NodeJS version: `npm run rpc` +- run `python tests/node/websock_rpc_test.py` to run the rpc client. diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html new file mode 100644 index 000000000000..22907f1561d1 --- /dev/null +++ b/web/apps/browser/rpc_server.html @@ -0,0 +1,79 @@ + + + + + + + + + + + + + + + + + + + + TVM RPC Test Page + + + + + +

TVM WebSocket RPC Server

+ To use this page +
    +
  • Run "make" and "npm run bundle" to create the libraries.
  • +
  • + run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. +
  • +
  • Click Connect to proxy.
  • +
  • run "python tests/python/websock_rpc_test.py" to run the rpc client.
  • +
+ +

Options

+ Proxy URL
+ RPC Server Key
+ + +
+ + + diff --git a/web/apps/node/example.js b/web/apps/node/example.js new file mode 100644 index 000000000000..f81a9c903e5d --- /dev/null +++ b/web/apps/node/example.js @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Example code to start the runtime. + */ +const path = require("path"); +const fs = require("fs"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); +// Here we pass the javascript module generated by emscripten as the +// LibraryProvider to provide WASI related libraries. +// the async version of the API. +tvmjs.instantiate(wasmSource, new EmccWASI()) +.then((tvm) => { + // List all the global functions from the runtime. + console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames()); +}); + diff --git a/web/apps/node/wasi_example.js b/web/apps/node/wasi_example.js new file mode 100644 index 000000000000..95ec2e0b1d07 --- /dev/null +++ b/web/apps/node/wasi_example.js @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Example code to start the runtime. + */ +const { WASI } = require('wasi'); +const path = require("path"); +const fs = require("fs"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +const wasi = new WASI({ args: process.argv, env: process.env }); +// Here we pass the javascript module generated by emscripten as the +// LibraryProvider to provide WASI related libraries. +const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), wasi); + +// List all the global functions from the runtime. +console.log("Runtime using WASI\n", tvm.listGlobalFuncNames()); diff --git a/web/example_rpc_node.js b/web/apps/node/wasi_rpc_server.js similarity index 60% rename from web/example_rpc_node.js rename to web/apps/node/wasi_rpc_server.js index 45f917a3234b..eb4c6ed52be9 100644 --- a/web/example_rpc_node.js +++ b/web/apps/node/wasi_rpc_server.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,17 +17,20 @@ * under the License. */ -// Javascript RPC server example -// Start and connect to websocket proxy. +/** + * Example code to start the RPC server on nodejs using WASI + */ +const { WASI } = require("wasi"); +const tvmjs = require("../../dist"); + +// Get import returns a fresh library in each call. +const getImports = () => { + return new WASI({ + args: process.argv, + env: process.env + }); +}; -// Load Emscripten Module, need to change path to root/lib -const path = require("path"); -process.chdir(path.join(__dirname, "../lib")); -var Module = require("../lib/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const proxyUrl = "ws://localhost:8888/ws"; -var websock_proxy = "ws://localhost:9190/ws"; -var num_sess = 100; -tvm.startRPCServer(websock_proxy, "js", num_sess) +new tvmjs.RPCServer(proxyUrl, "wasm", getImports, console.log); diff --git a/tests/webgl/test_local_multi_stage.py b/web/emcc/decorate_as_wasi.py similarity index 50% rename from tests/webgl/test_local_multi_stage.py rename to web/emcc/decorate_as_wasi.py index 54a554b74ed9..741e33bb22ea 100644 --- a/tests/webgl/test_local_multi_stage.py +++ b/web/emcc/decorate_as_wasi.py @@ -14,34 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te -import numpy as np +"""Decorate emcc generated js to a WASI compatible API.""" -def test_local_multi_stage(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return +import sys - n = te.var("n") - A = te.placeholder((n,), name='A', dtype="int32") - B = te.compute((n,), lambda i: A[i] + 1, name="B") - C = te.compute((n,), lambda i: B[i] * 2, name="C") +template_head = """ +function EmccWASI() { +""" - s = te.create_schedule(C.op) - s[B].opengl() - s[C].opengl() +template_tail = """ + this.Module = Module; + this.start = Module.wasmLibraryProvider.start; + this.imports = Module.wasmLibraryProvider.imports; + this.wasiImport = this.imports["wasi_snapshot_preview1"]; +} - f = tvm.build(s, [A, C], "opengl", name="multi_stage") - - ctx = tvm.opengl(0) - n = 10 - a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) - c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx) - f(a, c) - - tvm.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2) +if (typeof module !== "undefined" && module.exports) { + module.exports = EmccWASI; +} +""" if __name__ == "__main__": - test_local_multi_stage() + if len(sys.argv) != 3: + print("Usage ") + result = template_head + open(sys.argv[1]).read() + template_tail + with open(sys.argv[2], "w") as fo: + fo.write(result) diff --git a/web/emcc/preload.js b/web/emcc/preload.js new file mode 100644 index 000000000000..882280f9cac0 --- /dev/null +++ b/web/emcc/preload.js @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-unused-vars */ +/** + * JS config used by --pre-js in emcc. + * Wrap module as a LibraryProvider. + */ + +var __wasmLib = {}; + +function __wasmLibInstantiateWasm(imports, successCallback) { + __wasmLib.imports = imports; + __wasmLib.successCallback = successCallback; +} + +function __wasmLibStart(wasmInstance) { + __wasmLib.successCallback(wasmInstance); +} + +__wasmLib.start = __wasmLibStart; + +var Module = { + "instantiateWasm": __wasmLibInstantiateWasm, + "wasmLibraryProvider": __wasmLib +}; diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc new file mode 100644 index 000000000000..97099e75f16f --- /dev/null +++ b/web/emcc/tvmjs_support.cc @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvmjs_support.cc + * \brief Support functions to be linked with wasm_runtime to provide + * PackedFunc callbacks in tvmjs. + * We do not need to link this file in standalone wasm. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + + +#include +#include +#include +#include +#include + +extern "C" { +// --- Additional C API for the Wasm runtime --- +/*! + * \brief Allocate space aligned to 64 bit. + * \param size The size of the space. + * \return The allocated space. + */ +TVM_DLL void* TVMWasmAllocSpace(int size); + +/*! + * \brief Free the space allocated by TVMWasmAllocSpace. + * \param data The data pointer. + */ +TVM_DLL void TVMWasmFreeSpace(void* data); + +/*! + * \brief Create PackedFunc from a resource handle. + * \param resource_handle The handle to the resource. + * \param out The output PackedFunc. + * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer +3A * \return 0 if success. + */ +TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, + TVMFunctionHandle *out); + +// --- APIs to be implemented by the frontend. --- +/*! + * \brief Wasm frontend packed function caller. + * + * \param args The arguments + * \param type_codes The type codes of the arguments + * \param num_args Number of arguments. + * \param ret The return value handle. + * \param resource_handle The handle additional resouce handle from fron-end. + * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. + */ +extern int TVMWasmPackedCFunc(TVMValue* args, + int* type_codes, + int num_args, + TVMRetValueHandle ret, + void* resource_handle); + +/*! + * \brief Wasm frontend resource finalizer. + * \param resource_handle The pointer to the external resource. + */ +extern void TVMWasmPackedCFuncFinalizer(void* resource_handle); +} // extern "C" + + +void* TVMWasmAllocSpace(int size) { + int num_count = (size + 7) / 8; + return new int64_t[num_count]; +} + +void TVMWasmFreeSpace(void* arr) { + delete[] static_cast(arr); +} + +int TVMWasmFuncCreateFromCFunc(void* resource_handle, + TVMFunctionHandle *out) { + return TVMFuncCreateFromCFunc( + TVMWasmPackedCFunc, resource_handle, + TVMWasmPackedCFuncFinalizer, out); +} + + +namespace tvm { +namespace runtime { + +// chrono in the WASI does not provide very accurate time support +// and also have problems in the i64 support in browser. +// We redirect the timer to a JS side time using performance.now +PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, + TVMContext ctx, + int number, + int repeat, + int min_repeat_ms) { + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms]( + TVMArgs args, TVMRetValue *rv) { + + TVMRetValue temp; + auto finvoke = [&](int n) { + // start timing + for (int i = 0; i < n; ++i) { + pf.CallPacked(args, &temp); + } + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + }; + + auto* get_timer = runtime::Registry::Get("wasm.GetTimer"); + CHECK(get_timer != nullptr) << "Cannot find wasm.GetTimer in the global function"; + TypedPackedFunc timer_ms = (*get_timer)( + TypedPackedFunc(finvoke)); + + std::ostringstream os; + finvoke(1); + + int setup_number = number; + + for (int i = 0; i < repeat; ++i) { + double duration_ms = 0.0; + + do { + if (duration_ms > 0.0) { + setup_number = static_cast( + std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random + } + duration_ms = timer_ms(setup_number); + } while (duration_ms < min_repeat_ms); + + double speed = duration_ms / setup_number / 1000; + os.write(reinterpret_cast(&speed), sizeof(speed)); + } + + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + return PackedFunc(ftimer); +} + +TVM_REGISTER_GLOBAL("wasm.RPCTimeEvaluator") +.set_body_typed([](Optional opt_mod, + std::string name, + int device_type, + int device_id, + int number, + int repeat, + int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + + if (opt_mod.defined()) { + Module m = opt_mod.value(); + std::string tkey = m->type_key(); + return WrapWasmTimeEvaluator( + m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapWasmTimeEvaluator( + *pf, ctx, number, repeat, min_repeat_ms); + } +}); + +} // namespace runtime +} // namespace tvm diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc new file mode 100644 index 000000000000..6ff652cf408a --- /dev/null +++ b/web/emcc/wasm_runtime.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file wasm_runtime.cc + * \brief TVM wasm runtime library pack. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + +#include +#include + +#include "src/runtime/c_runtime_api.cc" +#include "src/runtime/cpu_device_api.cc" +#include "src/runtime/workspace_pool.cc" +#include "src/runtime/library_module.cc" +#include "src/runtime/system_library.cc" + +#include "src/runtime/module.cc" +#include "src/runtime/ndarray.cc" +#include "src/runtime/object.cc" +#include "src/runtime/registry.cc" +#include "src/runtime/file_util.cc" +#include "src/runtime/graph/graph_runtime.cc" +#include "src/runtime/rpc/rpc_session.cc" +#include "src/runtime/rpc/rpc_endpoint.cc" +#include "src/runtime/rpc/rpc_event_impl.cc" +#include "src/runtime/rpc/rpc_channel.cc" +#include "src/runtime/rpc/rpc_local_session.cc" +#include "src/runtime/rpc/rpc_module.cc" + + +// --- Implementations of backend and wasm runtime API. --- + +int TVMBackendParallelLaunch(FTVMParallelLambda flambda, + void* cdata, + int num_task) { + TVMParallelGroupEnv env; + env.num_task = 1; + flambda(0, &env, cdata); + return 0; +} + +int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { + return 0; +} + +// --- Environment PackedFuncs for testing --- +namespace tvm { +namespace runtime { + +TVM_REGISTER_GLOBAL("testing.echo") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = args[0]; +}); + +TVM_REGISTER_GLOBAL("testing.add_one") +.set_body_typed([](int x) { + return x + 1; +}); + +TVM_REGISTER_GLOBAL("testing.wrap_callback") +.set_body([](TVMArgs args, TVMRetValue *ret) { + PackedFunc pf = args[0]; + *ret = runtime::TypedPackedFunc([pf](){ + pf(); + }); + }); +} // namespace runtime +} // namespace tvm diff --git a/web/example_rpc.html b/web/example_rpc.html deleted file mode 100644 index ae2b1dd9c44b..000000000000 --- a/web/example_rpc.html +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - - - - - - - - - - - - - - TVM RPC Test Page - - - - -

TVM Test Page

- To use this page, the easiest way is to do -
    -
  • run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. -
  • Click Connect to proxy. -
  • run "python tests/web/websock_rpc_test.py" to run the rpc client. -
-

Options

- Proxy URL
- RPC Server Key
- - -
- - - - diff --git a/web/package.json b/web/package.json new file mode 100644 index 000000000000..76aa111e2acf --- /dev/null +++ b/web/package.json @@ -0,0 +1,29 @@ +{ + "name": "tvmjs", + "displayName": "TVM Wasm JS runtime", + "license": "Apache-2.0", + "version": "0.7.0", + "scripts": { + "build": "tsc -b", + "watch": "tsc -b -w", + "lint": "eslint -c .eslintrc.json .", + "bundle": "npm run build && rollup -c rollup.config.js", + "example": "npm run bundle && node apps/node/example.js", + "example:wasi": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_example.js", + "rpc": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_rpc_server.js" + }, + "devDependencies": { + "typescript": "^3.8.3", + "@types/node": "^12.12.37", + "eslint": "^6.8.0", + "@typescript-eslint/eslint-plugin": "^2.29.0", + "@typescript-eslint/parser": "^2.29.0", + "typedoc": "^0.17.6", + "rollup": "^2.7.6", + "ws": "^7.2.5", + "@rollup/plugin-commonjs": "^11.1.0", + "@rollup/plugin-node-resolve": "^7.1.3", + "rollup-plugin-typescript2": "^0.27.0" + }, + "dependencies": {} +} diff --git a/web/.eslintrc.js b/web/rollup.config.js similarity index 69% rename from web/.eslintrc.js rename to web/rollup.config.js index 2e82ba50e3c4..0046e4434076 100644 --- a/web/.eslintrc.js +++ b/web/rollup.config.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,29 +17,18 @@ * under the License. */ -module.exports = { - "env": { - "browser": true, - "node": true, - "es6": true +import commonjs from '@rollup/plugin-commonjs'; +import resolve from '@rollup/plugin-node-resolve'; + +export default { + input: 'dist/index.js', + output: { + file: 'dist/tvmjs.bundle.js', + format: 'umd', + name: 'tvmjs', + exports: 'named', + globals: {'ws': 'ws'} }, - "extends": "eslint:recommended", - "rules": { - "indent": [ - "error", - 2 - ], - "linebreak-style": [ - "error", - "unix" - ], - "quotes": [ - "error", - "double" - ], - "semi": [ - "error", - "always" - ] - } + plugins: [commonjs(), resolve()], + external: ['ws'] }; diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts new file mode 100644 index 000000000000..f533b4e491a6 --- /dev/null +++ b/web/src/ctypes.ts @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Types for C API. + */ + +/** A pointer to points to the raw address space. */ +export type Pointer = number; + +/** A pointer offset, need to add a base address to get a valid ptr. */ +export type PtrOffset = number; + +// -- TVM runtime C API -- +/** + * const char *TVMGetLastError(); + */ +export type FTVMGetLastError = () => Pointer; + +/** + * int TVMModGetFunction(TVMModuleHandle mod, + * const char* func_name, + * int query_imports, + * TVMFunctionHandle *out); + */ +export type FTVMModGetFunction = ( + mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number; +/** + * int TVMModImport(TVMModuleHandle mod, + * TVMModuleHandle dep); + */ +export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; +/** + * int TVMModFree(TVMModuleHandle mod); + */ +export type FTVMModFree = (mod: Pointer) => number; + +/** + * int TVMFuncFree(TVMFunctionHandle func); + */ +export type FTVMFuncFree = (func: Pointer) => number; + +/** + * int TVMFuncCall(TVMFunctionHandle func, + * TVMValue* arg_values, + * int* type_codes, + * int num_args, + * TVMValue* ret_val, + * int* ret_type_code); + */ +export type FTVMFuncCall = ( + func: Pointer, argValues: Pointer, typeCode: Pointer, + nargs: number, retValue: Pointer, retCode: Pointer) => number; + +/** + * int TVMCFuncSetReturn(TVMRetValueHandle ret, + * TVMValue* value, + * int* type_code, + * int num_ret); + */ +export type FTVMCFuncSetReturn = ( + ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number; + +/** + * int TVMCbArgToReturn(TVMValue* value, int* code); + */ +export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number; + +/** + * int TVMFuncListGlobalNames(int* outSize, const char*** outArray); + */ +export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number; + +/** + * int TVMFuncRegisterGlobal( + * const char* name, TVMFunctionHandle f, int override); + */ +export type FTVMFuncRegisterGlobal = ( + name: Pointer, f: Pointer, override: number) => number; + +/** + *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); + */ +export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number; + +/** + * int TVMArrayAlloc(const tvm_index_t* shape, + * int ndim, + * int dtype_code, + * int dtype_bits, + * int dtype_lanes, + * int device_type, + * int device_id, + * TVMArrayHandle* out); + */ +export type FTVMArrayAlloc = ( + shape: Pointer, ndim: number, + dtypeCode: number, dtypeBits: number, + dtypeLanes: number, deviceType: number, deviceId: number, + out: Pointer) => number; + +/** + * int TVMArrayFree(TVMArrayHandle handle); + */ +export type FTVMArrayFree = (handle: Pointer) => number; + +/** + * int TVMArrayCopyFromBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyFromBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyToBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyToBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyFromTo(TVMArrayHandle from, + * TVMArrayHandle to, + * TVMStreamHandle stream); + */ +export type FTVMArrayCopyFromTo = ( + from: Pointer, to: Pointer, stream: Pointer) => number; + +/** + * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); + */ +export type FTVMSynchronize = ( + deviceType: number, deviceId: number, stream: Pointer) => number; + +/** + * typedef int (*TVMBackendPackedCFunc)(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMValue* out_ret_value, + * int* out_ret_tcode); + */ +export type FTVMBackendPackedCFunc = ( + argValues: Pointer, argCodes: Pointer, nargs: number, + outValue: Pointer, outCode: Pointer) => number; + +// -- TVM Wasm Auxiliary C API -- + +/** void* TVMWasmAllocSpace(int size); */ +export type FTVMWasmAllocSpace = (size: number) => Pointer; + +/** void TVMWasmFreeSpace(void* data); */ +export type FTVMWasmFreeSpace = (ptr: Pointer) => void; + +/** + * int TVMWasmPackedCFunc(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMRetValueHandle ret, + * void* resource_handle); + */ +export type FTVMWasmPackedCFunc = ( + args: Pointer, typeCodes: Pointer, nargs: number, + ret: Pointer, resourceHandle: Pointer) => number; + +/** + * int TVMWasmFuncCreateFromCFunc(void* resource_handle, + * TVMFunctionHandle *out); + */ +export type FTVMWasmFuncCreateFromCFunc = ( + resource: Pointer, out: Pointer) => number; + +/** + * void TVMWasmPackedCFuncFinalizer(void* resource_handle); + */ +export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void; + +/** + * Size of common data types. + */ +export const enum SizeOf { + U8 = 1, + U16 = 2, + I32 = 4, + I64 = 8, + F32 = 4, + F64 = 8, + TVMValue = 8, + DLDataType = I32, + DLContext = I32 + I32, +} + +/** + * Type code in TVM FFI. + */ +export const enum TypeCode { + Int = 0, + UInt = 1, + Float = 2, + TVMOpaqueHandle = 3, + Null = 4, + TVMDataType = 5, + TVMContext = 6, + TVMDLTensorHandle = 7, + TVMObjectHandle = 8, + TVMModuleHandle = 9, + TVMPackedFuncHandle = 10, + TVMStr = 11, + TVMBytes = 12, + TVMNDArrayHandle = 13, + TVMObjectRValueRefArg = 14 +} \ No newline at end of file diff --git a/web/src/environment.ts b/web/src/environment.ts new file mode 100644 index 000000000000..df0fe68c81e0 --- /dev/null +++ b/web/src/environment.ts @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Runtime environment that provide js libaries calls. + */ +import { Pointer } from "./ctypes"; +import { LibraryProvider } from "./types"; +import { assert } from "./support"; +import * as ctypes from "./ctypes"; + +/** + * Detect library provider from the importObject. + * + * @param importObject The import object. + */ +function detectLibraryProvider( + importObject: Record +): LibraryProvider | undefined { + if ( + importObject["wasmLibraryProvider"] && + importObject["wasmLibraryProvider"]["start"] && + importObject["wasmLibraryProvider"]["imports"] !== undefined + ) { + const item = importObject as { wasmLibraryProvider: LibraryProvider }; + // create provider so that we capture imports in the provider. + return { + imports: item.wasmLibraryProvider.imports, + start: (inst: WebAssembly.Instance): void => { + item.wasmLibraryProvider.start(inst); + }, + }; + } else if (importObject["imports"] && importObject["start"] !== undefined) { + return importObject as LibraryProvider; + } else if (importObject["wasiImport"] && importObject["start"] !== undefined) { + // WASI + return { + imports: { + "wasi_snapshot_preview1": importObject["wasiImport"], + }, + start: (inst: WebAssembly.Instance): void => { + importObject["start"](inst); + } + }; + } else { + return undefined; + } +} + +/** + * Environment to impelement most of the JS library functions. + */ +export class Environment implements LibraryProvider { + logger: (msg: string) => void; + imports: Record; + /** + * Maintains a table of FTVMWasmPackedCFunc that the C part + * can call via TVMWasmPackedCFunc. + * + * We maintain a separate table so that we can have un-limited amount + * of functions that do not maps to the address space. + */ + packedCFuncTable: Array = [ + undefined, + ]; + /** + * Free table index that can be recycled. + */ + packedCFuncTableFreeId: Array = []; + + private libProvider?: LibraryProvider; + + constructor( + importObject: Record = {}, + logger: (msg: string) => void = console.log + ) { + this.logger = logger; + this.libProvider = detectLibraryProvider(importObject); + // get imports from the provider + if (this.libProvider !== undefined) { + this.imports = this.libProvider.imports; + } else { + this.imports = importObject; + } + // update with more functions + this.imports.env = this.environment(this.imports.env); + } + + /** Mark the start of the instance. */ + start(inst: WebAssembly.Instance): void { + if (this.libProvider !== undefined) { + this.libProvider.start(inst); + } + } + + private environment(initEnv: Record): Record { + // default env can be be overriden by libraries. + const defaultEnv = { + "__cxa_thread_atexit": (): void => {}, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + "emscripten_notify_memory_growth": (index: number): void => {} + }; + const wasmPackedCFunc: ctypes.FTVMWasmPackedCFunc = ( + args: Pointer, + typeCodes: Pointer, + nargs: number, + ret: Pointer, + resourceHandle: Pointer + ): number => { + const cfunc = this.packedCFuncTable[resourceHandle]; + assert(cfunc !== undefined); + return cfunc(args, typeCodes, nargs, ret, resourceHandle); + }; + + const wasmPackedCFuncFinalizer: ctypes.FTVMWasmPackedCFuncFinalizer = ( + resourceHandle: Pointer + ): void => { + this.packedCFuncTable[resourceHandle] = undefined; + this.packedCFuncTableFreeId.push(resourceHandle); + }; + + const newEnv = { + TVMWasmPackedCFunc: wasmPackedCFunc, + TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer, + "__console_log": (msg: string): void => { + this.logger(msg); + } + }; + return Object.assign(defaultEnv, initEnv, newEnv); + } +} \ No newline at end of file diff --git a/web/src/index.ts b/web/src/index.ts new file mode 100644 index 000000000000..5d7d7ccc39cc --- /dev/null +++ b/web/src/index.ts @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export { + Scalar, DLContext, DLDataType, + PackedFunc, Module, NDArray, Instance, + instantiate +} from "./runtime"; +export { Disposable, LibraryProvider } from "./types"; +export { RPCServer } from "./rpc_server"; +export { wasmPath } from "./support"; \ No newline at end of file diff --git a/web/src/memory.ts b/web/src/memory.ts new file mode 100644 index 000000000000..ac737b7c297d --- /dev/null +++ b/web/src/memory.ts @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Classes to manipulate Wasm memories. + */ +import { Pointer, PtrOffset, SizeOf } from "./ctypes"; +import { Disposable } from "./types"; +import { assert, StringToUint8Array } from "./support"; + +import * as ctypes from "./ctypes"; + +/** + * Wasm Memory wrapper to perform JS side raw memory access. + */ +export class Memory { + memory: WebAssembly.Memory; + wasm32 = true; + private buffer: ArrayBuffer | SharedArrayBuffer; + private viewU8: Uint8Array; + private viewU16: Uint16Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF32: Float32Array; + private viewF64: Float64Array; + + constructor(memory: WebAssembly.Memory) { + this.memory = memory; + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } + + loadU8(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU8[ptr >> 0]; + } + + loadU16(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU16[ptr >> 1]; + } + + loadU32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU32[ptr >> 2]; + } + + loadI32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewI32[ptr >> 2]; + } + + loadI64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const base = ptr >> 2; + // assumes little endian, for now truncate high. + return this.viewI32[base]; + } + + loadF32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF32[ptr >> 2]; + } + + loadF64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF64[ptr >> 3]; + } + + loadPointer(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + loadUSize(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + sizeofPtr(): number { + return this.wasm32 ? SizeOf.I32 : SizeOf.I64; + } + /** + * Load raw bytes from ptr. + * @param ptr The head address + * @param numBytes The number + */ + loadRawBytes(ptr: Pointer, numBytes: number): Uint8Array { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const result = new Uint8Array(numBytes); + result.set(this.viewU8.slice(ptr, ptr + numBytes)); + return result; + } + /** + * Load TVMByteArray from ptr. + * + * @param ptr The address of the header. + */ + loadTVMBytes(ptr: Pointer): Uint8Array { + const data = this.loadPointer(ptr); + const length = this.loadUSize(ptr + this.sizeofPtr()); + return this.loadRawBytes(data, length); + } + /** + * Load null-terminated C-string from ptr. + * @param ptr The head address + */ + loadCString(ptr: Pointer): string { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + // NOTE: the views are still valid for read. + const ret = []; + let ch = 1; + while (ch != 0) { + ch = this.viewU8[ptr]; + if (ch != 0) { + ret.push(String.fromCharCode(ch)); + } + ++ptr; + } + return ret.join(""); + } + /** + * Store raw bytes to the ptr. + * @param ptr The head address. + * @param bytes The bytes content. + */ + storeRawBytes(ptr: Pointer, bytes: Uint8Array): void { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + this.viewU8.set(bytes, ptr); + } + + /** + * Update memory view after the memory growth. + */ + private updateViews(): void { + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} + +/** + * Auxiliary call stack for the FFI calls. + * + * Lifecyle of a call stack. + * - Calls into allocXX to allocate space, mixed with storeXXX to store data. + * - Calls into ptrFromOffset, no further allocation(as ptrFromOffset can change), + * can still call into storeXX + * - Calls into commitToWasmMemory once. + * - reset. + */ +export class CachedCallStack implements Disposable { + /** List of temporay arguments that can be disposed during reset. */ + tempArgs: Array = []; + + private memory: Memory; + private cAllocSpace: ctypes.FTVMWasmAllocSpace; + private cFreeSpace: ctypes.FTVMWasmFreeSpace; + + private buffer: ArrayBuffer; + private viewU8: Uint8Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF64: Float64Array; + + private stackTop: PtrOffset = 0; + private basePtr: Pointer = 0; + + private addressToSetTargetValue: Array<[PtrOffset, PtrOffset]> = []; + + constructor( + memory: Memory, + allocSpace: ctypes.FTVMWasmAllocSpace, + freeSpace: ctypes.FTVMWasmFreeSpace + ) { + const initCallStackSize = 128; + this.memory = memory; + this.cAllocSpace = allocSpace; + this.cFreeSpace = freeSpace; + this.buffer = new ArrayBuffer(initCallStackSize); + this.basePtr = this.cAllocSpace(initCallStackSize); + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + this.updateViews(); + } + + dispose(): void { + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + this.basePtr = 0; + } + } + /** + * Rest the call stack so that it can be reused again. + */ + reset(): void { + this.stackTop = 0; + assert(this.addressToSetTargetValue.length == 0); + while (this.tempArgs.length != 0) { + (this.tempArgs.pop() as Disposable).dispose(); + } + } + + /** + * Commit all the cached data to WasmMemory. + * This function can only be called once. + * No further store function should be called. + * + * @param nbytes Number of bytes to be stored. + */ + commitToWasmMemory(nbytes: number = this.stackTop): void { + // commit all pointer values. + while (this.addressToSetTargetValue.length != 0) { + const [targetOffset, valueOffset] = this.addressToSetTargetValue.pop() as [ + number, + number + ]; + this.storePtr(targetOffset, this.ptrFromOffset(valueOffset)); + } + this.memory.storeRawBytes(this.basePtr, this.viewU8.slice(0, nbytes)); + } + + /** + * Allocate space by number of bytes + * @param nbytes Number of bytes. + * @note This function always allocate space that aligns to 64bit. + */ + allocRawBytes(nbytes: number): PtrOffset { + // always aligns to 64bit + nbytes = ((nbytes + 7) >> 3) << 3; + + if (this.stackTop + nbytes > this.buffer.byteLength) { + const newSize = Math.max( + this.buffer.byteLength * 2, + this.stackTop + nbytes + ); + const oldU8 = this.viewU8; + this.buffer = new ArrayBuffer(newSize); + this.updateViews(); + this.viewU8.set(oldU8); + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + } + this.basePtr = this.cAllocSpace(newSize); + } + const retOffset = this.stackTop; + this.stackTop += nbytes; + return retOffset; + } + + /** + * Allocate space for pointers. + * @param count Number of pointers. + * @returns The allocated pointer array. + */ + allocPtrArray(count: number): PtrOffset { + return this.allocRawBytes(this.memory.sizeofPtr() * count); + } + + /** + * Get the real pointer from offset values. + * Note that the returned value becomes obsolete if alloc is called on the stack. + * @param offset The allocated offset. + */ + ptrFromOffset(offset: PtrOffset): Pointer { + return this.basePtr + offset; + } + + // Store APIs + storePtr(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeUSize(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeI32(offset: PtrOffset, value: number): void { + this.viewI32[offset >> 2] = value; + } + + storeU32(offset: PtrOffset, value: number): void { + this.viewU32[offset >> 2] = value; + } + + storeI64(offset: PtrOffset, value: number): void { + // For now, just store as 32bit + // NOTE: wasm always uses little endian. + const low = value & 0xffffffff; + const base = offset >> 2; + this.viewI32[base] = low; + this.viewI32[base + 1] = 0; + } + + storeF64(offset: PtrOffset, value: number): void { + this.viewF64[offset >> 3] = value; + } + + storeRawBytes(offset: PtrOffset, bytes: Uint8Array): void { + this.viewU8.set(bytes, offset); + } + + /** + * Allocate then set C-String pointer to the offset. + * This function will call into allocBytes to allocate necessary data. + * The address won't be set immediately(because the possible change of basePtr) + * and will be filled when we commit the data. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgString(offset: PtrOffset, data: string): void { + const strOffset = this.allocRawBytes(data.length + 1); + this.storeRawBytes(strOffset, StringToUint8Array(data)); + this.addressToSetTargetValue.push([offset, strOffset]); + } + /** + * Allocate then set the argument location with a TVMByteArray. + * Allocate new temporary space for bytes. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgBytes(offset: PtrOffset, data: Uint8Array): void { + // Note: size of size_t equals sizeof ptr. + const headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2); + const dataOffset = this.allocRawBytes(data.length); + this.storeRawBytes(dataOffset, data); + this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length); + + this.addressToSetTargetValue.push([offset, headerOffset]); + this.addressToSetTargetValue.push([headerOffset, dataOffset]); + } + + /** + * Update internal cache views. + */ + private updateViews(): void { + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts new file mode 100644 index 000000000000..054a1b6019cc --- /dev/null +++ b/web/src/rpc_server.ts @@ -0,0 +1,379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { SizeOf, TypeCode } from "./ctypes"; +import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; +import * as runtime from "./runtime"; +import { Class } from "estree"; + +enum RPCServerState { + InitHeader, + InitHeaderKey, + InitServer, + WaitForCallback, + ReceivePacketHeader, + ReceivePacketBody, +} + +/** RPC magic header */ +const RPC_MAGIC = 0xff271; + +/** + * An utility class to read from binary bytes. + */ +class ByteStreamReader { + offset = 0; + bytes: Uint8Array; + + constructor(bytes: Uint8Array) { + this.bytes = bytes; + } + + readU32(): number { + const i = this.offset; + const b = this.bytes; + const val = b[i] | (b[i + 1] << 8) | (b[i + 2] << 16) | (b[i + 3] << 24); + this.offset += 4; + return val; + } + + readU64(): number { + const val = this.readU32(); + this.offset += 4; + return val; + } + + readByteArray(): Uint8Array { + const len = this.readU64(); + assert(this.offset + len <= this.bytes.byteLength); + const ret = new Uint8Array(len); + ret.set(this.bytes.slice(this.offset, this.offset + len)); + this.offset += len; + return ret; + } +} + +/** + * A websocket based RPC + */ +export class RPCServer { + url: string; + key: string; + socket: WebSocket; + state: RPCServerState = RPCServerState.InitHeader; + logger: (msg: string) => void; + getImports: () => Record; + private name: string; + private inst?: runtime.Instance = undefined; + private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void; + private currPacketHeader?: Uint8Array; + private currPacketLength = 0; + private remoteKeyLength = 0; + private pendingBytes = 0; + private buffredBytes = 0; + private messageQueue: Array = []; + + constructor( + url: string, + key: string, + getImports: () => Record, + logger: (msg: string) => void = console.log + ) { + this.url = url; + this.key = key; + this.name = "WebSocketRPCServer[" + this.key + "]: "; + this.getImports = getImports; + this.logger = logger; + + this.checkLittleEndian(); + + if (typeof WebSocket == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const WebSocket = require("ws"); + this.socket = new WebSocket(url); + } else { + this.socket = new (WebSocket as any)(url); + } + + //this.socket = this.getSocket(url); + this.socket.binaryType = "arraybuffer"; + + this.socket.addEventListener("open", (event: Event) => { + return this.onOpen(event); + }); + this.socket.addEventListener("message", (event: MessageEvent) => { + return this.onMessage(event); + }); + this.socket.addEventListener("close", (event: CloseEvent) => { + return this.onClose(event); + }); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onClose(_event: CloseEvent): void { + if (this.inst !== undefined) { + this.inst.dispose(); + } + if (this.state == RPCServerState.ReceivePacketHeader) { + this.log("Closing the server in clean state"); + } else { + this.log("Closing the server, final state=" + this.state); + } + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onOpen(_event: Event): void { + // Send the headers + let bkey = StringToUint8Array("server:" + this.key); + bkey = bkey.slice(0, bkey.length - 1); + const intbuf = new Int32Array(1); + intbuf[0] = RPC_MAGIC; + this.socket.send(intbuf); + intbuf[0] = bkey.length; + this.socket.send(intbuf); + this.socket.send(bkey); + this.log("connected..."); + // request bytes: magic + keylen + this.requestBytes(SizeOf.I32 + SizeOf.I32); + this.state = RPCServerState.InitHeader; + } + + /** Handler for raw message. */ + private onMessage(event: MessageEvent): void { + const buffer = event.data; + this.buffredBytes += buffer.byteLength; + this.messageQueue.push(new Uint8Array(buffer)); + this.processEvents(); + } + /** Process ready events. */ + private processEvents(): void { + while (this.buffredBytes >= this.pendingBytes && this.pendingBytes != 0) { + this.onDataReady(); + } + } + /** State machine to handle each request */ + private onDataReady(): void { + switch (this.state) { + case RPCServerState.InitHeader: { + this.handleInitHeader(); + break; + } + case RPCServerState.InitHeaderKey: { + this.handleInitHeaderKey(); + break; + } + case RPCServerState.ReceivePacketHeader: { + this.currPacketHeader = this.readFromBuffer(SizeOf.I64); + const reader = new ByteStreamReader(this.currPacketHeader); + this.currPacketLength = reader.readU64(); + assert(this.pendingBytes == 0); + this.requestBytes(this.currPacketLength); + this.state = RPCServerState.ReceivePacketBody; + break; + } + case RPCServerState.ReceivePacketBody: { + const body = this.readFromBuffer(this.currPacketLength); + assert(this.pendingBytes == 0); + assert(this.currPacketHeader !== undefined); + this.onPacketReady(this.currPacketHeader, body); + break; + } + case RPCServerState.WaitForCallback: { + assert(this.pendingBytes == 0); + break; + } + default: { + throw new Error("Cannot handle state " + this.state); + } + } + } + + private onPacketReady(header: Uint8Array, body: Uint8Array): void { + if (this.inst === undefined) { + // initialize server. + const reader = new ByteStreamReader(body); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const code = reader.readU32(); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const ver = Uint8ArrayToString(reader.readByteArray()); + const nargs = reader.readU32(); + const tcodes = []; + const args = []; + for (let i = 0; i < nargs; ++i) { + tcodes.push(reader.readU32()); + } + + for (let i = 0; i < nargs; ++i) { + const tcode = tcodes[i]; + if (tcode == TypeCode.TVMStr) { + const str = Uint8ArrayToString(reader.readByteArray()); + args.push(str); + } else if (tcode == TypeCode.TVMBytes) { + args.push(reader.readByteArray()); + } else { + throw new Error("cannot support type code " + tcode); + } + } + this.onInitServer(args, header, body); + } else { + assert(this.serverRecvData !== undefined); + this.serverRecvData(header, body); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + } + + /** Event handler during server initialization. */ + private onInitServer( + args: Array, + header: Uint8Array, + body: Uint8Array + ): void { + // start the server + assert(args[0] == "rpc.WasmSession"); + assert(args[1] instanceof Uint8Array); + assert(this.pendingBytes == 0); + + runtime.instantiate(args[1].buffer, this.getImports()) + .then((inst: runtime.Instance) => { + this.inst = inst; + const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); + + const messageHandler = fcreate( + (cbytes: Uint8Array): runtime.Scalar => { + assert(this.inst !== undefined); + if (this.socket.readyState == 1) { + this.socket.send(cbytes); + return this.inst.scalar(cbytes.length, "int32"); + } else { + return this.inst.scalar(0, "int32"); + } + }, + this.name, + this.key + ); + + fcreate.dispose(); + const writeFlag = this.inst.scalar(3, "int32"); + + this.serverRecvData = (header: Uint8Array, body: Uint8Array): void => { + if (messageHandler(header, writeFlag) == 0) { + this.socket.close(); + } + if (messageHandler(body, writeFlag) == 0) { + this.socket.close(); + } + }; + + // Forward the same init sequence to the wasm RPC. + // The RPC will look for "rpc.wasmSession" + // and we will redirect it to the correct local session. + // register the callback to redirect the session to local. + const flocal = this.inst.getGlobalFunc("rpc.LocalSession"); + const localSession = flocal(); + flocal.dispose(); + assert(localSession instanceof runtime.Module); + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + this.inst.registerFunc( + "rpc.WasmSession", + // eslint-disable-next-line @typescript-eslint/no-unused-vars + (_args: unknown): runtime.Module => { + return localSession; + } + ); + messageHandler(header, writeFlag); + messageHandler(body, writeFlag); + localSession.dispose(); + + this.log("Finish initializing the Wasm Server.."); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + // call process events in case there are bufferred data. + this.processEvents(); + }); + this.state = RPCServerState.WaitForCallback; + } + + private log(msg: string): void { + this.logger(this.name + msg); + } + + private handleInitHeader(): void { + const reader = new ByteStreamReader(this.readFromBuffer(SizeOf.I32 * 2)); + const magic = reader.readU32(); + if (magic == RPC_MAGIC + 1) { + throw new Error("key: " + this.key + " has already been used in proxy"); + } else if (magic == RPC_MAGIC + 2) { + throw new Error("RPCProxy do not have matching client key " + this.key); + } + assert(magic == RPC_MAGIC, this.url + " is not an RPC Proxy"); + this.remoteKeyLength = reader.readU32(); + assert(this.pendingBytes == 0); + this.requestBytes(this.remoteKeyLength); + this.state = RPCServerState.InitHeaderKey; + } + + private handleInitHeaderKey(): void { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const remoteKey = Uint8ArrayToString( + this.readFromBuffer(this.remoteKeyLength) + ); + assert(this.pendingBytes == 0); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + + private checkLittleEndian(): void { + const a = new ArrayBuffer(4); + const b = new Uint8Array(a); + const c = new Uint32Array(a); + b[0] = 0x11; + b[1] = 0x22; + b[2] = 0x33; + b[3] = 0x44; + assert(c[0] === 0x44332211, "RPCServer little endian to work"); + } + + private requestBytes(nbytes: number): void { + this.pendingBytes += nbytes; + } + + private readFromBuffer(nbytes: number): Uint8Array { + const ret = new Uint8Array(nbytes); + let ptr = 0; + while (ptr < nbytes) { + assert(this.messageQueue.length != 0); + const nleft = nbytes - ptr; + if (this.messageQueue[0].byteLength <= nleft) { + const buffer = this.messageQueue.shift() as Uint8Array; + ret.set(buffer, ptr); + ptr += buffer.byteLength; + } else { + const buffer = this.messageQueue[0]; + ret.set(buffer.slice(0, nleft), ptr); + this.messageQueue[0] = buffer.slice(nleft, buffer.byteLength); + ptr += nleft; + } + } + this.buffredBytes -= nbytes; + this.pendingBytes -= nbytes; + return ret; + } +} diff --git a/web/src/runtime.ts b/web/src/runtime.ts new file mode 100644 index 000000000000..cd9b967596af --- /dev/null +++ b/web/src/runtime.ts @@ -0,0 +1,1113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * TVM JS Wasm Runtime library. + */ +import { Pointer, PtrOffset, SizeOf, TypeCode } from "./ctypes"; +import { Disposable } from "./types"; +import { Memory, CachedCallStack } from "./memory"; +import { assert, StringToUint8Array } from "./support"; +import { Environment } from "./environment"; + +import * as ctypes from "./ctypes"; + +/** + * Type for PackedFunc inthe TVMRuntime. + */ +export type PackedFunc = ((...args: any) => any) & + Disposable & { _tvmPackedCell: PackedFuncCell }; + +/** + * @internal + * FFI Library wrapper, maintains most runtime states. + */ +class FFILibrary implements Disposable { + wasm32: boolean; + memory: Memory; + exports: Record; + private wasmInstance: WebAssembly.Instance; + + private recycledCallStacks: Array = []; + + constructor( + wasmInstance: WebAssembly.Instance, + imports: Record + ) { + this.wasmInstance = wasmInstance; + this.memory = new Memory(this.detectWasmMemory(this.wasmInstance, imports)); + assert( + this.wasmInstance.exports !== undefined, + "Expect the library module contains exports" + ); + this.exports = this.wasmInstance.exports as Record; + this.wasm32 = this.memory.wasm32; + this.validateInstance(); + } + + dispose(): void { + while (this.recycledCallStacks.length != 0) { + (this.recycledCallStacks.pop() as Disposable).dispose(); + } + } + + sizeofPtr(): number { + return this.memory.sizeofPtr(); + } + + checkCall(code: number): void { + if (code != 0) { + const msgPtr = (this.exports + .TVMGetLastError as ctypes.FTVMGetLastError)(); + throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); + } + } + + getOrAllocCallStack(): CachedCallStack { + if (this.recycledCallStacks.length != 0) { + return this.recycledCallStacks.pop() as CachedCallStack; + } + return new CachedCallStack( + this.memory, + this.exports.TVMWasmAllocSpace as ctypes.FTVMWasmAllocSpace, + this.exports.TVMWasmFreeSpace as ctypes.FTVMWasmFreeSpace + ); + } + + recycleCallStack(callstack: CachedCallStack): void { + callstack.reset(); + this.recycledCallStacks.push(callstack); + } + + private validateInstance(): void { + this.checkExports(["TVMWasmAllocSpace", "TVMWasmFreeSpace", "TVMFuncFree"]); + } + + private checkExports(funcNames: Array): void { + const missList = []; + for (const name of funcNames) { + const f = this.exports[name]; + if (!(f instanceof Function)) { + missList.push(name); + } + } + if (missList.length != 0) { + throw new Error("Cannot find " + missList + " in exports"); + } + } + + private detectWasmMemory( + instance: WebAssembly.Instance, + imports: Record + ): WebAssembly.Memory { + if (instance.exports.memory instanceof WebAssembly.Memory) { + return instance.exports.memory; + } + if (imports.env && imports.env.memory instanceof WebAssembly.Memory) { + return imports.env.memory; + } + + throw new Error( + "Cannt detect wasm memory from imports " + + imports + + " or exports" + + instance.exports + ); + } +} + +/** + * A typed scalar constant used to represent a typed number + * argument to PackedFunc calls. + */ +export class Scalar { + /** The value. */ + value: number; + /** The data type of the scalar. */ + dtype: string; + + constructor(value: number, dtype: string) { + this.value = value; + this.dtype = dtype; + } +} + +/** + * Cell holds the PackedFunc object. + */ +class PackedFuncCell implements Disposable { + handle: Pointer; + private lib: FFILibrary; + + constructor(handle: Pointer, lib: FFILibrary) { + this.handle = handle; + this.lib = lib; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle) + ); + this.handle = 0; + } + } +} + +const DeviceEnumToStr: Record = { + 1: "cpu", + 2: "gpu", + 4: "opencl", + 7: "vulkan", + 8: "metal", +}; + +const DeviceStrToEnum: Record = { + cpu: 1, + gpu: 2, + cuda: 2, + cl: 4, + opencl: 4, + vulkan: 7, + metal: 8, +}; + +/** + * Represent a runtime context where a NDArray can reside. + */ +export class DLContext { + /** The device type code of the context. */ + deviceType: number; + /** The device index. */ + deviceId: number; + + private lib: FFILibrary; + + constructor(deviceType: number | string, deviceId: number, lib: FFILibrary) { + const tp = typeof deviceType; + if (tp == "string") { + this.deviceType = DeviceStrToEnum[deviceType]; + } else if (tp == "number") { + this.deviceType = deviceType as number; + } else { + throw new Error("Cannot take type " + tp + " as deviceType"); + } + this.deviceId = deviceId; + this.lib = lib; + } + + /** + * Synchronize the context + */ + sync(): void { + this.lib.checkCall( + (this.lib.exports.TVMSynchronize as ctypes.FTVMSynchronize)( + this.deviceType, + this.deviceId, + 0 + ) + ); + } + + toString(): string { + return ( + DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")" + ); + } +} + +const DLDataTypeCodeToStr: Record = { + 0: "int", + 1: "uint", + 2: "float", + 4: "handle", +}; + +/** + * Runtime data type of NDArray. + */ +export class DLDataType { + /** The type code */ + code: number; + /** Number of bits in the data type. */ + bits: number; + /** Number of vector lanes. */ + lanes: number; + + constructor(code: number, bits: number, lanes: number) { + this.code = code; + this.bits = bits; + this.lanes = lanes; + } + + toString(): string { + const ret = DLDataTypeCodeToStr[this.code] + this.bits.toString(); + if (this.lanes != 1) { + return ret + "x" + this.lanes.toString(); + } else { + return ret; + } + } + + numStorageBytes(): number { + return (this.bits * this.lanes + 7) >> 3; + } +} + +/** + * n-dimnesional array. + */ +export class NDArray implements Disposable { + /** Internal array handle. */ + handle: Pointer; + /** Number of dimensions. */ + ndim: number; + /** Data type of the array. */ + dtype: string; + /** Shape of the array. */ + shape: Array; + /** Context of the array. */ + context: DLContext; + + private byteOffset: number; + private dltensor: Pointer; + private lib: FFILibrary; + private dlDataType: DLDataType; + + constructor(handle: Pointer, lib: FFILibrary) { + this.handle = handle; + this.lib = lib; + + this.dltensor = this.getDLTensorFromArrayHandle(this.handle); + // constant offsets. + const arrayOffsetData = 0; + const arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr(); + const arrayOffsetDevType = arrayOffsetContext; + const arrayOffsetDevId = arrayOffsetContext + SizeOf.I32; + const arrayOffsetNdim = arrayOffsetContext + SizeOf.DLContext; + const arrayOffsetDtype = arrayOffsetNdim + SizeOf.I32; + const arrayOffsetDtypeCode = arrayOffsetDtype; + const arrayOffsetDtypeBits = arrayOffsetDtype + SizeOf.U8; + const arrayOffsetDtypeLanes = arrayOffsetDtypeBits + SizeOf.U8; + const arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType; + const arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr(); + const arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr(); + // ndim + this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim); + // shape + const cshapePtr = lib.memory.loadPointer(this.dltensor + arrayOffsetShape); + this.shape = []; + for (let i = 0; i < this.ndim; ++i) { + this.shape.push(lib.memory.loadI64(cshapePtr + i * SizeOf.I64)); + } + // dtype + const code = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeCode); + const bits = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeBits); + const lanes = lib.memory.loadU16(this.dltensor + arrayOffsetDtypeLanes); + this.dlDataType = new DLDataType(code, bits, lanes); + this.dtype = this.dlDataType.toString(); + + // ctx + const deviceType = lib.memory.loadI32(this.dltensor + arrayOffsetDevType); + const deviceId = lib.memory.loadI32(this.dltensor + arrayOffsetDevId); + this.context = new DLContext(deviceType, deviceId, lib); + + // byte_offset + this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset); + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle) + ); + this.handle = 0; + } + } + /** + * Copy data from another NDArray or javascript array. + * The number of elements must match. + * + * @param data The source data array. + * @returns this + */ + copyFrom(data: NDArray | Array): this { + if (data instanceof NDArray) { + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( + data.handle, + this.handle, + 0 + ) + ); + return this; + } else { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + if (data.length != size) { + throw new Error( + "data size and shape mismatch data.length" + + data.length + + " vs " + + size + ); + } + let buffer: ArrayBuffer; + if (this.dtype == "float32") { + buffer = Float32Array.from(data).buffer; + } else if (this.dtype == "float64") { + buffer = Float64Array.from(data).buffer; + } else if (this.dtype == "int32") { + buffer = Int32Array.from(data).buffer; + } else if (this.dtype == "int8") { + buffer = Int8Array.from(data).buffer; + } else if (this.dtype == "uint8") { + buffer = Uint8Array.from(data).buffer; + } else { + throw new Error("Unsupported data type " + this.dtype); + } + return this.copyFromRawBytes(new Uint8Array(buffer)); + } + } + /** + * Copy data from raw bytes. + * @param data Uint8Array of bytes. + * @returns this + */ + copyFromRawBytes(data: Uint8Array): this { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + const nbytes = this.dlDataType.numStorageBytes() * size; + if (nbytes != data.length) { + throw new Error("Expect the data's length equals nbytes=" + nbytes); + } + + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.memory.storeRawBytes(tempPtr, data); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)( + this.handle, + tempPtr, + nbytes + ) + ); + + this.lib.recycleCallStack(stack); + return this; + } + /** + * Return a copied Uint8Array of the raw bytes in the NDArray. + * @returns The result array. + */ + toRawBytes(): Uint8Array { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + const nbytes = this.dlDataType.numStorageBytes() * size; + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)( + this.handle, + tempPtr, + nbytes + ) + ); + const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes); + + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Return a TypedArray copy of the NDArray, the specific type depends on + * the dtype of the NDArray. + * @returns The result array. + */ + toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array { + const stype = this.dtype; + if (stype == "float32") { + return new Float32Array(this.toRawBytes().buffer); + } else if (stype == "float64") { + return new Float64Array(this.toRawBytes().buffer); + } else if (stype == "int32") { + return new Int32Array(this.toRawBytes().buffer); + } else if (stype == "int8") { + return new Int8Array(this.toRawBytes().buffer); + } else if (stype == "uint8") { + return new Uint8Array(this.toRawBytes().buffer); + } else { + throw new Error("Unsupported data type " + this.dtype); + } + } + + private getDLTensorFromArrayHandle(handle: Pointer): Pointer { + // Note: this depends on the NDArray C ABI. + // keep this function in case of ABI change. + return handle; + } +} + +/** + * Runtime Module. + */ +export class Module implements Disposable { + handle: Pointer; + private lib: FFILibrary; + private makePackedFunc: (ptr: Pointer) => PackedFunc; + + constructor( + handle: Pointer, + lib: FFILibrary, + makePackedFunc: (ptr: Pointer) => PackedFunc + ) { + this.handle = handle; + this.lib = lib; + this.makePackedFunc = makePackedFunc; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle) + ); + this.handle = 0; + } + } + + /** + * Get a function in the module. + * @param name The name of the function. + * @returns The result function. + */ + getFunction(name: string): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( + this.handle, + stack.ptrFromOffset(nameOffset), + 1, + outPtr + ) + ); + const handle = this.lib.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Import another module into the current runtime module. + * @param mod The module to be imported. + */ + importModule(mod: Module): void { + this.lib.checkCall( + (this.lib.exports.TVMModImport as ctypes.FTVMModImport)( + this.handle, + mod.handle + ) + ); + } +} + +/** + * TVM runtime instance. + */ +export class Instance implements Disposable { + memory: Memory; + exports: Record; + private lib: FFILibrary; + private env: Environment; + + /** + * Internal function(registered by the runtime) + */ + private wasmCreateLibraryModule?: PackedFunc & + ((getFunc: PackedFunc, getGlobal: PackedFunc) => PackedFunc); + + /** + * Constructor + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * + * @param wasmModule The input module or instance. + * @param importObject The imports to initialize the wasmInstance if it is not provided. + * @param wasmInstance Additional wasm instance argument for deferred construction. + * @param env Directly specified environment module. + * + * @see Please use the async version {@link instantiate} when targeting browsers. + */ + constructor( + wasmModule: WebAssembly.Module, + importObject: Record = {}, + wasmInstance?: WebAssembly.Instance, + env?: Environment + ) { + if (wasmInstance instanceof WebAssembly.Instance) { + assert( + env instanceof Environment, + "env must be provided when passing in instance" + ); + } else { + assert(env === undefined); + env = new Environment(importObject); + wasmInstance = new WebAssembly.Instance(wasmModule, env.imports); + } + + env.start(wasmInstance); + this.env = env; + this.lib = new FFILibrary(wasmInstance, env.imports); + this.memory = this.lib.memory; + this.exports = this.lib.exports; + this.registerEnvGlobalPackedFuncs(); + } + + dispose(): void { + this.lib.dispose(); + } + /** + * Get system-wide library module in the wasm. + * System lib is a global module that contains self register functions in startup. + * @returns The system library module. + */ + systemLib(): Module { + const getSysLib = this.getGlobalFunc("runtime.SystemLib"); + const mod = getSysLib() as Module; + getSysLib.dispose(); + return mod; + } + /** + * List all the global function names registered in the runtime. + * @returns The name list. + */ + listGlobalFuncNames(): Array { + const stack = this.lib.getOrAllocCallStack(); + + const outSizeOffset = stack.allocPtrArray(2); + + const outSizePtr = stack.ptrFromOffset(outSizeOffset); + const outArrayPtr = stack.ptrFromOffset( + outSizeOffset + this.lib.sizeofPtr() + ); + + this.lib.checkCall( + (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)( + outSizePtr, + outArrayPtr + ) + ); + + const size = this.memory.loadI32(outSizePtr); + const array = this.memory.loadPointer(outArrayPtr); + const names: Array = []; + + for (let i = 0; i < size; ++i) { + names.push( + this.memory.loadCString( + this.memory.loadPointer(array + this.lib.sizeofPtr() * i) + ) + ); + } + + this.lib.recycleCallStack(stack); + return names; + } + + /** + * Register function to be global function in tvm runtime. + * @param name The name of the function. + * @param f function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerFunc( + name: string, + func: PackedFunc | Function, + override = false + ): void { + const packedFunc = this.toPackedFunc(func); + const ioverride = override ? 1 : 0; + + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + stack.commitToWasmMemory(); + + this.lib.checkCall( + (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( + stack.ptrFromOffset(nameOffset), + packedFunc._tvmPackedCell.handle, + ioverride + ) + ); + } + + /** + * Get global PackedFunc from the runtime. + * @param name The name of the function. + * @returns The result function. + */ + getGlobalFunc(name: string): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)( + stack.ptrFromOffset(nameOffset), + outPtr + ) + ); + const handle = this.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find global function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Check if func is PackedFunc. + * + * @param func The input. + * @returns The check result. + */ + isPackedFunc(func: unknown): boolean { + // eslint-disable-next-line no-prototype-builtins + return typeof func == "function" && func.hasOwnProperty("_tvmPackedCell"); + } + + /** + * Convert func to PackedFunc + * + * @param func Input function. + * @returns The converted function. + */ + toPackedFunc(func: Function): PackedFunc { + if (this.isPackedFunc(func)) return func as PackedFunc; + return this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + } + + /** + * Convert dtype to {@link DLDataType} + * + * @param dtype The input dtype string or DLDataType. + * @returns The converted result. + */ + toDLDataType(dtype: string | DLDataType): DLDataType { + if (dtype instanceof DLDataType) return dtype; + if (typeof dtype == "string") { + let pattern = dtype; + let code, + bits = 32, + lanes = 1; + if (pattern.substring(0, 5) == "float") { + pattern = pattern.substring(5, pattern.length); + code = TypeCode.Float; + } else if (pattern.substring(0, 3) == "int") { + pattern = pattern.substring(3, pattern.length); + code = TypeCode.Int; + } else if (pattern.substring(0, 4) == "uint") { + pattern = pattern.substring(4, pattern.length); + code = TypeCode.UInt; + } else if (pattern.substring(0, 6) == "handle") { + pattern = pattern.substring(5, pattern.length); + code = TypeCode.TVMOpaqueHandle; + bits = 64; + } else { + throw new Error("Unknown dtype " + dtype); + } + + const arr = pattern.split("x"); + if (arr.length >= 1) { + const parsed = parseInt(arr[0]); + if (parsed + "" == arr[0]) { + bits = parsed; + } + } + if (arr.length >= 2) { + lanes = parseInt(arr[1]); + } + return new DLDataType(code, bits, lanes); + } else { + throw new Error("Unknown dtype " + dtype); + } + } + + /** + * Create a new {@link Scalar} that can be passed to a PackedFunc. + * @param value The number value. + * @param dtype The dtype string. + * @returns The created scalar. + */ + scalar(value: number, dtype: string): Scalar { + return new Scalar(value, dtype); + } + + /** + * Create a new {@link DLContext} + * @param deviceType The device type. + * @param deviceId The device index. + * @returns The created context. + */ + context(deviceType: number | string, deviceId: number): DLContext { + return new DLContext(deviceType, deviceId, this.lib); + } + + /** + * Create an empty {@link NDArray} with given shape and dtype. + * + * @param shape The shape of the array. + * @param dtype The data type of the array. + * @param ctx The context of the ndarray. + * @returns The created ndarray. + */ + empty( + shape: Array | number, + dtype: string | DLDataType = "float32", + ctx: DLContext = this.context("cpu", 0) + ): NDArray { + dtype = this.toDLDataType(dtype); + shape = typeof shape == "number" ? [shape] : shape; + + const stack = this.lib.getOrAllocCallStack(); + const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64); + for (let i = 0; i < shape.length; ++i) { + stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]); + } + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)( + stack.ptrFromOffset(shapeOffset), + shape.length, + dtype.code, + dtype.bits, + dtype.lanes, + ctx.deviceType, + ctx.deviceId, + outPtr + ) + ); + const ret = new NDArray(this.memory.loadPointer(outPtr), this.lib); + this.lib.recycleCallStack(stack); + return ret; + } + + /** Register global packed functions needed by the backend to the env. */ + private registerEnvGlobalPackedFuncs(): void { + // Register the timer function to enable the time_evaluator. + let perf: Performance; + if (typeof performance == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const performanceNode = require('perf_hooks'); + perf = performanceNode.performance as Performance; + } else { + perf = performance as Performance; + } + + const getTimer = (func: PackedFunc) => { + return (n: number): number => { + const nscalar = this.scalar(n, "int32"); + const tstart: number = perf.now(); + func(nscalar); + const tend: number = perf.now(); + return tend - tstart; + } + }; + this.registerFunc("wasm.GetTimer", getTimer); + const rpcWrapTimeEvaluator = this.getGlobalFunc("wasm.RPCTimeEvaluator"); + this.registerFunc("runtime.RPCTimeEvaluator", rpcWrapTimeEvaluator, true); + rpcWrapTimeEvaluator.dispose(); + } + + private createPackedFuncFromCFunc( + func: ctypes.FTVMWasmPackedCFunc + ): PackedFunc { + let findex = this.env.packedCFuncTable.length; + if (this.env.packedCFuncTableFreeId.length != 0) { + findex = this.env.packedCFuncTableFreeId.pop() as number; + } else { + this.env.packedCFuncTable.push(undefined); + } + this.env.packedCFuncTable[findex] = func; + + const stack = this.lib.getOrAllocCallStack(); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall( + (this.exports + .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)( + findex, + outPtr + ) + ); + const ret = this.makePackedFunc(this.memory.loadPointer(outPtr)); + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Set packed function arguments into the location indicated by argsValue and argsCode. + * Allocate new temporary space from the stack if necessary. + * + * @parma stack The call stack + * @param args The input arguments. + * @param argsValue The offset of argsValue. + * @param argsCode The offset of argsCode. + */ + setPackedArguments( + stack: CachedCallStack, + args: Array, + argsValue: PtrOffset, + argsCode: PtrOffset + ): void { + for (let i = 0; i < args.length; ++i) { + let val = args[i]; + const tp = typeof val; + const valueOffset = argsValue + i * SizeOf.TVMValue; + const codeOffset = argsCode + i * SizeOf.I32; + if (val instanceof NDArray) { + stack.storePtr(valueOffset, val.handle); + stack.storeI32(codeOffset, TypeCode.TVMNDArrayHandle); + } else if (val instanceof Scalar) { + if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { + stack.storeI64(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.Int); + } else if (val.dtype.startsWith("float")) { + stack.storeF64(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.Float); + } else { + assert(val.dtype == "handle", "Expect handle"); + stack.storePtr(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.TVMOpaqueHandle); + } + } else if (tp == "number") { + stack.storeF64(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.Float); + // eslint-disable-next-line no-prototype-builtins + } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { + stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + } else if (val === null || val == undefined) { + stack.storePtr(valueOffset, 0); + stack.storeI32(codeOffset, TypeCode.Null); + } else if (tp == "string") { + stack.allocThenSetArgString(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.TVMStr); + } else if (val instanceof Uint8Array) { + stack.allocThenSetArgBytes(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.TVMBytes); + } else if (val instanceof Function) { + val = this.toPackedFunc(val); + stack.tempArgs.push(val); + stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + } else if (val instanceof Module) { + stack.storePtr(valueOffset, val.handle); + stack.storeI32(codeOffset, TypeCode.TVMModuleHandle); + } else { + throw new Error("Unsupported argument type " + tp); + } + } + } + + private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc { + const lib = this.lib; + return ( + argValues: Pointer, + argCodes: Pointer, + nargs: number, + ret: Pointer, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _handle: Pointer + ): number => { + const jsArgs = []; + for (let i = 0; i < nargs; ++i) { + const valuePtr = argValues + i * SizeOf.TVMValue; + const codePtr = argCodes + i * SizeOf.I32; + let tcode = lib.memory.loadI32(codePtr); + + if ( + tcode == TypeCode.TVMObjectHandle || + tcode == TypeCode.TVMObjectRValueRefArg || + tcode == TypeCode.TVMPackedFuncHandle || + tcode == TypeCode.TVMModuleHandle + ) { + lib.checkCall( + (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( + valuePtr, + codePtr + ) + ); + } + tcode = lib.memory.loadI32(codePtr); + jsArgs.push(this.retValueToJS(valuePtr, tcode)); + } + + const rv = func(...jsArgs); + + if (rv !== undefined && rv !== null) { + const stack = lib.getOrAllocCallStack(); + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const codeOffset = stack.allocRawBytes(SizeOf.I32); + this.setPackedArguments(stack, [rv], valueOffset, codeOffset); + const valuePtr = stack.ptrFromOffset(valueOffset); + const codePtr = stack.ptrFromOffset(codeOffset); + stack.commitToWasmMemory(); + lib.checkCall( + (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)( + ret, + valuePtr, + codePtr, + 1 + ) + ); + lib.recycleCallStack(stack); + } + return 0; + }; + } + + private makePackedFunc(handle: Pointer): PackedFunc { + const cell = new PackedFuncCell(handle, this.lib); + + const packedFunc = (...args: any): any => { + const stack = this.lib.getOrAllocCallStack(); + + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length); + const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length); + + this.setPackedArguments(stack, args, valueOffset, tcodeOffset); + + const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const rcodeOffset = stack.allocRawBytes(SizeOf.I32); + const rvaluePtr = stack.ptrFromOffset(rvalueOffset); + const rcodePtr = stack.ptrFromOffset(rcodeOffset); + + // commit to wasm memory, till rvalueOffset (the return value don't need to be committed) + stack.commitToWasmMemory(rvalueOffset); + + this.lib.checkCall( + (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( + handle, + stack.ptrFromOffset(valueOffset), + stack.ptrFromOffset(tcodeOffset), + args.length, + rvaluePtr, + rcodePtr + ) + ); + + const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr)); + this.lib.recycleCallStack(stack); + return ret; + }; + // Attach attributes to the function type. + // This is because javascript do not allow us to overload call. + const ret: any = packedFunc; + ret.dispose = (): void => { + cell.dispose(); + }; + ret._tvmPackedCell = cell; + return ret as PackedFunc; + } + + private retValueToJS(rvaluePtr: Pointer, tcode: number): any { + switch (tcode) { + case TypeCode.Int: + case TypeCode.UInt: + return this.memory.loadI64(rvaluePtr); + case TypeCode.Float: + return this.memory.loadF64(rvaluePtr); + case TypeCode.TVMNDArrayHandle: { + return new NDArray(this.memory.loadPointer(rvaluePtr), this.lib); + } + case TypeCode.TVMPackedFuncHandle: { + return this.makePackedFunc(this.memory.loadPointer(rvaluePtr)); + } + case TypeCode.TVMModuleHandle: { + return new Module( + this.memory.loadPointer(rvaluePtr), + this.lib, + (ptr: Pointer) => { + return this.makePackedFunc(ptr); + } + ); + } + case TypeCode.Null: + return undefined; + case TypeCode.TVMStr: { + return this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); + } + case TypeCode.TVMBytes: { + return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); + } + default: + throw new Error("Unsupported return type code=" + tcode); + } + } +} + +/** + * Asynchrously instantiate a new {@link Instance}. + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * We can take benefit of syslib implementations from the Emscripten + * by passing its generated js Module as the imports. + */ +export function instantiate( + bufferSource: ArrayBuffer, + importObject: Record = {} +): Promise { + const env = new Environment(importObject); + + return WebAssembly.instantiate(bufferSource, env.imports).then( + (result: WebAssembly.WebAssemblyInstantiatedSource): Instance => { + return new Instance(result.module, {}, result.instance, env); + } + ); +} diff --git a/web/src/support.ts b/web/src/support.ts new file mode 100644 index 000000000000..7a2667a2299f --- /dev/null +++ b/web/src/support.ts @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Convert string to Uint8array. + * @param str The string. + * @returns The corresponding Uint8Array. + */ +export function StringToUint8Array(str: string): Uint8Array { + const arr = new Uint8Array(str.length + 1); + for (let i = 0; i < str.length; ++i) { + arr[i] = str.charCodeAt(i); + } + arr[str.length] = 0; + return arr; +} + +/** + * Convert Uint8array to string. + * @param array The array. + * @returns The corresponding string. + */ +export function Uint8ArrayToString(arr: Uint8Array): string { + const ret = []; + for (const ch of arr) { + ret.push(String.fromCharCode(ch)); + } + return ret.join(""); +} + +/** + * Internal assert helper + * @param condition condition The condition to fail. + * @param msg msg The message. + */ +export function assert(condition: boolean, msg?: string): asserts condition { + if (!condition) { + throw new Error("AssertError:" + (msg || "")); + } +} + +/** + * Get the path to the wasm library in nodejs. + * @return The wasm path. + */ +export function wasmPath(): string { + return __dirname + "/wasm"; +} \ No newline at end of file diff --git a/web/src/types.ts b/web/src/types.ts new file mode 100644 index 000000000000..621375a23f5f --- /dev/null +++ b/web/src/types.ts @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** Common type definitions. */ + +/** + * Library interface provider that can provide + * syslibs(e.g. libs provided by WASI and beyond) for the Wasm runtime. + * + * It can be viewed as a generalization of imports used in WebAssembly instance creation. + * + * The {@link LibraryProvider.start} callback will be called + * to allow the library provider to initialize related resources during startup time. + * + * We can use Emscripten generated js Module as a { wasmLibraryProvider: LibraryProvider }. + */ +export interface LibraryProvider { + /** The imports that can be passed to WebAssembly instance creation. */ + imports: Record; + /** + * Callback function to notify the provider the created instance. + * @param inst The created instance. + */ + start: (inst: WebAssembly.Instance) => void; +} + +/** + * Disposable classes that contains resources (WasmMemory, GPU buffer) + * which needs to be explicitly disposed. + */ +export interface Disposable { + /** + * Dispose the internal resource + * This function can be called multiple times, + * only the first call will take effect. + */ + dispose: () => void; +} diff --git a/tests/web/test_module_load.js b/web/tests/node/test_module_load.js similarity index 64% rename from tests/web/test_module_load.js rename to web/tests/node/test_module_load.js index f4c809536bb5..45e84fd404a9 100644 --- a/tests/web/test_module_load.js +++ b/web/tests/node/test_module_load.js @@ -19,14 +19,18 @@ // Load Emscripten Module, need to change path to root/lib const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/test_module.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "test_addone.wasm")); + +const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); // Load system library -var sysLib = tvm.systemLib(); +const sysLib = tvm.systemLib(); function randomArray(length, max) { return Array.apply(null, Array(length)).map(function() { @@ -36,23 +40,22 @@ function randomArray(length, max) { function testAddOne() { // grab pre-loaded function - var faddOne = sysLib.getFunction("add_one"); - var assert = require('assert'); - tvm.assert(tvm.isPackedFunc(faddOne)); - var n = 124; - var A = tvm.empty(n).copyFrom(randomArray(n, 1)); - var B = tvm.empty(n); + const faddOne = sysLib.getFunction("add_one"); + assert(tvm.isPackedFunc(faddOne)); + const n = 124; + const A = tvm.empty(n).copyFrom(randomArray(n, 1)); + const B = tvm.empty(n); // call the function. faddOne(A, B); - AA = A.asArray(); // retrieve values in js array - BB = B.asArray(); // retrieve values in js array + const AA = A.toArray(); // retrieve values in js array + const BB = B.toArray(); // retrieve values in js array // verify for (var i = 0; i < BB.length; ++i) { assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); } - faddOne.release(); + faddOne.dispose(); } testAddOne(); -sysLib.release(); +sysLib.dispose(); console.log("Finish verifying test_module_load"); diff --git a/tests/web/test_basic.js b/web/tests/node/test_ndarray.js similarity index 55% rename from tests/web/test_basic.js rename to web/tests/node/test_ndarray.js index 6852319dbc12..ba43621ecb05 100644 --- a/tests/web/test_basic.js +++ b/web/tests/node/test_ndarray.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -16,31 +16,34 @@ * specific language governing permissions and limitations * under the License. */ - -// Load Emscripten Module, need to change path to root/build const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist/tvmjs.bundle") + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); // Basic fields. -tvm.assert(tvm.float32 == "float32"); -tvm.assert(tvm.listGlobalFuncNames() !== "undefined"); -var sysLib = tvm.systemLib(); -tvm.assert(typeof sysLib.getFunction !== "undefined"); -sysLib.release(); +assert(tvm.listGlobalFuncNames() !== undefined); // Test ndarray -function testArrayCopy(dtype, arr) { - var data = [1, 2, 3, 4, 5, 6]; - var a = tvm.empty([2, 3], dtype); - a.copyFrom(data); - var ret = a.asArray(); - tvm.assert(ret instanceof arr); - tvm.assert(ret.toString() == arr.from(data)); - a.release(); +function testArrayCopy(dtype, arrayType) { + let data = [1, 2, 3, 4, 5, 6]; + let a = tvm.empty([2, 3], dtype).copyFrom(data); + + assert(a.context.toString() == "cpu(0)"); + assert(a.shape[0] == 2 && a.shape[1] == 3); + + let ret = a.toArray(); + assert(ret instanceof arrayType); + assert(ret.toString() == arrayType.from(data).toString()); + // test multiple dispose. + a.dispose(); + a.dispose(); } testArrayCopy("float32", Float32Array); @@ -48,8 +51,3 @@ testArrayCopy("int", Int32Array); testArrayCopy("int8", Int8Array); testArrayCopy("uint8", Uint8Array); testArrayCopy("float64", Float64Array); - -// Function registration -tvm.registerFunc("xyz", function(x, y) { - return x + y; -}); diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js new file mode 100644 index 000000000000..c961f9576e3f --- /dev/null +++ b/web/tests/node/test_packed_func.js @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +const path = require("path"); +const fs = require("fs"); +const assert = require('assert'); +const tvmjs = require("../../dist") + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); + +function testGetGlobal() { + let flist = tvm.listGlobalFuncNames(); + let faddOne = tvm.getGlobalFunc("testing.add_one"); + let fecho = tvm.getGlobalFunc("testing.echo"); + + assert(faddOne(tvm.scalar(1, "int")) == 2); + // check function argument with different types. + assert(fecho(1123) == 1123); + assert(fecho("xyz") == "xyz"); + + let bytes = new Uint8Array([1, 2, 3]); + let rbytes = fecho(bytes); + assert(rbytes.length == bytes.length); + + for (let i = 0; i < bytes.length; ++i) { + assert(rbytes[i] == bytes[i]); + } + + assert(fecho(undefined) == undefined); + + let arr = tvm.empty([2, 2]).copyFrom([1, 2, 3, 4]); + let arr2 = fecho(arr); + assert(arr.handle == arr2.handle); + assert(arr2.toArray().toString() == arr.toArray().toString()); + + let mod = tvm.systemLib(); + let ret = fecho(mod); + assert(ret.handle == mod.handle); + assert(flist.length != 0); + + mod.dispose(); + ret.dispose(); + arr.dispose(); + arr2.dispose(); + fecho.dispose(); + faddOne.dispose(); +} + +function testReturnFunc() { + function addy(y) { + function add(x, z) { + return x + y + z; + } + return add; + } + + let fecho = tvm.getGlobalFunc("testing.echo"); + let myf = tvm.toPackedFunc(addy); + assert(tvm.isPackedFunc(myf)); + let myf2 = tvm.toPackedFunc(myf); + assert(myf2._tvmPackedCell.handle === myf._tvmPackedCell.handle); + let f = myf(10); + + assert(tvm.isPackedFunc(f)); + assert(f(11, 0) == 21); + assert(f("x", 1) == "x101"); + assert(f("x", "yz") == "x10yz"); + + fecho.dispose(); + myf.dispose(); + myf2.dispose(); + // test multiple dispose. + f.dispose(); + f.dispose(); +} + +function testRegisterGlobal() { + tvm.registerFunc("xyz", function (x, y) { + return x + y; + }); + + let f = tvm.getGlobalFunc("xyz"); + assert(f(1, 2) == 3); + f.dispose(); + + let syslib = tvm.systemLib(); + syslib.dispose(); +} + +function testTimer() { + const fecho = tvm.getGlobalFunc("testing.echo"); + const fgetTimer = tvm.getGlobalFunc("wasm.GetTimer"); + + let finvoke = (n) => { + let x = "xyz"; + for (let i = 0; i < n; ++i) { + x = fecho(x); + } + }; + const number = 10000; + const invokeTimer = fgetTimer(finvoke); + console.log("Time cost:", number / invokeTimer(number) * 1000, " ops/sec"); + fecho.dispose(); + invokeTimer.dispose(); + fgetTimer.dispose(); +} + +testGetGlobal(); +testRegisterGlobal(); +testReturnFunc(); +testTimer(); diff --git a/tests/web/prepare_test_libs.py b/web/tests/python/prepare_test_libs.py similarity index 69% rename from tests/web/prepare_test_libs.py rename to web/tests/python/prepare_test_libs.py index a0e2c13eab82..ec4eb5be1536 100644 --- a/tests/web/prepare_test_libs.py +++ b/web/tests/python/prepare_test_libs.py @@ -14,27 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# Prepare test library for js. +# Prepare test library for standalone wasm runtime test. + import tvm from tvm import te -from tvm.contrib import emscripten +from tvm.contrib import emcc import os + def prepare_test_libs(base_path): - target = "llvm -target=asmjs-unknown-emscripten -system-lib" + target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" if not tvm.runtime.enabled(target): raise RuntimeError("Target %s is not enbaled" % target) n = te.var("n") A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - fadd1 = tvm.build(s, [A, B], target, name="add_one") - obj_path = os.path.join(base_path, "test_add_one.bc") - fadd1.save(obj_path) - emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path, - options=["-s", "WASM=0", "-s", "USE_GLFW=3", "-s", - "USE_WEBGL2=1", "-lglfw"]) + fadd = tvm.build(s, [A, B], target, name="add_one") + + wasm_path = os.path.join(base_path, "test_addone.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + if __name__ == "__main__": curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../build")) + prepare_test_libs(os.path.join(curr_path, "../../dist/wasm")) diff --git a/tests/web/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py similarity index 55% rename from tests/web/websock_rpc_test.py rename to web/tests/python/websock_rpc_test.py index 8be8ce04cb75..7fa0c6bdfb57 100644 --- a/tests/web/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -22,45 +22,61 @@ import tvm from tvm import te -import os from tvm import rpc -from tvm.contrib import util, emscripten +from tvm.contrib import util, emcc import numpy as np proxy_host = "localhost" proxy_port = 9090 -def test_rpc_array(): +def test_rpc(): if not tvm.runtime.enabled("rpc"): return - # graph - n = tvm.runtime.convert(1024) + # generate the wasm library + target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" + if not tvm.runtime.enabled(target): + raise RuntimeError("Target %s is not enbaled" % target) + n = te.var("n") A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - remote = rpc.connect(proxy_host, proxy_port, key="js") - target = "llvm -target=asmjs-unknown-emscripten -system-lib" - def check_remote(): - if not tvm.runtime.enabled(target): - print("Skip because %s is not enabled" % target) - return - temp = util.tempdir() + + fadd = tvm.build(s, [A, B], target, name="addone") + temp = util.tempdir() + + wasm_path = temp.relpath("addone.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + + wasm_binary = open(wasm_path, "rb").read() + + remote = rpc.connect(proxy_host, proxy_port, key="wasm", + session_constructor_args=["rpc.WasmSession", wasm_binary]) + + def check(remote): + # basic function checks. + fecho = remote.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + # run the generated library. + f1 = remote.system_lib() ctx = remote.cpu(0) - f = tvm.build(s, [A, B], target, name="myadd") - path_obj = temp.relpath("dev_lib.bc") - path_dso = temp.relpath("dev_lib.js") - f.save(path_obj) - emscripten.create_js(path_dso, path_obj, side_module=True) - # Upload to suffix as dso so it can be loaded remotely - remote.upload(path_dso, "dev_lib.dso") - data = remote.download("dev_lib.dso") - f1 = remote.load_module("dev_lib.dso") a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) - time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) + # invoke the function + addone = f1.get_function("addone") + addone(a, b) + + # time evaluator + time_f = f1.time_evaluator("addone", ctx, number=10) + time_f(a, b) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote() -test_rpc_array() + check(remote) + + +test_rpc() diff --git a/web/tsconfig.json b/web/tsconfig.json new file mode 100644 index 000000000000..3c20b3d20692 --- /dev/null +++ b/web/tsconfig.json @@ -0,0 +1,13 @@ +{ + "compilerOptions": { + "module": "commonjs", + "target": "es6", + "outDir": "dist", + "rootDir": "src", + "declaration": true, + "sourceMap": true, + "strict": true, + }, + "include": ["src"], + "exclude": ["node_modules"] +} diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js deleted file mode 100644 index b62b298d969e..000000000000 --- a/web/tvm_runtime.js +++ /dev/null @@ -1,1274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/** - * TVM Javascript web runtime library. - * - * @projectname tvm - * @version 0.7.dev1 - */ -/* eslint no-unused-vars: "off" */ -/* eslint no-unexpected-multiline: "off" */ -/* eslint indent: "off" */ -/* eslint no-console: "off" */ -/** - * TVM Runtime namespace. - * Provide tvm_runtime.create to create a {@link tvm.TVMRuntime}. - * - * @namespace tvm_runtime - */ -var tvm_runtime = tvm_runtime || {}; - -/** - * TVM root namespace. - * The classes inside this namespace need to be constructed by factory functions. - * Use {@link tvm_runtime}.create to get started. - * - * @namespace tvm - */ -(function() { - /** - * TVMRuntime object for interacting with TVM runtime. - * This object can be constructed using {@link tvm_runtime}.create - * - * @class - * @memberof tvm - */ - function TVMRuntime() { - "use strict"; - var runtime_ref = this; - // Utility function to throw error - function throwError(message) { - if (typeof runtime_ref.logger !== "undefined") { - runtime_ref.logger(message); - } - if (typeof Error !== "undefined") { - throw new Error(message); - } - throw message; - } - var Module = this.Module; - var Runtime = this.Runtime; - if (typeof Module === "undefined") { - throwError("Emscripten Module is not available"); - } - // constants - var SIZEOF_POINTER = 4; - var SIZEOF_SIZE_T = 4; - var SIZEOF_FLOAT = 4; - var SIZEOF_INT = 4; - var SIZEOF_INT8 = 1; - var SIZEOF_INT64 = 8; - var SIZEOF_DOUBLE = 8; - var SIZEOF_TYPE = 4; - var SIZEOF_CTX = SIZEOF_INT + SIZEOF_INT; - var SIZEOF_TVMVALUE = SIZEOF_DOUBLE; - var ARRAY_OFFSET_DATA = 0; - var ARRAY_OFFSET_CTX = ARRAY_OFFSET_DATA + SIZEOF_POINTER; - var ARRAY_OFFSET_DEV_TYPE = ARRAY_OFFSET_CTX; - var ARRAY_OFFSET_DEV_ID = ARRAY_OFFSET_CTX + SIZEOF_INT; - var ARRAY_OFFSET_NDIM = ARRAY_OFFSET_CTX + SIZEOF_CTX; - var ARRAY_OFFSET_DTYPE = ARRAY_OFFSET_NDIM + SIZEOF_INT; - var ARRAY_OFFSET_DTYPE_CODE = ARRAY_OFFSET_DTYPE; - var ARRAY_OFFSET_DTYPE_BITS = ARRAY_OFFSET_DTYPE_CODE + SIZEOF_INT8; - var ARRAY_OFFSET_DTYPE_LANES = ARRAY_OFFSET_DTYPE_BITS + SIZEOF_INT8; - var ARRAY_OFFSET_SHAPE = ARRAY_OFFSET_DTYPE + SIZEOF_TYPE; - var ARRAY_OFFSET_STRIDES = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER; - var ARRAY_OFFSET_BYTE_OFFSET = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER; - // Type codes - var kInt = 0; - var kUInt = 1; - var kFloat = 2; - var kTVMOpaqueHandle = 3; - var kNull = 4; - var kTVMDataType = 5; - var kTVMContext = 6; - var kTVMDLTensorHandle = 7; - var kTVMObjectHandle = 8; - var kTVMModuleHandle = 9; - var kTVMPackedFuncHandle = 10; - var kTVMStr = 11; - var kTVMBytes = 12; - var kTVMObjectRValueRefArg = 14; - //----------------------------------------- - // TVM CWrap library - // ---------------------------------------- - var TVMGetLastError = Module.cwrap( - "TVMGetLastError", - "string", // const char* - []); - - var TVMAPISetLastError = Module.cwrap - ("TVMAPISetLastError", - null, - ["string" // const char* - ]); - - var TVMModImport = Module.cwrap - ("TVMModImport", - "number", - ["number", // TVMModuleHandle mod - "number" // TVMModuleHandle dep - ]); - - var TVMModGetFunction = Module.cwrap - ("TVMModGetFunction", - "number", - ["number", // TVMModuleHandle mod - "string", // const char* func_name - "number", // int query_imports - "number" // TVMFunctionHandle *out - ]); - - var TVMModFree = Module.cwrap - ("TVMModFree", - "number", - ["number" // TVMModeHandle mod - ]); - - var TVMFuncFree = Module.cwrap - ("TVMFuncFree", - "number", - ["number" // TVMFunctionHandle func - ]); - - var TVMFuncCall = Module.cwrap - ("TVMFuncCall", - "number", - ["number", // TVMFunctionHandle func - "number", // TVMValue* arg_values - "number", // int* arg_tcodes - "number", // int num_args - "number", // int ret_val - "number" // int ret_type_code - ]); - - var TVMCFuncSetReturn = Module.cwrap - ("TVMCFuncSetReturn", - "number", - ["number", // TVMRetValueHandle ret - "number", // TVMValue* value - "number", // int* type_code - "number" // int num_ret - ]); - - var TVMCbArgToReturn = Module.cwrap - ("TVMCbArgToReturn", - "number", - ["number", // TVMValue* value - "number" // int* code - ]); - - var TVMFuncCreateFromCFunc = Module.cwrap - ("TVMFuncCreateFromCFunc", - "number", - ["number", // TVMPackedCFunc func, - "number", // void* resource_handle - "number", // TVMPackedCFuncFinalizer fin - "number" // TVMFunctionHandle *out - ]); - - var TVMFuncRegisterGlobal = Module.cwrap - ("TVMFuncRegisterGlobal", - "number", - ["string", // name - "number", // TVMFunctionHandle f - "number" // int override - ]); - - var TVMFuncGetGlobal = Module.cwrap - ("TVMFuncGetGlobal", - "number", - ["string", // const char* name - "number" // TVMFunctionHandle* out - ]); - - var TVMFuncListGlobalNames = Module.cwrap - ("TVMFuncListGlobalNames", - "number", - ["number", // int* out_size - "number" // const char*** out_array - ]); - - - var TVMArrayAlloc = Module.cwrap - ("TVMArrayAlloc", - "number", - ["number", // const tvm_index_t* shape - "number", // int ndim - "number", // int dtype_code - "number", // int dtype_bits - "number", // int dtype_lanes - "number", // int device_type - "number", // int device_id - "number" // int TVMArrayHandle* out - ]); - - var TVMArrayFree = Module.cwrap - ("TVMArrayFree", - "number", - ["number" // TVMArrayHandle handle - ]); - - var TVMArrayCopyFromTo = Module.cwrap - ("TVMArrayCopyFromTo", - "number", - ["number", // TVMArrayHandle from - "number" // TVMArrayHandle to - ]); - - var TVMArrayCopyFromBytes = Module.cwrap - ("TVMArrayCopyFromBytes", - "number", - ["number", // TVMArrayHandle handle - "number", // int data - "number" // size_t nbytes - ]); - - var TVMArrayCopyToBytes = Module.cwrap - ("TVMArrayCopyToBytes", - "number", - ["number", // TVMArrayHandle handle - "number", // int data - "number" // size_t nbytes - ]); - - var TVMModLoadFromFile = Module.cwrap - ("TVMModLoadFromFile", - "number", - ["string", // const char* file_name - "string", // const char* format - "number" // TVMModuleHandle* out - ]) - - //----------------------------------------- - // Static utility functions - // ---------------------------------------- - this.assert = function(condition, message) { - if (!condition) { - message = message || "assert failed"; - throwError(message); - } - }; - /** - * Logging function. - * Override this to change logger behavior. - * - * @param {string} message - */ - this.logger = function(message) { - console.log(message); - }; - - function logging(message) { - runtime_ref.logger(message); - } - // Override print error to logging - Module.printErr = logging; - var CHECK = this.assert; - - function TVM_CALL(ret) { - if (ret != 0) { - throwError(TVMGetLastError()); - } - } - - function CInt64ArrayToJS(ptr, size) { - var ret = []; - for (var i = 0; i < size; ++i) { - ret.push(Module.getValue(ptr + i * SIZEOF_INT64, "i64")); - } - return ret; - } - - function CStringToJS(ptr) { - var ret = []; - var ch = 1; - while (ch != 0) { - ch = Module.getValue(ptr, "i8"); - if (ch != 0) { - ret.push(String.fromCharCode(ch)); - } - ++ptr; - } - return ret.join(""); - } - - function CBytesToJS(ptr) { - var data = Module.getValue(ptr, "*"); - var size = Module.getValue(ptr + SIZEOF_POINTER, "i32"); - var ret = new Uint8Array(new ArrayBuffer(size)); - ret.set(new Uint8Array(Module.HEAPU8.buffer, data, size)); - return ret; - } - - function StringToUint8Array(str) { - var arr = new Uint8Array(str.length + 1); - for(var i = 0; i < str.length; ++i) { - arr[i] = str.charCodeAt(i); - } - arr[str.length] = 0; - return arr; - } - //----------------------------------------- - // Class declarations - // ---------------------------------------- - function CBuffer(nbytes) { - this.data = Module._malloc(nbytes); - } - - function RefTVMValue() { - this.data = Module._malloc(SIZEOF_TVMVALUE); - } - - function TVMArgs(nargs) { - this.nargs = nargs; - this.value = Module._malloc(SIZEOF_TVMVALUE * nargs); - this.tcode = Module._malloc(SIZEOF_INT * nargs); - this.temp = []; - } - - function TVMType(code, bits, lanes) { - this.code = code; - this.bits = bits; - this.lanes = lanes; - } - /** - * TVM device context. - * @class - * @memberof tvm - */ - function TVMContext(device_type, device_id) { - this.device_type = device_type; - this.device_id = device_id; - } - /** - * TVM n-dimensional array. - * - * Use {@link tvm.TVMRuntime}.empty to create an instance. - * @class - * @memberof tvm - */ - function NDArray(handle) { - this.handle = handle; - this.ndim = Module.getValue(this.handle + ARRAY_OFFSET_NDIM, "i32"); - // shape - var cshape = Module.getValue(this.handle + ARRAY_OFFSET_SHAPE, "*"); - this.shape = CInt64ArrayToJS(cshape, this.ndim); - // dtype - var code = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_CODE, "i8"); - var bits = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_BITS, "i8"); - var lanes = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_LANES, "i16"); - var dtype = new TVMType(code, bits, lanes); - this.dtype = dtype; - this.BYTES_PER_ELEMENT = (dtype.bits * dtype.lanes / 8); - // ctx - var device_type = Module.getValue(this.handle + ARRAY_OFFSET_DEV_TYPE, "i32"); - var device_id = Module.getValue(this.handle + ARRAY_OFFSET_DEV_ID, "i32"); - this.context = new TVMContext(device_type, device_id); - // byte_offset - this.byteOffset = Module.getValue(this.handle + ARRAY_OFFSET_BYTE_OFFSET, "i64"); - } - - function TVMFunction(handle) { - this.handle = handle; - } - /** - * Module container of TVM generated functions. - * - * @class - * @memberof tvm - */ - function TVMModule(handle) { - this.handle = handle; - } - /** - * A typed scalar constant. - * This can be used to pass number as integer types to tvm function. - * Use {@link tvm.TVMRuntime}.constant to create an instance. - * @class - * @memberof tvm - */ - function TVMConstant(value, dtype) { - this.value = value; - this.dtype = dtype; - } - //----------------------------------------- - // Private Functions - // ---------------------------------------- - function getTVMType(dtype) { - if (dtype instanceof TVMType) return dtype; - if (typeof dtype == "string") { - var pattern = dtype; - var code, bits = 32, lanes = 1; - if (pattern.substring(0, 5) == "float") { - pattern = pattern.substring(5, pattern.length); - code = kFloat; - } else if (pattern.substring(0, 3) == "int") { - pattern = pattern.substring(3, pattern.length); - code = kInt; - } else if (pattern.substring(0, 4) == "uint") { - pattern = pattern.substring(4, pattern.length); - code = kUInt; - } else if (pattern.substring(0, 6) == "handle") { - pattern = pattern.substring(5, pattern.length); - code = kTVMOpaqueHandle; - bits = 64; - } else { - throw throwError("Unknown dtype " + dtype); - } - var arr = pattern.split("x"); - if (arr.length >= 1) { - var parsed = parseInt(arr[0]); - if (parsed == arr[0]) { - bits = parsed; - } - } - if (arr.length >= 2) { - lanes = parseInt(arr[1]); - } - return new TVMType(code, bits, lanes); - } else { - throw throwError("Unknown dtype " + dtype); - } - } - - function TVMRetValueToJS(vptr, tcode) { - switch (tcode) { - case kInt: - case kUInt: return Module.getValue(vptr, "i64"); - case kFloat: return Module.getValue(vptr, "double"); - case kTVMPackedFuncHandle: return makeTVMFunction(Module.getValue(vptr, "*")); - case kTVMModuleHandle: return new TVMModule(Module.getValue(vptr, "*")); - case kNull: return null; - case kTVMStr: return CStringToJS(Module.getValue(vptr, "*")); - case kTVMBytes: return CBytesToJS(Module.getValue(vptr, "*")); - default: throwError("Unsupported return type code=" + tcode); - } - } - - function makeTVMFunction(handle) { - var func = new TVMFunction(handle); - var ret = function () { - // alloc - var args = new TVMArgs(arguments.length); - var rvalue = new RefTVMValue(); - var rtcode = new RefTVMValue(); - args.setArguments(arguments); - TVM_CALL(TVMFuncCall(handle, args.value, args.tcode, - args.nargs, rvalue.data, rtcode.data)); - var rv = TVMRetValueToJS(rvalue.data, rtcode.asInt()); - // release - args.release(); - rvalue.release(); - rtcode.release(); - return rv; - }; - var release = function() { - func.release(); - }; - ret._tvm_function = func; - ret.release = release; - return ret; - } - //----------------------------------------- - // Javascript PackedCallback System - // ---------------------------------------- - var funcTable = [0]; - var freeFuncId = []; - - function invokeCallback(arg_value, arg_tcode, nargs, ret, handle) { - var args = []; - for (var i = 0; i < nargs; ++i) { - var vptr = arg_value + i * SIZEOF_TVMVALUE; - var tcodeptr = arg_tcode + i * SIZEOF_INT; - var tcode = Module.getValue(tcodeptr, "i32"); - if (tcode == kTVMObjectHandle || - tcode == kTVMObjectRValueRefArg || - tcode == kTVMPackedFuncHandle || - tcode == kTVMModuleHandle) { - TVM_CALL(TVMCbArgToReturn(vptr, tcodeptr)); - } - tcode = Module.getValue(tcodeptr, "i32"); - args.push(TVMRetValueToJS(vptr, tcode)); - } - var rv = funcTable[handle].apply(null, args); - if (typeof rv !== "undefined") { - // alloc - var rarg = new TVMArgs(1); - rarg.setArguments([rv]); - TVM_CALL(TVMCFuncSetReturn(ret, rarg.value, rarg.tcode, 1)); - // release - rarg.release(); - } - return 0; - } - function freeCallback(handle) { - funcTable[handle] = 0; - freeFuncId.push(handle); - } - var fptrInvokeCallback = null; - var fptrFreeCallback = null; - if (typeof Runtime !== "undefined" && - typeof Runtime.addFunction !== "undefined") { - fptrInvokeCallback = Runtime.addFunction(invokeCallback); - fptrFreeCallback = Runtime.addFunction(freeCallback); - } - /** - * Check if a function is TVM PackedFunc - * @param {Function} f function to be checked. - * @return {boolean} Whether f is PackedFunc - */ - this.isPackedFunc = function(f) { - return (typeof f == "function") && f.hasOwnProperty("_tvm_function"); - }; - var isPackedFunc = this.isPackedFunc; - /** - * Convert a javascript function to TVM function. - * @param {Function} f javascript function. - * @return {Function} The created TVMFunction. - */ - this.convertFunc = function(f) { - if (isPackedFunc(f)) return f; - CHECK(fptrInvokeCallback !== null, - "Emscripten Runtime addFunction is not available"); - var fid; - if (freeFuncId.length != 0) { - fid = freeFuncId.pop(); - } else { - fid = funcTable.length; - funcTable.push(0); - } - funcTable[fid] = f; - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMFuncCreateFromCFunc( - fptrInvokeCallback, fid, fptrFreeCallback, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - return makeTVMFunction(out_handle); - }; - var convertFunc = this.convertFunc; - //----------------------------------------- - // Private Class declarations - // ---------------------------------------- - CBuffer.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.data != 0) { - Module._free(this.data); - this.data = 0; - } - }, - }; - // RefTVMValue - RefTVMValue.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.data != 0) { - Module._free(this.data); - this.data = 0; - } - }, - asInt : function() { - return Module.getValue(this.data, "i32"); - }, - asInt64 : function() { - return Module.getValue(this.data, "i64"); - }, - asDouble : function() { - return Module.getValue(this.data, "double"); - }, - asHandle : function() { - return Module.getValue(this.data, "*"); - } - }; - // TVMArgs - TVMArgs.prototype = { - release : function() { - if (this.value != 0) { - Module._free(this.value); - Module._free(this.tcode); - this.value = 0; - for (var i = 0; i< this.temp.length; ++i) { - if (this.temp[i].release instanceof Function) { - this.temp[i].release(); - } - } - } - }, - setInt : function(index, value) { - Module.setValue(this.tcode + index * SIZEOF_INT, kInt, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "i64"); - }, - setDouble : function(index, value) { - Module.setValue(this.tcode + index * SIZEOF_INT, kFloat, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "double"); - }, - setHandle : function(index, value, tcode) { - Module.setValue(this.tcode + index * SIZEOF_INT, tcode, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "*"); - }, - setString : function(index, value) { - var sdata = new CBuffer(value.length + 1); - Module.HEAPU8.set(StringToUint8Array(value), sdata.data); - this.temp.push(sdata); - Module.setValue(this.tcode + index * SIZEOF_INT, kTVMStr, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, sdata.data, "*"); - }, - setBytes : function(index, value) { - CHECK(value instanceof Uint8Array); - var sdata = new CBuffer(value.length); - var sheader = new CBuffer(SIZEOF_POINTER + SIZEOF_SIZE_T); - Module.HEAPU8.set(new Uint8Array(value), sdata.data); - Module.setValue(sheader.data, sdata.data, "*"); - Module.setValue(sheader.data + SIZEOF_POINTER, value.length, "i32"); - this.temp.push(sdata); - this.temp.push(sheader); - Module.setValue(this.tcode + index * SIZEOF_INT, kTVMBytes, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, sheader.data, "*"); - }, - setArguments : function(args) { - for (var i = 0; i < args.length; ++i) { - var v = args[i]; - var tp = typeof v; - if (v instanceof NDArray) { - this.setHandle(i, v.handle, kTVMDLTensorHandle); - } else if (v instanceof TVMConstant) { - var code = getTVMType(v.dtype).code; - if (code == kInt || code == kUInt) { - this.setInt(i, v.value); - } else if (code == kFloat) { - this.setDouble(i, v.value); - } else { - CHECK(code == kTVMOpaqueHandle); - this.setHandle(i, v.value, kTVMOpaqueHandle); - } - } else if (tp == "number") { - this.setDouble(i, v); - } else if (tp == "function" && v.hasOwnProperty("_tvm_function")) { - this.setString(i, v._tvm_function.handle, kTVMPackedFuncHandle); - } else if (v === null) { - this.setHandle(i, 0, kNull); - } else if (tp == "string") { - this.setString(i, v); - } else if (v instanceof Uint8Array) { - this.setBytes(i, v); - } else if (v instanceof Function) { - v = convertFunc(v); - this.temp.push(v); - this.setHandle(i, v._tvm_function.handle, kTVMPackedFuncHandle); - } else if (v instanceof TVMModule) { - this.setHandle(i, v.handle, kTVMModuleHandle); - } else { - throwError("Unsupported argument type " + tp); - } - } - } - }; - // TVMType - var TYPE_CODE2STR = { - 0 : "int", - 1 : "uint", - 2 : "float", - 4 : "handle" - }; - - TVMType.prototype = { - toString : function() { - var ret = TYPE_CODE2STR[this.code] + this.bits.toString(); - if (this.lanes != 1) { - return ret + "x" + this.lanes.toString(); - } else { - return ret; - } - } - }; - // TVMFunction - TVMFunction.prototype = { - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMFuncFree(this.handle)); - this.handle = 0; - } - } - }; - // TVMContext - var CTX_MASK2STR = { - 1 : "cpu", - 2 : "gpu", - 4 : "opencl", - 7 : "vulkan", - 8 : "metal", - 9 : "vpi", - 11 : "opengl", - }; - var CTX_STR2MASK = { - "cpu": 1, - "gpu": 2, - "cuda": 2, - "cl": 4, - "opencl": 4, - "vulkan": 7, - "metal": 8, - "vpi": 9, - "opengl": 11, - }; - TVMContext.prototype = { - toString : function() { - return CTX_MASK2STR[this.device_type] + "(" + this.device_id.toString() + ")"; - } - }; - //----------------------------------------- - // Public Functions - // ---------------------------------------- - /** - * Construct a TVMContext given device type and id. - * - * @param {number} device_type, string or int, The device type. - * @param {number} device_id, the device id. - * @return {tvm.TVMContext} The created TVMContext - */ - this.context = function(device_type, device_id) { - if (typeof device_type == "string") { - device_type = CTX_STR2MASK[device_type]; - } - return new TVMContext(device_type, device_id); - }; - var context = this.context; - /** - * Create empty ndarray with given shape. - * - * @param {Array.} shape The shape of the array. - * @param {string} dtype The data type of the array, optional, default="float32" - * @param {tvm.TVMContext} ctx The context of the array, optional, default=cpu(0). - * @return {tvm.NDArray} The created ndarray. - */ - this.empty = function(shape, dtype, ctx) { - dtype = (typeof dtype !== "undefined") ? dtype: "float32"; - ctx = (typeof ctx !== "undefined") ? ctx : context("cpu", 0); - shape = (typeof shape == "number") ? [shape] : shape; - // alloc - var cshape = Module._malloc(SIZEOF_INT64 * shape.length); - var out = new RefTVMValue(); - for (var i = 0; i < shape.length; ++i) { - Module.setValue(cshape + i * SIZEOF_INT64, shape[i], "i64"); - } - dtype = getTVMType(dtype); - TVM_CALL(TVMArrayAlloc(cshape, shape.length, - dtype.code, dtype.bits, dtype.lanes, - ctx.device_type, ctx.device_id, - out.data)); - var out_handle = out.asHandle(); - // release - Module._free(cshape); - out.release(); - return new NDArray(out_handle); - }; - /** - * List all global function names in the TVM runtime. - * @return {Array.} List of global function names. - */ - this.listGlobalFuncNames = function() { - // alloc - var out_size = new RefTVMValue(); - var out_array = new RefTVMValue(); - TVM_CALL(TVMFuncListGlobalNames(out_size.data, out_array.data)); - var length = out_size.asInt(); - var base = out_array.asHandle(); - var names = []; - for (var i = 0 ; i < length; ++i) { - names.push( - CStringToJS(Module.getValue(base + i * SIZEOF_POINTER, "*"))); - } - // release - out_size.release(); - out_array.release(); - return names; - }; - var listGlobalFuncNames = this.listGlobalFuncNames; - /** - * Get a global function from TVM runtime. - * - * @param {string} The name of the function. - * @return {Function} The corresponding function, null if function do not exist - */ - this.getGlobalFunc = function (name) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMFuncGetGlobal(name, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle != 0) { - return makeTVMFunction(out_handle); - } else { - return null; - } - }; - var getGlobalFunc = this.getGlobalFunc; - /** - * Register function to be global function in tvm runtime. - * @param {string} name The name of the function. - * @param {Function} f function to be registered. - * @param {boolean} override Whether overwrite function in existing registry. - */ - this.registerFunc = function(name, f, override) { - f = convertFunc(f); - override = (typeof override !== "undefined") ? override: false; - var ioverride = override ? 1 : 0; - TVM_CALL(TVMFuncRegisterGlobal(name, f._tvm_function.handle, ioverride)); - }; - /** - * Create a typed scalar constant. - * This can be used to pass number as integer types to tvm function. - * - * @param {number} value The value of the data. - * @param {string} dtype The data type. - * @param {tvm.TVMConstant} The created typed scalar. - */ - this.constant = function(value, dtype) { - return new TVMConstant(value, dtype); - }; - //----------------------------------------- - // Wrap of TVM Functions. - // ---------------------------------------- - var systemFunc = {}; - /** - * Get system-wide library module singleton.5A - * System lib is a global module that contains self register functions in startup. - * @return {tvm.TVMModule} The system module singleton. - */ - this.systemLib = function() { - if (typeof systemFunc.fGetSystemLib === "undefined") { - systemFunc.fGetSystemLib = getGlobalFunc("runtime.SystemLib"); - } - return systemFunc.fGetSystemLib(); - }; - - this.startRPCServer = function(url, key, counter) { - if (typeof key === "undefined") { - key = ""; - } - if (typeof counter === "undefined") { - counter = 1; - } - // Node js, import websocket - var bkey = StringToUint8Array("server:" + key); - bkey = bkey.slice(0, bkey.length - 1); - var server_name = "WebSocketRPCServer[" + key + "]"; - var RPC_MAGIC = 0xff271; - function checkEndian() { - var a = new ArrayBuffer(4); - var b = new Uint8Array(a); - var c = new Uint32Array(a); - b[0] = 0x11; - b[1] = 0x22; - b[2] = 0x33; - b[3] = 0x44; - CHECK(c[0] === 0x44332211, "Need little endian to work"); - } - checkEndian(); - // start rpc - function RPCServer(counter) { - var socket; - if (typeof module !== "undefined" && module.exports) { - // WebSocket for nodejs - const WebSocket = require("ws"); - socket = new WebSocket(url); - } else { - socket = new WebSocket(url); - } - var self = this; - socket.binaryType = "arraybuffer"; - this.init = true; - this.counter = counter; - - if (typeof systemFunc.fcreateServer === "undefined") { - systemFunc.fcreateServer = - getGlobalFunc("rpc._CreateEventDrivenServer"); - } - if (systemFunc.fcreateServer == null) { - throwError("RPCServer is not included in runtime"); - } - - var message_handler = systemFunc.fcreateServer( - function(cbytes) { - if (socket.readyState == 1) { - socket.send(cbytes); - return new TVMConstant(cbytes.length, "int32"); - } else { - return new TVMConstant(0, "int32"); - } - } , server_name, "%toinit"); - - function on_open(event) { - var intbuf = new Int32Array(1); - intbuf[0] = RPC_MAGIC; - socket.send(intbuf); - intbuf[0] = bkey.length; - socket.send(intbuf); - socket.send(bkey); - logging(server_name + " connected..."); - } - - function on_message(event) { - if (self.init) { - var msg = new Uint8Array(event.data); - CHECK(msg.length >= 4, "Need message header to be bigger than 4"); - var magic = new Int32Array(event.data)[0]; - - if (magic == RPC_MAGIC + 1) { - throwError("key: " + key + " has already been used in proxy"); - } else if (magic == RPC_MAGIC + 2) { - logging(server_name + ": RPCProxy do not have matching client key " + key); - } else { - CHECK(magic == RPC_MAGIC, url + "is not RPC Proxy"); - self.init = false; - } - logging(server_name + "init end..."); - if (msg.length > 4) { - if (message_handler( - new Uint8Array(event.data, 4, msg.length -4), - new TVMConstant(3, "int32")) == 0) { - socket.close(); - } - } - } else { - if (message_handler(new Uint8Array(event.data), - new TVMConstant(3, "int32")) == 0) { - socket.close(); - } - } - } - function on_close(event) { - message_handler.release(); - logging(server_name + ": closed finish..."); - if (!self.init && self.counter != 0) { - logging(server_name + ":reconnect to serve another request, session left=" + counter); - // start a new server. - new RPCServer(counter - 1); - } - } - socket.addEventListener("open", on_open); - socket.addEventListener("message", on_message); - socket.addEventListener("close", on_close); - } - return new RPCServer(counter); - }; - - /** - * Load a TVM module from a library file. - * The file must be present in the Emscripten virtual file system. - * For example, you can pass "--preload-file file" or "--preload-file dir/" - * to "emcc" when compiling the TVM library, in order to populate files into - * the file system. - * For more detail, see: - * https://kripken.github.io/emscripten-site/docs/porting/files/packaging_files - * @param {string} file_name Path of the file to be loaded. The path refers - * to the Emscripten virtual file system. - * @param {string} format The format of the file. - * @return {tvm.TVMModule} The loaded module. - */ - this.loadModuleFromFile = function (file_name, format) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMModLoadFromFile(file_name, format, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle != 0) { - return new TVMModule(out_handle); - } else { - return null; - } - }; - var loadModuleFromFile = this.loadModuleFromFile; - - /** - * Wrapper runtime module. - * Wraps around set_input, load_params, run, and get_output. - * - * @class - * @memberof tvm - */ - function GraphModule(tvm_graph_module, ctx) { - CHECK(tvm_graph_module instanceof TVMModule, - "tvm_graph_module must be TVMModule"); - CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); - - this.tvm_graph_module = tvm_graph_module; - this.ctx = ctx; - this._set_input = tvm_graph_module.getFunction("set_input"); - this._load_params = tvm_graph_module.getFunction("load_params"); - this._run = tvm_graph_module.getFunction("run"); - this._get_output = tvm_graph_module.getFunction("get_output"); - }; - - GraphModule.prototype = { - /** - * Set input to graph module. - * - * @param {string} key The name of the input. - * @param {NDArray} value The input value. - */ - "set_input" : function(key, value) { - CHECK(typeof key == "string", "key must be string"); - CHECK(value instanceof NDArray, "value must be NDArray"); - this._set_input(key, value); - }, - - /** - * Load parameters from serialized byte array of parameter dict. - * - * @param {Uint8Array} params The serialized parameter dict. - */ - "load_params" : function(params) { - CHECK(params instanceof Uint8Array, "params must be Uint8Array"); - this._load_params(params); - }, - - /** - * Load parameters from serialized base64 string of parameter dict. - * - * @param {string} base64_params The serialized parameter dict. - */ - "load_base64_params" : function(base64_params) { - CHECK(typeof base64_params == "string", "base64_params must be string"); - var decoded_string = atob(base64_params); - var decoded_u8 = new Uint8Array(decoded_string.length); - for (var i = 0; i < decoded_string.length; i++) { - decoded_u8[i] = decoded_string[i].charCodeAt(0); - } - this.load_params(decoded_u8); - }, - - /** - * Run forward execution of the graph. - */ - "run" : function() { - this._run(); - }, - - /** - * Get index-th output to out. - * - * @param {NDArray} out The output array container. - * @return {NDArray} The output array container. - */ - "get_output" : function(index, out) { - CHECK(typeof index == "number", "index must be number"); - CHECK(out instanceof NDArray, "out must be NDArray"); - this._get_output(new TVMConstant(index, "int32"), out); - return out; - } - }; - - /** - * Create a runtime executor module given a graph and a module. - * @param {string} graph_json_str The Json string of the graph. - * @param {TVMModule} libmod The TVM module. - * @param {TVMContext} ctx The context to deploy the module. - * @return {GraphModule} Runtime graph module for executing the graph. - */ - this.createGraphRuntime = function(graph_json_str, libmod, ctx) { - CHECK(typeof graph_json_str == "string", "graph_json_str must be string"); - CHECK(libmod instanceof TVMModule, "libmod must be TVMModule"); - CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); - - var fcreate = getGlobalFunc("tvm.graph_runtime.create"); - CHECK(fcreate != null, "Cannot find tvm.graph_runtime.create"); - - var tvm_graph_module = fcreate(graph_json_str, libmod, - new TVMConstant(ctx.device_type, "int32"), - new TVMConstant(ctx.device_id, "int32")); - - return new GraphModule(tvm_graph_module, ctx); - }; - - //----------------------------------------- - // Class defintions - // ---------------------------------------- - // NDArray. - NDArray.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMArrayFree(this.handle)); - this.handle = 0; - } - }, - /** - * Copy data from another NDArray or javascript array. - * The number of elements must match. - * - * @param {Array} data The source data array. - */ - copyFrom : function(data) { - if (data instanceof NDArray) { - TVM_CALL(TVMArrayCopyFromTo(data.handle, this.handle)); - } else { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - if (data.length != size) { - throwError("data size and shape mismatch data.length" + data.length + " vs " + size); - } - if (this.dtype == "float32") { - data = Float32Array.from(data); - } else if (this.dtype == "float64") { - data = Float64Array.from(data); - } else if (this.dtype == "int32") { - data = Int32Array.from(data); - } else if (this.dtype == "int8") { - data = Int8Array.from(data); - } else if (this.dtype == "uint8") { - data = Uint8Array.from(data); - } else { - throwError("Unsupported data type " + this.dtype); - } - return this.copyFromRawBytes(new Uint8Array(data.buffer)); - } - }, - /** - * Copy data from raw bytes. - * @param {Uint8Array} data Uint8Array of bytes. - */ - copyFromRawBytes : function(data) { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - var dtype = getTVMType(this.dtype); - var nbytes = this.BYTES_PER_ELEMENT * size; - CHECK(data instanceof Uint8Array); - CHECK(data.length == nbytes, - "Data length and bytes do not match " + data.length + - " vs " + nbytes); - var temp = Module._malloc(nbytes); - Module.HEAPU8.set(data, temp); - TVM_CALL(TVMArrayCopyFromBytes(this.handle, temp, nbytes)); - Module._free(temp); - return this; - }, - /** - * Return a copied Uint8Array of the raw bytes in the NDArray. - * @return {Uint8Array} The created array. - */ - asRawBytes : function() { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - var nbytes = this.BYTES_PER_ELEMENT * size; - var temp = Module._malloc(nbytes); - TVM_CALL(TVMArrayCopyToBytes(this.handle, temp, nbytes)); - var ret = new Uint8Array(new ArrayBuffer(nbytes)); - ret.set(new Uint8Array(Module.HEAPU8.buffer, temp, nbytes)); - Module._free(temp); - return ret; - }, - /** - * Return Array data content as javascript typed array. - * @return {TypedArray} The created array. - */ - asArray : function() { - if (this.dtype == "float32") { - return new Float32Array(this.asRawBytes().buffer); - } else if (this.dtype == "float64") { - return new Float64Array(this.asRawBytes().buffer); - } else if (this.dtype == "int32") { - return new Int32Array(this.asRawBytes().buffer); - } else if (this.dtype == "int8") { - return new Int8Array(this.asRawBytes().buffer); - } else if (this.dtype == "uint8") { - return new Uint8Array(this.asRawBytes().buffer); - } else { - throwError("Unsupported data type " + this.dtype); - } - } - }; - - TVMModule.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMModFree(this.handle)); - this.handle = 0; - } - }, - /** - * Get function from the module. - * @param {string} name The name of the function. - * @return {Function} The correspondin function. - */ - getFunction : function(name) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMModGetFunction(this.handle, name, 0, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle == 0) { - throwError("Module has no function " + name); - } - return makeTVMFunction(out_handle); - }, - /** - * Add module to the import list of current one. - * @param {tvm.TVMModule} mod The other module to be imported. - */ - import_module : function(mod) { - CHECK(mod instanceof TVMModule, "mod must be instance of TVMModule"); - TVM_CALL(TVMModImport(this.handle, mod.handle)); - } - }; - //----------------------------------------- - // Static variables. - // ---------------------------------------- - /** Float32 type */ - this.float32 = "float32"; - /** Int32 type */ - this.int32 = "int32"; - } - /** - * Create a TVM runtime given emscripten module. - * @property {string} create - * @memberof tvm_runtime - * @param Module The emscripten module. - * @return {tvm.TVMRuntime} The created TVM runtime. - */ - this.create = function(Module) { - var tvm = {}; - tvm.Module = Module; - if (typeof Module.addFunction !== "undefined") { - tvm.Runtime = Module; - } else { - tvm.Runtime = Module.Runtime; - } - TVMRuntime.apply(tvm); - return tvm; - }; -}).apply(tvm_runtime); - -// export things in node -if (typeof module !== "undefined" && module.exports) { - module.exports = tvm_runtime; -} diff --git a/web/web_runtime.cc b/web/web_runtime.cc deleted file mode 100644 index 701ded76288e..000000000000 --- a/web/web_runtime.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file web_runtime.cc - */ -#include -#include - -#include "../src/runtime/c_runtime_api.cc" -#include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" -#include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" -#include "../src/runtime/module.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" -#include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" -#include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/opengl/opengl_device_api.cc" -#include "../src/runtime/opengl/opengl_module.cc" - -namespace tvm { -namespace contrib { - -struct RPCEnv { - public: - RPCEnv() { - base_ = "/rpc"; - mkdir(&base_[0], 0777); - } - // Get Path. - std::string GetPath(const std::string& file_name) { - return base_ + "/" + file_name; - } - - private: - std::string base_; -}; - -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") -.set_body_typed([](std::string path) { - static RPCEnv env; - return env.GetPath(path); - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") -.set_body_typed([](std::string path) { - std::string file_name = "/rpc/" + path; - LOG(INFO) << "Load module from " << file_name << " ..."; - return Module::LoadFromFile(file_name, ""); - }); -} // namespace contrib -} // namespace tvm - -// dummy parallel runtime -int TVMBackendParallelLaunch( - FTVMParallelLambda flambda, - void* cdata, - int num_task) { - TVMAPISetLastError("Parallel is not supported in Web runtime"); - return -1; -} - -int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { - return 0; -}