From 7c6eb9a28841c77002d37cbd6d282c860d99e565 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 15 Sep 2015 16:46:07 -0600 Subject: [PATCH] fix threaed engine per device --- src/c_api.cc | 6 ------ src/engine/threaded_engine_perdevice.cc | 8 ++++++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/c_api.cc b/src/c_api.cc index 9b31b8e47641..6427a6357c90 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -180,12 +180,6 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, API_END(); } -int MXEngineWaitAll() { - API_BEGIN(); - Engine::Get()->WaitForAll(); - API_END(); -} - // NOTE: return value is added in API_END int MXNDArrayCreateNone(NDArrayHandle *out) { API_BEGIN(); diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 606103f9e7ab..0a3da50e69be 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -42,13 +42,17 @@ class ThreadedEnginePerDevice : public ThreadedEngine { protected: void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { + const Context& ctx = opr_block->ctx; if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { - CHECK_EQ(opr_block->ctx.dev_mask, cpu::kDevMask); + if (ctx.dev_mask == gpu::kDevMask) { + #if MXNET_USE_CUDA + mshadow::SetDevice(ctx.dev_id); + #endif + } RunContext run_ctx; run_ctx.stream = nullptr; this->ExecuteOprBlock(run_ctx, opr_block); } else { - const Context& ctx = opr_block->ctx; if (ctx.dev_mask == cpu::kDevMask) { cpu_worker_.task_queue.Push(opr_block); } else {