diff --git a/.travis.yml b/.travis.yml index f01653aec86f..a02fbe658554 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,6 +10,7 @@ env: - TASK=build CXX=g++ - TASK=python CXX=g++ - TASK=python3 CXX=g++ + - TASK=python_naive CXX=g++ - TASK=unittest_gtest CXX=g++ # dependent apt packages diff --git a/Makefile b/Makefile index 5334e52ab52e..880ea7a73ab8 100644 --- a/Makefile +++ b/Makefile @@ -70,6 +70,10 @@ ifeq ($(USE_CUDNN), 1) LDFLAGS += -lcudnn endif +ifeq ($(USE_THREADED_ENGINE), 1) + CFLAGS += -DMXNET_USE_THREADED_ENGINE +endif + ifneq ($(ADD_CFLAGS), NONE) CFLAGS += $(ADD_CFLAGS) endif @@ -80,7 +84,7 @@ endif .PHONY: clean all test lint doc -BIN = tests/test_simple_engine +BIN = tests/test_threaded_engine all: lib/libmxnet.a lib/libmxnet.so $(BIN) SRC = $(wildcard src/*.cc src/*/*.cc) diff --git a/doc/developer-guide/engine.md b/doc/developer-guide/engine.md index 51508520a526..be55cdd30c8c 100644 --- a/doc/developer-guide/engine.md +++ b/doc/developer-guide/engine.md @@ -1,8 +1,79 @@ -DAG Engine -========== +Execution Engine +================ -NArray ------- +MXNet's engine is not only for deep learning or any domain-specific problem. Rather, it is designed to face a general problem: execute a bunch of functions following their dependencies. Execution of any two functions with dependencies should be serialized. Functions with no dependencies *may* be executed in parallel to boost performance. -Push Function -------------- +Interface +============== +The core interface of execution engine is: +```c++ +virtual void Push(Fn exec_fun, Context exec_ctx, + std::vector const& use_vars, + std::vector const& mutate_vars) = 0; +``` +This API allows users to push a function (`exec_fun`), along with its context information and dependencies to the engine. The `exec_ctx` is the context information in which the `exec_fun` should be executed. `use_vars` denotes the variables that the function would read from while `mutate_vars` are the variables that to be modified. Regardless of the details that would be explained later, the engine guarantees following order: + +>*The execution of any two functions that any one of them modifies at least one common variable would be serialized in their push order.* + +Function +-------- +The function type of the engine is: +```c++ +using Fn = std::function; +``` +The `RunContext` contains runtime information which is determined by the engine: +```c++ +struct RunContext { + // stream pointer which could be safely cast to + // cudaStream_t* type + void *stream; +}; +``` +Alternatively, one could use `mxnet::engine::DAGEngine::Fn` which is the same type defination. + +All the functions will be executed by the internal threads of the engine. In such model, it is usually not suggested to push *blocking* functions to the engine (usually for dealing with I/O tasks like disk, web service, UI, etc.) since it will occupy the execution thread and reduce the total throughput. In such case, we provide another *asynchronous* function type: +```c++ +using Callback = std::function; +using AsyncFn = std::function; +``` +In the `AsyncFn` function, user could pass the heavy part to their own threads and safely exit the function body. The engine will not consider the function to be finished until the `Callback` function is called. + +Context +-------- +User could specify the `Context` of the function to be executed within. This usually includes whether the function should be run on CPU or GPU, and if GPU, which GPU to use. `Context` is different from `RunContext`. `Context` contains device type (gpu/cpu) and device id while `RunContext` contains information that could only be decided during runtime like on which stream the function should be executed. + +Variable +-------- +`Variable` is used to specify the dependencies of functions. The design of MXNet engine is to decouple it with other modules in MXNet. So `Variable` is like an engine-given token for user to represent the external resources the functions may use or modified. It is designed to be light, so create, delete or copy a variable will incur little overhead. Upon pushing functions, users need to specify the variables that will be used (immutable) in `use_vars` vector and the variables to be modified (mutable) in `mutate_vars` vector. The only rule for the engine to resolve the dependencies among functions pushed is: + +>*The execution of any two functions that any one of them modifies at least one common variable would be serialized in their push order.* + +For example, if `Fn1`, `Fn2` both mutate `V2`, `Fn2` is guaranteed to be executed after `Fn1` if `Fn2` is pushed after `Fn1`. On the other hand, if `Fn1` and `Fn2` both use `V2`, their actual execution order could be any kind. + +This design allows the engine to schedule *non-functional* operations. For example, the weight update function in DNN could now use `+=` operator rather than generating a new weight array each time. + +To create a variable, use `NewVar()` API. To delete a variable, use `PushDelete` API. + +Push & Wait +----------- +**All `Push` APIs are asynchronous.** The API call will return immediately no matter the pushed `Fn` is finished or not. This allows engine to start computing at the same time user thread is pushing functions. All `Push` APIs are not thread-safe. To be specific, only one thread should make engine API calls at one time. + +If you want to wait for a specific `Fn` to be finished, include a callback function in the closure and call the function at the end of your `Fn`. + +If you want to wait for all `Fn` that involves (use/mutate) a certain variable to be finished, use `WaitForVar(var)` API. + +If you want to wait for all pushed `Fn` to be finished, use `WaitForAll()` API. + +Save Object Creation Cost +---------------------------- +In some cases, you need to push several functions to the engine but for tons of times. If the computation of these functions are light, the overhead of copying lambdas and creating use/mutate variable lists would become relatively high. We provide an API to create an `OprHandle` beforehand: +```c++ +virtual OprHandle NewOperator(AsyncFn fn, + std::vector const& use_vars, + std::vector const& mutate_vars) = 0; +``` +So you could keep pushing the `OprHandle` without repeatedly creating them: +```c++ +virtual void Push(OprHandle op, Context exec_ctx) = 0; +``` +To delete it, simply call `DeleteOperator(OprHandle op)` but please make sure the operator has finished computing. \ No newline at end of file diff --git a/include/mxnet/dag_engine.h b/include/mxnet/dag_engine.h deleted file mode 100644 index ca10dad85441..000000000000 --- a/include/mxnet/dag_engine.h +++ /dev/null @@ -1,173 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file dag_engine.h - * \brief DAG engine that schedules data. - */ -#ifndef MXNET_DAG_ENGINE_H_ -#define MXNET_DAG_ENGINE_H_ -#include - -#if DMLC_USE_CXX11 == 0 -#error "C++11 was required for DAG engine module." -#endif - -#include -#include -#include "base.h" -#include "context.h" - -namespace mxnet { - -/*! - * \brief Namespace of engine implementation. - */ -namespace engine { - -/*! - * \brief Inner representation of variable. - */ -struct Var; - -/*! - * \brief Inner representation of operator. - */ -struct Opr; - -} // namespace engine - -/*! - * \brief Dynamic dataflow DAG engine that schedules operations. - */ -class DAGEngine { - public: - /*! - * \brief Operation to pass to DAG engine. - */ - using Fn = std::function; - /*! - * \brief Callback function to notify operation complete. - */ - using Callback = std::function; - /*! - * \brief Asynchronous operation to pass to DAG engine. - */ - using AsyncFn = std::function; - /*! - * \brief Variable of dag engine, used to specify dependencies defined to be a - * pointer, that points to an internal data structure of the engine - * itself. - */ - using Variable = engine::Var*; - /*! - * \brief Operator of the engine. - */ - using OprHandle = engine::Opr*; - /*! - * \brief Allocate a new variable, the variable can then - * be used to schedule the operation concurrently via dependency - * patterns. - * \return The new variable allocated. - */ - virtual Variable NewVar() = 0; - /*! - * \brief Create a new operator. The returned operator could be saved - * externally so that it could be resued for scheduling. - * \param fn The execution function. - * \param use_vars The variables that current operation will use but not - * mutate. - * \param mutate_vars Teh variables that current operation will mutate. - * \return The new operator allocated. - */ - virtual OprHandle NewOperator(AsyncFn fn, - std::vector const& use_vars, - std::vector const& mutate_vars) = 0; - /*! - * \brief Delete the given operator. - * \param op The operator to delete. - */ - virtual void DeleteOperator(OprHandle op) = 0; - /*! - * \brief Push an operator to the engine. - * \param op The operator to push. - * \param exec_ctx Execution context. - */ - virtual void Push(OprHandle op, Context exec_ctx) = 0; - /*! - * \brief Push an synchronous operation to the DAG engine. - * \param exec_fun Execution function that executes the operation. - * \param exec_ctx Execution context. - * \param use_vars The variables that current operation will use but not - * mutate. - * \param mutate_vars The variables that current operation will mutate. - */ - virtual void Push(Fn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) = 0; - /*! - * \brief Push an asynchronous operation to the DAG engine. - * \param exec_fun Execution function, this function takes a parameter - * on_complete that must be called when the execution - * completes. - * \param exec_ctx Execution context. - * \param use_vars The variables that current operation will use but not - * mutate. - * \param mutate_vars The variables that current operation will mutate. - */ - virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) = 0; - /*! - * \brief Schedule the delete of a variable. - * - * The delete will not happen immediately, but will wait until all the - * operations depending on var is completed. - * - * \param delete_fun A function that will be called after the variable is - * deleted. - * \param exec_ctx Execution context. - * \param var The variable to be deleted. - */ - virtual void PushDelete(Fn delete_fun, Context exec_ctx, Variable var) = 0; - /*! - * \brief Wait to read a variable. - * - * The caller should read the content immediately in a synchronized way, - * before any subsequent write operations are issued. - * The subsequent write operations to the variable can destroy the content. - * - * \param var The variable we should wait for, - * This function returns when all the write operations to this - * var has been completed. - */ - virtual void WaitToRead(Variable var) = 0; - /*! - * \brief Wait to write a variable. - * - * The caller should rwrite the content immediately in a synchronized way, - * before any subsequent write operations are issued. - * The subsequent write operations to the variable can destroy the content. - * - * \param var The variable we should wait for, - * This function returns when all the read/write operations - * on var has been completed. - */ - virtual void WaitToWrite(Variable var) = 0; - /*! - * \brief Wait until all the activity of dag engine finishes. - */ - virtual void WaitForAll() = 0; - /*! - * \brief Virtual destructor. - */ - virtual ~DAGEngine() noexcept(false) {} - /*! - * \return DAG engine singleton. - */ - static DAGEngine* Get(); - - // remove DISALLOW_COPY_AND_ASSIGN since this is virtual class. -}; // class DAGEngine - -} // namespace mxnet - -#endif // MXNET_DAG_ENGINE_H_ diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h new file mode 100644 index 000000000000..91b9f2a72b8d --- /dev/null +++ b/include/mxnet/engine.h @@ -0,0 +1,168 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file engine.h + * \brief Engine that schedules data. + */ +#ifndef MXNET_ENGINE_H_ +#define MXNET_ENGINE_H_ +#include + +#if DMLC_USE_CXX11 == 0 +#error "C++11 was required for engine module." +#endif + +#include +#include +#include "base.h" +#include "context.h" + +namespace mxnet { + +/*! + * \brief Namespace of engine implementation. + */ +namespace engine { + +/*! + * \brief Inner representation of variable. + */ +struct Var; + +/*! + * \brief Inner representation of operator. + */ +struct Opr; + +} // namespace engine + +/*! + * \brief Function property. + */ +enum class FnProperty { kNormal, kIO, kAsync }; // enum class FnProperty + +/*! + * \brief Dynamic dataflow engine that schedules operations. + */ +class Engine { + public: + /*! + * \brief Operation to pass to engine. + */ + using Fn = std::function; + /*! + * \brief Callback function to notify operation complete. + */ + using Callback = std::function; + /*! + * \brief Asynchronous operation to pass to engine. + */ + using AsyncFn = std::function; + /*! + * \brief Variable of engine, used to specify dependencies defined to be a + * pointer, that points to an internal data structure of the engine + * itself. + */ + using VarHandle = engine::Var*; + /*! + * \brief Operator of the engine. + */ + using OprHandle = engine::Opr*; + /*! + * \brief Allocate a new variable, the variable can then + * be used to schedule the operation concurrently via dependency + * patterns. + * \return The new variable allocated. + */ + virtual VarHandle NewVariable() = 0; + /*! + * \brief Create a new operator. The returned operator could be saved + * externally so that it could be resued for scheduling. + * \param fn The execution function. + * \param const_vars The variables that current operation will use but not + * mutate. + * \param mutable_vars The variables that current operation will mutate. + * \param prop Property of the function. + * \return The new operator allocated. + */ + virtual OprHandle NewOperator(AsyncFn fn, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop = FnProperty::kNormal) = 0; + /*! + * \brief Delete the given operator. + * \param op The operator to delete. + * + * The delete will not happen immediately, but will wait until all the + * operations using this operator are completed. + */ + virtual void DeleteOperator(OprHandle op) = 0; + /*! + * \brief Push an operator to the engine. + * \param op The operator to push. + * \param exec_ctx Execution context. + */ + virtual void Push(OprHandle op, Context exec_ctx) = 0; + /*! + * \brief Push an synchronous operation to the engine. + * \param exec_fun Execution function that executes the operation. + * \param exec_ctx Execution context. + * \param const_vars The variables that current operation will use but not + * mutate. + * \param mutable_vars The variables that current operation will mutate. + * \param prop Property of the function. + */ + virtual void Push(Fn exec_fun, Context exec_ctx, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop = FnProperty::kNormal) = 0; + /*! + * \brief Push an asynchronous operation to the engine. + * \param exec_fun Execution function, this function takes a parameter + * on_complete that must be called when the execution + * completes. + * \param exec_ctx Execution context. + * \param const_vars The variables that current operation will use but not + * mutate. + * \param mutable_vars The variables that current operation will mutate. + * \param prop Property of the function. + */ + virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop = FnProperty::kNormal) = 0; + /*! + * \brief Schedule the deletion of a variable. + * + * The delete will not happen immediately, but will wait until all the + * operations depending on var are completed. + * + * \param delete_fun A function that will be called after the variable is + * deleted. + * \param exec_ctx Execution context. + * \param var The variable to be deleted. + */ + virtual void DeleteVariable(Fn delete_fun, Context exec_ctx, + VarHandle var) = 0; + /*! + * \brief Wait for a variable. + * \param var The variable we should wait for. This function returns when the + * variable is ready. + */ + virtual void WaitForVar(VarHandle var) = 0; + /*! + * \brief Wait until all the activity of engine finishes. + */ + virtual void WaitForAll() = 0; + /*! + * \brief Virtual destructor. + */ + virtual ~Engine() noexcept(false); + /*! + * \return Engine singleton. + */ + static Engine* Get(); +}; // class Engine + +} // namespace mxnet + +#endif // MXNET_ENGINE_H_ diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 26e796d55d39..fb8fc2b7484d 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -16,7 +16,7 @@ #include "./context.h" #include "./storage.h" #include "./context.h" -#include "./dag_engine.h" +#include "./engine.h" // check c++11 #if DMLC_USE_CXX11 == 0 #error "cxx11 was required for narray module" @@ -79,7 +79,7 @@ class NArray { */ inline void WaitToRead() const { if (is_none()) return; - DAGEngine::Get()->WaitToRead(ptr_->var); + Engine::Get()->WaitForVar(ptr_->var); } /*! * \brief Block until all the pending read/write operations with respect @@ -87,10 +87,15 @@ class NArray { */ inline void WaitToWrite() const { if (is_none()) return; - DAGEngine::Get()->WaitToWrite(ptr_->var); + /*! + * Push an empty mutable function to flush all preceding reads to the + * variable. + */ + Engine::Get()->Push([](RunContext) {}, Context{}, {}, {ptr_->var}); + Engine::Get()->WaitForVar(ptr_->var); } - /*! \return the associated DAG variable of the narray.*/ - inline DAGEngine::Variable var() const { + /*! \return the associated variable of the narray.*/ + inline Engine::VarHandle var() const { return ptr_->var; } /*! @@ -239,8 +244,8 @@ class NArray { struct Chunk { /*! \brief storage handlefrom storage engine */ Storage::Handle shandle; - /*! \brief variable from DAG engine */ - DAGEngine::Variable var; + /*! \brief variable from engine */ + Engine::VarHandle var; /*! * \brief if this is true, this means the data do not come * from Storage, and do not need to be freed @@ -250,13 +255,13 @@ class NArray { bool delay_alloc; /*! \brief default cosntructor */ Chunk() : static_data(true), delay_alloc(false) { - var = DAGEngine::Get()->NewVar(); + var = Engine::Get()->NewVariable(); } /*! \brief construct from static data */ Chunk(const TBlob &data, int dev_id) : static_data(true), delay_alloc(false) { - var = DAGEngine::Get()->NewVar(); + var = Engine::Get()->NewVariable(); shandle.ctx = Context(data.dev_mask_, dev_id); shandle.dptr = data.dptr_; shandle.size = data.shape_.Size() * sizeof(real_t); @@ -264,7 +269,7 @@ class NArray { /*! \brief construct a new chunk */ Chunk(uint64_t size, Context ctx, bool delay_alloc_) : static_data(false), delay_alloc(true) { - var = DAGEngine::Get()->NewVar(); + var = Engine::Get()->NewVariable(); shandle.size = size * sizeof(real_t); shandle.ctx = ctx; if (!delay_alloc_) this->CheckAndAlloc(); @@ -279,11 +284,11 @@ class NArray { /*! \brief destructor */ ~Chunk() { if (static_data) { - DAGEngine::Get()->PushDelete([](RunContext s) {}, shandle.ctx, var); + Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var); } else { CHECK(!delay_alloc) << "deleted before allocation"; Storage::Handle h = this->shandle; - DAGEngine::Get()->PushDelete([h](RunContext s) { + Engine::Get()->DeleteVariable([h](RunContext s) { Storage::Get()->Free(h); }, shandle.ctx, var); } diff --git a/make/config.mk b/make/config.mk index 3e93e240e493..3bc639ca1dba 100644 --- a/make/config.mk +++ b/make/config.mk @@ -55,6 +55,9 @@ PS_THIRD_PATH = NONE USE_RABIT_PS = 0 RABIT_PATH = rabit +# Whether to use threaded engine instead of naive one +# USE_THREADED_ENGINE =1 + # use openmp iterator USE_OPENMP_ITER = 1 # the additional link flags you want to add diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index 1046051464be..4b52c354df19 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -22,21 +22,29 @@ export CXX="g++-4.8" if [ ${TASK} == "build" ]; then echo "USE_CUDA=1" >> config.mk + echo "USE_THREADED_ENGINE=1" >> config.mk ./dmlc-core/scripts/setup_nvcc.sh $NVCC_PREFIX make all || exit -1 fi if [ ${TASK} == "python" ]; then echo "USE_CUDA=0" >> config.mk + echo "USE_THREADED_ENGINE=1" >> config.mk make all || exit -1 nosetests tests/python || exit -1 fi if [ ${TASK} == "python3" ]; then echo "USE_CUDA=0" >> config.mk + echo "USE_THREADED_ENGINE=1" >> config.mk make all || exit -1 nosetests3 tests/python || exit -1 fi +if [ ${TASK} == "python_naive" ]; then + echo "USE_CUDA=0" >> config.mk + make all || exit -1 + nosetests tests/python || exit -1 +fi # TODO(yutian): add unittest back diff --git a/src/c_api.cc b/src/c_api.cc index 8a1cde4da75e..546e858b3e5c 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -335,7 +335,7 @@ int MXNArrayListLoad(const char* fname, int MXNArrayWaitAll() { API_BEGIN(); - DAGEngine::Get()->WaitForAll(); + Engine::Get()->WaitForAll(); API_END(); } diff --git a/src/common/object_pool.h b/src/common/object_pool.h new file mode 100644 index 000000000000..052688ce601a --- /dev/null +++ b/src/common/object_pool.h @@ -0,0 +1,175 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_COMMON_OBJECT_POOL_H_ +#define MXNET_COMMON_OBJECT_POOL_H_ +#include +#include +#include +#include +#include + +namespace common { + +/*! + * \brief Object pool for fast allocation and deallocation. + */ +template +class ObjectPool { + public: + /*! + * \brief Destructor. + */ + ~ObjectPool(); + /*! + * \brief Create new object. + * \return Pointer to the new object. + */ + template + T* New(Args&&... args); + /*! + * \brief Delete an existing object. + * \param ptr The pointer to delete. + * + * Make sure the pointer to delete is allocated from this pool. + */ + void Delete(T* ptr); + + /*! + * \brief Get singleton instance of pool. + * \return Object Pool. + */ + static ObjectPool* Get(); + + private: + /*! + * \brief Internal structure to hold pointers. + */ + struct LinkedList { + union { + LinkedList* next{nullptr}; + T t; + }; + }; + /*! + * \brief Page size of allocation. + * + * Currently defined to be 4KB. + */ + constexpr static std::size_t kPageSize = 1 << 12; + std::mutex m_; + /*! + * \brief Head of free list. + */ + LinkedList* head_{nullptr}; + /*! + * \brief Pages allocated. + */ + std::vector allocated_; + /*! + * \brief Private constructor. + */ + ObjectPool(); + /*! + * \brief Allocate a page of raw objects. + * + * This function is not protected and must be called with caution. + */ + void AllocateChunk(); + DISALLOW_COPY_AND_ASSIGN(ObjectPool); +}; // class ObjectPool + +/*! + * \brief Helper trait class for easy allocation and deallocation. + */ +template +struct ObjectPoolAllocatable { + /*! + * \brief Create new object. + * \return Pointer to the new object. + */ + template + static T* New(Args&&... args); + /*! + * \brief Delete an existing object. + * \param ptr The pointer to delete. + * + * Make sure the pointer to delete is allocated from this pool. + */ + static void Delete(T* ptr); +}; // struct ObjectPoolAllocatable + +template +ObjectPool::~ObjectPool() { + // TODO(hotpxl): mind destruction order + // for (auto i : allocated_) { + // free(i); + // } +} + +template +template +T* ObjectPool::New(Args&&... args) { + LinkedList* ret; + { + std::lock_guard lock{m_}; + if (head_->next == nullptr) { + AllocateChunk(); + } + ret = head_; + head_ = head_->next; + } + return new (static_cast(ret)) T(std::forward(args)...); +} + +template +void ObjectPool::Delete(T* ptr) { + ptr->~T(); + auto linked_list_ptr = reinterpret_cast(ptr); + { + std::lock_guard lock{m_}; + linked_list_ptr->next = head_; + head_ = linked_list_ptr; + } +} + +template +ObjectPool* ObjectPool::Get() { + static ObjectPool inst; + return &inst; +} + +template +ObjectPool::ObjectPool() { + AllocateChunk(); +} + +template +void ObjectPool::AllocateChunk() { + static_assert(sizeof(LinkedList) <= kPageSize, "Object too big."); + void* new_chunk_ptr; + int ret = posix_memalign(&new_chunk_ptr, kPageSize, kPageSize); + CHECK_EQ(ret, 0) << "Allocation failed"; + allocated_.emplace_back(new_chunk_ptr); + auto new_chunk = static_cast(new_chunk_ptr); + auto size = kPageSize / sizeof(LinkedList); + for (std::size_t i = 0; i < size - 1; ++i) { + new_chunk[i].next = &new_chunk[i + 1]; + } + new_chunk[size - 1].next = head_; + head_ = new_chunk; +} + +template +template +T* ObjectPoolAllocatable::New(Args&&... args) { + return ObjectPool::Get()->New(std::forward(args)...); +} + +template +void ObjectPoolAllocatable::Delete(T* ptr) { + ObjectPool::Get()->Delete(ptr); +} + +} // namespace common +#endif // MXNET_COMMON_OBJECT_POOL_H_ diff --git a/src/dag_engine/dag_egine.cc.bak b/src/dag_engine/dag_egine.cc.bak deleted file mode 100644 index c892e0eddc14..000000000000 --- a/src/dag_engine/dag_egine.cc.bak +++ /dev/null @@ -1,20 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - */ -#include "mxnet/dag_engine.h" -#include "simple_engine.h" -#include "dag_engine_impl.h" - -namespace mxnet { - -DAGEngine* DAGEngine::Get() { - /*! - * \brief Change specific engine to use. - */ - using EngineImplementation = engine::SimpleEngine; - - static EngineImplementation inst; - return &inst; -} - -} // namespace mxnet diff --git a/src/dag_engine/naive_engine.cc b/src/dag_engine/naive_engine.cc deleted file mode 100644 index bffeb474bfa6..000000000000 --- a/src/dag_engine/naive_engine.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright by Contributors -#include - -namespace mxnet { -namespace engine { - -// The Naive engine interface -class NaiveEngine : public DAGEngine { - public: - NaiveEngine() { - #if MXNET_USE_CUDA - stream_ = mshadow::NewStream(true, false); - ctx_.stream = stream_; - #endif - } - - ~NaiveEngine() { - #if MXNET_USE_CUDA - mshadow::DeleteStream(stream_); - #endif - } - - Variable NewVar() override { - return nullptr; - } - - OprHandle NewOperator(AsyncFn fn, - std::vector const& use_vars, - std::vector const& mutate_vars) override { - LOG(FATAL) << "Not implemented"; - return nullptr; - } - - void DeleteOperator(OprHandle op) override { - LOG(FATAL) << "Not implemented"; - } - - void Push(OprHandle op, Context exec_ctx) override { - LOG(FATAL) << "Not implemented"; - } - - void Push(Fn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) override { - if (exec_ctx.dev_mask == gpu::kDevMask) { -#if MXNET_USE_CUDA - mshadow::SetDevice(exec_ctx.dev_id); - ctx_.stream = stream_; - exec_fun(ctx_); - stream_->Wait(); -#else - LOG(FATAL) << "GPU is not enabled"; -#endif - } else { - exec_fun(ctx_); - } - } - - void PushAsync(AsyncFn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) override { - LOG(FATAL) << "Not implemented"; - } - - void PushDelete(Fn delete_fun, Context exec_ctx, Variable var) override { - this->Push(delete_fun, exec_ctx, {}, {var}); - } - - void WaitToRead(Variable var) override { - } - - void WaitToWrite(Variable var) override { - } - - void WaitForAll() override { - } - - private: - RunContext ctx_; - #if MXNET_USE_CUDA - mshadow::Stream *stream_; - #endif -}; - -} // namespace engine - -DAGEngine* DAGEngine::Get() { - static mxnet::engine::NaiveEngine engine; - return &engine; -} -} // namespace mxnet diff --git a/src/dag_engine/object_pool.h b/src/dag_engine/object_pool.h deleted file mode 100644 index 1257cb540ccc..000000000000 --- a/src/dag_engine/object_pool.h +++ /dev/null @@ -1,108 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - */ -#ifndef MXNET_DAG_ENGINE_OBJECT_POOL_H_ -#define MXNET_DAG_ENGINE_OBJECT_POOL_H_ -#include -#include -#include "common.h" - -template -class SmallObjectPool { - public: - struct LinkedList { - union { - LinkedList* next{nullptr}; - T t; - }; - }; - ~SmallObjectPool() = default; - T* New(); - void Delete(T* ptr); - - static SmallObjectPool* Get(); - - private: - constexpr static std::size_t kPageSize = 1 << 12; - std::recursive_mutex m_; - LinkedList* head_{nullptr}; - SmallObjectPool(); - void AllocateChunk(); - - SmallObjectPool(SmallObjectPool const&) = delete; - SmallObjectPool(SmallObjectPool&&) = delete; - SmallObjectPool& operator=(SmallObjectPool const&) = delete; - SmallObjectPool& operator=(SmallObjectPool&&) = delete; -}; - -template -T* SmallObjectPool::New() { - LinkedList* ret; - { - std::lock_guard lock{m_}; - if (head_->next == nullptr) { - AllocateChunk(); - } - ret = head_; - head_ = head_->next; - } - return new(static_cast(ret)) T{}; -} - -template -void SmallObjectPool::Delete(T* ptr) { - ptr->~T(); - auto linked_list_ptr = reinterpret_cast(ptr); - { - std::lock_guard lock{m_}; - linked_list_ptr->next = head_; - head_ = linked_list_ptr; - } -} - -template -SmallObjectPool* SmallObjectPool::Get() { - static SmallObjectPool inst; - return &inst; -} - -template -SmallObjectPool::SmallObjectPool() { - AllocateChunk(); -} - -template -void SmallObjectPool::AllocateChunk() { - std::lock_guard lock{m_}; - static_assert(kPageSize % sizeof(LinkedList) == 0, - "Could not align to page size."); - auto&& new_chunk = static_cast(malloc(kPageSize)); - auto size = kPageSize / sizeof(LinkedList); - for (std::size_t i = 0 ; i < size - 1; ++i) { - new_chunk[i].next = &new_chunk[i + 1]; - } - new_chunk[size - 1].next = head_; - head_ = new_chunk; -} - - -struct A { - A() { - LOG("constructing"); - } - ~A() { - LOG("destructing"); - } -}; - -int main() { - auto&& pool = SmallObjectPool::Get(); - auto a = pool->New(); - auto b = pool->New(); - LOG("addresses %p %p", a, b); - pool->Delete(a); - a = pool->New(); - LOG("address again %p", a); - return 0; -} -#endif // MXNET_DAG_ENGINE_OBJECT_POOL_H_ diff --git a/src/dag_engine/simple_engine.cc.bak b/src/dag_engine/simple_engine.cc.bak deleted file mode 100644 index defc7404519e..000000000000 --- a/src/dag_engine/simple_engine.cc.bak +++ /dev/null @@ -1,287 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - */ -#include "simple_engine.h" -#include -#include -#include -#include -#include -#include -#include "../common/cuda_utils.h" - -namespace mxnet { - -namespace engine { - -#ifdef DAG_ENGINE_DEBUG -std::atomic OprBlock::counter{0}; -std::atomic VersionedVarBlock::counter{0}; -std::atomic SimpleVar::counter{0}; -std::atomic SimpleOpr::counter{0}; -#endif // DAG_ENGINE_DEBUG - -SimpleVar* SimpleVar::CastFromBase(Var* v) { return v->Cast(); } - -SimpleOpr* SimpleOpr::CastFromBase(Opr* o) { return o->Cast(); } - -SimpleEngine::SimpleEngine() - : pending_{0}, thread_pool_{[this]() { ThreadWorker(); }} {} - -SimpleEngine::~SimpleEngine() noexcept(false) { task_queue_.SignalForKill(); } - -SimpleVar* SimpleEngine::NewVar() { - auto ret = new SimpleVar{}; - ret->head = new VersionedVarBlock{}; - return ret; -} - -SimpleOpr* SimpleEngine::NewOperator(SimpleEngine::AsyncFn fn, - std::vector const& use_vars, - std::vector const& mutate_vars) { - auto ret = new SimpleOpr{}; - ret->fn = fn; - ret->use_vars.resize(use_vars.size()); - ret->mutate_vars.resize(mutate_vars.size()); - std::transform(use_vars.begin(), use_vars.end(), ret->use_vars.begin(), - SimpleVar::CastFromBase); - std::transform(mutate_vars.begin(), mutate_vars.end(), - ret->mutate_vars.begin(), SimpleVar::CastFromBase); -#ifdef DAG_ENGINE_DEBUG - // Check for duplicates. - auto use = use_vars; - auto mutate = mutate_vars; - auto use_size = use.size(); - auto mutate_size = mutate.size(); - std::sort(use.begin(), use.end()); - std::sort(mutate.begin(), mutate.end()); - for (std::size_t i = 0; i < use_size; ++i) { - if (i != 0 && use.at(i) == use.at(i - 1)) { - LOG(FATAL) << "duplicate items found in `use_vars`"; - } - } - for (std::size_t i = 0; i < mutate_size; ++i) { - if (i != 0 && mutate.at(i) == mutate.at(i - 1)) { - LOG(FATAL) << "duplicate items found in `mutate_vars`"; - } - } - std::size_t j = 0; - for (std::size_t i = 0; i < use_size; ++i) { - while (j < mutate_size && mutate.at(j) < use.at(i)) { - ++j; - } - if (j == mutate_size) { - break; - } - if (mutate.at(j) == use.at(i)) { - LOG(FATAL) - << "duplicate items found between `use_vars` and `mutate_vars`"; - } - } -#endif // DAG_ENGINE_DEBUG - return ret; -} - -void SimpleEngine::DeleteOperator(OprHandle op) { - auto&& simple_opr = SimpleOpr::CastFromBase(op); - std::vector deps{}; - deps.reserve(simple_opr->use_vars.size() + simple_opr->mutate_vars.size()); - deps.insert(deps.end(), simple_opr->use_vars.begin(), - simple_opr->use_vars.end()); - deps.insert(deps.end(), simple_opr->mutate_vars.begin(), - simple_opr->mutate_vars.end()); - auto&& func = [simple_opr](RunContext) { delete simple_opr; }; - Push(func, Context{}, {}, deps); -} - -void SimpleEngine::Push(Fn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) { - auto f = [exec_fun](RunContext ctx, Callback on_complete) { - exec_fun(ctx); - on_complete(); - }; - PushAsync(f, exec_ctx, use_vars, mutate_vars); -} - -void SimpleEngine::Push(OprHandle op, Context exec_ctx) { - auto&& simple_opr = SimpleOpr::CastFromBase(op); - auto&& opr_block = new OprBlock{}; - opr_block->opr = simple_opr; - opr_block->wait.store(simple_opr->use_vars.size() + - simple_opr->mutate_vars.size() + 1); - opr_block->ctx = exec_ctx; - opr_block->rctx = RunContext{nullptr}; - ++pending_; - // Add read dependencies. - for (auto&& i : simple_opr->use_vars) { - std::lock_guard lock{i->m}; - if (i->ready_to_read) { - assert(i->pending_write == nullptr); - ++i->num_pending_reads; - --opr_block->wait; - } else { - auto&& new_var_block = new VersionedVarBlock{}; - assert(i->head->next == nullptr); - assert(i->head->trigger == nullptr); - assert(i->head->write == false); - i->head->next = new_var_block; - i->head->trigger = opr_block; - i->head = new_var_block; - } - } - // Add write dependencies. - for (auto&& i : simple_opr->mutate_vars) { - std::lock_guard lock{i->m}; - auto&& new_var_block = new VersionedVarBlock{}; - i->head->next = new_var_block; - i->head->trigger = opr_block; - i->head->write = true; - if (i->ready_to_read) { - /*! - * Raise `num_pending_reads` temporarily to avoid premature triggering. - */ - ++i->num_pending_reads; - i->pending_write = i->head; - if (--i->num_pending_reads == 0) { - --opr_block->wait; - } - i->ready_to_read = false; - } - i->head = new_var_block; - } - if (--opr_block->wait == 0) { - task_queue_.Push(opr_block); - } -} - -void SimpleEngine::PushAsync(AsyncFn fn, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) { - auto&& opr = NewOperator(fn, use_vars, mutate_vars); - opr->temporary = true; - Push(opr, exec_ctx); -} - -void SimpleEngine::PushDelete(Fn delete_fn, Context exec_ctx, Variable var) { - auto&& simple_var = SimpleVar::CastFromBase(var); - auto&& func = [delete_fn, simple_var](RunContext ctx) { - /*! - * Mark variable as orphan, so during `SimpleEngine::OnComplete` it could be - * recycled. - */ - simple_var->to_delete = true; - delete_fn(ctx); - }; - Push(func, exec_ctx, {}, {var}); -} - -void SimpleEngine::WaitForVar(Variable var) { - std::unique_lock lock{finished_m_}; - std::atomic done{false}; - auto&& callback = [this, &done](RunContext) { - std::unique_lock lock{finished_m_}; - done.store(true); - finished_cv_.notify_all(); - }; - Push(callback, Context{}, {var}, {}); - finished_cv_.wait(lock, [&done]() { return done.load(); }); -} - -void SimpleEngine::WaitForAll() { - std::unique_lock lock{finished_m_}; - finished_cv_.wait(lock, [this]() { return pending_.load() == 0; }); -} - -void SimpleEngine::OnComplete(SimpleOpr* simple_opr) { - /*! - * Mark complete for read variables. - */ - for (auto&& i : simple_opr->use_vars) { - std::lock_guard lock{i->m}; - if (--i->num_pending_reads == 0) { - if (i->pending_write != nullptr && - --i->pending_write->trigger->wait == 0) { - task_queue_.Push(i->pending_write->trigger); - } - } - } - /*! - * Mark complete for write variables. - */ - for (auto&& i : simple_opr->mutate_vars) { - bool to_delete = false; - { - std::lock_guard lock{i->m}; - assert(i->ready_to_read == false); - auto head = i->pending_write->next; - delete i->pending_write; - i->pending_write = nullptr; - if (i->to_delete) { - assert(head->next == nullptr); - delete head; - to_delete = true; - } else { - while (true) { - if (head->write == true) { - ++i->num_pending_reads; - i->pending_write = head; - if (--i->num_pending_reads == 0) { - if (--head->trigger->wait == 0) { - task_queue_.Push(head->trigger); - } - } - break; - } else if (head->next == nullptr) { - i->ready_to_read = true; - break; - } else { - ++i->num_pending_reads; - if (--head->trigger->wait == 0) { - task_queue_.Push(head->trigger); - } - auto prev = head; - head = head->next; - delete prev; - } - } - } - } - if (to_delete) { - delete i; - } - } - { - std::unique_lock lock{finished_m_}; - if (--pending_ == 0) { - finished_cv_.notify_all(); - } - } -} - -void SimpleEngine::ThreadWorker() { - OprBlock* opr_block; - while (task_queue_.Pop(&opr_block)) { - assert(opr_block->wait.load() == 0); - auto simple_opr = opr_block->opr; - auto callback = [this, simple_opr]() { - OnComplete(simple_opr); - if (simple_opr->temporary) { - delete simple_opr; - } - }; - if (opr_block->ctx.dev_mask == gpu::kDevMask) { -#if MXNET_USE_CUDA - CUDA_CALL(cudaSetDevice(opr_block->ctx.dev_id)); -#else // MXNET_USE_CUDA - LOG(FATAL) << "Please compile with CUDA enabled"; -#endif // MXNET_USE_CUDA - } - simple_opr->fn(opr_block->rctx, callback); - delete opr_block; - } -} - -} // namespace engine - -} // namespace mxnet diff --git a/src/dag_engine/simple_engine.h b/src/dag_engine/simple_engine.h deleted file mode 100644 index 2a29fa5a9832..000000000000 --- a/src/dag_engine/simple_engine.h +++ /dev/null @@ -1,170 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - */ -#ifndef MXNET_DAG_ENGINE_SIMPLE_ENGINE_H_ -#define MXNET_DAG_ENGINE_SIMPLE_ENGINE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "dag_engine_impl.h" -#include "thread_pool.h" - -namespace mxnet { - -namespace engine { - -/*! - * \brief Forward declarations. - */ -struct SimpleOpr; - -/*! - * \brief Operation in the queue. - */ -struct OprBlock { -#ifdef DAG_ENGINE_DEBUG - static std::atomic counter; - OprBlock() { LOG(INFO) << __func__ << " " << ++counter; } - ~OprBlock() { LOG(INFO) << __func__ << " " << --counter; } -#endif // DAG_ENGINE_DEBUG - std::atomic wait{0}; - SimpleOpr* opr{nullptr}; - Context ctx; - RunContext rctx; -}; // struct OprBlock - -/*! - * \brief Variable with version information. - */ -struct VersionedVarBlock { -#ifdef DAG_ENGINE_DEBUG - static std::atomic counter; - VersionedVarBlock() { LOG(INFO) << __func__ << " " << ++counter; } - ~VersionedVarBlock() { LOG(INFO) << __func__ << " " << --counter; } -#endif // DAG_ENGINE_DEBUG - VersionedVarBlock* next{nullptr}; - OprBlock* trigger{nullptr}; - bool write{false}; -}; // struct VersionedVarBlock - -/*! - * \brief Variable implementation. - */ -struct SimpleVar final : public Var { -#ifdef DAG_ENGINE_DEBUG - static std::atomic counter; - SimpleVar() { LOG(INFO) << __func__ << " " << ++counter; } - ~SimpleVar() { LOG(INFO) << __func__ << " " << --counter; } -#endif // DAG_ENGINE_DEBUG - std::mutex m; - std::size_t num_pending_reads{0}; - VersionedVarBlock* head{nullptr}; - VersionedVarBlock* pending_write{nullptr}; - /*! - * If true, then there are no current or future processing of the chain. - */ - bool ready_to_read{true}; - /*! - * If true, delete after operation completes. - */ - bool to_delete{false}; - - static SimpleVar* CastFromBase(Var* ptr); -}; // struct SimpleVar - -/*! - * \brief Operator implementation. - */ -struct SimpleOpr final : public Opr { -#ifdef DAG_ENGINE_DEBUG - static std::atomic counter; - SimpleOpr() { LOG(INFO) << __func__ << " " << ++counter; } - ~SimpleOpr() { LOG(INFO) << __func__ << " " << --counter; } -#endif // DAG_ENGINE_DEBUG - DAGEngine::AsyncFn fn; - std::vector use_vars; - std::vector mutate_vars; - bool temporary{false}; - - static SimpleOpr* CastFromBase(Opr* ptr); -}; // struct SimpleOpr - -/*! - * \brief Engine implementation. - */ -class SimpleEngine final : public DAGEngine { - public: - /*! - * \brief Constructor and destructor. - */ - SimpleEngine(); - ~SimpleEngine() noexcept(false); - /*! - * \brief Overriding methods. - */ - SimpleVar* NewVar() override; - SimpleOpr* NewOperator(AsyncFn fn, std::vector const& use_vars, - std::vector const& mutate_vars) override; - void DeleteOperator(OprHandle op) override; - void Push(OprHandle op, Context exec_ctx) override; - void Push(Fn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) override; - void PushAsync(AsyncFn exec_fun, Context exec_ctx, - std::vector const& use_vars, - std::vector const& mutate_vars) override; - void PushDelete(Fn delete_fn, Context exec_ctx, Variable var) override; - void WaitForVar(Variable var) override; - void WaitForAll() override; - /*! - * \brief Callback on operation completion. - * - * On operation completion, this will trigger subsequent operations. - */ - void OnComplete(SimpleOpr* simple_opr); - /*! - * \brief Worker. - * - * The method to pass to thread pool to parallelize. - */ - void ThreadWorker(); - - private: - /*! - * \brief Concurrency for thread pool. - */ - static constexpr std::size_t kNumWorkingThreads = 16; - /*! - * \brief Number of pending operations. - */ - std::atomic pending_; - /*! - * \brief Notify waits for single or all variables. - */ - std::mutex finished_m_; - std::condition_variable finished_cv_; - /*! - * \brief Task queue. - */ - dmlc::ConcurrentBlockingQueue task_queue_; - /*! - * \brief Thread pool. - */ - ThreadPool thread_pool_; - /*! - * \brief Disallow copy construction and assignment. - */ - DISALLOW_COPY_AND_ASSIGN(SimpleEngine); -}; // class SimpleEngine - -} // namespace engine - -} // namespace mxnet - -#endif // MXNET_DAG_ENGINE_SIMPLE_ENGINE_H_ diff --git a/src/dag_engine/threaded_engine.cc.bak b/src/dag_engine/threaded_engine.cc.bak deleted file mode 100644 index e5b44d5d1db2..000000000000 --- a/src/dag_engine/threaded_engine.cc.bak +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright (c) 2015 by Contributors -#include -#include -#include -#include -#include -#include -#include - -#include "dmlc/logging.h" -#include "mxnet/dag_engine.h" -#include "../common/spin_lock.h" -#include "../common/concurrent_blocking_queue.h" - -using namespace std; - -namespace mxnet { - -#define DEFAULT_NUM_WORKER_THREADS 4 - -class ThreadedEngine : public DAGEngine { - public: - explicit ThreadedEngine(int numthreads = DEFAULT_NUM_WORKER_THREADS): numthreads_(numthreads) { - for (int i = 0; i < numthreads; ++i) { - worker_queues_.push_back(new ConcurrentBlockingQueue()); - workers_.emplace_back(&ThreadedEngine::WorkerRoutine, this, i); - } - } - ~ThreadedEngine() { - for (int i = 0; i < numthreads_; ++i) { - worker_queues_[i]->SignalForKill(); - delete worker_queues_[i]; - workers_[i].join(); - } - } - void Push(AsyncOp exec_fun, - Context exec_ctx, - const vector &use_vars, - const vector &mutate_vars) override { - shared_ptr opd(new OpDescr{exec_fun, exec_ctx, use_vars, mutate_vars}, - [this] (OpDescr* o) { this->OnDepsResolved(o); }); - for ( Variable v : use_vars ) { // read - VarDescr* vard = static_cast(v); // safe to cast here - spin_lock(&vard->lock); - if (vard->rw < 0) { - vard->waitings.push(make_pair(opd, DepType::kRead)); - } else { - ++vard->rw; - } - spin_unlock(&vard->lock); - } - for ( Variable v : mutate_vars ) { // write - VarDescr* vard = static_cast(v); // safe to cast here - spin_lock(&vard->lock); - if (vard->rw != 0) { - vard->waitings.push(make_pair(opd, DepType::kWrite)); - } else { - vard->rw = -1; - } - spin_unlock(&vard->lock); - } - } - void Push(Op exec_fun, - Context exec_ctx, - const vector &use_vars, - const vector &mutate_vars) override { - this->Push([exec_fun](RunContext ctx, Callback on_complete) { - exec_fun(ctx); on_complete(); - }, exec_ctx, use_vars, mutate_vars); - } - void PushDelete(Op delete_fun, Context exec_ctx, Variable var) override { - this->Push([delete_fun, var] (RunContext ctx) { - delete_fun(ctx); - delete static_cast(var); // TODO(minjie): use variable pool instead - }, exec_ctx, {}, {var}); - } - Variable NewVar() override { - // in practice return a ptr to a cell - // that have the info about the variable - // use ptr directly instead of ID because this avoids an indirect mapping - // TODO(minjie): use variable pool instead - VarDescr* vd = new VarDescr; - vd->lock = SPINLOCK_INITIALIZER; - vd->rw = 0; - return vd; - } - void WaitForVar(Variable var) override { - // TODO(minjie): tbd - } - void WaitForAll() override { - // TODO(minjie): tbd - } - - private: - enum class DepType { - kRead = 0, - kWrite, - kDelete, - }; - struct OpDescr { - AsyncOp op; - Context exec_ctx; - vector read_vars; - vector write_vars; - }; - struct VarDescr { - spinlock lock; - int rw; // a semaphore-like count - // if rw > 0, the variable has several readers and the number - // means how many operators are currently reading it; - // if rw < 0, the varaible has one writer (should be -1) - queue, DepType>> waitings; - }; - void TriggerWaiting(VarDescr* vard) { - // ATTENTION: this function should be called with vard->lock held. - CHECK(vard->rw == 0) << "the variable should be free during triggering"; - if (!vard->waitings.empty()) { - // pop all reads first - while (vard->waitings.front().second == DepType::kRead) { - vard->waitings.pop(); - ++vard->rw; - } - if (vard->rw == 0) { - // pop the next write - vard->waitings.pop(); - vard->rw = -1; - } - } - } - void OnOpFinished(OpDescr* opd) { - CHECK(opd) << "completing a nullptr op!"; - for (Variable v : opd->read_vars) { - VarDescr* vard = static_cast(v); // safe to cast here - spin_lock(&vard->lock); - CHECK(vard->rw > 0) << "incorrect rw count (reader):" << vard->rw; - if (--vard->rw == 0) { - TriggerWaiting(vard); - } - spin_unlock(&vard->lock); - } - for (Variable v : opd->write_vars) { - VarDescr* vard = static_cast(v); // safe to cast here - spin_lock(&vard->lock); - CHECK(vard->rw == -1) << "incorrect rw count (writer):" << vard->rw; - vard->rw = 0; - TriggerWaiting(vard); - spin_unlock(&vard->lock); - } - delete opd; // delete the operator - } - RunContext GetRunContext(const Context& ctx) { - // TODO(minjie): get the correct runtime context - return RunContext(); - } - void OnDepsResolved(OpDescr* opd) { - static default_random_engine generator; - static uniform_int_distribution distribution(0, numthreads_ - 1); - int thrid = distribution(generator); - // LOG(INFO) << "schedule operator " << opd << " to thread #" << thrid; - worker_queues_[thrid]->Push(opd); - } - void WorkerRoutine(int thrid) { - OpDescr* opd = nullptr; - while (!worker_queues_[thrid]->Pop(opd)) { - // LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; - opd->op(GetRunContext(opd->exec_ctx), [this, opd] () { this->OnOpFinished(opd); }); - opd = nullptr; - } - } - - private: - const int numthreads_; - vector*> worker_queues_; - vector workers_; -}; - -// implements the singleton factory -DAGEngine* DAGEngine::Get() { - static ThreadedEngine engine; - return &engine; -} -} // namespace mxnet diff --git a/src/engine/engine.cc b/src/engine/engine.cc new file mode 100644 index 000000000000..4047099d14cc --- /dev/null +++ b/src/engine/engine.cc @@ -0,0 +1,28 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#include "mxnet/engine.h" +#include "engine_impl.h" +#include "naive_engine.h" +#include "threaded_engine.h" + +namespace mxnet { + +Engine::~Engine() noexcept(false) {} + +Engine* Engine::Get() { + /*! + * \brief Change specific engine to use. + */ +#ifdef MXNET_USE_THREADED_ENGINE + using EngineImplementation = engine::ThreadedEngine; +#else // MXNET_USE_THREADED_ENGINE +#warning "Using naive engine."; + using EngineImplementation = engine::NaiveEngine; +#endif // MXNET_USE_THREADED_ENGINE + + static EngineImplementation inst; + return &inst; +} + +} // namespace mxnet diff --git a/src/dag_engine/dag_engine_impl.h b/src/engine/engine_impl.h similarity index 62% rename from src/dag_engine/dag_engine_impl.h rename to src/engine/engine_impl.h index aa090074710b..c1ebe2a042f1 100644 --- a/src/dag_engine/dag_engine_impl.h +++ b/src/engine/engine_impl.h @@ -1,29 +1,29 @@ /*! * Copyright (c) 2015 by Contributors */ -#ifndef MXNET_DAG_ENGINE_DAG_ENGINE_IMPL_H_ -#define MXNET_DAG_ENGINE_DAG_ENGINE_IMPL_H_ +#ifndef MXNET_ENGINE_ENGINE_IMPL_H_ +#define MXNET_ENGINE_ENGINE_IMPL_H_ #include -#include "mxnet/dag_engine.h" +#include "mxnet/engine.h" -// #define DAG_ENGINE_DEBUG +#define ENGINE_DEBUG 0 namespace mxnet { namespace engine { struct Var { -#ifdef DAG_ENGINE_DEBUG +#if ENGINE_DEBUG virtual ~Var() = default; -#endif // DAG_ENGINE_DEBUG +#endif // ENGINE_DEBUG template T* Cast(); }; // struct Var struct Opr { -#ifdef DAG_ENGINE_DEBUG +#if ENGINE_DEBUG virtual ~Opr() = default; -#endif // DAG_ENGINE_DEBUG +#endif // ENGINE_DEBUG template T* Cast(); }; // struct Opr @@ -32,25 +32,25 @@ template T* Var::Cast() { static_assert(std::is_base_of::value, "must inherit `mxnet::engine::Var`"); -#ifndef DAG_ENGINE_DEBUG - return static_cast(this); -#else // DAG_ENGINE_DEBUG +#if ENGINE_DEBUG return dynamic_cast(this); -#endif // DAG_ENGINE_DEBUG +#else // ENGINE_DEBUG + return static_cast(this); +#endif // ENGINE_DEBUG } template T* Opr::Cast() { static_assert(std::is_base_of::value, "must inherit `mxnet::engine::Opr`"); -#ifndef DAG_ENGINE_DEBUG - return static_cast(this); -#else // DAG_ENGINE_DEBUG +#if ENGINE_DEBUG return dynamic_cast(this); -#endif // DAG_ENGINE_DEBUG +#else // ENGINE_DEBUG + return static_cast(this); +#endif // ENGINE_DEBUG } } // namespace engine } // namespace mxnet -#endif // MXNET_DAG_ENGINE_DAG_ENGINE_IMPL_H_ +#endif // MXNET_ENGINE_ENGINE_IMPL_H_ diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc new file mode 100644 index 000000000000..ff3b0467b08e --- /dev/null +++ b/src/engine/naive_engine.cc @@ -0,0 +1,70 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#include "naive_engine.h" +#include + +namespace mxnet { +namespace engine { + +NaiveEngine::VarHandle NaiveEngine::NewVariable() { return nullptr; } + +NaiveEngine::NaiveEngine() { +#if MXNET_USE_CUDA + stream_ = mshadow::NewStream(true, false); + ctx_.stream = stream_; +#endif +} + +NaiveEngine::~NaiveEngine() { +#if MXNET_USE_CUDA + mshadow::DeleteStream(stream_); +#endif +} + +NaiveEngine::OprHandle NaiveEngine::NewOperator(AsyncFn, + std::vector const&, + std::vector const&, + FnProperty) { + LOG(FATAL) << "Not implemented"; + return nullptr; +} + +void NaiveEngine::DeleteOperator(OprHandle) { LOG(FATAL) << "Not implemented"; } + +void NaiveEngine::Push(OprHandle, Context) { LOG(FATAL) << "Not implemented"; } + +void NaiveEngine::Push(Fn exec_fun, Context exec_ctx, + std::vector const&, + std::vector const&, FnProperty) { + if (exec_ctx.dev_mask == gpu::kDevMask) { +#if MXNET_USE_CUDA + mshadow::SetDevice(exec_ctx.dev_id); + ctx_.stream = stream_; + exec_fun(ctx_); + stream_->Wait(); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } else { + exec_fun(ctx_); + } +} + +void NaiveEngine::PushAsync(AsyncFn, Context, std::vector const&, + std::vector const&, FnProperty) { + LOG(FATAL) << "Not implemented"; +} + +void NaiveEngine::DeleteVariable(Fn delete_fun, Context exec_ctx, + VarHandle var) { + this->Push(delete_fun, exec_ctx, {}, {var}, FnProperty::kNormal); +} + +void NaiveEngine::WaitForVar(VarHandle) {} + +void NaiveEngine::WaitForAll() {} + +} // namespace engine + +} // namespace mxnet diff --git a/src/engine/naive_engine.h b/src/engine/naive_engine.h new file mode 100644 index 000000000000..bbcbc9d2c215 --- /dev/null +++ b/src/engine/naive_engine.h @@ -0,0 +1,47 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_ENGINE_NAIVE_ENGINE_H_ +#define MXNET_ENGINE_NAIVE_ENGINE_H_ + +#include +#include "engine_impl.h" + +namespace mxnet { + +namespace engine { + +class NaiveEngine final : public Engine { + public: + NaiveEngine(); + ~NaiveEngine(); + VarHandle NewVariable() override; + OprHandle NewOperator(AsyncFn fn, std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop) override; + void DeleteOperator(OprHandle op) override; + void Push(OprHandle op, Context exec_ctx) override; + void Push(Fn exec_fun, Context exec_ctx, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop) override; + void PushAsync(AsyncFn exec_fun, Context exec_ctx, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop) override; + void DeleteVariable(Fn delete_fun, Context exec_ctx, VarHandle var) override; + void WaitForVar(VarHandle var) override; + void WaitForAll() override; + + private: + RunContext ctx_; +#if MXNET_USE_CUDA + mshadow::Stream* stream_; +#endif // MXNET_USE_CUDA +}; // class NaiveEngine + +} // namespace engine + +} // namespace mxnet + +#endif // MXNET_ENGINE_NAIVE_ENGINE_H_ diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h new file mode 100644 index 000000000000..75bca74935bc --- /dev/null +++ b/src/engine/stream_manager.h @@ -0,0 +1,128 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_ENGINE_STREAM_MANAGER_H_ +#define MXNET_ENGINE_STREAM_MANAGER_H_ + +#include +#include +#include +#include +#include "mxnet/base.h" +#include "mxnet/context.h" +#include "../common/cuda_utils.h" + +namespace mxnet { + +namespace engine { + +/*! + * \brief Stream manager. + * + * Uses a basic round-robin algorithm to dispatch GPU streams. Returns default + * context on CPU. + */ +template +class StreamManager { + public: + StreamManager(); + ~StreamManager(); + RunContext GetRunContext(Context const& ctx); + RunContext GetIORunContext(Context const& ctx); + + private: + std::mutex m_; +#if MXNET_USE_CUDA + std::array*, kStreams>, kNumGpus> + gpu_streams_; + std::array*, kNumGpus> gpu_io_streams_; + std::array gpu_cnt_; +#endif // MXNET_USE_CUDA + DISALLOW_COPY_AND_ASSIGN(StreamManager); +}; // class StreamManager + +template +RunContext StreamManager::GetRunContext( + Context const& ctx) { + switch (ctx.dev_mask) { + case cpu::kDevMask: + return {nullptr}; + case gpu::kDevMask: { +#if MXNET_USE_CUDA + std::size_t use_counter; + CUDA_CALL(cudaSetDevice(ctx.dev_id)); + { + std::lock_guard lock{m_}; + auto&& counter = gpu_cnt_.at(ctx.dev_id); + if (counter == -1) { + for (auto&& i : gpu_streams_.at(ctx.dev_id)) { + i = mshadow::NewStream(true, false); + } + counter = 0; + } + use_counter = counter; + counter = (counter + 1) % kStreams; + } + return {gpu_streams_.at(ctx.dev_id).at(use_counter)}; +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA + } + } + return {nullptr}; +} + +template +RunContext StreamManager::GetIORunContext( + Context const& ctx) { + switch (ctx.dev_mask) { + case cpu::kDevMask: + return {nullptr}; + case gpu::kDevMask: { +#if MXNET_USE_CUDA + CUDA_CALL(cudaSetDevice(ctx.dev_id)); + { + std::lock_guard lock{m_}; + if (gpu_io_streams_.at(ctx.dev_id) == nullptr) { + gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream(true, false); + } + } + return {gpu_io_streams_.at(ctx.dev_id)}; +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA + } + } + return {nullptr}; +} + +template +StreamManager::StreamManager() { +#if MXNET_USE_CUDA + for (std::size_t i = 0; i < kNumGpus; ++i) { + gpu_cnt_.at(i) = -1; + } + for (auto&& i : gpu_io_streams_) { + i = nullptr; + } +#endif // MXNET_USE_CUDA +} + +template +StreamManager::~StreamManager() { +#if MXNET_USE_CUDA + for (std::size_t i = 0; i < kNumGpus; ++i) { + if (gpu_cnt_.at(i) != -1) { + for (auto&& j : gpu_streams_.at(i)) { + mshadow::DeleteStream(j); + } + } + } +#endif // MXNET_USE_CUDA +} + +} // namespace engine + +} // namespace mxnet + +#endif // MXNET_ENGINE_STREAM_MANAGER_H_ diff --git a/src/dag_engine/thread_pool.h b/src/engine/thread_pool.h similarity index 91% rename from src/dag_engine/thread_pool.h rename to src/engine/thread_pool.h index 4d5b67cc56a3..292b6c433d45 100644 --- a/src/dag_engine/thread_pool.h +++ b/src/engine/thread_pool.h @@ -1,8 +1,8 @@ /*! * Copyright (c) 2015 by Contributors */ -#ifndef MXNET_DAG_ENGINE_THREAD_POOL_H_ -#define MXNET_DAG_ENGINE_THREAD_POOL_H_ +#ifndef MXNET_ENGINE_THREAD_POOL_H_ +#define MXNET_ENGINE_THREAD_POOL_H_ #include #include @@ -64,4 +64,4 @@ ThreadPool::~ThreadPool() noexcept(false) { } // namespace engine } // namespace mxnet -#endif // MXNET_DAG_ENGINE_THREAD_POOL_H_ +#endif // MXNET_ENGINE_THREAD_POOL_H_ diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc new file mode 100644 index 000000000000..cd9758835346 --- /dev/null +++ b/src/engine/threaded_engine.cc @@ -0,0 +1,361 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#include "threaded_engine.h" +#include +#include +#include +#include +#include +#include +#include "../common/cuda_utils.h" + +namespace mxnet { + +namespace engine { + +#if ENGINE_DEBUG +std::atomic OprBlock::counter{0}; +std::atomic VersionedVarBlock::counter{0}; +std::atomic ThreadedVar::counter{0}; +std::atomic ThreadedOpr::counter{0}; +#endif // ENGINE_DEBUG + +ThreadedVar::ThreadedVar(VersionedVarBlock* head) : head_{head} { +#if ENGINE_DEBUG + LOG(INFO) << __func__ << " " << ++counter; +#endif // ENGINE_DEBUG +} + +void ThreadedVar::AppendReadDependency(OprBlock* opr_block) { + std::lock_guard lock{m_}; + if (ready_to_read_) { + assert(pending_write_ == nullptr); + ++num_pending_reads_; + --opr_block->wait; + } else { + auto&& new_var_block = VersionedVarBlock::New(); + assert(head_->next == nullptr); + assert(head_->trigger == nullptr); + assert(head_->write == false); + head_->next = new_var_block; + head_->trigger = opr_block; + head_ = new_var_block; + } +} + +void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { + std::lock_guard lock{m_}; + auto&& new_var_block = VersionedVarBlock::New(); + head_->next = new_var_block; + head_->trigger = opr_block; + head_->write = true; + if (ready_to_read_) { + /*! + * Raise `num_pending_reads_` temporarily to avoid premature triggering. + */ + ++num_pending_reads_; + pending_write_ = head_; + if (--num_pending_reads_ == 0) { + --opr_block->wait; + } + ready_to_read_ = false; + } + head_ = new_var_block; +} + +template +void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { + std::lock_guard lock{m_}; + if (--num_pending_reads_ == 0) { + if (pending_write_ != nullptr && --pending_write_->trigger->wait == 0) { + dispatcher(pending_write_->trigger); + } + } +} + +template +bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { + std::lock_guard lock{m_}; + assert(ready_to_read_ == false); + auto cur_head = pending_write_->next; + VersionedVarBlock::Delete(pending_write_); + pending_write_ = nullptr; + if (to_delete_) { + assert(cur_head->next == nullptr); + VersionedVarBlock::Delete(cur_head); + return true; + } else { + while (true) { + if (cur_head->write == true) { + ++num_pending_reads_; + pending_write_ = cur_head; + if (--num_pending_reads_ == 0) { + if (--cur_head->trigger->wait == 0) { + dispatcher(cur_head->trigger); + } + } + break; + } else if (cur_head->next == nullptr) { + ready_to_read_ = true; + break; + } else { + ++num_pending_reads_; + if (--cur_head->trigger->wait == 0) { + dispatcher(cur_head->trigger); + } + auto prev = cur_head; + cur_head = cur_head->next; + VersionedVarBlock::Delete(prev); + } + } + return false; + } +} + +void ThreadedVar::SetToDelete() { + std::lock_guard lock{m_}; + to_delete_ = true; +} + +bool ThreadedVar::ready_to_read() { + std::lock_guard lock{m_}; + return ready_to_read_; +} + +ThreadedVar* ThreadedVar::CastFromBase(Var* v) { + return v->Cast(); +} + +ThreadedOpr* ThreadedOpr::CastFromBase(Opr* o) { + return o->Cast(); +} + +ThreadedEngine::ThreadedEngine() + : pending_{0}, + thread_pool_{[this]() { ThreadWorker(&task_queue_); }}, + io_thread_pool_{[this]() { ThreadWorker(&io_task_queue_); }} {} + +ThreadedEngine::~ThreadedEngine() noexcept(false) { + task_queue_.SignalForKill(); + io_task_queue_.SignalForKill(); +} + +ThreadedVar* ThreadedEngine::NewVariable() { + auto ret = ThreadedVar::New(VersionedVarBlock::New()); + return ret; +} + +ThreadedOpr* ThreadedEngine::NewOperator( + ThreadedEngine::AsyncFn fn, std::vector const& const_vars, + std::vector const& mutable_vars, FnProperty prop) { + auto ret = ThreadedOpr::New(); + ret->fn = fn; + ret->prop = prop; + ret->const_vars.resize(const_vars.size()); + ret->mutable_vars.resize(mutable_vars.size()); + std::transform(const_vars.begin(), const_vars.end(), ret->const_vars.begin(), + ThreadedVar::CastFromBase); + std::transform(mutable_vars.begin(), mutable_vars.end(), + ret->mutable_vars.begin(), ThreadedVar::CastFromBase); +#if ENGINE_DEBUG + // Check for duplicates. + auto use = const_vars; + auto mutate = mutable_vars; + auto use_size = use.size(); + auto mutate_size = mutate.size(); + std::sort(use.begin(), use.end()); + std::sort(mutate.begin(), mutate.end()); + for (std::size_t i = 0; i < use_size; ++i) { + if (i != 0 && use.at(i) == use.at(i - 1)) { + LOG(FATAL) << "duplicate items found in `const_vars`"; + } + } + for (std::size_t i = 0; i < mutate_size; ++i) { + if (i != 0 && mutate.at(i) == mutate.at(i - 1)) { + LOG(FATAL) << "duplicate items found in `mutable_vars`"; + } + } + std::size_t j = 0; + for (std::size_t i = 0; i < use_size; ++i) { + while (j < mutate_size && mutate.at(j) < use.at(i)) { + ++j; + } + if (j == mutate_size) { + break; + } + if (mutate.at(j) == use.at(i)) { + LOG(FATAL) + << "duplicate items found between `const_vars` and `mutable_vars`"; + } + } +#endif // ENGINE_DEBUG + return ret; +} + +void ThreadedEngine::DeleteOperator(OprHandle op) { + auto&& threaded_opr = ThreadedOpr::CastFromBase(op); + std::vector deps{}; + deps.reserve(threaded_opr->const_vars.size() + + threaded_opr->mutable_vars.size()); + deps.insert(deps.end(), threaded_opr->const_vars.begin(), + threaded_opr->const_vars.end()); + deps.insert(deps.end(), threaded_opr->mutable_vars.begin(), + threaded_opr->mutable_vars.end()); + auto&& func = + [threaded_opr](RunContext) { ThreadedOpr::Delete(threaded_opr); }; + Push(func, Context{}, {}, deps, FnProperty::kAsync); +} + +void ThreadedEngine::Push(OprHandle op, Context exec_ctx) { + auto&& threaded_opr = ThreadedOpr::CastFromBase(op); + auto&& opr_block = OprBlock::New(); + opr_block->opr = threaded_opr; + opr_block->wait.store(threaded_opr->const_vars.size() + + threaded_opr->mutable_vars.size() + 1); + opr_block->ctx = exec_ctx; + ++pending_; + // Add read dependencies. + for (auto&& i : threaded_opr->const_vars) { + i->AppendReadDependency(opr_block); + } + // Add write dependencies. + for (auto&& i : threaded_opr->mutable_vars) { + i->AppendWriteDependency(opr_block); + } + if (--opr_block->wait == 0) { + if (opr_block->opr->prop == FnProperty::kAsync) { + DoExecute(opr_block); + } else { + DoPushToQueue(opr_block); + } + } +} + +void ThreadedEngine::Push(Fn exec_fun, Context exec_ctx, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop) { + auto f = [exec_fun](RunContext ctx, Callback on_complete) { + exec_fun(ctx); + on_complete(); + }; + PushAsync(f, exec_ctx, const_vars, mutable_vars, prop); +} + +void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop) { + auto&& opr = NewOperator(fn, const_vars, mutable_vars, prop); + opr->temporary = true; + Push(opr, exec_ctx); +} + +void ThreadedEngine::DeleteVariable(Fn delete_fn, Context exec_ctx, + VarHandle var) { + auto&& threaded_var = ThreadedVar::CastFromBase(var); + auto&& func = [delete_fn, threaded_var](RunContext ctx) { + /*! + * Mark variable as orphan, so during `ThreadedEngine::OnComplete` it could + * be recycled. + */ + threaded_var->SetToDelete(); + delete_fn(ctx); + }; + Push(func, exec_ctx, {}, {var}, FnProperty::kAsync); +} + +void ThreadedEngine::WaitForVar(VarHandle var) { + auto&& threaded_var = ThreadedVar::CastFromBase(var); + if (threaded_var->ready_to_read()) { + return; + } + { + std::unique_lock lock{finished_m_}; + std::atomic done{false}; + auto&& callback = [this, &done](RunContext) { + std::unique_lock lock{finished_m_}; + done.store(true); + finished_cv_.notify_all(); + }; + Push(callback, Context{}, {var}, {}, FnProperty::kNormal); + finished_cv_.wait(lock, [&done]() { return done.load(); }); + } +} + +void ThreadedEngine::WaitForAll() { + std::unique_lock lock{finished_m_}; + finished_cv_.wait(lock, [this]() { return pending_.load() == 0; }); +} + +void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { + /*! + * Mark complete for read variables. + */ + for (auto&& i : threaded_opr->const_vars) { + i->CompleteReadDependency([this](OprBlock* opr) { DoPushToQueue(opr); }); + } + /*! + * Mark complete for write variables. + */ + for (auto&& i : threaded_opr->mutable_vars) { + bool to_delete = i->CompleteWriteDependency( + [this](OprBlock* opr) { DoPushToQueue(opr); }); + if (to_delete) { + ThreadedVar::Delete(i); + } + } + { + std::unique_lock lock{finished_m_}; + if (--pending_ == 0) { + finished_cv_.notify_all(); + } + } +} + +void ThreadedEngine::ThreadWorker( + dmlc::ConcurrentBlockingQueue* task_queue) { + OprBlock* opr_block; + while (task_queue->Pop(&opr_block)) { + DoExecute(opr_block); + } +} + +void ThreadedEngine::DoPushToQueue(OprBlock* opr_block) { + switch (opr_block->opr->prop) { + case FnProperty::kIO: + io_task_queue_.Push(opr_block); + break; + default: + task_queue_.Push(opr_block); + break; + } +} + +void ThreadedEngine::DoExecute(OprBlock* opr_block) { + assert(opr_block->wait.load() == 0); + auto threaded_opr = opr_block->opr; + auto callback = [this, threaded_opr]() { + OnComplete(threaded_opr); + if (threaded_opr->temporary) { + ThreadedOpr::Delete(threaded_opr); + } + }; + if (opr_block->ctx.dev_mask == gpu::kDevMask) { +#if MXNET_USE_CUDA + CUDA_CALL(cudaSetDevice(opr_block->ctx.dev_id)); +#else // MXNET_USE_CUDA + LOG(FATAL) << "Please compile with CUDA enabled"; +#endif // MXNET_USE_CUDA + } + auto&& rctx = opr_block->opr->prop == FnProperty::kIO + ? streams_.GetIORunContext(opr_block->ctx) + : streams_.GetRunContext(opr_block->ctx); + threaded_opr->fn(rctx, callback); + OprBlock::Delete(opr_block); +} + +} // namespace engine + +} // namespace mxnet diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h new file mode 100644 index 000000000000..e2f5835b6507 --- /dev/null +++ b/src/engine/threaded_engine.h @@ -0,0 +1,212 @@ +/*! + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_ENGINE_THREADED_ENGINE_H_ +#define MXNET_ENGINE_THREADED_ENGINE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "engine_impl.h" +#include "thread_pool.h" +#include "stream_manager.h" +#include "../common/object_pool.h" + +namespace mxnet { + +namespace engine { + +/*! + * \brief Forward declarations. + */ +struct ThreadedOpr; + +/*! + * \brief Operation in the queue. + */ +struct OprBlock : public common::ObjectPoolAllocatable { +#if ENGINE_DEBUG + static std::atomic counter; + OprBlock() { LOG(INFO) << __func__ << " " << ++counter; } + ~OprBlock() { LOG(INFO) << __func__ << " " << --counter; } +#endif // ENGINE_DEBUG + std::atomic wait{0}; + ThreadedOpr* opr{nullptr}; + Context ctx; +}; // struct OprBlock + +/*! + * \brief Variable with version information. + */ +struct VersionedVarBlock + : public common::ObjectPoolAllocatable { +#if ENGINE_DEBUG + static std::atomic counter; + VersionedVarBlock() { LOG(INFO) << __func__ << " " << ++counter; } + ~VersionedVarBlock() { LOG(INFO) << __func__ << " " << --counter; } +#endif // ENGINE_DEBUG + VersionedVarBlock* next{nullptr}; + OprBlock* trigger{nullptr}; + bool write{false}; +}; // struct VersionedVarBlock + +/*! + * \brief Variable implementation. + */ +class ThreadedVar final : public Var, + public common::ObjectPoolAllocatable { + public: +#if ENGINE_DEBUG + static std::atomic counter; + ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } +#endif // ENGINE_DEBUG + explicit ThreadedVar(VersionedVarBlock* head); + void AppendReadDependency(OprBlock* opr_block); + void AppendWriteDependency(OprBlock* opr_block); + template + void CompleteReadDependency(Dispatcher dispatcher); + template + bool CompleteWriteDependency(Dispatcher dispatcher); + void SetToDelete(); + bool ready_to_read(); + + static ThreadedVar* CastFromBase(Var* ptr); + + private: + // TODO(hotpxl) change this to spinlock for faster runtime + std::mutex m_; + std::size_t num_pending_reads_{0}; + VersionedVarBlock* head_{nullptr}; + VersionedVarBlock* pending_write_{nullptr}; + /*! + * If true, then there are no current or future processing of the chain. + */ + bool ready_to_read_{true}; + /*! + * If true, delete after operation completes. + */ + bool to_delete_{false}; +}; // struct ThreadedVar + +/*! + * \brief Operator implementation. + */ +struct ThreadedOpr final : public Opr, + public common::ObjectPoolAllocatable { +#if ENGINE_DEBUG + static std::atomic counter; + ThreadedOpr() { LOG(INFO) << __func__ << " " << ++counter; } + ~ThreadedOpr() { LOG(INFO) << __func__ << " " << --counter; } +#endif // ENGINE_DEBUG + Engine::AsyncFn fn; + std::vector const_vars; + std::vector mutable_vars; + FnProperty prop; + bool temporary{false}; + + static ThreadedOpr* CastFromBase(Opr* ptr); +}; // struct ThreadedOpr + +/*! + * \brief Engine implementation. + */ +class ThreadedEngine final : public Engine { + public: + /*! + * \brief Constructor and destructor. + */ + ThreadedEngine(); + ~ThreadedEngine() noexcept(false); + /*! + * \brief Overriding methods. + */ + ThreadedVar* NewVariable() override; + ThreadedOpr* NewOperator(AsyncFn fn, std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop) override; + void DeleteOperator(OprHandle op) override; + void Push(OprHandle op, Context exec_ctx) override; + void Push(Fn exec_fun, Context exec_ctx, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop) override; + void PushAsync(AsyncFn exec_fun, Context exec_ctx, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop) override; + void DeleteVariable(Fn delete_fn, Context exec_ctx, VarHandle var) override; + void WaitForVar(VarHandle var) override; + void WaitForAll() override; + /*! + * \brief Callback on operation completion. + * + * On operation completion, this will trigger subsequent operations. + */ + void OnComplete(ThreadedOpr* threaded_opr); + /*! + * \brief Worker. + * \param task_queue Queue to work on. + * + * The method to pass to thread pool to parallelize. + */ + void ThreadWorker(dmlc::ConcurrentBlockingQueue* task_queue); + + private: + /*! + * \brief Concurrency for thread pool. + */ + static constexpr std::size_t kNumWorkingThreads = 16; + /*! + * \brief Constants for runtime context. + */ + static constexpr std::size_t kMaxNumGpus = 16; + static constexpr std::size_t kNumStreamsPerGpu = 16; + /*! + * \brief Number of pending operations. + */ + std::atomic pending_; + /*! + * \brief Notify waits for single or all variables. + */ + std::mutex finished_m_; + std::condition_variable finished_cv_; + /*! + * \brief Streams. + */ + StreamManager streams_; + /*! + * \brief Task queues. + */ + dmlc::ConcurrentBlockingQueue task_queue_; + dmlc::ConcurrentBlockingQueue io_task_queue_; + /*! + * \brief Thread pools. + */ + ThreadPool thread_pool_; + ThreadPool<1> io_thread_pool_; + /*! + * \brief Push to corresponding task queue. + * \param opr_block The operator block. + */ + void DoPushToQueue(OprBlock* opr_block); + /*! + * \brief Execute an operation. + * \param opr_block The operator block. + */ + void DoExecute(OprBlock* opr_block); + /*! + * \brief Disallow copy construction and assignment. + */ + DISALLOW_COPY_AND_ASSIGN(ThreadedEngine); +}; // class ThreadedEngine + +} // namespace engine + +} // namespace mxnet + +#endif // MXNET_ENGINE_THREADED_ENGINE_H_ diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 3a6bc4169417..68a23e8253ab 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -9,7 +9,7 @@ #include #include #include "mxnet/kvstore.h" -#include "mxnet/dag_engine.h" +#include "mxnet/engine.h" namespace mxnet { @@ -18,7 +18,7 @@ namespace mxnet { */ class KVStoreLocal : public KVStore { public: - KVStoreLocal() : engine_(DAGEngine::Get()) { Clear(); } + KVStoreLocal() : engine_(Engine::Get()) { Clear(); } virtual ~KVStoreLocal() { Clear(); } virtual void InitDevices(const std::vector& devices) { @@ -121,7 +121,7 @@ class KVStoreLocal : public KVStore { local_.clear(); } - DAGEngine* engine_; + Engine* engine_; Updater updater_; bool aggregator_; diff --git a/src/narray/narray.cc b/src/narray/narray.cc index e65f4127e65d..7bce1d5c5243 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -51,13 +51,13 @@ inline void BinaryOp(const NArray &lhs, narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); }; if (lhs.ptr_->var == ret.ptr_->var && rhs.ptr_->var == ret.ptr_->var) { - DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); } else if (lhs.ptr_->var == ret.ptr_->var) { - DAGEngine::Get()->Push(func, lhs.ctx(), {rhs.ptr_->var}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {rhs.ptr_->var}, {ret.ptr_->var}); } else if (rhs.ptr_->var == ret.ptr_->var) { - DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); } else { - DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); } break; } @@ -69,13 +69,13 @@ inline void BinaryOp(const NArray &lhs, narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); }; if (lhs.ptr_->var == ret.ptr_->var && rhs.ptr_->var == ret.ptr_->var) { - DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); } else if (lhs.ptr_->var == ret.ptr_->var) { - DAGEngine::Get()->Push(func, lhs.ctx(), {rhs.ptr_->var}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {rhs.ptr_->var}, {ret.ptr_->var}); } else if (rhs.ptr_->var == ret.ptr_->var) { - DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); } else { - DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); } break; } @@ -95,7 +95,7 @@ inline void SetValueOp(const real_t &rhs, NArray *out) { TBlob tmp = ret.data(); narray::Eval(rhs, &tmp, ctx); }; - DAGEngine::Get()->Push(func, ret.ctx(), {}, {ret.ptr_->var}); + Engine::Get()->Push(func, ret.ctx(), {}, {ret.ptr_->var}); break; } #if MXNET_USE_CUDA @@ -105,7 +105,7 @@ inline void SetValueOp(const real_t &rhs, NArray *out) { TBlob tmp = ret.data(); narray::Eval(rhs, &tmp, ctx); }; - DAGEngine::Get()->Push(func, ret.ctx(), {}, {ret.ptr_->var}); + Engine::Get()->Push(func, ret.ctx(), {}, {ret.ptr_->var}); break; } #endif @@ -141,9 +141,9 @@ inline void ScalarOp(const NArray &lhs, narray::Eval(lhs.data(), rhs, &tmp, ctx); }; if (lhs.ptr_->var == ret.ptr_->var) { - DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); } else { - DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); } break; } @@ -155,9 +155,9 @@ inline void ScalarOp(const NArray &lhs, narray::Eval(lhs.data(), rhs, &tmp, ctx); }; if (lhs.ptr_->var == ret.ptr_->var) { - DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); } else { - DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); + Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); } break; } @@ -176,7 +176,7 @@ void CopyFromTo(const NArray &from, NArray *to) { int a = from.ctx().dev_mask; int b = to->ctx().dev_mask; if (a == cpu::kDevMask && b == cpu::kDevMask) { - DAGEngine::Get()->Push([from, ret](RunContext ctx) { + Engine::Get()->Push([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Copy(from.data(), &tmp, @@ -184,7 +184,7 @@ void CopyFromTo(const NArray &from, NArray *to) { }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); } else if (a == cpu::kDevMask && b == gpu::kDevMask) { #if MXNET_USE_CUDA - DAGEngine::Get()->Push([from, ret](RunContext ctx) { + Engine::Get()->Push([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Copy(from.data(), &tmp, @@ -195,7 +195,7 @@ void CopyFromTo(const NArray &from, NArray *to) { #endif } else if (a == gpu::kDevMask && b == cpu::kDevMask) { #if MXNET_USE_CUDA - DAGEngine::Get()->Push([from, ret](RunContext ctx) { + Engine::Get()->Push([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Copy(from.data(), &tmp, @@ -206,7 +206,7 @@ void CopyFromTo(const NArray &from, NArray *to) { #endif } else if (a == gpu::kDevMask && b == gpu::kDevMask) { #if MXNET_USE_CUDA - DAGEngine::Get()->Push([from, ret](RunContext ctx) { + Engine::Get()->Push([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Copy(from.data(), &tmp, diff --git a/src/storage/cpu_device_storage.h b/src/storage/cpu_device_storage.h index 1241a56e0d12..82aa0e7573f6 100644 --- a/src/storage/cpu_device_storage.h +++ b/src/storage/cpu_device_storage.h @@ -38,12 +38,13 @@ class CPUDeviceStorage { }; // class CPUDeviceStorage inline void* CPUDeviceStorage::Alloc(size_t size) { -#ifdef __APPLE__ - return CHECK_NOTNULL(malloc(size)); -#elif _MSC_VER +#if _MSC_VER return CHECK_NOTNULL(_aligned_malloc(size, alignment_)); #else - return CHECK_NOTNULL(memalign(alignment_, size)); + void* ptr; + int ret = posix_memalign(&ptr, alignment_, size); + CHECK_EQ(ret, 0) << "Allocation failed"; + return ptr; #endif } diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 98c7981b6a88..ad040ee88ad8 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include "storage_manager.h" #include "naive_storage_manager.h" #include "pooled_storage_manager.h" @@ -40,41 +41,45 @@ struct Storage::Impl { } } - // std::unordered_map< - // int, std::unordered_map>> - // storage_managers; std::array, kMaxNumberOfDeviceIDs>, kMaxNumberOfDevices> storage_managers; + std::mutex m; }; // struct Storage::Impl Storage::Handle Storage::Alloc(size_t size, Context ctx) { Handle hd; hd.ctx = ctx; - auto&& device = impl_->storage_managers.at(ctx.dev_mask); - auto&& device_id_it = device.at(ctx.dev_id); - // Allocate device if necessary. - if (!device_id_it) { - switch (ctx.dev_mask) { - case cpu::kDevMask: - device_id_it = common::MakeUnique< - Storage::Impl::CurrentStorageManager>(); - break; - case gpu::kDevMask: - device_id_it = common::MakeUnique< - Storage::Impl::CurrentStorageManager>(); - break; - default: - LOG(FATAL) << "Unimplemented device"; + hd.size = size; + { + std::lock_guard lock{impl_->m}; + auto&& device = impl_->storage_managers.at(ctx.dev_mask); + auto&& device_id_it = device.at(ctx.dev_id); + // Allocate device if necessary. + if (!device_id_it) { + switch (ctx.dev_mask) { + case cpu::kDevMask: + device_id_it = + common::MakeUnique>(); + break; + case gpu::kDevMask: + device_id_it = + common::MakeUnique>(); + break; + default: + LOG(FATAL) << "Unimplemented device"; + } } + Impl::ActivateDevice(ctx); + hd.dptr = device_id_it->Alloc(size); } - Impl::ActivateDevice(ctx); - hd.dptr = device_id_it->Alloc(size); - hd.size = size; return hd; } void Storage::Free(Storage::Handle handle) { + std::lock_guard lock{impl_->m}; Impl::ActivateDevice(handle.ctx); impl_->storage_managers.at(handle.ctx.dev_mask) .at(handle.ctx.dev_id) @@ -84,6 +89,7 @@ void Storage::Free(Storage::Handle handle) { Storage::~Storage() = default; Storage* Storage::Get() { + // This function is thread-safe in C++11 static Storage inst; return &inst; } diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index aeff3427d8f3..914c8f8b8c9f 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -450,14 +450,14 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { OpNode& opnode = op_nodes_[nid]; opnode.op_ctx.is_train = is_train; if (opnode.cached_exec.exec_fun != nullptr) { - DAGEngine::Get()->Push( + Engine::Get()->Push( opnode.cached_exec.exec_fun, opnode.ctx, opnode.cached_exec.use_vars, opnode.cached_exec.mutate_vars); } else { auto exec = GetOpExecEntry(nid); - DAGEngine::Get()->Push( + Engine::Get()->Push( exec.exec_fun, opnode.ctx, exec.use_vars, diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index 66cd074b406b..af2160415e8d 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -91,11 +91,11 @@ class GraphExecutor : public Executor { // all the information needed to push the op to engine struct OpExecEntry { // execution function for - DAGEngine::Fn exec_fun; + Engine::Fn exec_fun; // variables to read from - std::vector use_vars; + std::vector use_vars; // variables to mutate - std::vector mutate_vars; + std::vector mutate_vars; // constructor OpExecEntry() : exec_fun(nullptr) {} }; diff --git a/tests/test_simple_engine.cc b/tests/test_simple_engine.cc deleted file mode 100644 index 453a13e11d4b..000000000000 --- a/tests/test_simple_engine.cc +++ /dev/null @@ -1,113 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - */ -#include -#include -#include -#include -#include -#include - -#include "mxnet/dag_engine.h" - -void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); } - -int main() { - auto&& engine = mxnet::DAGEngine::Get(); - auto&& var = engine->NewVar(); - std::vector oprs; - - // Test #1 - printf("============= Test #1 ==============\n"); - for (int i = 0; i < 10; ++i) { - oprs.push_back(engine->NewOperator( - [i](mxnet::RunContext ctx, mxnet::DAGEngine::Callback cb) { - Foo(ctx, i); - std::this_thread::sleep_for(std::chrono::seconds{1}); - cb(); - }, - {var}, {})); - engine->Push(oprs.at(i), mxnet::Context{}); - } - engine->WaitForAll(); - printf("Going to push delete\n"); - // std::this_thread::sleep_for(std::chrono::seconds{1}); - for (auto&& i : oprs) { - engine->DeleteOperator(i); - } - engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); - engine->WaitForAll(); - - printf("============= Test #2 ==============\n"); - var = engine->NewVar(); - oprs.clear(); - for (int i = 0; i < 10; ++i) { - oprs.push_back(engine->NewOperator( - [i](mxnet::RunContext ctx, mxnet::DAGEngine::Callback cb) { - Foo(ctx, i); - std::this_thread::sleep_for(std::chrono::milliseconds{500}); - cb(); - }, - {}, {var})); - engine->Push(oprs.at(i), mxnet::Context{}); - } - // std::this_thread::sleep_for(std::chrono::seconds{1}); - engine->WaitForAll(); - for (auto&& i : oprs) { - engine->DeleteOperator(i); - } - engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); - - printf("============= Test #3 ==============\n"); - var = engine->NewVar(); - oprs.clear(); - engine->WaitToWrite(var); - engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); - engine->WaitForAll(); - - printf("============= Test #4 ==============\n"); - var = engine->NewVar(); - oprs.clear(); - oprs.push_back(engine->NewOperator( - [](mxnet::RunContext ctx, mxnet::DAGEngine::Callback cb) { - std::this_thread::sleep_for(std::chrono::seconds{2}); - Foo(ctx, 42); - cb(); - }, - {}, {var})); - engine->Push(oprs.at(0), mxnet::Context{}); - LOG(INFO) << "Operator pushed, should wait for 2 seconds."; - engine->WaitToWrite(var); - LOG(INFO) << "OK, here I am."; - for (auto&& i : oprs) { - engine->DeleteOperator(i); - } - engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); - engine->WaitForAll(); - - printf("============= Test #5 ==============\n"); - var = engine->NewVar(); - oprs.clear(); - oprs.push_back(engine->NewOperator( - [](mxnet::RunContext ctx, mxnet::DAGEngine::Callback cb) { - Foo(ctx, 42); - std::this_thread::sleep_for(std::chrono::seconds{2}); - cb(); - }, - {var}, {})); - engine->Push(oprs.at(0), mxnet::Context{}); - LOG(INFO) << "Operator pushed, should not wait."; - engine->WaitToWrite(var); - LOG(INFO) << "OK, here I am."; - engine->WaitForAll(); - LOG(INFO) << "That was 2 seconds."; - for (auto&& i : oprs) { - engine->DeleteOperator(i); - } - engine->PushDelete([](mxnet::RunContext) {}, mxnet::Context{}, var); - engine->WaitForAll(); - var = nullptr; - oprs.clear(); - - return 0; -} diff --git a/tests/test_threaded_engine.cc b/tests/test_threaded_engine.cc index fecd552d1b50..d3708711779a 100644 --- a/tests/test_threaded_engine.cc +++ b/tests/test_threaded_engine.cc @@ -1,43 +1,113 @@ -// Copyright (c) 2015 by Contributors +/*! + * Copyright (c) 2015 by Contributors + */ #include -#include +#include +#include +#include +#include #include -#include "mxnet/dag_engine.h" +#include "mxnet/engine.h" -using namespace std; -using namespace mxnet; - -void Foo(RunContext rctx, int i) { - cout << "say: " << i << endl; -} +void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); } int main() { - DAGEngine* engine = DAGEngine::Get(); - Context exec_ctx; + auto&& engine = mxnet::Engine::Get(); + auto&& var = engine->NewVariable(); + std::vector oprs; // Test #1 - cout << "============= Test #1 ==============" << endl; - vector vars; + printf("============= Test #1 ==============\n"); for (int i = 0; i < 10; ++i) { - vars.push_back(engine->NewVar()); + oprs.push_back(engine->NewOperator( + [i](mxnet::RunContext ctx, mxnet::Engine::Callback cb) { + Foo(ctx, i); + std::this_thread::sleep_for(std::chrono::seconds{1}); + cb(); + }, + {var}, {})); + engine->Push(oprs.at(i), mxnet::Context{}); } + engine->WaitForAll(); + printf("Going to push delete\n"); + // std::this_thread::sleep_for(std::chrono::seconds{1}); + for (auto&& i : oprs) { + engine->DeleteOperator(i); + } + engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var); + engine->WaitForAll(); + + printf("============= Test #2 ==============\n"); + var = engine->NewVariable(); + oprs.clear(); for (int i = 0; i < 10; ++i) { - engine->Push([i] (RunContext rctx) { Foo(rctx, i); }, - exec_ctx, vars, {}); + oprs.push_back(engine->NewOperator( + [i](mxnet::RunContext ctx, mxnet::Engine::Callback cb) { + Foo(ctx, i); + std::this_thread::sleep_for(std::chrono::milliseconds{500}); + cb(); + }, + {}, {var})); + engine->Push(oprs.at(i), mxnet::Context{}); } + // std::this_thread::sleep_for(std::chrono::seconds{1}); + engine->WaitForAll(); + for (auto&& i : oprs) { + engine->DeleteOperator(i); + } + engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var); - usleep(1000000); + printf("============= Test #3 ==============\n"); + var = engine->NewVariable(); + oprs.clear(); + engine->WaitForVar(var); + engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var); + engine->WaitForAll(); - // Test #2 - cout << "============= Test #2 ==============" << endl; - for (int i = 0; i < 10; ++i) { - engine->Push([i] (RunContext rctx) { Foo(rctx, i); }, - exec_ctx, {}, vars); + printf("============= Test #4 ==============\n"); + var = engine->NewVariable(); + oprs.clear(); + oprs.push_back(engine->NewOperator( + [](mxnet::RunContext ctx, mxnet::Engine::Callback cb) { + std::this_thread::sleep_for(std::chrono::seconds{2}); + Foo(ctx, 42); + cb(); + }, + {}, {var}, mxnet::FnProperty::kIO)); + engine->Push(oprs.at(0), mxnet::Context{}); + LOG(INFO) << "IO operator pushed, should wait for 2 seconds."; + engine->WaitForVar(var); + LOG(INFO) << "OK, here I am."; + for (auto&& i : oprs) { + engine->DeleteOperator(i); } + engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var); + engine->WaitForAll(); - usleep(1000000); + printf("============= Test #5 ==============\n"); + var = engine->NewVariable(); + oprs.clear(); + oprs.push_back(engine->NewOperator( + [](mxnet::RunContext ctx, mxnet::Engine::Callback cb) { + Foo(ctx, 42); + std::this_thread::sleep_for(std::chrono::seconds{2}); + cb(); + }, + {var}, {})); + engine->Push(oprs.at(0), mxnet::Context{}); + LOG(INFO) << "Operator pushed, should not wait."; + engine->WaitForVar(var); + LOG(INFO) << "OK, here I am."; + engine->WaitForAll(); + LOG(INFO) << "That was 2 seconds."; + for (auto&& i : oprs) { + engine->DeleteOperator(i); + } + engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var); + engine->WaitForAll(); + var = nullptr; + oprs.clear(); - // Test #3 return 0; }