diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 8e8d9634d177..f837654d8a89 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -29,6 +29,7 @@ namespace mxnet { // forward declaration +class NDArray; namespace autograd { class AGNode; @@ -51,10 +52,11 @@ class AGNodeEntry { class AutogradRuntime; } // namespace autograd -// FIXME int64_t is not available mshadow +// enum for storage types #define CSR_IND_PTR_TYPE mshadow::kInt32 #define CSR_IDX_DTYPE mshadow::kInt32 #define ROW_SPARSE_IDX_TYPE mshadow::kInt32 +// FIXME int64_t is not available mshadow namespace csr { enum CSRAuxType {kIndPtr, kIdx}; } @@ -64,12 +66,26 @@ enum RowSparseAuxType {kIdx}; } enum NDArrayStorageType { - kUndefinedStorage, // undefined chunk + kUndefinedStorage, // undefined storage kDefaultStorage, // dense kRowSparseStorage, // row sparse kCSRStorage, // csr }; +/*! + * \brief issue an copy operation from one NDArray to another + * the two ndarray can sit on different devices + * this operation will be scheduled by the engine + * + * \param from the ndarray we want to copy data from + * \param to the target ndarray + * \param priority Priority of the action. + * \param alloc_output whether to allocate memory for the output ndarray + * \note The function name explicitly marks the order of from and to + * due to different possible convention carried by copy function. + */ +void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0, bool alloc_output = true); + /*! * \brief ndarray interface */ @@ -96,16 +112,17 @@ class NDArray { Mkl_mem_ = std::make_shared(); #endif } - /*! \brief constructor for NDArray with chunk type + /*! \brief constructor for NDArray with storage type */ NDArray(const NDArrayStorageType storage_type, const TShape &shape, Context ctx, bool delay_alloc = true, int dtype = mshadow::default_type_flag, std::vector aux_types = {}) : shape_(shape), offset_(0), dtype_(dtype), entry_({nullptr, 0, 0}) { + // Assign default aux types if not given if (aux_types.size() == 0) { if (storage_type == kRowSparseStorage) aux_types = {ROW_SPARSE_IDX_TYPE}; - if (storage_type == kCSRStorage) aux_types = {CSR_IND_PTR_TYPE, CSR_IDX_DTYPE}; - CHECK_NE(storage_type, kDefaultStorage); + else if (storage_type == kCSRStorage) aux_types = {CSR_IND_PTR_TYPE, CSR_IDX_DTYPE}; + else LOG(FATAL) << "Unknown storage type"; } ptr_ = std::make_shared(ctx, delay_alloc, aux_types, storage_type); #if MKL_EXPERIMENTAL == 1 @@ -119,15 +136,16 @@ class NDArray { * make sure the memory region is available through out the life of NDArray * \param data the memory content of static data * \param dev_id the device id this tensor sits at + * \param shared_var the same var handle shared with others. + It will not be deleted during destruction. */ - NDArray(const TBlob &data, int dev_id) - : ptr_(std::make_shared(data, dev_id)), shape_(data.shape_), offset_(0), + NDArray(const TBlob &data, int dev_id, Engine::VarHandle shared_var = nullptr) + : ptr_(std::make_shared(data, dev_id, shared_var)), shape_(data.shape_), offset_(0), dtype_(data.type_flag_), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 Mkl_mem_ = std::make_shared(); #endif } - // TODO this constructor should be removed NDArray(NDArray data, const std::vector aux_data, Context ctx, NDArrayStorageType storage_type, const TShape &shape) : ptr_(std::make_shared(data, aux_data, ctx, storage_type)), shape_(shape), @@ -146,7 +164,7 @@ class NDArray { } /*! * \return the shape of underlying chunk which stores the NDArray values. - * For default storage, it is the same as shape(). For row-sparse chunks, it is the shape of + * For default storage, it is the same as shape(). For row-sparse storage, it is the shape of * the tensor which stores the non-zero values. */ inline const TShape &storage_shape() const { @@ -418,7 +436,7 @@ class NDArray { * \return NDArray in new shape and type. */ inline NDArray AsArray(const TShape &shape, int dtype) const { - CHECK(storage_type() == kDefaultStorage) << "Not implemented yet"; + CHECK_EQ(storage_type(), kDefaultStorage) << "Not implemented yet"; CHECK_GE(shape_.Size() * mshadow::mshadow_sizeof(dtype_), shape.Size() * mshadow::mshadow_sizeof(dtype)) << "NDArray.AsArray: target memory size is bigger"; @@ -451,17 +469,25 @@ class NDArray { * This is an internal function used by system that normal user should not use */ inline void CheckAndAlloc() const { + CHECK_EQ(storage_type(), kDefaultStorage); ptr_->CheckAndAlloc(); } /* ! - * \brief Alloc number of dense rows for kRowSparseStorage + * \brief Alloc memory for non-default storage * aux_shape is only known at run time */ inline void CheckAndAlloc(const std::vector &aux_shapes) const { - // probably should round up memory reservation + CHECK_NE(storage_type(), kDefaultStorage); ptr_->CheckAndAlloc(shape_, aux_shapes, dtype_); } - + inline void CheckAndAllocData(const TShape &storage_shape) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAllocData(storage_shape, dtype_); + } + inline void CheckAndAllocAuxData(size_t i, const TShape &aux_shape) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAllocAuxData(i, aux_shape); + } /*! * \brief Save list of narray into the Stream.x * \param fo The stream of output. @@ -487,13 +513,12 @@ class NDArray { // shandle is used to store the actual values in the NDArray // aux_handles store the aux data(such as indices) if it's needed by non-default storage. struct Chunk { - // every time a new element is added to a non default storage /*! \brief storage handle from storage engine. for non-default storage, shandle stores the data(value) array. */ Storage::Handle shandle; /*! \brief storage handles for aux data (e.g index) - for row_sparse, aux_handles[0] = indic + for row_sparse, aux_handles[0] = indices for csr, aux_handles[0] = indptr, aux_handles[1] = indices */ std::vector aux_handles; @@ -504,7 +529,7 @@ class NDArray { * from Storage, and do not need to be freed */ bool static_data; - /*! \brief whether allocation is delayed */ + /*! \brief whether allocation is delayed. */ bool delay_alloc; /*! \brief construct from static data */ NDArrayStorageType storage_type = kDefaultStorage; @@ -517,6 +542,8 @@ class NDArray { TShape storage_shape; // The shape of aux data. The default value for the shape is 0. std::vector aux_shapes; + // \brief skip the deletion of var handle. Usually set when shared_var is present. + bool skip_delete_var = false; /*! \brief construct a new chunk */ Chunk(TShape shape, Context ctx_, bool delay_alloc_, int dtype) @@ -528,33 +555,26 @@ class NDArray { shandle.ctx = ctx_; if (!delay_alloc_) this->CheckAndAlloc(); } - Chunk(const NDArray &nd_data, const std::vector &nd_aux, Context ctx_, + Chunk(const NDArray &nd, const std::vector &nd_aux, Context ctx_, NDArrayStorageType storage_type_) : static_data(false), delay_alloc(false), storage_type(storage_type_), ctx(ctx_) { // Vars var = Engine::Get()->NewVariable(); // Data Storage - const auto &data = nd_data.data(); + const auto &data = nd.data(); storage_shape = data.shape_; shandle.ctx = ctx; shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_); shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx); // Copy data - // TODO(haibin) refactor. Single threaded copy is slow. - nd_data.WaitToRead(); - CHECK_EQ(nd_data.storage_type(), kDefaultStorage); - CHECK_EQ(nd_data.dtype(), data.type_flag_); - CHECK_EQ(shandle.ctx.dev_mask(), cpu::kDevMask) - << "Sparse NDArray on GPU not yet supported"; - MSHADOW_TYPE_SWITCH(nd_data.dtype(), DType, { - auto copy = TBlob(static_cast(shandle.dptr), storage_shape, - shandle.ctx.dev_mask(), data.type_flag_); - mshadow::Copy(copy.FlatTo1D(), data.FlatTo1D()); - }); + // Single threaded copy may not saturate memory bandwidth + CHECK_EQ(nd.storage_type(), kDefaultStorage); + auto data_blob = TBlob(shandle.dptr, storage_shape, shandle.ctx.dev_mask(), data.type_flag_); + NDArray data_wrapper(data_blob, ctx.dev_id, var); + CopyFromTo(nd, &data_wrapper, 0, false); // Aux shapes, types and storage - storage_shape = data.shape_; CHECK_GT(storage_shape.ndim(), 0); for (size_t i = 0; i < nd_aux.size(); i++) { const auto &aux_d = nd_aux[i].data(); @@ -565,26 +585,23 @@ class NDArray { aux_handle.size = aux_shapes[i].Size() * mshadow::mshadow_sizeof(aux_types[i]); aux_handle = Storage::Get()->Alloc(aux_handle.size, aux_handle.ctx); aux_handles.emplace_back(aux_handle); - // Copy aux data - nd_aux[i].WaitToRead(); CHECK_EQ(nd_aux[i].storage_type(), kDefaultStorage); - CHECK_EQ(nd_aux[i].dtype(), aux_types[i]); - CHECK_EQ(aux_handle.ctx.dev_mask(), cpu::kDevMask) - << "Sparse NDArray on GPU not yet supported"; - MSHADOW_TYPE_SWITCH(nd_aux[i].dtype(), DType, { - auto copy = TBlob(static_cast(aux_handle.dptr), aux_shapes[i], - ctx.dev_mask(), aux_types[i]); - mshadow::Copy(copy.FlatTo1D(), aux_d.FlatTo1D()); - }); + TBlob aux_blob(aux_handle.dptr, aux_shapes[i], ctx.dev_mask(), aux_types[i]); + NDArray aux_wrapper(aux_blob, ctx.dev_id, var); + CopyFromTo(nd_aux[i], &aux_wrapper, 0, false); } } - Chunk(const TBlob &data, int dev_id) - : static_data(true), - delay_alloc(false) { + Chunk(const TBlob &data, int dev_id, Engine::VarHandle shared_var) + : static_data(true), delay_alloc(false) { CHECK(storage_type == kDefaultStorage); - var = Engine::Get()->NewVariable(); + if (shared_var == nullptr) { + var = Engine::Get()->NewVariable(); + } else { + skip_delete_var = true; + var = shared_var; + } if (data.dev_mask_ == cpu::kDevMask) { shandle.ctx = Context::CPU(); } else { @@ -609,37 +626,53 @@ class NDArray { } /*! \brief check if delay alloc is on, do alloc if not yet done */ inline void CheckAndAlloc(void) { - // Should only be used for kDefaultStorage - if (storage_type != kDefaultStorage) { - LOG(FATAL) << "CheckAndAlloc with " << storage_type; - } if (delay_alloc) { shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx); delay_alloc = false; } } - inline void CheckAndAlloc(TShape shape, const std::vector &aux_shapes, int dtype) { - CHECK_EQ(storage_type, kRowSparseStorage) << "Not yet implemented"; + inline void CheckAndAlloc(const TShape &shape, const std::vector &aux_shapes, int dtype) { // calculate size, perform allocation if (delay_alloc) { - // For row sparse storage, aux_shape indicates the number of rows to allocate + CHECK_EQ(storage_type, kRowSparseStorage) << "Not yet implemented"; + // For row sparse, aux_shape indicates the number of rows to allocate auto aux_shape = aux_shapes[0]; - CHECK_EQ(aux_shape.ndim(), 1); - auto num_rows = aux_shape[0]; CHECK_EQ(shape.ndim(), 2) << "High dim RowSparse not yet implemented"; - auto dbytes = num_rows * shape[1] * mshadow::mshadow_sizeof(dtype); - auto aux_bytes = num_rows * mshadow::mshadow_sizeof(aux_types[0]); - shandle = Storage::Get()->Alloc(dbytes, ctx); - aux_handles.push_back(Storage::Get()->Alloc(aux_bytes, ctx)); - delay_alloc = false; - // Initialize shapes - this->aux_shapes = aux_shapes; - storage_shape = shape; - storage_shape[0] = num_rows; + CheckAndAllocAuxData(rowsparse::kIdx, aux_shape); + TShape storage_shape(shape); + storage_shape[0] = aux_shape[0]; + CheckAndAllocData(storage_shape, dtype); } } + inline void CheckAndAllocData(const TShape &shape, int dtype) { + CHECK_NE(aux_shapes.size(), 0) << "data is expected to be allocated after aux_data"; + storage_shape = shape; + auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype); + shandle = Storage::Get()->Alloc(dbytes, ctx); + // delay_alloc is only set when data storage handle is present + delay_alloc = false; + } + inline void CheckAndAllocAuxData(size_t i, const TShape &shape) { + CHECK_EQ(aux_shapes.size(), aux_handles.size()); + if (aux_shapes.size() <= i) { + aux_shapes.resize(i + 1); + aux_handles.resize(i + 1); + } + // Initialize shape + aux_shapes[i] = shape; + // Init aux storage + Storage::Handle aux_handle; + if (storage_type == kRowSparseStorage) { + auto aux_bytes = shape[0] * mshadow::mshadow_sizeof(aux_types[i]); + aux_handle = Storage::Get()->Alloc(aux_bytes, ctx); + } else if (storage_type == kCSRStorage) { + LOG(FATAL) << "Not implemented"; + } + aux_handles[i] = aux_handle; + } /*! \brief destructor */ ~Chunk() { + if (skip_delete_var) return; bool skip_free = static_data || delay_alloc; Storage::Handle h = this->shandle; std::vector aux_h = this->aux_handles; @@ -669,19 +702,6 @@ class NDArray { autograd::AGNodeEntry entry_; }; -/*! - * \brief issue an copy operation from one NDArray to another - * the two ndarray can sit on different devices - * this operation will be scheduled by the engine - * - * \param from the ndarray we want to copy data from - * \param to the target ndarray - * \param priority Priority of the action. - * \note The function name explicitly marks the order of from and to - * due to different possible convention carried by copy function. - */ -void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0); - /*! * \brief Perform elementwise sum over each data from source, store result into out. diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 51c921859e26..bf9961c8234e 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -7,7 +7,6 @@ #ifndef MXNET_OP_ATTR_TYPES_H_ #define MXNET_OP_ATTR_TYPES_H_ - #include #include @@ -18,6 +17,9 @@ #include "./operator.h" #include "./ndarray.h" +#define FCOMP_EX_CPU "FComputeEx" +#define FCOMP_EX_GPU "FComputeEx" + namespace mxnet { using nnvm::NodeAttrs; @@ -64,8 +66,8 @@ using FCompute = std::function" and "FComputeEx" - * e.g FComputeEx + * \note Register under "FComputeEx" and "FComputeEx" + * Dispatched only when operators process non-default storage inputs or outputs */ using FComputeEx = std::function& ndinputs, const int& infered_num_outputs, std::vector* p_ndoutputs, - NDArrayStorageType* contains_storage_type) { - *contains_storage_type = kDefaultStorage; + int* dispatch_stype) { std::vector& ndoutputs = *p_ndoutputs; static auto& infershape = nnvm::Op::GetAttr("FInferShape"); static auto& infertype = nnvm::Op::GetAttr("FInferType"); @@ -177,11 +176,7 @@ void SetShapeType(const nnvm::Op* op, in_storage_types.push_back(i.storage_type()); } for (auto& i : ndoutputs) { - int storage_type = i.storage_type(); - if (storage_type == kUndefinedStorage) { - storage_type = -1; - } - out_storage_types.push_back(storage_type); + out_storage_types.push_back(i.storage_type()); } if (inferstorage.count(op)) { CHECK(inferstorage[op](attrs, &in_storage_types, &out_storage_types)); @@ -190,26 +185,16 @@ void SetShapeType(const nnvm::Op* op, // LOG(INFO) << "FInferStorageType not present."; } - // TODO(haibin) replace with common:: - for (auto &i : in_storage_types) { - CHECK_NE(i, -1); - if (i != kDefaultStorage) { - *contains_storage_type = static_cast(i); - break; - } - } - for (auto &i : out_storage_types) { - if (i != kDefaultStorage && i != -1) { - *contains_storage_type = static_cast(i); - break; - } - } + bool contains_non_default = common::ContainsNonDefaultStorage(in_storage_types); + contains_non_default |= common::ContainsNonDefaultStorage(out_storage_types); + int kNonDefaultStorage = -2; + *dispatch_stype = contains_non_default ? kNonDefaultStorage : kDefaultStorage; for (int i = 0; i < infered_num_outputs; ++i) { NDArrayStorageType storage_type = static_cast(out_storage_types[i]); if (ndoutputs[i].is_none()) { // If failed to infer the storage type, assume the output storage is dense - if (storage_type == kDefaultStorage || out_storage_types[i] == -1) { + if (storage_type == kDefaultStorage || out_storage_types[i] == kUndefinedStorage) { ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); } else { ndoutputs[i] = NDArray(storage_type, out_shapes[i], ctx, true, out_types[i]); @@ -265,7 +250,6 @@ void SetDependency(std::vector *p_read_vars, if (mutate.count(op)) { auxidx = mutate[op](attrs); std::sort(auxidx.begin(), auxidx.end()); - // TODO(haibin) replace with common::PrepVars for (auto& i : auxidx) { auto var = ndinputs[i].var(); write_vars.push_back(var); @@ -435,7 +419,7 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, } else { // TODO(piiswrong): infer ctx Context ctx; - NDArrayStorageType storage_type; + int storage_type; SetContext(&ctx, attrs, num_inputs, ndinputs, infered_num_outputs, ndoutputs); SetShapeType(op, attrs, ctx, ndinputs, infered_num_outputs, &ndoutputs, &storage_type); diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 66adeae16bf0..da2953c4a110 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -513,7 +513,7 @@ int MXSymbolInferStorageType(SymbolHandle sym, MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); nnvm::Graph g = Symbol2Graph(*s); - nnvm::StorageTypeVector arg_storage_types(g.indexed_graph().input_nodes().size(), -1); + nnvm::StorageTypeVector arg_storage_types(g.indexed_graph().input_nodes().size()); if (keys == nullptr && num_args != 0) { std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); CHECK_LE(num_args, read_only_args.size()); diff --git a/src/common/utils.h b/src/common/utils.h index c76abe36545d..d99c097a84c8 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -36,7 +36,6 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, namespace common { #if DMLC_USE_CXX11 -// TODO move to op_utils.h template inline void PrepDefaultBlobs(const std::vector& ndinputs, const std::vector& ndoutputs, @@ -64,22 +63,25 @@ inline void PrepDefaultBlobs(const std::vector& ndinputs, inline void PrepVars(const std::vector &nds, std::vector *vars) { for (auto& i : nds) { - auto v = i.var(); - vars->push_back(v); + vars->push_back(i.var()); } } -// Only dispatch based on input storage type for now. -inline NDArrayStorageType GetDispatchStorageType(const nnvm::StorageTypeVector& vstorage_type) { - NDArrayStorageType dispatch_storage_type = kDefaultStorage; - for (auto& i : vstorage_type) { - if (i != kDefaultStorage) { - CHECK_NE(i, -1); - dispatch_storage_type = NDArrayStorageType(i); - break; +// Check if any storage type is not default storage +inline bool ContainsNonDefaultStorage(const nnvm::StorageTypeVector& vstorage) { + for (auto& i : vstorage) { + if (i != kUndefinedStorage && i != kDefaultStorage) return true; + } + return false; +} + +inline bool ContainsDefaultStorage(const std::vector& ndarrays) { + for (auto &nd : ndarrays) { + if (nd.storage_type() == kDefaultStorage) { + return true; } } - return dispatch_storage_type; + return false; } inline FCompute GetFCompute(const Op* op, Context ctx) { @@ -94,32 +96,19 @@ inline FCompute GetFCompute(const Op* op, Context ctx) { return nullptr; } -inline FComputeEx GetFComputeEx(const Op* op, Context ctx, - NDArrayStorageType storage_type) { - static auto& fcpu_rs = nnvm::Op::GetAttr("FComputeEx"); - static auto& fgpu_rs = nnvm::Op::GetAttr("FComputeEx"); - static auto& fcpu_csr = nnvm::Op::GetAttr("FComputeEx"); - static auto& fgpu_csr = nnvm::Op::GetAttr("FComputeEx"); - if (storage_type == kDefaultStorage) return nullptr; +inline FComputeEx GetFComputeEx(const Op* op, Context ctx, int stype) { + static auto& fcpu = nnvm::Op::GetAttr(FCOMP_EX_CPU); + static auto& fgpu = nnvm::Op::GetAttr(FCOMP_EX_GPU); + if (stype == kDefaultStorage) return nullptr; if (ctx.dev_mask() == cpu::kDevMask) { - if (storage_type == kRowSparseStorage) return fcpu_rs.get(op, nullptr); - if (storage_type == kCSRStorage) return fcpu_csr.get(op, nullptr); + return fcpu.get(op, nullptr); } else if (ctx.dev_mask() == gpu::kDevMask) { - if (storage_type == kRowSparseStorage) return fgpu_rs.get(op, nullptr); - if (storage_type == kCSRStorage) return fgpu_csr.get(op, nullptr); + return fgpu.get(op, nullptr); } LOG(FATAL) << "Unknown device mask"; return nullptr; } -inline bool HasDefaultStorage(const std::vector& ndarrays) { - for (auto &nd : ndarrays) { - if (nd.storage_type() == kDefaultStorage) { - return true; - } - } - return false; -} // heuristic to dermine number of threads per GPU inline int GetNumThreadPerGPU() { diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 3cc8b01a3bba..70db4d16fee0 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -15,6 +15,7 @@ #endif #include "../common/utils.h" +#define EXEC_DISPATCH_DEBUG 0 namespace mxnet { namespace op { @@ -138,7 +139,6 @@ class BackwardOpExecutor : public OpExecutor { class FComputeExecutor : public OpExecutor { public: void Run(RunContext rctx, bool is_gpu) override { - // std::cout << "FCompute::Run" << std::endl; op_ctx.run_ctx = rctx; if (!initialized) { if (is_gpu) { @@ -159,7 +159,6 @@ class FComputeExecutor : public OpExecutor { mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); #endif - // std::cout << "FCompute::Done" << std::endl; } void Setup() override { in_array_ = in_array; @@ -184,10 +183,8 @@ class FComputeExecutor : public OpExecutor { class FComputeExExecutor : public OpExecutor { public: void Run(RunContext rctx, bool is_gpu) override { - // std::cout << "FComputeExExecutor::Run" << std::endl; op_ctx.run_ctx = rctx; fcompute_(attrs_, op_ctx, in_data_, req, out_data_); - // std::cout << "FComputeExExecutor::Done" << std::endl; } void Setup() override { in_data_ = in_array; @@ -236,10 +233,12 @@ Graph AttachOpExecs(Graph g) { if (fmutate_inputs.count(inode.source->op())) { mutate_index = fmutate_inputs[inode.source->op()](inode.source->attrs); } - NDArrayStorageType dispatch_stype = static_cast(dispatch_stypes[i]); FCompute fcompute = common::GetFCompute(inode.source->op(), vctx[i]); FComputeEx fcompute_ex = - common::GetFComputeEx(inode.source->op(), vctx[i], dispatch_stype); + common::GetFComputeEx(inode.source->op(), vctx[i], dispatch_stypes[i]); +#if EXEC_DISPATCH_DEBUG + LOG(INFO) << "dispatch type = " << dispatch_stypes[i]; +#endif if (fcreate_layer_op.count(inode.source->op())) { std::vector ishape; std::vector itype; @@ -265,12 +264,12 @@ Graph AttachOpExecs(Graph g) { mxnet::op::OpPropGetOpProperty(inode.source->attrs), mutate_index); } else if (fcompute_ex != nullptr) { -#if EXECUTOR_DEBUG +#if EXEC_DISPATCH_DEBUG LOG(INFO) << "FComputeEx for op " << inode.source->op()->name; #endif ret[i] = std::make_shared(fcompute_ex, inode.source->attrs); } else if (fcompute != nullptr) { -#if EXECUTOR_DEBUG +#if EXEC_DISPATCH_DEBUG LOG(INFO) << "FCompute for op " << inode.source->op()->name; #endif ret[i] = std::make_shared(fcompute, inode.source->attrs); diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index f32908b428d2..b23f7fa47fc9 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -23,6 +23,8 @@ const int kBadStorageID = -1; const int kExternalStorageID = -2; const int kDynamicStorageID = -3; +const int kNonDefaultStorage = -2; + /*! * \brief executor to execute an operator * This is a graph executor dependent interface diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 6415db0b5c82..bd0f2f35f7e3 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -427,7 +427,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, } arg_shapes.resize(idx.input_nodes().size(), TShape()); arg_types.resize(idx.input_nodes().size(), -1); - arg_storage_types.resize(idx.input_nodes().size(), -1); + arg_storage_types.resize(idx.input_nodes().size(), kUndefinedStorage); // other initializations g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); g = nnvm::pass::InferType(g, arg_types, "__dtype__"); @@ -435,15 +435,23 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, const auto& vstorage_type = g.GetAttr("storage_type"); // dispatch on a per op basis - nnvm::StorageTypeVector dispatch_stypes(idx.num_nodes(), -1); + nnvm::StorageTypeVector dispatch_stypes(idx.num_nodes()); for (size_t nid = 0; nid < idx.num_nodes(); nid++) { const auto& inode = idx[nid]; - nnvm::StorageTypeVector vs; - for (const auto& e : inode.inputs) { - vs.emplace_back(vstorage_type[idx.entry_id(e)]); + auto num_outputs = inode.source->num_outputs(); + auto num_inputs = inode.inputs.size(); + nnvm::StorageTypeVector vs(num_inputs + num_outputs); + for (size_t i = 0; i < num_inputs; i++) { + auto e = inode.inputs[i]; + vs[i] = vstorage_type[idx.entry_id(e)]; + CHECK_NE(vs[i], kUndefinedStorage); + } + for (uint32_t i = 0; i < num_outputs; ++i) { + uint32_t eid = idx.entry_id(nid, i); + vs[i + num_inputs] = vstorage_type[eid]; } - int dispatch_storage_type = common::GetDispatchStorageType(vs); - dispatch_stypes[nid] = dispatch_storage_type; + bool contains_non_default = common::ContainsNonDefaultStorage(vs); + dispatch_stypes[nid] = contains_non_default ? kNonDefaultStorage : kDefaultStorage; } g.attrs["dispatch_storage_types"] = std::make_shared(std::move(dispatch_stypes)); @@ -509,13 +517,11 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { NDArrayStorageType storage_type = (NDArrayStorageType) vstorage_type[eid]; CHECK_NE(vshape[eid].ndim(), 0U); CHECK_NE(vdtype[eid], -1); - // enable sparse gradient update, init NDArray based on storage_type + // init NDArray based on storage_type if (storage_type != kDefaultStorage) { - // std::cout << "Sparse NDArray for head gradient " << idx.entry_id(nid, 0) << std::endl; data_entry_[idx.entry_id(nid, 0)] = NDArray(storage_type, vshape[eid], data_context[eid], true, vdtype[eid]); } else { - // std::cout << "Dense NDArray for head gradient " << idx.entry_id(nid, 0) << std::endl; data_entry_[idx.entry_id(nid, 0)] = NDArray(vshape[eid], data_context[eid], false, vdtype[eid]); } @@ -598,10 +604,8 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet"; const NDArray& src = data_pool_.at(storage_id); data_entry_[i] = src.AsArray(vshape[i], vdtype[i]); - // std::cout << "Dense AsNDArray " << i << "\n"; } else { data_entry_[i] = NDArray(storage_type, vshape[i], vctx[i]); - // std::cout << "Sparse NDArray " << i << "\n"; } } } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 834193012780..55e1a84daeda 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -231,13 +231,13 @@ void ScalarOp(const NDArray &lhs, default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } -// FIXME NDArray storage types may be differnet -void CopyFromTo(const NDArray &from, NDArray *to, int priority) { + +void CopyFromTo(const NDArray &from, NDArray *to, int priority, bool alloc_output) { if (from.var() == to->var()) { // skip to copy to itself return; } - CHECK(from.storage_type() == to->storage_type()); + CHECK(from.storage_type() == to->storage_type()) << "Copying with different storage type"; CHECK(from.shape() == to->shape()) << "operands shape mismatch" << "from.shape = " << from.shape() << " to.shape=" << to->shape(); @@ -247,15 +247,15 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { NDArray ret = *to; int a = from.ctx().dev_mask(); int b = to->ctx().dev_mask(); - + bool alloc = alloc_output; std::vector const_vars; if (from.var() != ret.var()) const_vars.push_back(from.var()); if (a == cpu::kDevMask && b == cpu::kDevMask) { - Engine::Get()->PushSync([from, ret](RunContext ctx) { + Engine::Get()->PushSync([from, ret, alloc](RunContext ctx) { auto storage_type = from.storage_type(); if (storage_type == kDefaultStorage) { - ret.CheckAndAlloc(); + if (alloc) ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); @@ -265,7 +265,7 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { // All zeros return; } - ret.CheckAndAlloc({aux_shape}); + if (alloc) ret.CheckAndAlloc({aux_shape}); TBlob val = ret.data(); TBlob idx = ret.aux_data(rowsparse::kIdx); ndarray::Copy(from.data(), &val, @@ -280,9 +280,9 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { - Engine::Get()->PushSync([from, ret](RunContext ctx) { + Engine::Get()->PushSync([from, ret, alloc](RunContext ctx) { if (from.storage_type() != kDefaultStorage) LOG(FATAL) << "GPU not implemented yet"; - ret.CheckAndAlloc(); + if (alloc) ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); @@ -291,9 +291,9 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { }, ret.ctx(), const_vars, {ret.var()}, FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("CopyCPU2GPU")); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { - Engine::Get()->PushSync([from, ret](RunContext ctx) { + Engine::Get()->PushSync([from, ret, alloc](RunContext ctx) { if (from.storage_type() != kDefaultStorage) LOG(FATAL) << "GPU not implemented yet"; - ret.CheckAndAlloc(); + if (alloc) ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); @@ -302,9 +302,9 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2CPU")); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { - Engine::Get()->PushSync([from, ret](RunContext ctx) { + Engine::Get()->PushSync([from, ret, alloc](RunContext ctx) { if (from.storage_type() != kDefaultStorage) LOG(FATAL) << "GPU not implemented yet"; - ret.CheckAndAlloc(); + if (alloc) ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index d296d177abce..8969bb768823 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -56,22 +56,22 @@ template inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, - std::vector *out_attrs, - const AttrType& none) { + std::vector *out_attrs) { // LOG(INFO) << "ElemwiseStorageAttr for " << attrs.name; auto deduce = [&](std::vector *vec, const char *name, AttrType& result, bool fallback) { + auto &v = *vec; for (size_t i = 0; i < vec->size(); ++i) { // LOG(INFO) << "deduce " << (*vec)[i]; - CHECK_NE((*vec)[i], -1) << "ElemwiseStorageAttr assumes all input storage types are known"; - if (assign(&result, (*vec)[i]) == false && fallback) { + if (v[i] == kUndefinedStorage) { + // if input type is unknown, assume it's default storage + CHECK(assign(&v[i], kDefaultStorage)); + } else if (assign(&result, v[i]) == false && fallback) { result = kDefaultStorage; - // LOG(INFO) << "ElemwiseStorageAttr Fallback"; - return; } } }; - AttrType dattr = none; + AttrType dattr = kUndefinedStorage; deduce(in_attrs, "input", dattr, enable_fallback); if (reverse_infer) { LOG(FATAL) << "not implemented yet"; @@ -83,8 +83,8 @@ inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs, << name << ": " << "expected " << dattr << ", got " << (*vec)[i]; } }; + if (is_none(dattr)) dattr = kDefaultStorage; write(out_attrs, "output"); - if (is_none(dattr)) return false; return true; } @@ -114,20 +114,8 @@ inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; - // TODO(haibin) not doing inverse infer yet - return ElemwiseStorageAttr( - attrs, in_attrs, out_attrs, -1); -} - -// Useful for binary multiplication / division -template -inline bool ElemwiseSameStorageType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; - CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; - return ElemwiseStorageAttr( - attrs, in_attrs, out_attrs, -1); + return ElemwiseStorageAttr( + attrs, in_attrs, out_attrs); } inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, @@ -135,8 +123,8 @@ inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; - return ElemwiseAttr( - attrs, in_attrs, out_attrs, -1); + return ElemwiseStorageAttr( + attrs, in_attrs, out_attrs); } // Transfer gradient and input to FGradient function diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index e9ba5f81c339..ca96eeec8eea 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -80,6 +80,10 @@ inline bool type_is_none(const int& x) { return x == -1; } +/*! \brief check if storage type is none (-1) */ +inline bool storage_type_is_none(const int& x) { + return x == kUndefinedStorage; +} /*! * \brief Assign x to y. Checks for compatiblity when y is not empty. * Allow missing dim in both x and y (as 0). @@ -121,6 +125,21 @@ inline bool type_assign(int *y, const int& x) { return true; } +/*! + * \brief Assign x to y. Checks for compatiblity when y is not -1. + * \param y target type. + * \param x source type. + * \return whether x and y are compatible. + */ +inline bool storage_type_assign(int *y, const int& x) { + if (*y == kUndefinedStorage) { + *y = x; + return true; + } else if (*y != x && x != kUndefinedStorage) { + return false; + } + return true; +} /*! * \brief macro assign shape to out if out is unknown otherwise check consistency * Use macro so we can see the error file more clearly diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index f04f6f8030b2..f6c60950c938 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "../mshadow_op.h" #include "../elemwise_op_common.h" @@ -35,7 +36,7 @@ void BinaryCompute(const nnvm::NodeAttrs& attrs, // TODO(haibin) This is an inefficient temporary implementation // Binary Compute between two row-sparse ndarray template -void BinaryComputeExRsRs(const nnvm::NodeAttrs& attrs, +void BinaryComputeRspRsp(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -52,7 +53,6 @@ void BinaryComputeExRsRs(const nnvm::NodeAttrs& attrs, auto num_rows_r = nd_r.aux_shape(rowsparse::kIdx)[0]; // This is (roughly) the number of result rows output.CheckAndAlloc({TShape({num_rows_l + num_rows_r})}); - // LOG(INFO) << "BinaryComputeExRsRs" << output.aux_shape(rowsparse::kIdx)[0]; // Indices mshadow::Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(output.dtype(), DType, { @@ -115,14 +115,21 @@ void BinaryComputeEx(const nnvm::NodeAttrs& attrs, using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - // If any input is dense, fallback to FCompute - if (common::HasDefaultStorage(inputs)) { - FComputeExFallback(attrs, ctx, inputs, req, outputs, BinaryCompute); + CHECK_EQ(inputs.size(), 2); + CHECK_EQ(outputs.size(), 1); + if (typeid(OP) == typeid(mshadow::op::plus)) { + // If any input is dense, fallback to FCompute + if (common::ContainsDefaultStorage(inputs)) { + FComputeExFallback(attrs, ctx, inputs, req, outputs, BinaryCompute); + return; + } + CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; + CHECK_EQ(inputs[1].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; + BinaryComputeRspRsp(attrs, ctx, inputs, req, outputs); return; + } else { + LOG(FATAL) << "Not implemented"; } - // Call RsRs function - CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; - BinaryComputeExRsRs(attrs, ctx, inputs, req, outputs); } template @@ -145,7 +152,7 @@ void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, // Only implemented for _backward_add for now template -void BinaryBackwardUseNoneEx(const nnvm::NodeAttrs& attrs, +void BinaryBackwardUseNoneRsp(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -153,12 +160,9 @@ void BinaryBackwardUseNoneEx(const nnvm::NodeAttrs& attrs, using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - if (inputs[0].storage_type() == kDefaultStorage) { - LOG(FATAL) << "BinaryBackwardUseNoneEx fallback not implemented yet"; - } - // LOG(INFO) << "BinaryBackwardUseNoneEx"; - // The following code assumes LOP == mshadow_op::identity == ROP CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage); + CHECK(typeid(LOP) == typeid(mshadow_op::identity)); + CHECK(typeid(ROP) == typeid(mshadow_op::identity)); TShape shape = inputs[0].aux_shape(rowsparse::kIdx); outputs[0].CheckAndAlloc({shape}); outputs[1].CheckAndAlloc({shape}); @@ -177,6 +181,20 @@ void BinaryBackwardUseNoneEx(const nnvm::NodeAttrs& attrs, }); }); } +// Only implemented for _backward_add for now +template +void BinaryBackwardUseNoneEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + auto stype = inputs[0].storage_type(); + CHECK_EQ(stype, kRowSparseStorage) << "Not implemented yet"; + BinaryBackwardUseNoneRsp(attrs, ctx, inputs, req, outputs); +} template void BinaryBackwardUseOut(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index 8edfacc66865..7f60cb455c21 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -11,7 +11,7 @@ namespace op { MXNET_OPERATOR_REGISTER_BINARY(elemwise_add) .add_alias("_add").add_alias("_plus").add_alias("_Plus") .set_attr("FCompute", BinaryCompute) -.set_attr("FComputeEx", BinaryComputeEx) +.set_attr(FCOMP_EX_CPU, BinaryComputeEx) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_add"}) .set_attr("FInferStorageType", ElemwiseStorageType<2, 1>); @@ -30,7 +30,7 @@ NNVM_REGISTER_OP(_backward_add) }) .set_attr("FCompute", BinaryBackwardUseNone) -.set_attr("FComputeEx", +.set_attr(FCOMP_EX_CPU, BinaryBackwardUseNoneEx) .set_attr("FInferStorageType", ElemwiseStorageType<1, 2>); diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc index e7c38644380f..57414700d55a 100644 --- a/src/operator/tensor/elemwise_unary_op.cc +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -59,7 +59,7 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) .set_attr("FIgnoreInputs", [](const NodeAttrs& attrs) { return std::vector(1, 1); }) .set_attr("FCompute", IdentityCompute) -.set_attr("FComputeEx", IdentityComputeEx) +.set_attr(FCOMP_EX_CPU, IdentityComputeEx) .set_attr("FInferShape", ElemwiseShape<2, 1>) .set_attr("FInferStorageType", IdentityAttrLikeRhsStorageType) .set_attr( @@ -103,7 +103,9 @@ NNVM_REGISTER_OP(_backward_cast) .set_attr("TIsBackward", true) .set_attr("FCompute", CastCompute); -// TODO(haibin) declare backward op for cast storage. Also add FCompute(identity compute) +// TODO(haibin) declare backward op for cast storage +// Only support cast to default storage now +// Other types require add infer_storage type pass NNVM_REGISTER_OP(cast_storage) .describe(R"code(Casts tensor storage type to the new type. )code" ADD_FILELINE) @@ -112,7 +114,10 @@ NNVM_REGISTER_OP(cast_storage) .set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FComputeEx", CastStorageComputeEx) +.set_attr("FCompute", IdentityCompute) +// _backward pass +// .set_attr("FGradient", ElemwiseGradUseNone{"negative"}) +.set_attr(FCOMP_EX_CPU, CastStorageComputeEx) .add_argument("data", "NDArray-or-Symbol", "The input.") .add_arguments(CastStorageParam::__FIELDS__()); diff --git a/src/operator/tensor/elemwise_unary_op.cu b/src/operator/tensor/elemwise_unary_op.cu index c8ce17757990..c5a72b4e8c4f 100644 --- a/src/operator/tensor/elemwise_unary_op.cu +++ b/src/operator/tensor/elemwise_unary_op.cu @@ -31,6 +31,10 @@ NNVM_REGISTER_OP(Cast) NNVM_REGISTER_OP(_backward_cast) .set_attr("FCompute", CastCompute); +NNVM_REGISTER_OP(cast_storage) +.set_attr("FCompute", IdentityCompute) +.set_attr(FCOMP_EX_GPU, CastStorageComputeEx); + // negative NNVM_REGISTER_OP(negative) .set_attr("FCompute", UnaryCompute); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 19587c9ee8ad..be2d086702b5 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -48,7 +48,6 @@ void IdentityCompute(const nnvm::NodeAttrs& attrs, using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - // LOG(INFO) << "IdentityCompute"; if (req[0] == kNullOp) return; if (req[0] == kWriteInplace) { CHECK_EQ(inputs[0].dptr_, outputs[0].dptr_); return; @@ -59,10 +58,8 @@ void IdentityCompute(const nnvm::NodeAttrs& attrs, }); } -// FIXME the index is hard coded for _identity_with_attr_like_rhs op -// Only implemented for row_sparse for now template -void IdentityComputeEx(const nnvm::NodeAttrs& attrs, +void IdentityComputeRsp(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -70,17 +67,13 @@ void IdentityComputeEx(const nnvm::NodeAttrs& attrs, using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - // LOG(INFO) << "IdentityComputeEx"; NDArrayStorageType storage_type = inputs[1].storage_type(); - CHECK_EQ(storage_type, kRowSparseStorage) - << "storage type " << storage_type << " not supported yet"; + CHECK_EQ(storage_type, kRowSparseStorage); if (req[0] == kNullOp) { LOG(FATAL) << "kNullOp in IdentityComputeEx not supported yet"; - return; } if (req[0] == kWriteInplace) { LOG(FATAL) << "kWriteInplace for sparse storage not supported yet"; - // CHECK_EQ(inputs[0].dptr_, outputs[0].dptr_); return; } TShape shape = inputs[1].aux_shape(rowsparse::kIdx); if (shape.ndim() == 0) return; @@ -97,6 +90,23 @@ void IdentityComputeEx(const nnvm::NodeAttrs& attrs, }); } +// FIXME the index is hard coded for _identity_with_attr_like_rhs op +template +void IdentityComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 2); + CHECK_EQ(outputs.size(), 1); + Stream *s = ctx.get_stream(); + NDArrayStorageType stype = inputs[1].storage_type(); + CHECK_EQ(stype, kRowSparseStorage) << "Not implemented yet"; + IdentityComputeRsp(attrs, ctx, inputs, req, outputs); +} + struct CastParam : public dmlc::Parameter { // use int for enumeration int dtype; @@ -150,7 +160,7 @@ struct CastStorageParam : public dmlc::Parameter { }; template -void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, +void CastStorageComputeRspDns(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -162,9 +172,11 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1); auto out = outputs[0]; auto in = inputs[0]; - CHECK(in.storage_type() == kRowSparseStorage); + auto stype = in.storage_type(); + CHECK_EQ(stype, kRowSparseStorage); + CHECK_EQ(out.storage_type(), kDefaultStorage); MSHADOW_TYPE_SWITCH(in.dtype(), DType, { - MSHADOW_TYPE_SWITCH(in.aux_type(rowsparse::kIdx), AuxType, { + MSHADOW_TYPE_SWITCH(in.aux_type(rowsparse::kIdx), IType, { // Fill in zeros. SLOW out.data().FlatTo1D(s) = 0; // data() is not empty @@ -173,7 +185,7 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, auto in_data = in.data().FlatTo2D(s); auto out_data = out.data().FlatTo2D(s); auto num_rows = in.aux_shape(rowsparse::kIdx)[0]; - auto in_idx = in.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto in_idx = in.aux_data(rowsparse::kIdx).FlatTo1D(s); for (size_t i = 0; i < num_rows; i += 1) { mshadow::Copy(out_data[in_idx[i]], in_data[i], s); } @@ -182,6 +194,25 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, }); } +template +void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + auto stype = inputs[0].storage_type(); + if (stype == kRowSparseStorage) { + CastStorageComputeRspDns(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not implemented"; + } +} + #define MXNET_OPERATOR_REGISTER_UNARY(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index b105e950adb5..b091cbca2d9f 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -21,7 +21,7 @@ NNVM_REGISTER_OP(_zeros) .set_attr("FInferShape", InitShape) .set_attr("FInferType", InitType) .set_attr("FCompute", FillCompute) -.set_attr("FComputeEx", FillComputeEx) +.set_attr(FCOMP_EX_CPU, FillComputeEx) .add_arguments(InitOpParam::__FIELDS__()); NNVM_REGISTER_OP(_ones) diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 5f873dc21a89..1c96c1f2cf5f 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -111,24 +111,6 @@ inline bool InitType(const nnvm::NodeAttrs& attrs, return true; } - -template -void FillComputeEx(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - if (value == 0 && outputs[0].storage_type() != kDefaultStorage) { - return; - } - CHECK_EQ(value, 0) << "Not implemented yet"; - CHECK_EQ(inputs.size(), 0); - CHECK_NE(outputs[0].storage_type(), kDefaultStorage); -} - template void FillCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -144,6 +126,20 @@ void FillCompute(const nnvm::NodeAttrs& attrs, }); } +template +void FillComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(outputs.size(), 1); + CHECK_EQ(inputs.size(), 0); + auto stype = outputs[0].storage_type(); + CHECK_EQ(value, 0) << "Not implemented yet"; +} template void RangeCompute(const nnvm::NodeAttrs& attrs, diff --git a/tests/cpp/ndarray_test.cc b/tests/cpp/ndarray_test.cc index e2aacfa9987e..d3f73ccde40d 100644 --- a/tests/cpp/ndarray_test.cc +++ b/tests/cpp/ndarray_test.cc @@ -52,21 +52,21 @@ NDArray Convert(NDArrayStorageType type, NDArray src) { // TODO provide type in attrs, which is empty now OpContext op_ctx; op_ctx.run_ctx = ctx; - std::vector inputs({src}), outputs({converted}); - op::CastStorageComputeEx({}, op_ctx, inputs, {}, outputs); + if (src.storage_type() == kRowSparseStorage) { + std::vector inputs({src}), outputs({converted}); + op::CastStorageComputeEx({}, op_ctx, inputs, {}, outputs); + } else if (src.storage_type() == kDefaultStorage) { + std::vector inputs({src.data()}), outputs({converted.data()}); + op::IdentityCompute({}, op_ctx, inputs, {kWriteTo}, outputs); + } else { + LOG(FATAL) << "unsupported storage type"; + } }, src.ctx(), {src.var()}, {converted.var()}, FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); converted.WaitToRead(); return converted; } -void BasicTest() { - Context ctx; - TShape shape({1, 2}); - NDArray nd(shape, ctx, false); - EXPECT_NE(nd.data().dptr_, nullptr); -} - void BinaryDenseSparseTest() { Context ctx = Context::CPU(); @@ -126,8 +126,6 @@ void BinaryRsRsTest() { NDArray input_nd0(raw_data0, {index0}, ctx, kRowSparseStorage, data_shape); NDArray input_nd1(raw_data1, {index1}, ctx, kRowSparseStorage, data_shape); - CheckDataRegion(input_nd0.data(), raw_data0.data()); - CheckDataRegion(input_nd1.data(), raw_data1.data()); TShape output_shape({4, 2}); NDArray output(kRowSparseStorage, output_shape, ctx); @@ -142,13 +140,16 @@ void BinaryRsRsTest() { inputs.push_back(input_nd0); inputs.push_back(input_nd1); outputs.push_back(output); - op::BinaryComputeExRsRs({}, op_ctx, inputs, req, outputs); + op::BinaryComputeRspRsp({}, op_ctx, inputs, req, outputs); }, input_nd0.ctx(), const_vars, {output.var()}, FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + // Check the data region of output ndarray NDArray dense_output = GetDenseND(output_shape, ctx, {15, 15, 10, 10, 5, 5, 0, 0}); NDArray copy = Convert(kDefaultStorage, output); + CheckDataRegion(input_nd0.data(), raw_data0.data()); + CheckDataRegion(input_nd1.data(), raw_data1.data()); CheckDataRegion(dense_output.data(), copy.data()); } @@ -156,18 +157,17 @@ void InferElemwiseStorageTest() { nnvm::NodeAttrs attrs; attrs.name = "Test op"; std::vector in_attrs({kRowSparseStorage, kDefaultStorage}); - std::vector out_attrs({-1}); + std::vector out_attrs({kUndefinedStorage}); op::ElemwiseStorageType<2, 1>(attrs, &in_attrs, &out_attrs); EXPECT_EQ(out_attrs[0], kDefaultStorage); in_attrs = {kDefaultStorage, kRowSparseStorage}; - out_attrs = {-1}; + out_attrs = {kUndefinedStorage}; op::ElemwiseStorageType<2, 1>(attrs, &in_attrs, &out_attrs); EXPECT_EQ(out_attrs[0], kDefaultStorage); } TEST(NDArray, basics) { - BasicTest(); BinaryRsRsTest(); //Wait for all operations to finish Engine::Get()->WaitForAll(); @@ -179,9 +179,8 @@ void TestDenseToDenseConversion() { Context ctx; TShape shape({2, 2}); NDArray nd = GetDenseND(shape, ctx, {1, 2, 3, 10}); - // TODO dense to dense conversion is not implemented yet - //auto nd_copy = Convert(kDefaultStorage, nd); - //CheckDataRegion(nd_copy.data(), nd.data()); + auto nd_copy = Convert(kDefaultStorage, nd); + CheckDataRegion(nd_copy.data(), nd.data()); } // sparse to dense conversion diff --git a/tests/python/unittest/test_multi_device_exec.py b/tests/python/unittest/test_multi_device_exec.py index a82141b71592..f80f40ba7c32 100644 --- a/tests/python/unittest/test_multi_device_exec.py +++ b/tests/python/unittest/test_multi_device_exec.py @@ -49,18 +49,12 @@ def check_ctx_group_sparse(mode='dense_sparse'): data2 = mx.symbol.Variable('data2') elif mode == 'dense_sparse': data1 = mx.symbol.Variable('data1') - #data1 = mx.symbol.Variable('data1', storage_type='row_sparse') data2 = mx.symbol.Variable('data2', storage_type='row_sparse') mlp = mx.symbol.elemwise_add(data1, data2, name='plus') texec = mlp.simple_bind(mx.cpu(0), data1=(3,2), data2=(3,2)) - print("Done simple_bind") output = texec.forward() - print(output[0].asnumpy()) - for arr, name in zip(texec.arg_arrays, mlp.list_arguments()): - pass - ''' This tests the simple bind function ''' diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 1b1b3ac2c13b..0ce4d3aa0b7a 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -101,9 +101,22 @@ def test_elemwise_add_multiple_stages(): assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + ds_np) exec_test.backward(out_grads = exec_test.outputs) assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) +''' +def test_cast_storage(): + dns_np = np.array([[0, 0], [5, 10], [0, 0], [0, 0], [0, 0]]) + val = np.array([5, 10]) + idx = np.array([1]) + b = mx.nd.array(idx, dtype=np.int32) + sp_nd = mx.sparse_nd.array(val, [b], 'row_sparse', (5,2)) + var = mx.symbol.Variable('sp_data', storage_type='row_sparse') + # 1 for row_storage type + test = mx.symbol.cast_storage(var, storage_type=1) + check_symbolic_forward(test, {'sp_data':sp_nd}, [dns_np]) +''' if __name__ == '__main__': test_elemwise_add_dense() test_elemwise_add_dense_sparse() test_elemwise_add_sparse_sparse() test_elemwise_add_multiple_stages() + test_cast_storage()