From 9db62c0b314d221db323f706173943135e861a78 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Sun, 30 Apr 2017 06:31:04 +0800 Subject: [PATCH] merge with 38f7c5584016e92ba1e0ee1b00ea6632740f67ce compiles on GPU update check alloc: Checkpoint. Pass elem-sum gpu test bug fix for copyfromto. sparse sgd test pass on gpu inefficient implementation for csr copy update submodule fix lint Simple bind with infer storage type (#32) * Symbol binding for sparse tensor development. (#31) * Initial checkin * Add init functions for simple bind in graph_executor * Add simple_bind c_api * Add simple bind c-api * Assign zeros to in_args, arg_grads, and aux_states * Add simple_bind2 python interface * Fix python interface bugs * Interface changes * Fix * Fix core dump * Add bind_ith_exec c_api * Change simple_bind2 * Fix seg fault * Finish simple_bind * Change _bind_ith_exec * Refactor simple_bind initialization flow for bind * Consolidate bind and simple_bind graph init flow * Fix bug * Clean up * Add comments * Clean up * Clean up * Minor correction * Rename APIs in graph executor * Refactor * Rebase * Delete deprecated functions * Move more front-end work to backend * Bug fix * Fix failed tests * Minor fix * Fix lint * Fix lint * Revert unnecessary changes * Revert * Revert * Clean up * Fix lint Conflicts: python/mxnet/symbol.py src/executor/graph_executor.cc * Add inferstorage to graph executor * re-enable tests for sparse embedding with simple_bind * type switch fix in sparse embedding" ; change `default` to `default_storage` for cast storage op (#33) * change default to default_storage * disable cpp test build temporarily attempt to fix windows build error, and fix lint (#34) update nnvm submodule (#37) Scipy build (#38) * update nnvm submodule * add scipy pip install for dockerfile Python3 unit tests (#39) * change xrange to range for python3 compatiblity" * remove more xrange from tests replace long with int for python3 (#40) fix the rest of TShape constructor errors (#41) fix lint (#42) fix wrong usage of mshadow::Shape1" (#43) implementation for Csr slice on cpu (#36) * CPU implementation for CSR remove seg_len from csr slice add some docs for slice csr change indptr, values, etc to be private member bug fix in sparse embedding update nnvm submoduel fix lint update unit test for sparse nd" * add const for SliceCsrIndPtr kernel Fix sparse dot according to the new RSP definition (#35) * Fix csr dot dns * Fix sparse dot * Add fallback and test cases for dot(csr, dns)=dns * Add int type switch * Fix * Fix * Fix update mshadow submodule (#44) Fix dns to rsp (#46) fix lint (#47) add runtime storage fallback detection" (#48) * add runtime storage fallback detection" * replace cast storage ex with cast storage impl Fm example (#45) * update csr slice logic to avoid confusion. add more exmaples. * add hint to module.update * more testcases(fallback) for sparse_nd * add to_csr() and to_rsp() method. More unit test (fallback now) * add fm test. fix lint * register sparse sgd under Optim.SGD * update dmlc-core submoduel * change indptr to _indptr temporarily. add const ref to fname --- Jenkinsfile | 6 +- Makefile | 3 +- dmlc-core | 2 +- include/mxnet/c_api.h | 124 +++ include/mxnet/executor.h | 33 + include/mxnet/ndarray.h | 517 +++++++++++- include/mxnet/op_attr_types.h | 15 +- include/mxnet/storage.h | 4 +- mshadow | 2 +- nnvm | 2 +- python/mxnet/__init__.py | 2 + python/mxnet/contrib/autograd.py | 2 + python/mxnet/executor.py | 14 +- python/mxnet/kvstore.py | 7 +- python/mxnet/model.py | 39 +- python/mxnet/module/__init__.pyc | Bin 0 -> 621 bytes .../__pycache__/__init__.cpython-34.pyc | Bin 0 -> 560 bytes .../__pycache__/base_module.cpython-34.pyc | Bin 0 -> 37653 bytes .../bucketing_module.cpython-34.pyc | Bin 0 -> 16988 bytes .../__pycache__/executor_group.cpython-34.pyc | Bin 0 -> 25082 bytes .../module/__pycache__/module.cpython-34.pyc | Bin 0 -> 24713 bytes .../__pycache__/python_module.cpython-34.pyc | Bin 0 -> 13826 bytes .../sequential_module.cpython-34.pyc | Bin 0 -> 15162 bytes python/mxnet/module/base_module.py | 12 +- python/mxnet/module/base_module.pyc | Bin 0 -> 40120 bytes python/mxnet/module/bucketing_module.py | 4 +- python/mxnet/module/bucketing_module.pyc | Bin 0 -> 18582 bytes python/mxnet/module/executor_group.py | 81 +- python/mxnet/module/executor_group.pyc | Bin 0 -> 22701 bytes python/mxnet/module/module.py | 12 +- python/mxnet/module/module.pyc | Bin 0 -> 26694 bytes python/mxnet/module/python_module.py | 2 +- python/mxnet/module/python_module.pyc | Bin 0 -> 14763 bytes python/mxnet/module/sequential_module.py | 4 +- python/mxnet/module/sequential_module.pyc | Bin 0 -> 15831 bytes python/mxnet/ndarray.py | 71 +- python/mxnet/sparse_ndarray.py | 641 +++++++++++++++ python/mxnet/symbol.py | 359 ++++++-- python/mxnet/test_utils.py | 59 +- src/c_api/c_api.cc | 81 ++ src/c_api/c_api_common.h | 2 + src/c_api/c_api_executor.cc | 326 ++++++++ src/c_api/c_api_ndarray.cc | 142 +++- src/c_api/c_api_symbolic.cc | 53 +- src/common/utils.h | 95 +++ src/executor/attach_op_execs_pass.cc | 133 ++- src/executor/exec_pass.h | 10 +- src/executor/graph_executor.cc | 773 +++++++++++++++--- src/executor/graph_executor.h | 90 +- src/executor/inplace_addto_detect_pass.cc | 2 + src/ndarray/ndarray.cc | 106 ++- src/ndarray/ndarray_function-inl.h | 61 +- src/operator/elemwise_op_common.h | 76 ++ src/operator/operator_common.h | 19 + src/operator/optimizer_op-inl.h | 163 ++++ src/operator/optimizer_op.cc | 9 +- src/operator/optimizer_op.cu | 6 +- .../elemwise_binary_broadcast_op_basic.cc | 1 + src/operator/tensor/elemwise_binary_op.h | 162 +++- .../tensor/elemwise_binary_op_basic.cc | 9 +- .../tensor/elemwise_binary_op_basic.cu | 7 +- src/operator/tensor/elemwise_unary_op.cc | 23 + src/operator/tensor/elemwise_unary_op.cu | 8 +- src/operator/tensor/elemwise_unary_op.h | 446 +++++++++- src/operator/tensor/indexing_op.cc | 34 + src/operator/tensor/indexing_op.h | 128 +++ src/operator/tensor/init_op.cc | 1 + src/operator/tensor/init_op.cu | 3 +- src/operator/tensor/init_op.h | 48 +- src/operator/tensor/matrix_op-inl.h | 299 +++++++ src/operator/tensor/matrix_op.cc | 17 + src/operator/tensor/matrix_op.cu | 7 +- .../ci_build/install/ubuntu_install_python.sh | 4 +- tests/cpp/engine/threaded_engine_test.cc | 8 +- tests/cpp/ndarray_test.cc | 245 ++++++ tests/cpp/test_utils.h | 105 +++ tests/cpp/unittest.mk | 2 +- tests/python/unittest/test_executor.py | 2 +- tests/python/unittest/test_infer_shape.py | 32 + tests/python/unittest/test_module.py | 69 ++ .../python/unittest/test_multi_device_exec.py | 31 + tests/python/unittest/test_ndarray.py | 1 + tests/python/unittest/test_operator.py | 1 - tests/python/unittest/test_optimizer.py | 115 ++- tests/python/unittest/test_sparse_ndarray.py | 273 +++++++ tests/python/unittest/test_sparse_operator.py | 198 +++++ 86 files changed, 5911 insertions(+), 532 deletions(-) create mode 100644 python/mxnet/module/__init__.pyc create mode 100644 python/mxnet/module/__pycache__/__init__.cpython-34.pyc create mode 100644 python/mxnet/module/__pycache__/base_module.cpython-34.pyc create mode 100644 python/mxnet/module/__pycache__/bucketing_module.cpython-34.pyc create mode 100644 python/mxnet/module/__pycache__/executor_group.cpython-34.pyc create mode 100644 python/mxnet/module/__pycache__/module.cpython-34.pyc create mode 100644 python/mxnet/module/__pycache__/python_module.cpython-34.pyc create mode 100644 python/mxnet/module/__pycache__/sequential_module.cpython-34.pyc create mode 100644 python/mxnet/module/base_module.pyc create mode 100644 python/mxnet/module/bucketing_module.pyc create mode 100644 python/mxnet/module/executor_group.pyc create mode 100644 python/mxnet/module/module.pyc create mode 100644 python/mxnet/module/python_module.pyc create mode 100644 python/mxnet/module/sequential_module.pyc create mode 100644 python/mxnet/sparse_ndarray.py create mode 100644 tests/cpp/ndarray_test.cc create mode 100644 tests/cpp/test_utils.h create mode 100644 tests/python/unittest/test_sparse_ndarray.py create mode 100644 tests/python/unittest/test_sparse_operator.py diff --git a/Jenkinsfile b/Jenkinsfile index 2f4406856288..b0bc2626266a 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -205,9 +205,9 @@ del /Q *.7z // Python unittest for CPU def python_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/unittest" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/unittest" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest" - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/train" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/train" } } @@ -215,7 +215,7 @@ def python_ut(docker_type) { // both CPU and GPU def python_gpu_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/gpu" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/gpu" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/gpu" } } diff --git a/Makefile b/Makefile index 12da6419873e..99fe4e96da89 100644 --- a/Makefile +++ b/Makefile @@ -44,8 +44,9 @@ ifeq ($(DEV), 1) endif # CFLAGS for debug +# FIXME(haibin) temporarily turn on -DDMLC_LOG_FATAL_THROW for debug ifeq ($(DEBUG), 1) - CFLAGS += -g -O0 + CFLAGS += -g -O0 -DDMLC_LOG_FATAL_THROW=1 else CFLAGS += -O3 -DNDEBUG=1 endif diff --git a/dmlc-core b/dmlc-core index a6c5701219e6..fc66c6241f02 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit a6c5701219e635fea808d264aefc5b03c3aec314 +Subproject commit fc66c6241f0278c619ed3c25b895bda0e7de99fd diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 1b112abe2ba9..c8c8afd7522b 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -244,6 +244,38 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape, int delay_alloc, int dtype, NDArrayHandle *out); + + +/*! + * \brief create an empty sparse NDArray with specified shape and data type + * \param storage_type the storage type of the ndarray + * \param shape the pointer to the shape + * \param ndim the dimension of the shape + * \param dev_type device type, specify device we want to take + * \param dev_id the device id of the specific device + * \param delay_alloc whether to delay allocation until + * the narray is first mutated + * \param dtype data type of created array + * \param num_aux the number of aux data to support this ndarray + * \param aux_type data type of the aux data for the created array + * \param aux_ndims the dimension of the shapes of aux data + * \param aux_shape the shapes of aux data + * \param out the returning handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out); + /*! * \brief create a NDArray handle that is loaded from raw bytes. * \param buf the head of the raw bytes @@ -356,6 +388,19 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_begin, mx_uint slice_end, NDArrayHandle *out); + +/*! + * \brief Slice the NDArray with non-default storage along axis 0. + * \param handle the handle to the NDArray + * \param slice_begin The beginning index of slice + * \param slice_end The ending index of slice + * \param out The NDArrayHandle of sliced NDArray + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArraySliceEx(NDArrayHandle handle, + mx_uint slice_begin, + mx_uint slice_end, + NDArrayHandle out); /*! * \brief Index the NDArray along axis 0. * \param handle the handle to the NDArray @@ -366,6 +411,13 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, MXNET_DLL int MXNDArrayAt(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out); + +/*! + * \brief get the storage type of the array + */ +MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle, + int *out_storage_type); + /*! * \brief Reshape the NDArray. * \param handle the handle to the narray @@ -404,6 +456,26 @@ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, */ MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle, int *out_dtype); + +/*! + * \brief get the type of the ith aux data in NDArray + * \param handle the handle to the narray + * \param i the index of the aux data + * \param out_type pointer holder to get type of aux data + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle, + mx_uint i, + int *out_type); + +// Get the ith aux data blob wrapped in an NDArray +MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle, + mx_uint i, + NDArrayHandle *out); + +// Get the data blob wrapped in an NDArray +MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle, + NDArrayHandle *out); /*! * \brief get the context of the NDArray * \param handle the handle to the narray @@ -935,6 +1007,25 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, mx_uint *aux_type_size, const int **aux_type_data, int *complete); + + + + +/*! + * \brief infer storage type of unknown input types given the known one. + */ +MXNET_DLL int MXSymbolInferStorageType(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int *arg_storage_type_data, + mx_uint *in_storage_type_size, + const int **in_storage_type_data, + mx_uint *out_storage_type_size, + const int **out_storage_type_data, + mx_uint *aux_storage_type_size, + const int **aux_storage_type_data, + int *complete); + //-------------------------------------------- // Part 4: Executor interface //-------------------------------------------- @@ -1081,6 +1172,39 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle, NDArrayHandle *aux_states, ExecutorHandle shared_exec, ExecutorHandle *out); + +MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const mx_uint* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const mx_uint num_shared_arg_names, + const char** shared_arg_name_list, + mx_uint* shared_buffer_len, + const char*** shared_buffer_name_list, + NDArrayHandle** shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out); /*! * \brief set a call back to notify the completion of operation */ diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index cf71666826ab..5856b87cf859 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -69,6 +69,21 @@ class Executor { * \return array of outputs in the executor. */ virtual const std::vector &outputs() const = 0; + /*! + * \brief get input argument map, key is arg name, value is arg's NDArray. + * \return input argument map in the executor. + */ + virtual const std::unordered_map& in_arg_map() const = 0; + /*! + * \brief get input argument graident map, key is arg name, value is gradient's NDArray. + * \return input argument gradient map in the executor. + */ + virtual const std::unordered_map& arg_grad_map() const = 0; + /*! + * \brief get aux state map, key is arg name, value is aux state's NDArray. + * \return aux state map in the executor. + */ + virtual const std::unordered_map& aux_state_map() const = 0; /*! * \brief Create an operator by bind symbol with context and arguments. * If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. @@ -91,6 +106,24 @@ class Executor { const std::vector &grad_req_type, const std::vector &aux_states, Executor* shared_exec = NULL); + + static Executor* SimpleBind(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& group2ctx, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, + const std::vector& grad_req_types, + const std::unordered_set& param_names, + std::vector* in_args, + std::vector* arg_grads, + std::vector* aux_states, + std::unordered_map* + shared_data_arrays = nullptr, + Executor* shared_exec = nullptr); /*! * \brief the prototype of user-defined monitor callback */ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index ea38909d07f1..d01352e795e4 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -28,8 +28,22 @@ #endif namespace mxnet { +// forward declarations +class NDArray; + +namespace op { +template +void FillZerosRspImpl(mshadow::Stream *s, NDArray *dst); + +template +void CastStorageComputeImpl(mshadow::Stream *s, const NDArray& input, const NDArray& output); +}; + +namespace ndarray { +template +void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx); +}; -// forward declaration namespace autograd { class AGNode; @@ -52,6 +66,27 @@ class AGNodeEntry { class AutogradRuntime; } // namespace autograd +// enum for storage types +#define CSR_IND_PTR_TYPE mshadow::kInt32 +#define CSR_IDX_DTYPE mshadow::kInt32 +#define ROW_SPARSE_IDX_TYPE mshadow::kInt32 +// FIXME int64_t is not available mshadow +namespace csr { +enum CSRAuxType {kIndPtr, kIdx}; +} + +namespace rowsparse { +enum RowSparseAuxType {kIdx}; +} + +enum NDArrayStorageType { + kUndefinedStorage = -1, // undefined storage + kDefaultStorage, // dense + kRowSparseStorage, // row sparse + kCSRStorage, // csr +}; + + /*! * \brief ndarray interface */ @@ -72,10 +107,55 @@ class NDArray { */ NDArray(const TShape &shape, Context ctx, bool delay_alloc = false, int dtype = mshadow::default_type_flag) - : ptr_(std::make_shared(shape.Size(), ctx, delay_alloc, dtype)), + : ptr_(std::make_shared(shape, ctx, delay_alloc, dtype)), shape_(shape), offset_(0), dtype_(dtype), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 Mkl_mem_ = std::make_shared(); +#endif + } + /*! \brief constructor for NDArray with storage type + */ + NDArray(const NDArrayStorageType storage_type, const TShape &shape, Context ctx, + bool delay_alloc = true, int dtype = mshadow::default_type_flag, + std::vector aux_types = {}, std::vector aux_shapes = {}, + TShape storage_shape = TShape(mshadow::Shape1(0))) + : shape_(shape), offset_(0), dtype_(dtype), entry_({nullptr, 0, 0}) { + // Assign default aux types if not given + if (aux_types.size() == 0) { + if (storage_type == kRowSparseStorage) { + aux_types = {ROW_SPARSE_IDX_TYPE}; + } else if (storage_type == kCSRStorage) { + aux_types = {CSR_IND_PTR_TYPE, CSR_IDX_DTYPE}; + } else { + LOG(FATAL) << "Unknown storage type" << storage_type; + } + } + // Assign default shapes if not given + // unknown shapes are intialized as {0} such that Size() would return 0 + if (aux_shapes.size() == 0) { + if (storage_type == kRowSparseStorage) { + aux_shapes = {TShape(mshadow::Shape1(0))}; + } else if (storage_type == kCSRStorage) { + // aux shapes for indptr and indices + aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))}; + } else { + LOG(FATAL) << "Unknown storage type" << storage_type; + } + } + if (storage_shape.Size() == 0) { + if (storage_type == kRowSparseStorage) { + storage_shape = shape; + storage_shape[0] = aux_shapes[rowsparse::kIdx][0]; + } else if (storage_type == kCSRStorage) { + storage_shape = aux_shapes[csr::kIdx]; + } else { + LOG(FATAL) << "Unknown storage type" << storage_type; + } + } + ptr_ = std::make_shared(storage_type, storage_shape, ctx, delay_alloc, + dtype, aux_types, aux_shapes); +#if MKL_EXPERIMENTAL == 1 + Mkl_mem_ = std::make_shared(); #endif } /*! @@ -84,29 +164,108 @@ class NDArray { * make sure the memory region is available through out the life of NDArray * \param data the memory content of static data * \param dev_id the device id this tensor sits at + * \param shared_var the same var handle shared with others. + It will not be deleted during destruction. */ - NDArray(const TBlob &data, int dev_id) - : ptr_(std::make_shared(data, dev_id)), shape_(data.shape_), offset_(0), + NDArray(const TBlob &data, int dev_id, Engine::VarHandle shared_var = nullptr) + : ptr_(std::make_shared(data, dev_id, shared_var)), shape_(data.shape_), offset_(0), dtype_(data.type_flag_), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 Mkl_mem_ = std::make_shared(); #endif } + /*! - * \return the shape of current NDArray + * \return the shape of current NDArray. */ inline const TShape &shape() const { return shape_; } + /*! + * \return the shape of underlying chunk which stores the NDArray values. + * For default storage, it is the same as shape(). For row-sparse storage, it is the shape of + * the tensor which stores the non-zero values. + */ + inline const TShape &storage_shape() const { + CHECK(ptr_ != nullptr); + return ptr_->storage_shape; + } + + /*! + * \brief For sparse operations, the storage shape is an estimated value + * in the beginning for allocating enough capacity for the final result. + * After the operation is done, the exact size of the shape is known + * and need to be reset using this function. For example, adding + * two CSRs with nnz1 and nnz2 as their numbers of non-zero values, respectively, + * would allocate the array of size nnz1+nnz2 first and get the final + * nnz that is smaller than nnz1+nnz2. Therefore, the storage shape's size + * needs to be shrunk from nnz1+nnz2 to nnz. + */ + inline void SetStorageShape(const TShape& sshape) { + CHECK(storage_type() != kDefaultStorage); + ptr_->storage_shape = sshape; + } + + /*! + * \return the shape of aux data at ith index. If it doesn't exist, return an empty one. + */ + inline const TShape aux_shape(size_t i) const { + CHECK(storage_type() != kDefaultStorage); + return ptr_->aux_shapes[i]; + } + + /*! + * \brief For a sparse operation on a csr matrix for example, + * the size of the column index array + * is an estimated value in the beginning for allocating enough capacity + * for the final result. After the operation is done, the exact size of + * the shape is known and need to be reset using this function. + */ + inline void SetAuxShape(size_t i, const TShape& shape) const { + ptr_->aux_shapes[i] = shape; + } + /*! * \return the data TBlob */ inline TBlob data() const { - CheckAndAlloc(); + CHECK(ptr_ != nullptr); TBlob res; - MSHADOW_TYPE_SWITCH(dtype_, DType, { - res = TBlob(static_cast(ptr_->shandle.dptr) - + offset_, shape_, ptr_->shandle.ctx.dev_mask()); + TShape shape = shape_; + auto stype = storage_type(); + if (stype == kDefaultStorage) CheckAndAlloc(); + MSHADOW_TYPE_SWITCH(dtype(), DType, { + auto dptr = static_cast(ptr_->shandle.dptr); + if (stype == kDefaultStorage) { + dptr += offset_; + } else if (stype == kCSRStorage || stype == kRowSparseStorage) { + shape = storage_shape(); + } else { + LOG(FATAL) << "unknown storage type " << stype; + } + res = TBlob(dptr, shape, ptr_->shandle.ctx.dev_mask(), dtype()); + }); +#if MKL_EXPERIMENTAL == 1 + res.Mkl_mem_ = Mkl_mem_; +#endif + return res; + } + /*! + * \return the aux TBlob + */ + inline TBlob aux_data(size_t i) const { + auto stype = storage_type(); + TBlob res; + auto shape = aux_shape(i); + auto type = aux_type(i); + MSHADOW_TYPE_SWITCH(type, DType, { + auto dptr = static_cast(ptr_->aux_handles[i].dptr); + if (stype == kRowSparseStorage || stype == kCSRStorage) { + CHECK_EQ(offset_, 0); + } else { + LOG(FATAL) << "Unexpected storage type"; + } + res = TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type); }); #if MKL_EXPERIMENTAL == 1 res.Mkl_mem_ = Mkl_mem_; @@ -117,6 +276,7 @@ class NDArray { * \return a chunk of raw data in TBlob */ inline TBlob raw_data(index_t offset, index_t length) const { + CHECK(storage_type() == kDefaultStorage); CheckAndAlloc(); TBlob res; TShape raw_shape(1); @@ -142,10 +302,30 @@ class NDArray { inline int dtype() const { return dtype_; } + inline int aux_type(size_t i) const { + CHECK(!is_none()); + return ptr_->aux_types[i]; + } + inline NDArrayStorageType storage_type() const { + if (is_none()) return kUndefinedStorage; + return ptr_->storage_type; + } /*! \return whether this ndarray is not initialized */ inline bool is_none() const { return ptr_.get() == nullptr; } + // returns true if a sparse ndarray's aux_data and storage are initialized + inline bool storage_initialized() const { + if (is_none()) return false; + auto stype = storage_type(); + CHECK_NE(stype, kDefaultStorage); + if (stype == kRowSparseStorage || stype == kCSRStorage) { + return aux_shape(0).Size() != 0; + } else { + LOG(FATAL) << "Unknown storage type"; + } + return true; + } /*! * \brief Block until all the pending write operations with respect * to current NDArray are finished, and read can be performed. @@ -279,17 +459,38 @@ class NDArray { void SyncCopyToCPU(void *data, size_t size) const; /*! * \brief Slice a NDArray - * \param begin begin index in first dim - * \param end end index in first dim + * \param begin begin index in first dim (inclusive) + * \param end end index in first dim (exclusive) * \return sliced NDArray */ NDArray Slice(index_t begin, index_t end) const; + + /*! + * \brief Slice a NDArray with non-default storage + * \param begin begin index in first dim (inclusive) + * \param end end index in first dim (exclusive) + * \return sliced NDArray + */ + void SliceEx(index_t begin, index_t end, NDArray *dst) const; /*! * \brief Index a NDArray * \param idx the index * \return idx-th sub array NDArray */ NDArray At(index_t idx) const; + // Wrap the tblob of aux data into an NDArray which shares the same variable with the + // current one. + inline const NDArray aux_ndarray(size_t i) const { + CHECK_NE(storage_type(), kDefaultStorage); + CHECK(i < ptr_->aux_shapes.size()); + return NDArray(aux_data(i), ctx().dev_id, var()); + } + // Wrap the tblob of data into an NDArray which shares the same variable with the + // current one. + inline const NDArray data_ndarray() const { + CHECK_NE(storage_type(), kDefaultStorage); + return NDArray(data(), ctx().dev_id, var()); + } /*! * \brief Create a NDArray that shares memory with current one * The new array must have smaller memory size than the current array. @@ -298,6 +499,7 @@ class NDArray { * \return NDArray in new shape and type. */ inline NDArray AsArray(const TShape &shape, int dtype) const { + CHECK_EQ(storage_type(), kDefaultStorage) << "Not implemented yet"; CHECK_GE(shape_.Size() * mshadow::mshadow_sizeof(dtype_), shape.Size() * mshadow::mshadow_sizeof(dtype)) << "NDArray.AsArray: target memory size is bigger"; @@ -323,8 +525,25 @@ class NDArray { * This is an internal function used by system that normal user should not use */ inline void CheckAndAlloc() const { + CHECK_EQ(storage_type(), kDefaultStorage); ptr_->CheckAndAlloc(); } + /* ! + * \brief Alloc memory for non-default storage + * aux_shape is only known at run time + */ + inline void CheckAndAlloc(const std::vector &aux_shapes) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAlloc(shape_, aux_shapes, dtype_); + } + inline void CheckAndAllocData(const TShape &storage_shape) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAllocData(storage_shape, dtype_); + } + inline void CheckAndAllocAuxData(size_t i, const TShape &aux_shape) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAllocAuxData(i, aux_shape); + } /*! * \brief Save list of narray into the Stream.x * \param fo The stream of output. @@ -347,43 +566,99 @@ class NDArray { private: friend class autograd::AutogradRuntime; /*! \brief the real data chunk that backs NDArray */ + // shandle is used to store the actual values in the NDArray + // aux_handles store the aux data(such as indices) if it's needed by non-default storage. struct Chunk { - /*! \brief storage handlefrom storage engine */ + /*! \brief storage handle from storage engine. + for non-default storage, shandle stores the data(value) array. + */ Storage::Handle shandle; + /*! \brief storage handles for aux data (e.g index) + for row_sparse, aux_handles[0] = indices + for csr, aux_handles[0] = indptr, aux_handles[1] = indices + */ + std::vector aux_handles; /*! \brief variable from engine */ Engine::VarHandle var; /*! * \brief if this is true, this means the data do not come * from Storage, and do not need to be freed */ + /*! \brief construct from static data */ bool static_data; - /*! \brief whether allocation is delayed */ + /*! \brief whether data allocation is delayed. This doesn't indicate whether aux data + allocation is delayed. */ bool delay_alloc; + // the type of the storage. The storage_type is never kUndefinedStorage once the chunk + // is constructed. + NDArrayStorageType storage_type = kDefaultStorage; + /*! \brief type of aux */ + std::vector aux_types; + // context of data + Context ctx; + // The shape of the chunk data. + // This might not be the same shape as the NDArray, since the storage may be sparse. + // The default value for storage_shape is {0} when an empty non-default NDArray is created. + TShape storage_shape; + // The shape of aux data. The default value for the shape depends on the type of storage. + // If aux_shapes[i].Size() is zero, aux data i is empty. + std::vector aux_shapes; + // \brief skip the deletion of var handle. Usually set when shared_var is present. + bool skip_delete_var = false; + /*! \brief default cosntructor */ - Chunk() : static_data(true), delay_alloc(false) { - var = Engine::Get()->NewVariable(); - } - /*! \brief construct from static data */ - Chunk(const TBlob &data, int dev_id) - : static_data(true), - delay_alloc(false) { + Chunk() : static_data(true), delay_alloc(false) {} + + /*! \brief construct a new chunk */ + Chunk(TShape shape, Context ctx_, bool delay_alloc_, int dtype) + : static_data(false), delay_alloc(true), ctx(ctx_) { + auto size = shape.Size(); + storage_shape = shape; var = Engine::Get()->NewVariable(); + shandle.size = size * mshadow::mshadow_sizeof(dtype); + shandle.ctx = ctx_; + if (!delay_alloc_) this->CheckAndAlloc(); + } + + Chunk(const TBlob &data, int dev_id, Engine::VarHandle shared_var) + : static_data(true), delay_alloc(false) { + CHECK(storage_type == kDefaultStorage); + // init var + if (shared_var == nullptr) { + var = Engine::Get()->NewVariable(); + } else { + skip_delete_var = true; + var = shared_var; + } + // init ctx if (data.dev_mask_ == cpu::kDevMask) { - shandle.ctx = Context::CPU(); + ctx = Context::CPU(); } else { CHECK_EQ(data.dev_mask_, gpu::kDevMask); - shandle.ctx = Context::GPU(dev_id); + ctx = Context::GPU(dev_id); } + // init shandle + shandle.ctx = ctx; shandle.dptr = data.dptr_; shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_); + storage_shape = data.shape_; } - /*! \brief construct a new chunk */ - Chunk(uint64_t size, Context ctx, bool delay_alloc_, int dtype) - : static_data(false), delay_alloc(true) { - var = Engine::Get()->NewVariable(); - shandle.size = size * mshadow::mshadow_sizeof(dtype); + // Constructor for a non-default storage chunk + Chunk(NDArrayStorageType storage_type_, const TShape &storage_shape_, Context ctx_, + bool delay_alloc_, int dtype, const std::vector &aux_types_, + const std::vector &aux_shapes_) + : static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_), + aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_), + aux_shapes(aux_shapes_) { shandle.ctx = ctx; - if (!delay_alloc_) this->CheckAndAlloc(); + var = Engine::Get()->NewVariable(); + // aux_handles always reflect the correct number of aux data + for (size_t i = 0; i < aux_shapes.size(); i++) { + CheckAndAllocAuxData(i, aux_shapes[i]); + } + if (!delay_alloc) { + CheckAndAllocData(storage_shape, dtype); + } } /*! \brief check if delay alloc is on, do alloc if not yet done */ inline void CheckAndAlloc(void) { @@ -392,16 +667,81 @@ class NDArray { delay_alloc = false; } } - /*! \brief destructor */ - ~Chunk() { - if (static_data || delay_alloc) { - Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var); + inline void CheckAndAlloc(const TShape &shape, const std::vector &aux_shapes, + int dtype) { + // calculate size, perform allocation + if (kRowSparseStorage == storage_type) { + // For row sparse, aux_shape indicates the number of rows to allocate + auto aux_shape = aux_shapes[rowsparse::kIdx]; + CHECK_EQ(shape.ndim(), 2) << "High dim RowSparse not yet implemented"; + CheckAndAllocAuxData(rowsparse::kIdx, aux_shape); + TShape storage_shape(shape); + storage_shape[0] = aux_shape[0]; + CheckAndAllocData(storage_shape, dtype); + } else if (kCSRStorage == storage_type) { + CheckAndAllocAuxData(csr::kIndPtr, aux_shapes[csr::kIndPtr]); + CheckAndAllocAuxData(csr::kIdx, aux_shapes[csr::kIdx]); + CheckAndAllocData(aux_shapes[csr::kIdx], dtype); } else { - Storage::Handle h = this->shandle; - Engine::Get()->DeleteVariable([h](RunContext s) { - Storage::Get()->Free(h); - }, shandle.ctx, var); + LOG(FATAL) << "Storage type " << storage_type << " not implemented for CheckAndAlloc"; + } + } + // create storage handle for data based on shape and dtype, assuming ctx is set + // storage shape is also updated + // if data is already allocated, try reuse the storage. Otherwise, free the current one + // and allocate new storage + inline void CheckAndAllocData(const TShape &shape, int dtype) { + CHECK_NE(aux_shapes.size(), 0) << "data is expected to be allocated after aux_data"; + auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype); + if (shandle.size < dbytes) { + // free storage if necessary and alloc again + if (shandle.size > 0) Storage::Get()->Free(shandle); + // init storage + shandle = Storage::Get()->Alloc(dbytes, ctx); + } + // init shape + storage_shape = shape; + // delay_alloc is only set when data storage handle is present + delay_alloc = false; + } + // create storage handle for aux data based on shape + // this function assumes ctx, aux shapes and aux types are set + // aux shape is also updated + // if aux data is already allocated, try reuse the storage. Otherwise, free the current one + // and allocate new storage + inline void CheckAndAllocAuxData(size_t i, const TShape &shape) { + CHECK_EQ(shape.ndim(), 1) << "shape must be 1D in CheckAndAllocAuxData"; + CHECK_NE(storage_type, kUndefinedStorage) + << "storage type cannot be kUndefinedStorage in CheckAndAllocAuxData"; + CHECK_NE(storage_type, kDefaultStorage) + << "storage type cannot be kDefaultStorage in CheckAndAllocAuxData"; + if (aux_handles.size() <= i) { + aux_handles.resize(i + 1); } + size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]); + if (aux_handles[i].size < aux_bytes) { + // free storage if necessary and alloc again + if (aux_handles[i].size > 0) Storage::Get()->Free(aux_handles[i]); + // init aux storage + aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx); + } + // init shape + aux_shapes[i] = shape; + } + /*! \brief destructor */ + ~Chunk() { + if (skip_delete_var) return; + bool skip_free = static_data || delay_alloc; + Storage::Handle h = this->shandle; + std::vector aux_h = this->aux_handles; + Engine::Get()->DeleteVariable([h, aux_h, skip_free](RunContext s) { + if (skip_free == false) { + Storage::Get()->Free(h); + for (size_t i = 0; i < aux_h.size(); i++) { + if (aux_h[i].size > 0) Storage::Get()->Free(aux_h[i]); + } + } + }, shandle.ctx, var); } }; @@ -409,11 +749,11 @@ class NDArray { std::shared_ptr Mkl_mem_; #endif /*! \brief internal data of NDArray */ - std::shared_ptr ptr_; + std::shared_ptr ptr_{nullptr}; /*! \brief shape of current NDArray */ TShape shape_; /*! \brief offset in chunk */ - size_t offset_; + size_t offset_ = 0; /*! \brief type of data */ int dtype_ = -1; /*! \brief node entry for autograd */ @@ -428,11 +768,112 @@ class NDArray { * \param from the ndarray we want to copy data from * \param to the target ndarray * \param priority Priority of the action. + * \param alloc_output whether to allocate memory for the output ndarray * \note The function name explicitly marks the order of from and to * due to different possible convention carried by copy function. */ void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0); +// Make a copy of a CSR NDArray +template +inline void CopyFromToCsrImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + // if source storage is not initialized, fill destination with zeros + auto s = ctx.get_stream(); + if (!from.storage_initialized()) { + // TODO(haibin) implement FillZerosCsrImpl + // op::FillZerosCsrImpl(s, to); + return; + } + // Allocate storage + to->CheckAndAllocAuxData(csr::kIndPtr, from.aux_shape(csr::kIndPtr)); + to->CheckAndAllocAuxData(csr::kIdx, from.aux_shape(csr::kIdx)); + to->CheckAndAllocData(from.aux_shape(csr::kIdx)); + // FIXME This is a naive implementation for CSR copy. It, however, is + // not efficient when the source CSR is sliced. In that case, we're copying + // a superset of values and indices of the slice. + // Ideally, we should truncate the values and indices array, and adjust indptr + // accordingly. + TBlob val = to->data(); + TBlob indptr = to->aux_data(csr::kIndPtr); + TBlob idx = to->aux_data(csr::kIdx); + ndarray::Copy(from.data(), &val, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(csr::kIndPtr), &indptr, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(csr::kIdx), &idx, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of a row-sparse NDArray +template +inline void CopyFromToRspImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + // if source is zeros, fill destination with zeros, too + auto s = ctx.get_stream(); + if (!from.storage_initialized()) { + op::FillZerosRspImpl(s, to); + return; + } + auto aux_shape = from.aux_shape(rowsparse::kIdx); + to->CheckAndAlloc({aux_shape}); + TBlob val = to->data(); + TBlob idx = to->aux_data(rowsparse::kIdx); + ndarray::Copy(from.data(), &val, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(rowsparse::kIdx), &idx, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of a dense NDArray +template +inline void CopyFromToDnsImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + TBlob tmp = to->data(); + ndarray::Copy(from.data(), &tmp, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of an NDArray based on storage type +template +void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace std; + using namespace mshadow; + // if storage type doesn't match, cast the storage first + auto from_stype = from.storage_type(); + auto to_stype = to->storage_type(); + NDArray casted_nd; + if (from_stype != to_stype) { + TShape shape = from.shape(); + auto from_ctx = from.ctx(); + auto s = ctx.get_stream(); + // TODO(haibin) inplace conversion + if (to_stype == kDefaultStorage) { + casted_nd = NDArray(shape, from_ctx); + } else { + casted_nd = NDArray(to_stype, shape, from_ctx); + } + op::CastStorageComputeImpl(s, from, casted_nd); + } else { + casted_nd = from; + } + if (to_stype == kDefaultStorage) { + CopyFromToDnsImpl(casted_nd, to, ctx); + } else if (to_stype == kRowSparseStorage) { + CopyFromToRspImpl(casted_nd, to, ctx); + } else if (to_stype == kCSRStorage) { + CopyFromToCsrImpl(casted_nd, to, ctx); + } else { + LOG(FATAL) << "unknown storage type" << to_stype; + } + if (is_same::value || is_same::value) { + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + } +} /*! * \brief Perform elementwise sum over each data from source, store result into out. diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 316a90fe0841..bf9961c8234e 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -7,7 +7,6 @@ #ifndef MXNET_OP_ATTR_TYPES_H_ #define MXNET_OP_ATTR_TYPES_H_ - #include #include @@ -18,6 +17,9 @@ #include "./operator.h" #include "./ndarray.h" +#define FCOMP_EX_CPU "FComputeEx" +#define FCOMP_EX_GPU "FComputeEx" + namespace mxnet { using nnvm::NodeAttrs; @@ -61,6 +63,17 @@ using FCompute = std::function& inputs, const std::vector& req, const std::vector& outputs)>; +/*! + * \brief Resiger an NDArray compute function for simple stateless forward only operator + * + * \note Register under "FComputeEx" and "FComputeEx" + * Dispatched only when operators process non-default storage inputs or outputs + */ +using FComputeEx = std::function& inputs, + const std::vector& req, + const std::vector& outputs)>; } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 1b765233947d..e236a9cf313b 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -23,11 +23,11 @@ class Storage { /*! * \brief Pointer to the data. */ - void* dptr; + void* dptr{nullptr}; /*! * \brief Size of the storage. */ - size_t size; + size_t size{0}; /*! * \brief Context information about device and ID. */ diff --git a/mshadow b/mshadow index c037b06ddd81..bbde96541478 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit c037b06ddd810d39322cd056650f8b1f4763dd9d +Subproject commit bbde96541478cd93fe9d617e8d1d955c264bac1d diff --git a/nnvm b/nnvm index b279286304ac..31920d7c0ccc 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit b279286304ac954098d94a2695bca599e832effb +Subproject commit 31920d7c0ccc9239561311cd1e568ea82bbe572b diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index ff5f6cd6be7e..768d9ede2643 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -8,6 +8,7 @@ from . import base from . import contrib from . import ndarray +from . import sparse_ndarray from . import name # use mx.sym as short for symbol from . import symbol as sym @@ -18,6 +19,7 @@ from . import operator # use mx.nd as short for mx.ndarray from . import ndarray as nd +from . import sparse_ndarray as sparse_nd # use mx.rnd as short for mx.random from . import random as rnd from . import random diff --git a/python/mxnet/contrib/autograd.py b/python/mxnet/contrib/autograd.py index 40ab289c8f4c..5f15e8c3f36f 100644 --- a/python/mxnet/contrib/autograd.py +++ b/python/mxnet/contrib/autograd.py @@ -7,6 +7,8 @@ import functools from ..base import _LIB, check_call, string_types from ..base import mx_uint, NDArrayHandle, c_array +# pylint: disable= unused-import +from ..sparse_ndarray import SparseNDArray from ..ndarray import NDArray, zeros_like from ..symbol import _GRAD_REQ_MAP diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 6b9aab2de6f1..b585c23121cd 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -11,6 +11,7 @@ from .base import mx_uint, NDArrayHandle, ExecutorHandle from .base import check_call, c_array, py_str from .ndarray import NDArray +from .sparse_ndarray import SparseNDArray, _STORAGE_TYPE_STR_TO_ID from . import ndarray as nd # those functions are not used here, we just import them to keep backward compatibility @@ -90,7 +91,18 @@ def _get_outputs(self): handles = ctypes.POINTER(NDArrayHandle)() check_call(_LIB.MXExecutorOutputs(self.handle, ctypes.byref(out_size), ctypes.byref(handles))) - return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] + num_output = out_size.value + outputs = [] + for i in range(num_output): + storage_type = ctypes.c_int(0) + check_call(_LIB.MXNDArrayGetStorageType(ctypes.cast(handles[i], NDArrayHandle), + ctypes.byref(storage_type))) + assert(storage_type != _STORAGE_TYPE_STR_TO_ID['undefined']) + output = NDArray(NDArrayHandle(handles[i])) \ + if storage_type.value == _STORAGE_TYPE_STR_TO_ID['default_storage'] \ + else SparseNDArray(NDArrayHandle(handles[i])) + outputs.append(output) + return outputs def forward(self, is_train=False, **kwargs): """Calculate the outputs specified by the bound symbol. diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index ab07421caffd..3384be7947ac 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -48,7 +48,7 @@ def updater_handle(key, lhs_handle, rhs_handle, _): class KVStore(object): """A key-value store for synchronization of values, over multiple devices.""" - def __init__(self, handle): + def __init__(self, handle, name2idx=None): """Initializes a new KVStore. Parameters @@ -58,6 +58,7 @@ def __init__(self, handle): """ assert isinstance(handle, KVStoreHandle) self.handle = handle + self.name2idx = name2idx if name2idx is not None else {} self._updater = None self._updater_func = None @@ -395,7 +396,7 @@ def _send_command_to_servers(self, head, body): check_call(_LIB.MXKVStoreSendCommmandToServers( self.handle, mx_uint(head), c_str(body))) -def create(name='local'): +def create(name='local', name2idx=None): """Creates a new KVStore. For single machine training, there are two commonly used types: @@ -435,4 +436,4 @@ def create(name='local'): handle = KVStoreHandle() check_call(_LIB.MXKVStoreCreate(c_str(name), ctypes.byref(handle))) - return KVStore(handle) + return KVStore(handle, name2idx=name2idx) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 5eddfac47981..b90500d4a9c5 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -37,7 +37,7 @@ 'eval_metric', 'locals']) -def _create_kvstore(kvstore, num_device, arg_params): +def _create_kvstore(kvstore, num_device, arg_params, name2idx=None): """Create kvstore This function select and create a proper kvstore if given the kvstore type. @@ -61,7 +61,7 @@ def _create_kvstore(kvstore, num_device, arg_params): # no need to use kv for single device and single machine kv = None else: - kv = kvs.create(kvstore) + kv = kvs.create(kvstore, name2idx=name2idx) if kvstore is 'local': # automatically select a proper local max_size = max(np.prod(param.shape) for param in @@ -85,25 +85,50 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, if update_on_kvstore: kvstore.pull(idx, param_on_devs, priority=-idx) -def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore): - """Perform update of param_arrays from grad_arrays on kvstore.""" - for index, pair in enumerate(zip(param_arrays, grad_arrays)): +def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, + stype_dict=None, param_names=None): + """Perform update of param_arrays from grad_arrays on kvstore. + If `param_names` is None or kvstore doesn't have a `name2idx` dictionary, + the index of a param is determined by the order it appears in `param_arrays`. """ + stype_dict = {} if stype_dict is None else stype_dict + for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue + index = i + if param_names is not None: + name = param_names[i] + index = index if name not in kvstore.name2idx else kvstore.name2idx[name] + # cast storage type if stype doesn't match + if name in stype_dict: + for i, grad in enumerate(grad_list): + stype = stype_dict[name] + if grad_list[i].storage_type != stype: + grad_list[i] = nd.cast_storage(grad, stype) # push gradient, priority is negative index kvstore.push(index, grad_list, priority=-index) # pull back the weights kvstore.pull(index, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, - kvstore=None): + kvstore=None, stype_dict=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" - for index, pair in enumerate(zip(param_arrays, grad_arrays)): + stype_dict = {} if stype_dict is None else stype_dict + for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue + # cast storage type if stype doesn't match + if param_names is not None and param_names[i] in stype_dict: + for i, grad in enumerate(grad_list): + stype = stype_dict[param_names[i]] + if grad_list[i].storage_type != stype: + grad_list[i] = nd.cast_storage(grad, stype) + index = i if kvstore: + if param_names is not None: + name = param_names + index = index if name not in kvstore.name2idx else kvstore.name2idx[name] # push gradient, priority is negative index kvstore.push(index, grad_list, priority=-index) # pull back the sum gradients, to the same locations. diff --git a/python/mxnet/module/__init__.pyc b/python/mxnet/module/__init__.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e904d474819ff0e717544513b372155fabd105fd GIT binary patch literal 621 zcmYjOyH3L}6uo&gEh!b{4?Jc-jl_%)bm#;KwJc;Yd05n{9f#Uh;k)?(uHytXkxz0S z@jY^)-$mM4@?OJmh4_5Hr9TiFfC4dqCIk_H450`ia6AGLfsCPuJ&!?5KqgQmo=-qb zLC&CYei))xDGzXH=D?%~>r%7l zzJS*f)Xw>p%h>wB3~xWrjU^g8_GMXXcqsOVZ)q2&qs284;)?R$cT+B=n zC~Z_tshQ(@g>|Ux`)rYdaeW5fcXoSKW1lhEkG)1m{}&pq$ZfJf==w&Z9HX+GAY{1C zK;I(G_e!$dR<<=ZuPhf$?!V6Ecgsxf`{b1JmK#b}Xu98d=k%wn}*;UMVXR3lG4A z6E_vMe00y}YaQD^(^*(+^t^TezTn?*O?_*Yd@)i2q(&VmU0@Dy7g86T&X0hNfP0X7 z)_cIlz&L(X;1ftE*89NDfQOKVz#~W_Lj%~%Cm;14;-b`yASw}~JAyK_BBWf2 zW}h`>w@KInEww^>g7!kQ(zYpqcxKwE3P~m@X^;I)wC5=qJ_%X;=Cnqqqlk&E!# zycp(A8!j_R9{OnG{8M$l)Wwb{TD_i7ydiIzRElQo$lP|iKB%pztZ!rM?uC?=JkF5Xw(l{RyhaRuP7xI)k@Nvgyb)P&i->@by@(E4U)t&ra yC0oHsqH|p-o#dLbGO_0p{$3H4*b5{W(~2sL7xkf?Tw8e4-AaZ=d^dLb>HYz5v7RFU literal 0 HcmV?d00001 diff --git a/python/mxnet/module/__pycache__/base_module.cpython-34.pyc b/python/mxnet/module/__pycache__/base_module.cpython-34.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c10d60c44392d072e66511c8a26bebae28e750ee GIT binary patch literal 37653 zcmd^oTZ|mnncnH1`;C_&MUm94B$=LRd)S;C$ssMuqDV@nZBbg064ygB-8J1c!zTMO ztm-Cb#vCuORCc|HH=EtHH!r(cBnXl?2{soXK;V}=BmwpzK=P26RGxxlk@zJK0fGR2 z%J==}QdK=Y!=bD=h|^@(R9Brk=Rg1X|Ns2AbLzj3PmHd2yzl(aIkk)veV$j#d3)Zcmitr{|NB+FUo{5Q@_@?m_kjE!RLg_<_n@i| zsm8Ec9@f8yRDDD>M%D7D{ynVfW2!N(mdEw)5mld1jRR`=0Dh0E>X`Zr@ElataaEm= z^Fyk7KvfUQ`C(PX&%<&)sot+ns_GH`z%e+Yo&w6F?8eQ)(XHw7yM7dW(5iOoLD{PY ztF>kjd4AKo{Lyu9wH0~|xfK`1jn3unpfcw zp5MfGEvIe*;WrEUt@yInY1itt_NKSmX;#{`R?}qHO04cM5j}{58atTl&xZ$aRw1eLGmP&LzA_rWixZ&!zEV$>+U{2yTk1(pUA-Hs z?E&=!|NGU0K^1;gJv^eCmN!6VxPJ9%PHhj86`tfdmUs?V@JzK|-OsBhgZ#PLua0dG zsV99Z+)!;49pcsB%(1>U%Ap0hlEd%YV4<6Z!L74b?gfhl2pLKOQPQ4r*)n+VMbXrRGQRU}?>7H2h`!4B9?fA!t^k$F=r7jCsB| z8}|WR8DV{x0B}DaeTySjX*Jg0o+iYjIE<*i+#&h5_2p89*eD4SN`i#qTlH4OuSaj= z#6G(9qnGYnm}`8e)$uB1k)Y~nhWFUKIiD^Ti!|R8;U*fLV=SDaZP~sZ-Gu ziN4utLp*ev)db^3Hdt~(U(^oc{`;+3GamiauXloLVb}`e;m3a11m8q)KhFxo@t_{)ctAQH54Shh1D?hM z?R#PCaoi73P+BgHR{UnQRz=^VxUUg)BRZbkiODyJ@S`}W{CI96H=Y~J9U18Bdjo$D z=EE=Hy5kyh6TiYWAQ>|qkQKl$pe8U67>WD>T3NNPJ|~w_$c?OFwTJG ziB;9ktNY**xsoHp^@Gv|w$7zkBFzpBiea#YzJ|X0knd3?tgW|O;ch%}63`o-IAnMN z3~|EJbuJ_)356E&kj_fQxs5KW4PU_J3;0Ilwh`sXRK}?!VD%)^8=M@x^|1A*Rt1f& zY?2(Kd;U5)QEfG++MYaXxIJ1ADz(*`D=4*l)2AZuQB=f}g&T$b@I^onp5TEU3tz$k zBsWTAXhM(2zd3Own&-)@H_lXlDF2ylAl)DL7s5vY11 zW3Hk7()e!}4^fTzW_=)Q{j!`H53`+5^y@HDqRyplg z%XkDOp`Yq{psPYXec${HpxYj$#4~sCp#HCG(~rEMs>jr4N7V9|3ZVbKsH(>a9-dmB zkSixt^`yOWK(4%`sxR9s2j$8us`{$Ea!9VcrmCL30;pbB)i?Nac~bR;>P=PslJxzE z0Eup%Qq?a@-=SVlsp_;m4rmLi`V~1pCMf!*0lDtcX;m%CgU40%jH=E`Vf127 zRp;gWq?BAx)kS#_-NA!rRrQ?wd|9eJuc}M(6I}x^7gY74{Crhazox2hNj>P}B~^V} z&ONCFB{hao;n!959aX)oK7*lx?(3nsE(P9I)hp^V4Bs1qtE;Mdjg=7RzynpDZ^omE zDdGRSqe|U^CQdEXtJM7{!rbrq#+)$P_+|4Gs`GjnL{QkNO;g?Jw0+U*c@Djy37uN2 zELzJ0pV!u!5dBe01e56G(787~%=6w#fQzjTlrJ{fUF91Mzj6;&&KpJV8qNqpP?zVz zPLpyE>P2-%YFzOv59F8ES%(0Db_oC>ZlGk)gqZPycBSau4m`is@E|=K{sT50(0Gcx zz>hY0wF5IistaVwg4Vi8rHSP##LnY;LDPF2c+CJE(&go5t+5`2K*J-ck9Quo!Uv+W zNp+1*yI(Tm!x$~`}7*D9Fq*P-9RCg6jGZlt4-7X^>dMja1PHM*#c zp3~m-X02VbmoIuB1dUb*ptahE4nW1E7P%R;X8>v^5_lUe*pRjm^vo*kGZ>KR;lu+) zf&H}iZmo$){-SqN&+jZLiVCpW>a;-sIN|J*pj-5N>iiFyt;fw7D3t;gyRi}gg$^!_ zZBHI2mbDQ?$a+F$A^OM-EFh8o*R}}Zdx;gS73Kt_8R#X@Ni%(13D6fBBBl!5mo9sP zq~uwzLGu70MHojC1e9qpzpEAj7O)T3+qDK@A%-we^*W)t429ZVH&pL#lUN+Dd6!o~ zt!TXlloB5a3r2IJpdD~wBnV}$rMk^G1U)ox2 zfi52t&RP?7LP?=x0C*+Sn$gg7?oFF9S^*8>fr6v^vS<;a<&-aq-+)q!h6SDsb=}n( zh+ohFTxf;J;4cgwc4{H`BGI_*5SBGf>$&|GsoL;{j>h-*PK15xXl{Nn88bPyoveXWJjAe5Tjc1w?FBMsQMT}B5VvA<#5ZaGm zF*^o-W;YUc5n=!vS&5s0-AJKS&6$QB*X+A(n#_*F(WsFQsYWRucWX}z(-J3@5;=6w zZYC@&mOf+>j4QB~Fnp@Pgs?>zv+|&dsaYf$s0~xyYNxJeG)!OGzeQe=c{~En!IIz> zEs-Nb#i!8;2>dgews#{sVNI9-;*Qw0N8+L)n@87dF6X>=G~PFR-jidRIg%k&Kp!1m8aofGWQ{RI~(SlBq)ld3~nn5T$LW>lkY)Pk`H`nR478?C8L$gH?CeLiIvOH zBiGl#HAU~%Jv{HZJD}aipg~Mbk1zQjP1ozl&NbP|axJx(ag5gR(xS$FW52~%hYW)8}Or<;#zVXGu!1E+xLW&x?fD0!*5ffYC?w5dYsvC4c32Wgn_6cw{+k*h286_Uu!Rn#uFN2{3u z$0~8tU^bamqDwbmp7B9@YnLarHYpu2wv>YyJdsv06D8L5I|Bo`((t-ehu>CVw!9B4 zKdQB0mh9ad)+95J?ux68gtO<22oj)yVyMgMx|EyHH~F}OU|(n{>1f)t{xjognMPz(qK#oW)CB8BJwgVnZCLq1~7W$!Ts9ENakD3@tEQ;%IzANzR`*}>!FflB>6aI zz7>hJWhA5ayaPT$K)WTZldDbW1<(O~>-P+d5FW+?o5qymH#d{2JcumZHh?$_vl&1T zMSfTVM3`n641)vUJ-`-h_*5%qr0Vj!oyvoNy7&h=ez`+vr7Mo8kiZ}Z#9yg&!VnhA zR}8?IiU~ahrEW&!M?)R$sPz$}64{CTaR8xc6g0zlw`tbH4*>x#f}{~tYY>j7>kylD z39m}L%4$g%B@JyIdJBPcYPyW*JxtxHM@0xzz3;SJv~&>iLErpGehst((ZzAJ)_~LL z%d53E$BZhpG+~A%#fg{=FhNn>feh1Z2k`}J=A?yZ=^`<|M3qYy7kj$`F#mPXgi=l72?|7Dh2c6yONmHJlVb5UIksPPbEGJc}Rx z<@XA{(YJ9xTm$jFl43-MC-7hE6LG2Y3c@q`3ViGO1deq)h+A;HA~J&$xMC4TDfHcV zPoiuLan@UFYe5(fYKuD_2S+#)I6kDOp~S|ECxw+W=I!zeZW>Kno`_FtvR*!1YPaB2 zDg_%srL?jM^$h{({wS!ghO_8eSNsKe1X4|@^iSFOLpZ41Nao+xk;D&Nw52n+>&HWa zJI344bA(KRLEr_5c$&-QcnMK=)Qlhmg6UgVE^{7u6g25)N>owF`k?5e*$+==xeWYJ zseukn+qhh|({H(4jE8HH#@#KZSqzz?@vtp}sM9&XW1@qi5p{R|ka?(NpZOPfAYykE z0xerFr8?eS^2RM0)C%KU{4h?#aK7LH0WyP!a8gckDmu=1FOngUD}fXX;0Li^qNIuB z7=)7%DjZPvkw=kN(T8%5ObSU&*v>N>1Lx)e;B|XgZTwWVhgl6=_E+!zL-pOf+8$Bc zql^gx?F=veNIeUU#i!s;&mpdDYB-QSx)61{x5Z#EhFGFqRs>G zhjn!uG4Jqxwr%5MLr?K5|}WSJ)*X>B)=}kZaw$C^6aVZ~?gTa?_}R z7rko$^aIfeT>(xlLce$#x`(^Yns0O(EAY=@nv@dI{pd>VG?AR@&gsYsX9|>f&2x$k2&MxhgknK_;tCDj7mNh3c&t4H146;{R z;<|JWSX($%DZ@o=)I$ekrjBecZiOB6fc2PHf{#!ITJN)0#~2lvwD>q9$E8LM_= zpr7D!(nC0+5lGsJ>ohF0TC>zTbxrHkRP>HYU_SjE6-O6mz38Cyfz^s~vGh zTMuWt9_U$g-?bqAP!WmMmhv>1APx{fiXt;=6>GQj=7T2ND>rX zA4K}z2$vvo--`zX74cwGXfqy5aTMoVpg;N<^BW3>LS{6Ci#WswE@OB=5Ur-pNE#5r z2_!iw>r&a{e00^pmbbPzHnCZ5n1ow4{J*ik!Ehr)MwIDkWZfOaX}LM5Uu zC>+q!(ue>_)ieA$i;Xxgk0*DsSUf36$*^5`g>78IVfhfWGNiYZTpi)ttaDV_&`3Fy zRI0U#c-WF@pcL1eQzw=v%QuC-(S&=CgH4foBfbX05M*jo{tv&RZHIYEoHEVAChD@Ah27D+p zV85g?ik8M1u*_Kr6BcK{vYP=hpFvqeAxyy3PGsKw9i@I-sa5DKnEeLSZ)s4Fuc(wH z=3x56f7CXhR{PYaxra;0Va#F58T4;McckE?mD57}HDys*6P?t}59E;CMTkX9^cgL- z>1WW^Tai3eGCVl#Xaq4WT|;twQ)<#aNaoye4lrg~;%wHH_ITE`uOFvD3%#^QF(Dcj zqT~@1N?&o_Y`I35ePp41q0^$H<=rJQ0o!eAV^oR30fAhA#@uKO6*eu<($fY=IDYR;(c;pz(Mb znn6O0R28KTYgr{S3W-xAx!3YrFYeNZ_bS9GKTOG@JnF+Yx{1nml}G(MQfW{zE!s(?RS_jxDy^oa5)8&fDh<&Z_??{kZ6uNIDwT#62Kheh03-@hA%0uk zZjbO9enDfW-1~F>ya_WP{4l5Q;l>HvhqP1{-9`|3U+eS#j6c8cVhV3xIheSw@iH$HS?S6}bbv(o$q>!P-oA`T^pcKaZKLNhG zlyu2VG|ncnjT!HhULy{;GA7@(K+Db`l6?JgS#u=P05<8fqtc5=6)Q0z9+`KYTBbvd zjseR^XWM#eUnrO_OZg8BZz*V(NQcz zFqoUcf(CR;YSd@|*!0CCk%z@ix6##oq1b@5Ct9)6gjqx~7u9PIl0X^S5*+JcGpIH{ z?#_6rd-M2-@C~8Tu4K$#l7nO-(JMj4$EiaEX>T-K zwAGFYDKl#8D#qkv%qlagDJ5J2x;Pxj*0+*zKkun`(S$U5mn6m0*9lEwhu)=qp4p`qa7uoR$K!rR*==4tQ|+xz zBT3;-14Aaa)4;x*(q2rjdMRIG7;i0I@taIa5eBo!g7Y4U8l6VkQox+Xa>T8qt!x!& ziezT15ByDprkq%%7R=Dv*DC-{KK$#n(s}$a*qNq6)E4KjQ5(dg{##1@Kr0N;8!USA zt~Tnpv=2qc?Xi9ZI)q)nk{GO6RK~elmnw_(h7OT)`$HsB-9z z84)UK^YY|QIV1i6QKhjCi9I3Ro2e}%{Vk8DdNHrn>^xSepf8E8J!I(a6tT&HRt2qQ z8)tw~v=75Y&<`iTA*{1{A=j5XkUN+=1V_Ozu5L~4#6H%_RVr9SK}QYNi>+1TiF%-y zhm~km|JUyw{j>M~)7EESuW+RUKX2eWi7%~$=r*+VCs=flqt}~$c2Jpk&gpFz>Vu_E z_JleM171Au&|aaG(&bFgaewX!QvBul@VoL1%E`I~y60I3ibKnWCJQAQE^%nR$g+pJ zi#57%bss_<0}|#yofA~=v3M7Y9kG6wnz`!BJ;9;^)IY4ke@2)5wydyhqj@a&8KzU7 ze*ga<=iBR@=u|#eMO{=dcQ6`$h4^ zi&JD9s}ByV?GvgEFpg76-#)1}^0_A`RCGH3ZXQ?t?W5|!s0#lSj+GoRjMl%?%&A{z!50WBYP^3;hZ$Vc z|3;dA(KQWZ>#^|iZS(;l6fC+xGgx17lyy{rfDr$R$RQuMxVT@p1h`>VhPs6 zujOwiNcU0&k4r&ZIi`+3?bEn=l@)&%5t(td4RS!=Q8nlU!(`$sx_|#Izj0c%0lSHv zysWl0>}VPfp;G-20r=OtAF$}I_3`HG60`wXoe-l7l_0DmsCt_3+kTz4-ILdM>m3F@ z)4LOWf+Vnprlf~IGv2M2{lAOJ=+^Ej1}QaQf1wa@UrQtgy3bV;b8yqPb9nW-Nks(8 zH0P_t>tns`oRT=dVS4AnP--d+DbTKDrQ);ZESJhdaILapo&2utJ^KF>Cyc;kyx13+ zI8LSM8tVKY8>V3`=@7vbq?)+MrzB}E-7@B7KkC4~6sY0YuhydVx({0~8wp5&VC5eI ze)tcOohkWj6vk$Sfrbv5S7DHK6i!uwDR25ph-a=f+ei@FgjF=P$}EGah9AhENdANk z{TF$jr3WPky2FvU*faiT43UlGq}51!P3LtGZhj<%(mI@-EzS!VHpx=v;Fv@nuPm*> z3NM?~B%CF=gxLs{M23@mP1v)#MNg|e&jYf1a(@u}9(G%RNpJT?NHzW75!UTsd>y@f z>(iUOkZB=9EHe8|CBVBde^SSXHhp~eS}=J|xx20klNpJcnD^uEvzKWY%H`?lsm#m^ z`j{=w&J|{0ei|Zn@y?8&gH>?=U6VS|E9_Lz>!edGb_>Now{+8xS&@2WbWmTQoAJc{ z1yA^oT3BzszFCKR+Cv(E*q31IDm9@H`GA>@WU5S=fx^AJL@RrdZKhJd$`xHqk%bCu zsMRDyqfRQ)A~BF9U$kg1wCk^O-D^wfh%nier7po`P!a})7? zJ3x(JTCx&|Z8I*K&KN}*E+dnrOBWG?b-ywlLz8ecxqBR4gE)+7azkTBuk(=bldQp` z+f6JzdPTC_La~n(+_DfYsV0>`p9qL&6S@nBcTXqDQX6CqW7QqJ`zh!o?17AN&WD*i zQIE!EP#1z~`CRU9ippy1b*w6Dd(#lulE;eFz6fh$s`?r1t8krN7ChOtg&cleq=?)w1e1g(^=2Dci6B&_ zWephj##iT(4%1n%TX%0bPq~y34>2=~EHEhp;q&a+hGx)VmLo_)p&550IDLG}V5z+-U>O1yCu4k$uX!hfj zA|pL7eXmQ)x=1C98aZMa*afOFbrEXUh3zCHU~95}-GF4hmRqkOnVv&$#+S}M6UNYk zBim4fr{OgT>Atxo8K{ok<1xo9NKze5Br49`dZ>6JBRAvGgl*#?qxQxluFLTucWQbs z4$Ai=Z$*2hnEs;kW#}N=df^()uREuzcfDJjXHTDsE}dH4I-a^*W5@?994?(IF06*k z*bTqMM^7eYK6QhxpD(;2k?6ReQYjus;D@`?$SB4W_LnTwh!2`2E7)xdvIu(>30e&F z@d4tf9IdjKFO1WH6#a#8pVmSJ8m+Q2gDc$7hPk!4Xl3j5dM{AHH= zH5`BkofVr9wT`+=64IoQNk-cg%!YM7-{1j0DHR{yp;gB(?xL-Ra0aXJ9uGAhket9h zlW=%|Lp+ueraA{oKv+JUaIHzMiiga2Xwu<^J37m0gj`)BikHW!=(kpNY;pM@5iUcl zFN9y?GdFnnkcZcCi2KkG2BlSr`yhG~kFw4#=z}O&r?)diS@q51`Iltxs6#paeIYl5 z)Tfag5~2o|24QIhz?Ydo94ztLPD${fboSe}Ve6ZlRhsZxC>a<8D|$-YT=NV&ty zEB)Ay3il7fU+VR39Y}dg0g7{y10>mM?+X4#3>U!Z37-=@o^YLUSG~E*F0P{&wS!G{ z0*_nPE~*~X3LPLV-heCn$cBA;ST{fn0u8{&{&0Gk+7GNyA7on6f&mNLdc79G1ftq(~S9P0{MY?>GFTM$0KHb zFKDc{lL=nJn@B~Flm|PJg6Q|4BAF>)bwepKAa2qzI-Vk-Ord>u!2dl|_dXn}iNee=;IcSsuj?(O#L!6~Oa0cl7E$;Xx*2#uXQ`$h z75}Fjb+AVj$FMurrI*>F658Wi z7Ql0h8(59zy@Rb>r)LXq)Aw7M>roDcHs)q$lY*{Ccwj-cyi8Y%0DPv(g^bx=S`c|8%lmB&CPL7ET(=x)85`&EQ1>JtpN- zBt)aqZ0X_JO1J{`?r;V}l&75!+=RnzJd-ep>~r9}V`<+8eIo`hqQbgnDQ>MV0+|oo z<;-JwgIAj_3?LpQ6x7Kj(AuTl^mSCJx zEGDA|Uy}AUXE2wzPgi+YOWTa+C*{>j%&}fxlnjvNq-dWK^LYiItA%>fUeb?*VeN80WGtGL@N?`xs1?;-1_y18ZVfiVwRfFdII)%1uTm{l zm=Mel|Bm4+f?)n^ z?VQxE-t>I3KXI`)S6}V9L5UEdH|=r>3vQbdH>P{@&z%4%%`nRD7DSYoMUE@lOm zT@(L9N*lY>8@8Cwa^MJezwVmu+7|<-=<>ZjxXE491k>rM9CJKme~=Qm`;!c-v~4jFAZx*~HHYXU zpYh3-F|Kt>-7P$?NQX=Md4xGV*jAWNH8$D3)FuqhAc4>n$$Jxjqcb?bi-eIdSv$yS z!qR16A080FFtA%ift)(_hi32g+}0pS?jm-O&ck=f#q=dk$Z!pX9O6Qs+4uTEoO{$Y zMPS2r>PWOm-u}$a6xad#D=9N7)s0j-rhDv*?MG)r`Vsyf08c|7{xx3seICZ$fo`Gj z4{*RDAe;Pn0~|wrI39q%tPzEqEJKdK76BL!!%y*m&`rX4@rZPjz3OTp9!nwM5gQst zZE9jNcOvKIGIMIGpkuNTlx6e6WPb-II$adwhk?Yt98Xm;}BAK(Zv32Bn6AA{?H)!vOWbyEJ07Gi5F8o6l9zfbtt*y6or-Uwy zXo3D?oQHqJLn7Kqw{Ie zL3c}a0HI!}GgrbX*X6K{dH%2qc;tGfxy$I9fz7O$l9V*hDVfhbfAzpgFS90_#gyz4 zAm){DBvc0Menk+{cKK? z+o3>qHF=!PWifWO^iIR5m%gC^yOhXmh0VL0aE`Vfz@Lxps7+dlUh9>C)rQzpIDVF! zEb&E_-4ID?hwp+1c2P*GKoS?m+{widvYuYTYeLKOZn25<*$P#@R#e@fb*ilD}hXDK5o(_2{g3X8pdU2nOKSsT6t zj5<8b1l}@F#4wnS2pCfCL&+}iLIwl>oN(cbMcG>?Q~F?6@VGZmB_tB$PojZu~^+VhXf6Xk`5$6Kr3ASX%#nhAg zBUXvx>7+oNsr@dKU6YX&tQd~hq1Eg(9B!eP|JI>tqmqJ{_woE8Y<$(T4_wCA`a{f zRX>?*%oc#i1V!?e-eiwPWmUG#b$z#w2#vt8kQAAol50pMzN2^AyPMji%{~UgFETSj zZ(phdSC{UXt@&oWX-KoVS$+!zJkF==$Ev6)j~#r2JBpB)xd-1JcsH8WA_gFuzT;}W zTe!R934&zMG%Y<4*yaj5o6iPkv;u@ZhVRVYO_uLws=srGY1J3_Ie*^6IYKY9vvX$` z-1+&rbMAcZ!dae6`6XRO&d)C{WNHOQOilCWu}7e-bxvxPCuWzHlA4y5?k1gfQcnwK z&CcBPrCaoI|AyG{^oa~l8*KF!dA)f0U!tOS%}Y7bq?jV?{fR2zTyOdc*0Z56$k49Y z713m5&||CS>=sk+n-?C~HKt|&=^lX$fE-@;*RTo=tNsCM!tEr$Te97$4e5}%ThtzO z)J5kqOCBX|XeWMVz%ci?w)nzB22dG)f%cW&fjfnb=Vfz3ybnGX?lVPGj&=5s?3JZ& zF+%MYyL-|29nYHsWzQ|mpT95{oSp?0duNMt^YdrVpI=TQnDz ziVO2+&z)PEpXG%Nz6j6hx#C$c&DjeV7R&?Zi}Q=~=PoRqKbu@wnmu=ZX~A5;Bj?Uu zn4Mpk6DW3637{6wEufBB3)E6^adGa#xdmzCe8MFJKt9}~Ai3(0tn%P>>j@ACzX9H>ktxHQpm?x-2xO~taLOw-H`!hLp@`aclJIu_~dodkSvu37; zw3oD*s<(P@Hstf#mh7GGtlg6X8Rqntc5k_WY-DV((ZdMI?y!+0Sn8OXGlHN~LHoq^ z%*c+CP4?~H>e%lbFTX&FgEiq(2;HBk>H*Xq`hwYT;d5HFhA`aVkamB?OXNK0*`Va+0h;IG%x5JyBVWLV7f-cJtGdSfZavY_&oBEKtPG)pLtr* zDT#(4iS%<7J1^TWBn44(=DifI`d~NNv-#)K6)KLpp<1^*qubhJU-1_C7>SZrc z8y}=RO(Cw}#O|vZ?^IVY7oDU#)wxuAnQC;pexl!+2)_;Nf27-Yz3D>>VdzbgmF8AR zh!=|Y90fQL?HICKRGS&py=88P@9TMKQI=ZJlVdz=NgsNUBh4XkC%8ULn}scxrt~sx zZAr>Y{4Py#&2^$lQnt`&>hW>anb6dwU6)&XJ-iIhFw%HiFDQAY1t*=?4vIUH-l<=0 zKDCFt0yfeTY^>S2S=wpvDB!~U((L)Av-9Wlh0{yLrKR)p7tWnEC17^VE-m2yBJHpp zt_ts5ac*Ju?BeWsTjOGJ{`}c z>yGaZZyio~h4skn-`n{Qz>L@E{e$B{7>FA(Ane4STxg@0W}DUJbX*u+ipCas4R2_b zRde0Oe@dM1&s>pJJ$ey4%oKw3L3Y;*li#0&GHergl5TO}jrWT%!&r z40xFhQ_@pxz7D)qhF5;9*vKC(sc8Y3pr2k>liQquMz2nn zewp1og!uGW>WqP;66;mL;8kFH0=pXIWB@_!*slj`N~?UWGG4+gCz%#a7K84KjuDw{ zP|>n%`2ao1?ZaobkRy4T&+U4m-+&q5k^QPm^6;f?$+A(BD`1FGBb%5 z8xlMO(nLU5aUnmLlEy~xgILEO_Z*-9Sii&9y_nB30AKT==Tq)*;$gQ$`ue9t;Br$W zzY3j>n_Nsav9(LGr4H+ywp-1hJ@Z5gT9z=?xjX@Ht=?LJTD<8|#1&XCxm;EmA(077 z*SS3i`-L}9qeJ$k%B<^oHC@UOGMSP8yG(>E^?b25H^mL>46P~5M#)O@aDhKpaaf+n zP1EGGXzY-I6K0v(pYa&XEE5&v?s=+wzvAj7_#B0-x-)lG)xPy;g&JioWeIC~abbR{~I<5iO$ zUng@mc;M=NS$N;gjmW!IKX%-BZ0kgd8?!+6_8AP}d7XyEyvHNp*Y^>blIdu(vC^vZ zg{!R!-=)8~2p;42j)*C@b|V92w`4IYmEcxBTAm=1AH4yB0Z0U=SDY zdf7D59L||c?|>#EaUKGS4aL^5+6{}7F>BRps|YG#{|4+a#Qnu>a}o@4wXzyn9mTr3 zx`)&eD}s}m45>wDfhYXwwW9uI-1^6E2!=EEqn74B$^SgNGin&<>i!=7{3IX#1b^(M zzt1~g;``m0icu=Xqa~w!;C#GPdWfB8%$=c9sfrCIO7RfiPY1FM&$FVpc;KE9;cXtu zJaCx9zk)+NY;h_(0ffKH8_AsTSNW5P=CVv9WNL5t*LnC89{!YvzsUnLsU%NGzZin? zQQ7=9{M$T}?DY3}#>pZijfTWhig+dEhz_feYXd&SH~MuP##gX~!o>K%Ahw^Kd}s2t zk=(&=4$8I+eMtTt!FTeN2tj~_Xhn>;yrVtjJ^U2gZlS#}2-9@Yumu z4|?NA#*gyJgKrL^jI>gCA^bD8*x}(H@xX10V-;fXOg8;cdWk_i0Qw2)dhd=kmLVGB zJodz*O((Aj3ctiN<`;)W-W=3B%u9Q{x*HF_rHS%wgp*Z7tPElU3H-|+!d4Ra@?YQh K;qe#7@%i7vTN!o$ literal 0 HcmV?d00001 diff --git a/python/mxnet/module/__pycache__/bucketing_module.cpython-34.pyc b/python/mxnet/module/__pycache__/bucketing_module.cpython-34.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca3b3adb5a1eb0425aca2b81bc61958de6dc9d06 GIT binary patch literal 16988 zcmeHOOLQC8dACBRu1o{Rj;P@XVm5l?rX}QQN3BUKc_b5)P7xU)=llK>dmYDhT3c> z=RJQ;`E~UKH&{m38of{q*R4|>6V&`;cC zH^6f*4zfpUAKqSdy}s{yy+ZUKxV;@{j9T9h`{9)>FF|Fr?fPEgwd$W>x77M4 z6&0}_<#u;B=-zFI{lPG4_r3ifMvt0VOqUR| z&zvp9e-AJLumUvT3LF3{*!GGFrc^Mk_G;=eHWKyn|GP?Uqov^KcWOim5CpVfvo~kC z*}#uK%iG6*Bio;2gEjNgW?uP?tW85|fS*%Y&4Sb{Dt{@fIi>vPl)uc|P1K)O{u$+; zRsK2k1az~c0z7@5BiebQ0shY`|AKlltu~jX{pXZ_Q9Z#N@W?FZ1?9ge5C9D|xSEbA zzoh)nE8kU5utp4onoG+6f^= zT`tJ&s`9VMEoRbE{#9vtQO12u`LD|(ObOGtru<)(+ZUDpW#xZGTD~OP@>S)3jknxP z%<${V|Astu6;N}hc_Uq%pl)xqs?^4A2+|d5cXxx{K*-h)yq+6uZ-?CwMA$oU?}y2* zd(YDcAmJ;$pudA#p`^G~ue6!TZh-kQ2^ZWY>}%oJ`5D-~ zXlEzTfcHHqNMFmFsC>t8sS5zTFiyzLwkjXwrkR{zpwroZ(CQ9`E6q+v8r+Zc-F7eX z{I+zzo&-fP`lzza2gZ{PuuYG+R~-fbIG8 z(EE|w4-&$51(=Y8>*>Ii5zJt5;w6C@20N2&M|qq(YyniGUmWg-JrCO+33f%|3Tjue z`-$6+k}>&75$T2?L@$Z78ICk^yB{Zk=M#Ac0v2*L3B6u80L@6KA;IFH|He8E5}-h!M=@b9RH;J%951B0Vng?MzFhqzJd?wpEWRY_I- z79LdCu05q5N{9Be6e{G~pH$VOS@md6J*umRkeJd@@2RB5TQulb)sA9=hxP2~PvvR7 z^b~y|YE9oim!~tOr!(pybftOvck&c+JcnFU2^!4G=$JXyFr^$@g(0UX)Js(u2f)T! zi6XXD7>P>-H~>Wf^Tx7Kz!b4;6y{Z^q|o4;SOa-)3J~q44dgln*z9^*Gdz|~5z8Ct zw9vt=bXu^K&X8XivpF2F~`4ao3H6uCGI(O7AW4mX-Dp3!I@}ueli}5^Vc7Gp8ku=@{{cAvs#zDFSVPJ#v^6-;OPJtgeB&?S zqN*1v%g&rLS2^pvWd16r9r>%ER-TyO(fPMCV8&W$WytK(&f$l@G>6{7H(tSo{0CZ6 z8Fh*NXANlDu%8$*cbe7FtG8i3={^+RFbR8M0wHg86y<^rL>N$h_IFyh)|*plZ3}mS zuPE&QNTL( zFvw{sI21Q`SXGt#DKazVh@&1=c!jvRgYCG}tc+e?%Z05NfTGSsVCvxGVW^D-niLj6 zYrQ$Gn;6op^NTE<=4Ba|W>r?4I&Ccv##b(N-$OIA;_94JuUvFSpPO89L3c+l_y$0n zWWm5Lwgf^7BEv5G>QT$4IUfxZ7{f7j@%W-^#iG-OX=Qk$Sbf@{WN2Fub}#hH+?-c= znb?h0)cpY8@!fb~ayQCj<@-VE=2r-SRA=@hUZl#QLILujb*jLzGunC&%1dw>?|Op( z_8HY5Ki@a@Q+{tvwEVubC~!`Kt9vjy$jeSE??xAnSE*6Kkxco8iB?Lg(=({25vVR_TaPyut1HW z<%Zk&c9OuXUF}D(WoW4M1F)FPE?%W27?SYvW-!FrXe%1Rsf49jfYKhc7i_ncklrKO zb>W9v3m3K>eLfg>uv;w%@16&nX{#4B3-lkxTEocsr=smDkhPnmZ4PKVM6+cmp9yax zpGY&L{|m=jzKMs|zhJrTsT=Sb77ydDM7J*Z@PREi;U-A%!9(~# zd^EUw?!Gq|kn)VafI=*W%y)LBvwm|;Yww^_BDY&jBSh0>=@ExcRZTm1EQ;1q`bVs5;0Gv#;5DHQgfkvHPTp*aYKS)e z1f)vELG%gK(1xjO3NkceP~Iz5rC^|jJN;s##WMk(;HSbVGNLDQyD8T{9Cy80m3!Y; zawHWC<}n7Wt5JBws871A`HDAWcL_U8+^dhxA_#IfEd!*&ccjT zEFDs}8wgoHb%HXEw1x7xQp~^&?jhU(4dC@OJkx{H zFS??f3$=7Xg0mSVrS$@7qxV0|8RsY*2+i*2!?Yjt%E(g zcYDBUPgCx1G-u2l?K<`As7z#y5*mJh{Xx(FYi`fLYj81>OpPoY$BRtuIP46j`e z-*9IaF~ETiBMnz_jN_7BM3*|HS=pE**`lD>&oC2ohT}?YEgB;JMa<_}@^_65a?xny zEF_rQDfYA@9Mmb6)`9&B(P_tY`@kAG4bU*B&LA z$HT0`)>(L%lUN_EQVqED*r6W)tE4xGcq)ZlGoYsukno;NKvPkoGIt5b6z+dJboyaKuWW+r%Qpv4}W1|7z>DgD)|`GImRd@ zE5m?sz_Jbk2;Rr&Py5r*nY*o{eP+Lq-ntidp`&toFGOozb%af_ALyN+EqSvna`qI2 z(YfhtK-C4 z@vSp?4VfwPW%N$x$MA(- zhfg|8u6+b}btrS{SRvD15|*Xt}^`)Ji={k5WAiq}38n2;H=DaP%} zR}+{7tBLmKu0&yHrKgWrOBR$Yl0Eaoj*cD z(t$95x?~uB_?t5i9ph&*kFoxd5^zMjRlEQZiaTH+V_8GC-`%wEqo+dhlj{Lqu0N)1gEycVf z1&USm69j1`<1X!`L#}mRTll0Lexd#NyhMj1dFHb@gaoJ!+_Ao>`@bxFQ!LnHP$G$>8jvMlv{i z$j*r6wukJDfwPEH6w{9u)WZd}w`BcYxXlN>97F9Qxa0-BWwxX@a8G+_Yg>Yo6;C6nT+Heae8Av@Q^SDeoW?b zAmV3BvI!73)bADmq+L#lnZOmORQ5LEQc6bC3^Y<1p|}XKn0OGGnZR|7C_0%f7KxUF zHDNt(8_5P({(fvNm>7l(JS8(JScG}|Pp3jlF;l3{jIURavx4Ru!f`A-CbE(ks#Z%C z6Q{4wg_>N@G3zdPs~qtAKN}P%ZYNnB^oW3fwd%?$u9&eos?cWmwY{T}27VTRLslV% zhqJE(u?kPz8=*n?h&C05!6zb!InIfqv4vz_hv3b~J?LN?{)57xF-H*jDnGTI_k!T{=Ufq$#EyL6k+iIi0gH`Ls&)-OHZu z;|Cos_lLdSx@6--3%X_2%yMqAf`F1D9mnBwS9$#3uF5P^P!Sb>X}R z7E}UfxQK_U%<`Xgtb_Y)`sAlOtVC$>06;c%?6B~My)T^}$(K5MY0ZXlgwJ6$F)l#a zM1U*J_0hF)BH)rH$bhsUnE>wuV#LE|ONgHK?yp=AD0Z(T$DrlZH zE(qhM`b)e>Hm^<)Gd2_<(O3bPGg!Jb$uULZOEZJ1pXY@!2Yr^8bGW2aT&l?wG9zjd zqzq6rPZ>w2#abFp`Zh1$;e|<2eT$d(cqtsImw79R#a}@wt(kz|&Rd`30^Cr)VyyJO?x#kA^rPK1e#Hh%biT!VJ{$0sW0T!2e3g36|jsF=J zqm$yZ%(@vR%-ma+e!DxP^xsMwiG4s3Q@2j3N7L#N4pl(^;vZU7R5n2&^dV{*$LNf9 z-os%EqjsUMjoRf2#py@TyfX^=8i8agR`i`$(8#s`jf{@yV$`u@x1pUcJf5mJ=GX|Tj_QK4*jRO{S=+t@d4``l>LZV7wBEX(Q%nmvk1tPvv7vYSZ0R9Az zC$vgu8t6>teDUeD*k|F}4Ja%!x+HJGp9nm_ur`niad=Ywcrbu>N7>WqNLM*}*|G;T z+pH?b3QnvEpAKHYmEn&r9STJldqD(?*fGb72;r<-!dmOirBV<=&Min5UpSev;3#(1 zWG%zZ1-CNqZ;m_6nmkRtK%Jy`4SSFCuI4d-3C4xQDjPm65MKNOnyJQFWOSIlgVWn<~85xAgO$nfR9NF)8ovMWK%qpgklsT&o-@UVI8V9S!4g zz$I_fC}ZLUlRuQJx;Yj|wx;FSnM|B~6{q1oY;Y4o5{P37BeFaijcrT5bqu}Yp$Uk= zJS#nb7m_EKgoeu2;%KF4Sv1Q}MiypFYhBV08_jtmNpiP3J)1)qOYw?W`8GJrq&QAV zG%{~0wn17G1jXsn7JMk=`|`(#oH7BD&R8N%=kiDCQeMmwE6j_eRoDj;GDhFQgYQtX zoIyb?pNCOEv*0XJBj+Gtp0A8vnlu5%r~1F_x&%(wY^K~7ie&DjYQ>-Vu>k@{()tO?VF>`oLHf)p_IEauF4LYb^|-#iU# zcqha$yOQ->(ir1ut#k&NM6Nm{_GpEv+4C^m%-`wC==n*!+4f_xx*TDg*s%x$G7ORo zhJYTyQ-PH|1Rn$|z*v^w$GO3Ox@Gns0f2p+A~U|xNe!Q8;3ABTN&al9QXKR~JUqnP zA`%+!WB?@en|KH`BS~dVHqcLQ@RyB_bT;d?-kh5x1#YFGhfS7#%!@=`XDm5pC*gj= zCp_>q$&@&hsYXOAtq(Ov%e?L!MyLgsb0Bo!34^a-AmHsW&R~lkt&YyVjRY-OW(;MJ zEqEb@4pV$w&~SFkcFbbR9ZY6&T^{Ow%5{kYfz0dSbE_GcE}F@Qzg(Dzm2aY8sq~^0 zaL&XODk{EY507(EFx^3`pxf*noNR%={euArs+^so=N@T^kwL+Q#)v==HZdStv?19nAkF@39QkiV5QzOCvGKBA>iS z_7LJU7Q)k*%?G3WtYEfF&gk=#N^1hzn6+9(cd-LLQ)YSeU+>bHqV|_plCyw)fMbSI z&(q;}K~slazJi?TIU78~>pe-P*MBY@$HCfmnCmX5^uCOlD`*y zjibY4hr(mW*=q~NIIt7R^1Sn<$>B{hz8xA*8gT)T@HaUQVVoY0IgP;{v*ggvVDMS* zXsm9);Ihd!i-GJ!_m%|paPElizGQX|om-ceOpY;~Ud9yo+O{!5t2nNI$Z#^_=CYn- zmy;e(7#Ufgzg+U*i&7)UoF6lgby~bz+>5^?VXMz0W@S#Ne*q_ROrUB%LNo+kDu|77 zt1PJ+yE)Q?%On`Nfnn1{Yn>uXCOtE>!Lje-p!H#cdlz;8$hC@@SXqQmf0*MYfc~s9 z8jlR>Hvt8ki6EeAUgCR#Jp7FJ1nX^^u0y%eZht)VdiKdoyX{9^RL}5jMBpi{c9RFV z42JTGo8(#2S@_0y{V+K&ISju4Q*w9tR)LnJ4D%zJE`vE2_7+QPyh!*}+%1MN^hdnV zchVj&w2^g}7m6KCx-~B!aaRhe#>!q{#R@NTxG;blzlBTVOwFkwyn6nH#?0~@{_4w( z<$7ab`7%PU7Z=YgpQ^#`Uc9mRLd~fZi3HRt4X05- LJh*ZWQQ`jq)JF^1 literal 0 HcmV?d00001 diff --git a/python/mxnet/module/__pycache__/executor_group.cpython-34.pyc b/python/mxnet/module/__pycache__/executor_group.cpython-34.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dc95be1b9a95fefa4f606a3e50d01b3dbe9cf2b GIT binary patch literal 25082 zcmd^oTW}o5d2Y|_VgW1`7Xk!G>Ou`EQ{+k@Xeo~6NR(nqkYptmEQ^fkKw=hyodK}q zVt3IqO9C6xxlopz*skLwCy8^ZN~&_JJS2HKmn4<9Q>Q8~sXQbpXR7j)CsnF)D?gla z)j2ue_jk|CE-o}^(@LrWYnweYeeeJN|K0w_%6Pfc^?vjpR+ajD)%RCK{*(Ac{}V}| z)G|^>1+Hp3Y8e^)S5Qq?Z5GsWK{@tYQ7sRtps0c&)hwyaVYNJhZ+I@r^Qc-bo9AKG z98;SWwOldJBdR&BHYe2b1fEA#P*(4v&q)=Gsh}eHDHV*XU_$aoRDe%YlAl(u1$dt3 z1JZCt-NlfOa;TT9Grd<|*$(U7xUIc)-R^e0M&$Wkz1_MQwi;n8_TqNC>8;`Orr+|{ z8?AMGHYM9@UTEt^3m;$=)Kb++#y+EiP=~={qaJ_Y;w9b@S8<9r@{6)M+;2TBFtJ#K#|zDXBZd>UK#Q7&Z+w)KxUH{Vy!JcF40+Q?TVKg9BPor1Hq^X*IDb=*7XdR(T-YFMYQrR$bGbJ|w{ zt(U9Ay-$9+?FRsdW+RF{K!a>-zcpah%m}pr0dCOP)olS<$Zx~3jqc&DzmCJf^8PY8$E-z^(Yg5Mz$?zFQ|IO|Lm z^f{Ev(2HosMG*8lhK#|j>p-Cty-2+-SKZ#Dvf*4Kv&k}HKmd0mT3D*O(z1R8jZ_QR zP(QA3=tuFSPcq?*O>!2wq- zwMWz)fxGK6t3as}fJ4mc=UlZ#qT@(|RW}*;BQOi#l8jTZ7;oK$k0zx!Y(;HNUSo1m z4s+isUnH%lw>LX4zQnFMj!~y~uC~clsmc7)Mi(wN+jYMgy@*U&sXC-9sG{ePBqhJo z0kzU(P$WjJx*u1I$*^5p0s7^l1|`v_Sc&t*44z?LFyQEZ-F_UOqiH0nSOOAPfVM{q zlkVfCSw}yMw{?C>|AvwIB!1CJtZWDx1F%5q5`G|dfYXE!F7POMFb6=)yF(&`lvFTG znP3^25e1oo`C%zsPQ?nB@<}dMK<>!I3Y1Buie<@Qo2lIM8pR&}0?72G zeFU3b2+}9}1mR6cDUpHmD;toTM0DB+o1IX5^`;+1UN-{212=4J0;x7BP6_es*da7} ze%xrc7Q9#d`i8gJXf@6Y&Gc?HnoYr<)zA~c$n#r)r@O>3j41T7&ZQKmv>C>sj^y!t z_9u_g&gLp)xaYjfl96}B-79DZn6QZ^O_4^O&mapjDlCTI*$~kW>(wXfNdp$~A~Z|n z$Qr#QkCX<|5@vyJwe<}$tli~bU-Me+*jrgy0_|8?nMbv56EZ^R<#I)m8{H6b`7xMM6=eCh??6#ukTx8La4BQLPn*5GPUL zPy^0erf$e?r5Qut@;R9GK~JDIhCl(!~Bc7Zx z8*E1nl1&mBBOY5K4D^EHX+&g9g1A7+xQEzhJuil9J|rHXQahtS`ACI$H(K6mw|*mx z(Rttc=k%%(_idS_pk0p zpl+}P?4(dePDR>l*51@;w%18upYuK~nQWAHxydJ2SO!b6V)3Q1jg9L=Wy4jt!$94v z0q&;Zb6g1Yy2g55&?R=2_h>$`Fcx6`vmvdzE^MI0iY@1q2hGZah%D?=Z!mo&*XhVOE6gCs+CN(kC zP>vU&zUES~q+!N9!YU(bPz}}!QGIH#UlgPwtS-Zn8KGf>O~JrYmJGyCQwye&F?myA zdL6Q-Qa47F{#mLNtWr@e3}9UHlET1Z%lu(lW?ae)tLV=tD#8wfDgi|1BIQS#c%Xv7 zvP9z!S|3&VUnxl2&qL<(kW_*&Jy2;v#bUyVC&k9}n@+=wYD#T?+2Mex9%^+5rkzQd z0NO7p?K}3dB9EVSFsv4IKQl>KeJ0h#t+On2dqPTlMH<7%ZaXN1qNu>;v!c77ueQb< zJG3Jzo=%Nm-*n;`Yd%883=ort#VPt5CqAlv+EMYWis$Sn8k%4|Gg~pGTK^-R?yOAr zH{5)>cWjmYru(KlFx^kOX1Y*N`sOglLUy`8?(zesd&5PcJ=2|I0|QmBny2f>?BtwQ zMa57{MmTs(3WhSfwg20tk?jPl*G{Mwlm~n=dy2iXQ-vyH_VJ-K^Kc4FkECy)&6qML z(+m_E^Ttav&}qz@(`g2pjd}BtGz0y{ym=zcK+7?2&ZQaXI_AxjGVLLC{izhc=2LS4 zP^M_j1=r`ZnVeU4pVK;ah6~cr1wuc-o_0i9u@&^Fu_a@-E0X5acWd;QK>#8eOsqzK zmKr@xN=U^T0$mvS7VM#->3>C~%DvH#T|JJtjC#Mw3)JoRT$ob~r1_e=O0 z{ay7weju-3sKHDC!*XpuZT>?Tyndd8mXQw*|Lghi2Zll7DQkEF8M}s)xmCm(LaZ$z zDN=&9>vUjOznXWYe*pR3GZ~-S?80;lmO{ggVYY^8%^TfCcneAX!v15V8N`mU!Z8Bd z`v;irsM*EZe#=1biQUD4E;8fLg}jS17L>qBh|R)BPJodBD(?3=T00S|8C8cf$8fBQ2Kj$~Qq4_**g>QWDu@W>0CuOtRQKJ_oC1dhW zhB(S7nW$L_MMfPZWBv6iMoYD=RLn}JPi7kSX-r0+cE>ff`6I@W)c@IZVN~$Mc442& z(z_^w$(ZTfJXGX^futxIH=okSifpYl_9YdQvD-djpVG2pur}1vqNCL8Wa8B+eTgeh z7|^dV`4khM$tshVn7q#9hma%{yUL`~w?~{A zx8hdN<{aAI=YOUz`(+xnr<~(%@ATq9np#G^GmHB$#!cSvGJerFkQmki{D(Yvqo|_e zd!ewSnI0k)kx!lqVDhN;(tx`!fJS86s4Eu^=COd@i}~sIl{jo5?7H!A)y~aFvQ1pk>6*SoMIw%2BnW=d|+aKg5R$-i&>WWY-)3dcC!!zB+GN_!Wzl8VG%qz#J; z1jE>`bh)p2h|}R7?1Q0@J%~~f|iZ4lJ<-C%qYjHL$MX8 zj3U;g)!-E=QvU9peVN7^SQFsX4t2|Krmpuj-QLVhP|={!mPu4Z3yWP(r+!Jd#x0`ra_9R86(o#l9_jO0UcdM|ueL&+8wc#|(Kz0P)VPt%e7zJcs?tIR4B zPWEn+i>jooBnlCo@wT0rT&V!Cofy~=M{5*JCk{a1#AU{VJMc}m$qq(U@4@YWW0u_! zC^%@v^RAtxY|S{73?To5!ns*m1L8WufVHge!C?M>vPE_t;3B)a%f1f_;Q1hO^dDGl z;-`hIXLG6oqGw=}My=9*U~8!20|MKmOd4Wz01LfF`>Z#r(c&>Vj}Rc$#pZYbt}-cTM^7 zGJerRNG!lkSyC9}bP#hR0q(b89B5w|$!DBxA1h~+=*7INkCA$V-tNMRU96X1?}Zpy z%hkN)$NK2GREMyPj=&QJX?sET5h0OaxkFGy!|j1RJ(v!avuQq9I>eD2hT`Y_=Ok!g znLl(U3D3qM5MnA@w+BgN4q%dlq$Tg#sU9(A1;i;^Q^7UnR++rbMA(7s#XEcyqXRKs zEDw4GxnxNDEqL~5_0z91u?ujCd0HO!njI#0NnihnQqhx0Q~}yQtP1>1!oDzyf9J&R zaH25k&Vdw8x>NGDS3R^q^=h4%!M_pwhVc_SK?gekJ9dP~IFZyV$XP1^i~&#^MEQnx z4|WB7;Tve@5QSk^W}Gfg;DjNw6a3kHXY+6~o%!X&49q;Jft z4s9G+p{cCUw3tfJAA-vj+yo349quJBF@hlvZ7{J*4eikt4)q05hr0rqXczddtrf;f zx4$I=Z|^Te07GRg4=f@BK^#lTBz?iRYTZr^%028KSL&oi{G(Vz%}BV+VXXMZ>TDUS zAh1X&gQCjE`>4pX2O~sgTHH^~`5q7P6ZsIUM}?vr>xCF0wy+{PWU|I&hlvo9*N{Ur zvxy}(p-K8j`EJ;*J@5}@KgL%RHJKm66~uhnXDAiw6R;rdo7E9f)B3(Q`RNW5F7Wbr zKs&X8ky>MUq90~z4P*Jw*Z&5p%y4o-xXK_-g=v^aXWd7ffl6t+2Ujzam)I$)P;r^#BqCzY%kHj%zxFINTc)vmWfi-#J zJq|@!42I{}AT?+ja2kbRO`*lnF^F2~?Yw27Jj2xoX+7*;+ufGMLDh{*AXr1dz!G47 z*Tk=carf(|15(*R2=i(VjlJ(bcq1&3?T0P`FIN!q>ti^EQxY%@WD_P}e7wXL`S)Fj zOMrWx<0vDeMov0K7}fw7r@rlYGO#G3B7XuF=aAVxWl8+vwK*G54xlhQ3oEj-RmK4g ziP!f}p=-Et5EO}EqJ?k)(NmqBzL*#ZCxY>kw%YG}>#1^ykA?wwwBHQajW5Y7xRK!f zqWzeT752MgPjMhpF^%d$U{R+7Qkusd(WCR-G>!Vk@8;d~b;DmpWc!{zcx-{HS~A7e zvE9{ZNrUH*+iR3cGYsUjA!+JByGTV+M#PnE<8^L4w@#I(oZd$c-txRkw(BL-%69!5 zc-UUIM+&&tg7)IUcr$lj;D7+O-54Zo$O;Gm-3-F@AEs0SA3svGP=GX1qO1dJsC~wR zvxRd4c~9_QpeK0#8KbhRCo$H(3lufB*8_TjY%$Q2d6u4j5>QHO=8(UIb@E;&WCkG0 zp0Qyv_9z8O0hb)gRwO`cCEo}Hk!2wFW(rJsEbTtNj=V8TkO-L;AhV_@j9fuUgwW+6 zYHSM!KxpbCxkpPFb$VVVA2j2Dca1k2#d#Asxz=q#V1fiC@oMmHHW`;@9I=AGwoe|v zV$fEKHjEq3S1KQO9g8l1skA~eDkq3YbRWufNyy3cx`-nAIcE#|4jNr`mO%TaP$enU z<836JZfb9FZtcQ}Z;f8$DW&a>ela9Uu*Z_>-9w0>r{3UOA6ZV@cu*tPn3K`~;Y@}5 zB46yH;P2txMn3ALYS}nrUSipIkid=J-83%qk{|ITo2Kp|i_@nhl?`Rq@1tA=Qsp60KpVyF_;SQE4a86Gz@^^Hdp?iI8Gk;5{o5fG1zQHe7IJ=|4} zvxP&H!l(-VpJIBKc>vT~Aq_BL8o={@=Vt2I6zu0w2zVK9`8wc{2nkgjJQG3eJsdH} zdjv5cm41om)-nWSL!ZHcjx;Qr!LRMr>*9V%k9E-h88tQ=O&?MG^vSFNJ9A`fL?(^{ z?GI=XZ2Ag#?*{5*5v8j>!i`(-4MKV1$pp5YhKk<`f!k1;1_NHb4!nnb-LcvJSX((- zfruUkkCWBkWY~tc!Wfr@{uq;$EY4NJv9lF;D~!9kwTnQVKi@~cl96f`n=x7W)x2R${M~n(QnM}6)DJbj(klkYM z#&57lMKQI1|!06AFV>R=UCD2oq5tsQ^lRQtI?(9mEw; zJ&%^h5fA}nu_|3sxDoCwk78x4LV1dn2T6N!WX}k(*7oQfAZZ3V-2x#Cn%$#L@2UHk z=@k>(Kw`^3!Wi=f{32=)*h>NdNbGTO^a@M?06>&^u2a-FFdnd8f)(sG#vRTGWJl>A z5!fqAP5FlRgdR#$MgQyQg)m|vf`SbJq@!_Mt&N!Wlxp^!T*(O)PGiX_Y=(e?1~F1U zixQuQ*@kq$ds~X&KujyBp2xuf+f#lBMGi6CY9RdH)?k1{3ZXgGI2?7}=V0c3WVD~& zUD4feMUGQ&=5^3U63~{2m***05Lg@0CZ|VAC!{eQU`F*wMYF>J#)FKr0V17d$7EBM z<6ZJrq*qH@E$L<)2#(j@3I`gsr(jmpb3pEmkqhEyi`Cu;!RWQX=b0fHt+HMGx>YTK66dlykwPTvtn7Jp*iG<8#KVA zGX*!_lib`i&Xmd@?8y85K03d680xogkQAtXM(hvev&g%3 zdOz?TcMEv}tlZiepJk5h1h$+@_$Qg!2d46i8Volgx3Df-b6R^W=NHk}xHckv93zLA zbY`3qLLz%-4mIF9xhMblFMlGE(;*4^UWsNXnS()q9U*AY$mh6z!M};3B<&`K*ir*P zVoHTR1k!3z?!Jku8FGdwMD6>bt+v~Lf9jge;6E4P>bYCpc@1*Tnv=D1CNS68>aO}RrHS36SEo- z)U-XKqIjUa?PmmcX*|Or7$^-VDa(m63_%Wp9nA1j^@}bXcVEVhLSL{7amn{F{mVdg zqpER9%?!cLS~#bd9oOJt0`Cc1tZt46W)fQ;pFNzcWw$Ipe7Vb@m8*b6;HcfzutC{M z#xTa+-^76mz=5#ix8Ih}!c)Os^9L~4442rLDVMTaDVH+boPV0tc+MW(AE=*7MP7u1 zoY2Dag`+4V?PYd1_Q$`a-u(TnrU>=Dr!V2Yln|Dl12!Q&QIqe24J83b5+tAvUob(8 zMkz-s$q;ePkPKX(VWG#F>_!n9i%r@8MHHb8PK{OIhoblCICPcqdm>9-eSir;X)wsb z<2FVOjwY>-9Po{MeHb|}`Yuy|K5o=7&oBo1zh*fTD?iycjpjpYZ4qQ51=(YULQOhVk7aqbnFva4=}ZKJ4O z#!pz&C|b1GhC{W&46M+HJnn>oh2TjsIpM%0465HC7LqRs3Nob|mvZ9bwLcY4P~UMD zTB?TB_SY;UV`+GN$c=7GQFE&uqTLH(j=~X^VKF};!UOCH?_rlJ(*1*&4`z6EUEvPw zQ*a{T&^7L+Llg=cMWBf*cU55+&em4?9oOX{_++2&v}7T^BoND&*nwXK+v5d2BmSRb zJ>G(8mp0JUwoCe*O1R#o{8Qg06DR_PwnBXy%JcrT@0fNN5K z!68UiKs()5GmLUpx3Z6+bU^>SZ}J7LQ7+n`ztp;Is8a7@woqX^^!WJs)gfeBuUYa+ zvo{H&wm|)I2Syis!p#4McP@v@2rZB;aKTuvwBn+?M2eTB} zVzM^+%le{KgE}r34)^r5an(__vRS3xf)3A1gL1zFwjevrXhcWs4ej{pW}9Qh=G+Rg zL@v!sMzhtzVY%1&TDYi0tBJu-?1xwRff}B@I0hvbA+c?N?$Ve)^3YLFUuw2a0 zmK6Lt#ZwBNz)E9Oy!NF*QUI|Su_o(+!yp2J_~sfqoLmD*z<0!?{deA*31Wipz+1it zNdkpH29=OuFkFH0Zy*A4cQ zNXLXFdAw{BJ}48xCQfQ`UxRGnnmhmBurQa>%QI+)18%@AL+KRvNDg|IRze6}xQY#z z1l3mv=Y1q6gJ#@XNs}qU2w*!)5kU-c{q|7O8;N&`Kw;p}w%2gL%E(5u#%%rjrW}0( z1?_kR!+t|Xoi%LGoOb|mGUom9;jv5c{`b0=Q2BU2@aT z5{Sqys9=BbZ9j;(h63MQnPFtdK zR0zv)1kWCUM)0t6uk(8YZDp~ubo_8FL`0bdA_1p~`jAF&^&UhKgrfq>@gA||HrQ3v z;PjX*hyuS1i|8Tp2?>C4dO!sSeS!1G`{Ej5wnhmu{tD8lVHh z`b6iUmt?PnUf6L~L0aem!{A58Wq}gt=$Wfvp)~DrDd(nS!X*JzoP3WV12%{mU7O8g zI7lZQ^$X?>P2o@wdnHQYBFZEGQT(&t@dn(N$$&0#)jdy8V`p&tT|C&#IT)eY=kv=N zssp0r?|V6N{X6L5g?z{++T%^UOhy@`At)>7bZ|3*2?`br{SlUM${!T)ubl7w;Q#V4 z?znMNzQ(sKVq#LcQuCI%ebz5jA|4 zTV>K@@_S5#=*Vy?Mh7vT&ScCyaS~B7l7iK4x#J6H;sU3C5*bx|OfCSXsTybbkv$BiY3@C^c}yU?frEu*+6r0!`LrLm3bRl*+` zi7N9ZTihstz@-cKWJm*H;C0lOKO!T40&X;o(Ub}}!I|r8IL0dMhA~V6TzvtsJzQJ| zTN%_qTyomr9rsqaOR*-(fFRnAPGt4vhJE3=cwE90||P0miwPR>m}T6w&3tTHXNsdMlPC#ev$AbN?8dPmP|ornnH^I6)KNbhVZ}R{~vS)iB$jq literal 0 HcmV?d00001 diff --git a/python/mxnet/module/__pycache__/module.cpython-34.pyc b/python/mxnet/module/__pycache__/module.cpython-34.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2fc2f5da525e4b4f32251be80e59ef73759a67e GIT binary patch literal 24713 zcmeHvYiwNEec!n=!{KYlAt_OhUG46bb|rCaYNTD;aV%@K)?+2h_HwsNZPuo}ow=NQ zNe(#=?!9*;a5X5j2B(4JlK@ubg(0(Wi6#Y=Nz(v0V`A`@r ziU9pkAVq=xe*bgsV@AW3w6;;C$dGg9-1C0@-~V&}M`>niw`YIpa!0AZQJGHx@z3BJ z{8t37Qfmk;<>pk&Qfr9lcuuY5l$%#>LACO#T~KRd%Hs2wJQvm4xPC6m^MqQP#Phgv zCsb=nwWrnEw9cPYt&(cbsI?i@o>gnJs$EuVW!0WjYje89l$4lPYxDYfTD2Bb`;b~Y zq}qqo+F{*pNwpSL`-oaQqS{B*+ELX$rq+%r>!v%S+*$PivzS%h5p_SW){b*JtCe#9 z&K0}_N@3y?Q*9q-#%d_iOoxqE;uDo%}Uf;Lx`p)idv$JVCc71ifz20rr zZKvbf-HvB>{i$}>N4;0~yhbnV`gMDw>)WmqI`*#PJFS-2YWAJ5+3i%PevFk>D?iR5 ztXR=_quU9+y^zt4>-fI2kLYA~H*B_>eb4`~rOO?^j>g_FZN=T(^t-*?N+BxO8ouX* zUTx=2fZo06P_5Z%hD`^(rirt)-Y$mZ)i|Vf5FHN;+%vE?;(cDv0^ zSjk1>Z*`iORy&%wg~qRYL8Fq7rqZQ~7Ib;^z1s_G9jEOD(HwgVyc*Y{=Aur6(~mz> zCK4RRa(FmyI9~{HxNzKXo^B}jrt%8P8&lgwbsyVUF#lguY6FGGlv`BqkMqEJ6G|U3 zcTzpD)P1xtDJfIRoo31!A|-_rf#ZM!=pjC<+%g|oct%p@lsg}%;K-oNf^rYVDP>7H ztlY&oWlmC#DEDZbGA}8|lzTi*K?_eQ_h~-j*m|gbf<3x=E&vPsKcn1d)q^p$wkVI! zDc4r+N%a7rhQ?76DY_)$r-M%_m-6Zh;l!w9sodb z>X(&!PR4-2ppmL_&r6l3ACr<875e?Saz7zi&&uP^Dfg4|h%Eg7 zymCJ!j~K9?(M9EcT1I|SxtEmtvNVM;Usmob635^_queVJe?hrdm3vL%n8kJFzAEuE z3R}Ed`CL@gfarWbuau73%>Yoi?gY(--4*}_)H(Z-O%oJT+f$MrYE3->Xg(bG`79a^8>%_dlLNZ$Z8*JFh+%c@`qEjJKU>kr#IeKu z-DbmS=>gikz;o?p2bjy*^z25<34-PZ@->oSV=ZUhYiTfiyzYYTM%Z@tYEtpB`b`Et zvj6VpCf4qv{RN36Wd|nNLLRVcy(-xT8KTZLjTV7glZ7%#quZ&s_o|KE-cqG*n!4Nd zcWQ#fWTDgb?)2L0-a|`n*Xt`l<@Ne97S)w?1g`ICjAyekwxvh8(cJS~jYakHrj3u- z%+W1(x^~A4*{)CseU5AgL_}KG9R(rL^#>goDFJ9ey#NHG<>0J#HQ^y~lAT<(*L$Je z>4x^M-@Vgxac)4*c(&{&POk%sp}XUuxz!0m&vAL;_odmCkRY*ZUl5|#a60xny7WRe zg`b)p$WldiCn~V@XlCF*tu!VLwkOsf!xsRpg6qKiA^t7(9`JBZZ42^$uK;XsSz%ti zH>SdZ+8$$`|4BY>TI$}I0_TD_>ePzry&}oR&V=$k5_c2`$C*W%VX4FSiY(0}dCunT za6}o%`P7*SS*`TOQBkj;pXFV?gid0xMBamPx|1ki1NiCVK;=&^S_c z5+<`fL(+K<6O%z7#n6s^LlM7{gn~#;srRv(lj=IA$0eBMfEX!r9NTlO&c}Ia0g(kx zo+*boPeu+)$;G%NXqK)5@}wgmOnQAmiF69ckdA-==?F-Vj)3my2q=z@fY9g&sEm$) zxai0!Mpi4+D^Zcy#Pj_k4wioi!R=^FK;T9+CI@tFIs>~6vR3?%6)WJ&8o?<9*k8=) zonz`A*27X?$4-qaw2Jt*bL!oqTK$3YkD(yAqFnVtWbJ`tn$EC2lKFQR&TnZiB$g0t8?_R#e!Ej`gs{j0JLglJ8 zVe4wO>ZMi}lpwf_NL;@1Jm%Qj#q1YQDVm7OMpMn837~d54KK2XSaLUeiP^DzI07j{VFsT66r(9=YcrDrNB^IKs?&8;)>M z*XZpvTTRE`H<%-usyV$qQ)F7Ri<*NhX?kR`O1gafByDwjAsEgenyKkn(h7zZGrQB$ zyy1LjOlr+-`r$shPS|a%Nz6kp#O!g(7b91*&^|cqUEdS(g1Ye>CTp8MxMtsbH=3?_ zkOlAz4jqdklt0*VcEJ+QmbbWW%b;rGc?>^#CrS0nbGXf#>_vw z%F}Y*%}PviA*4mh7PXpC>sqa1`A@UW6AYe1P{lX+0s@sko?Ea=)+D~it&`SdZZh|* zC4ad?mi*;P)~s~`CCd2nubh)#9qKfR_sc`4_eO)0;Aa9~0^aBG4Z!}B5CHH1WVtSx&PzOi^APPV`zpAWzz)ryZ8f~x(cZqE+3y1kZbZy`5nlC@9in&X@`@BjnVz1?nO%hvo;$}#BF zi(s3=5v9R0q>K~CAH$3mZw=&NmDao6)(T zoPF?vj=6SVQ@oDdc3jWC+r+wawNr36-M@U>?|JErQUJubh+&4=bhdA}SgzAg%&D`z z45UAHuLJzmlQHb+x(E#B09f@Z}5us=tBko*3VSY&!__K9?a!xW!(9v z;OSNg_KlwJQ=JEqXVdork*KrCoLj-qTY+a&bcAwX-D&Jpu`EysW)xpk637zQybcs% zrX0}=QHkyLd@2P)-+`di+5G8IiQCXak{ra80v)E~NB0S?JP7}?fp`^knIWhdi3$N^ z4`JfBQ%o##9>&C?>?G(nJ3TL1ie!Lt)-&WPJ^T7}_m6SMvb2_~d+Ky>DlMs@r2mXA zIW&N{_zIxtERG-;Dl`~PGmQDaaf7U)LbJ2cjTR*@X~)n3)PDu_ee&M^=NP=j=B8wv z!a%J}$w(wV3_fLPGuRY1b$$a`0r@v&7{303nN+6Gz=Sv{nb0^Q!~hi9tp1Bv5)!Kw z9f{@~Y;>>bYi&ii1v;x?vx?D?Q5`)<{|d5x8()!~RK7oNhLsZa#~IbfMvY2Hbbz%y za9rmP9@kZ5{U9^0c)2ok|8WL(e$>FgC)q{_BC9gAST7wosHl_`nEoI$xok3|r17_HSMvTBF$VuM10Lbg2P2x66QIZOM<^Bu37aTe$;r(Wc=EYAXK=tdUD?nl zpD(BLs5#rj>AozSeaC45*)~K-fGRF#O6=xFvNCLvM?=Gn?RLFD)PQ7`^)Q2FGH6X6 znWPv1jk?11E>z#d?1Rl2I%h9I^(^PC0aMwsnKV(W-xd`#`C>NX+;O0lTyJ@m19t2K z@J&g>hTgKjNVaSWTNaCl(b^`TAp3iOQ>of}TxNZ8RKl#2rYR{78s}wTl^b$t3 zHopmEY9P!g14bj>nSnz8K3NGK-ot~Cs@oY@iiDDIsr6ET&t$|>!7&x)VweTvRh4uw zVoJCw7Sg;|k-Qs0`uDk5F<&8sQA{A|?mh~z>r1Ix1|lkyKAH`vK=tl95HKAZ`gajj z{e7Y{LO*F(+8Iw@vaa=&>sM$ttyduCiHSBHOJ!WI`DZb6jUYaO7%ZQ`ey0&DfhhUI zGNvQ}-sQ5DuXe7G?A{l-*io3>`Oe@rZ^H ziqiOgURm$%s<%y3(;vDiOovVVnR*)vYRrfn4SGNX8gEQ+Mr}j)jEJb8x2HuQ;Or|? z&i}?arr}vzwFP5kP6VwTI7?cm7;6Z$osMYmjX_%-lSZg)Qd#KatACD18_XygjHQ!~-EwYr@1CuZUpX zfoAVETP<2?8r|I{#w!Xg+DTLJNP+JIh80tm#~WG#&OhnUTCMKgTDu9WCd8`|LJ1bF zUZ>uMCOZv~h8j{5fiz~vm4#A*((#Cp&ZWENLnkC_G_Y=z@$~*e{%WC0molOmiS8*~ zLJ+)$C-V6t2Pu6?dC>54=_qXP4pdY)HBJjENq0Xpz|0ZqucT*qR7sg$I+~Ynz&_wr zr7eT6Y);H>Zg$`y;#MO{GL4-UI&b+T5i9*=DM0QX#|BMhmP%w?tes3>5Pl<1Y=#fn zD8|Kl@=Zh_=R;nh)oh=zjK`g#0$+{@9*^IHq>q1H;tu1G1gP?bF5ule>fQHgZwb~! z=PNa+LI76@A^+>5Rba~~Bzz|8K>>hDoteQT41wa%av~0Gc`aB+L{?2^kuJ4>&r1K$ z=^*A@pkE1~LGWA{d5VCc{jeLA-h?Tq+4jUh0`aLp+7#vCwiAstK)bxPN%?_6#7DIm zEO>f~X(5Wym!A=<#%lCF`af8?;4lIO&3eh2%9pG|xgvD!{OzAKRGuxE?M#h8Qvddi zO4%o!izcA*p)Mkb3XEmA4iSzgoiOxa`Dlbf?4mIkf!l#U%c^-$F3~*sq>~~Q;+!ND zuQILW!Y!7aN|74k3Ecael9`@FzsZC#_8Ryuo28~7CSo`P$tqNGQpabjVvEoVWv2Nz zNDXWR>hK)aV%#cNi@C-5GHee|TZeNxgykWhu<@BSy4|3inIHa=^j^a^_+a5BmGA?`2^5)LIiQ@Anoyf!Vz5SI zXo<%Y23DMW7^+P~x`V~iDFSd{bpuSMkR^bO_0&X6^Z8l9DMbSb)*3;|(#;u#+OLmf zuo6xxMQ`c8{$$N$;oy;oQou9#{}P)gTlVUnLzlyc z${os2rXoT*ylCvM4;0*(U?qHrrD(zzO!YPf3b_Ez3_u_l;{8-E11h>8sTvJw0xEb4 zcnsJEcuX|yFnkiTAqDtIyLAAq0sCzWBEt_HVl;h)9}K$@j+_#KMqmZX3#<=66@=i_ zfyhAAL5BaM3TOBj&Jxf0f5AAUN)suo;NO_WnLbCMqa$?@3E@nFzO*2!rw1Prq)0E@ zS{_*BKsgPOO_-}uGa~RX!8Cw;wWx4Z-JdLQcyrRooZ3DnqeYmEwo4ynmVsLf#&cZa z2oZlu;!mm0Uox^?R{nP_M$pdFk|KpsjOe`cJ(I*$fYZ^Bq$u23W>{NqQw2+{ zjt4n8fagZcBtmRSP*4MoVc@br{oBCbZMbD}4tqePDJCImx~*L$VaCUR>vgG3U$?69fD?a=*i(=dvU;SSN(a(lwg)UpEw&Pk zqJUY$s*!7|c4B$K{v=zG5fKw5U1D`fNw>Ho1vQGI;+FErPDNtCkOz9DLrasMJ=DNx z=?R;vb-JQQM6$@3-0Q5!JGc#^D&*mToNI(*`zqSes_N$E78H)2$)Ohl{aTGjK1B~6 zrEUP`x!Xxj2|IvxNE5`2K3H>}z9(|b2*Gh+p_4hw>5e^x6IJOJJBMV(iKpvNZ%ZSq z4U9&1qbl<<=Uq-;ETj%L5=-YYT+t@zA*UM%RUQPjnw}tb?=>+^4)i!M7wx~A0&46b zyxwc7L9|4s-xW$L2%j0#IsnVNwltyB${!1Jx578}Ju_l~s2E zm5MFyG_as*Z;P~q#x4VG(3-rmRBWb`rO}QI;$*W@Ay^h}%q5RF2m0-AO_^xe+ zTR}#_+<*3p5v65I;L=N>mwU-Wu~b>Pt(~c^gUm(~vCo$ODjpl4-|``cFplpkzQM;3 zfFS`IeWyaFK=>(v%u!l^Kx;f!&4Xy{jF%uJ8#Wh>x$xcgATQ=RZZw8l9o+!VdbEi6 zd&ozz>IRsMACiTL<}&{0@BlH4<6ji* zA6#fb)A4Vq?`V(4v||`lWH^|cmetZM5zpYINa<0$6u+AcO&v6}q--*@pW@Q$^(#Qr zfU))&&!G=QW4D!<8lAf ztaqBhJcBC?t}>wUOYFD4%@~e4d%p=S#rhYN( zDbb4YBvI#Z4c~w;i2Voup4dpCJYb6Oqi}y2X0^rWLeF-T5zRd#u_cB)sHnGc&QqA# zo1UN&AyS4|VD&&u??SyrVlCNEu0r_x&vUI*CNtWv1EJ*-wkH|~Z6iho_7mTS^i&R> zDF2%%{jYeSL|}t4l82{1Yw)sZ1%BuB(1zdDrW?()Jgt$d;j)F7sH7&*5QRM+mEgC4 z6rs|@5dqvXlCjP45w^KYTJdyAvPHv$LvUZaSFN%Pw5)GRCTptn52|wdZFnEvfCevxQJr2qPNAP9M9wtQUr=YMEaDzS3y--Q9TVyQwZ5B zN5KvU)c=+Rf&)2^((C`C0JpuM!b5RyhdCCGz4N--`+^D=2L_3Z>+gI*eFIP<^}hCy znQN#eUR+>n@6K3nhes$h;t7goTAA|VyuV{n_LY=N5P-s?A|OJV)N2Eg3-Yd%{#7dn z+SqwBo(@Wi*cW`?GNw`p+z<-K)ebi0udVQe+Qxl_X$ozazRsef+ffK>NA(p$8u8F8 zva$ZPXSslkKzft)ui2Eym;xaUTkuLQJf&Q$CU}bZc*&mURYbjgxRA)m3(AGa3|k-+ zlx!U<{ly&j=Cou#s}~r{jwzgB=U5co3GOd4l9zI4MR0{Q|C>3sx?~c$o@P{MseBii z@rZL12bGtEXiK0)B}MV6hwhW(8sl^iE*ByleG)$Z3#+`zI>0c-+;07>0M}V_(@O+_ z_`Yn~!s~I;0c`*>&(@Q?get|xk`Z>gr2nO*vjZS~cGEK!m{L%*-en{xyUVM<4 zSFam+XW)`s>}Y+Lf`Awwh4>B_smMD+QWzL@vSc^|LVV~7f-{kx14`eiQA^HDRz&Q7 zByA!s7Mn=ilx?q8t1B=V!2G#3CZ>Dv;{5#WD-9rG`4EZHtbPdJfXs|xchAI#E@pSZ zy7QC^zBv4KrH*F#UvT0tpssL8h7FFaW@L3>%&~SdPH`t#3N%^$7g7Q;wBXGvh6acM zBAHr-OHWwgs02Cz`rwIV?!Sim(X3{zn(p39v?}#K{d#}Rl zGa{kf@5?r5_g^+=4q{JRZh?!0`C|3L`E@63Y}Eo>3_lNlP`ToR8?5#c3}1JkGzI>)h`$rCogIUUn!N6egr(du+~O(b4_oso zcP!RO^2$@Fb11ifn?Pl>(q9~T*c$`|e)!9;@V~rLM7o4=B|d_96c(6Qco5f$S%ipd zwYBCtG$RZZHpFdcZVegA_y65hktQ;@M*JI`L0v|b{@rp-I#Ghemt0_5X@tWU!Usvo z60{vv+nn|akr`DgPprE%f{J0yAJXI8rrWZBgUECv4jXn{HXCfvuzchH3evdH0k01z zID~OK`PV;cHfUJv8W2&4Hf4+GxrznW{3jzB{*#a5q>s%Z7mk>fR?_ zA`vwNzCpgO!(a3&vD5&yxt6%abZ~`|^hn?5F}I%imX1Vam?bjJK&xFph%RDhQ-#>w zttOn%+Md&4YFxcuw?k(K4wTfFY2Ql_=6ao1_o|?miEcNm^Id`QNCf=PVs*91ehIN? z92G#`T_bBm`pTcoAHZ$~JS55b-$%_dce;RxnmT4BGKXnx0QR12V+mcLh`4brTwdN; z#>QTSdk9DQ6YVT0mC(2V>u((OvUh?^65LRMJ9#y2V9#$FbOaaNPb%E8frZcDP(iOyCg5^IIHX~Ksv$Q1-OocfyC_%J0se^&LQ3F} zurnB`@?ro)90ovipqr*eDsMq9T&)!Q-+W76I+NAuIt{3jU12F>n}XpDGw3|_)=CW! z+MblzDZ_TcN-dW*>Dt5fJr^$(js>ttx=FXAaYK`%DY6dQ=hG)SUV-yW3aKeJ7QRHN zJ|%T+8vZI`pzyruYt&xC8%(i)@CGk1_B?{oKBWG?U=}x7?8JqWLTssGnlZ!{ZoM-G zd+Py@2JMfc9kb~N`1A5&RSepzm7MblpoN=>v$N>hiO+^t@BoOvMT7^AYSXK=@s^Ti ze0?0aQ7~w1=ztJ{=T>vaL-{WQFkS?F_rjOnp>gSTbm-r0(rLVzxjIP>Ou$#KySgzGCN`Z5>ynSa-Iy9oBOiqt6 zX&<7ByyPYz^aHzMWcxs@-5q^qVnz-KqBYoNY=X-y@uR%WV^i@6+B1}u>@OGPb;15A z>WB`F*rfxgyD)^gVFPvvz~;~6;sdE-06&Qf zMvzuAIUgP;JgKMnh|};7Fbm=9|0T@wIGlu$qAF;geQ!opY#>(V+X}MElPT#DYhHMNVU9p!XOtr$?6%IpSdc z{i4q4bl!NZI~yY={Ih~Yy#dLo`Um0(3E0rpfkj!BgPJ61=o>8w%XKQft{((*KLX4_ zU1Vh6!yuLMwO+!UuHT@@%MX9~y@qe_;q)3FgsW8$NiFC@u840G*{@TIly`pe&3yup zXG}5rzQKu!G!$>*`!I?k-CpY7HpHO}@@+x42NQWb2rY?F1j7AGGI25c{bd=YaM7T< z=>rI8vO_${x?Id)OSud=O^WIjxoF-3lO(f+82#JSY~w|A>EaPvNsADq={S(`RR;;n za_K;-1hJB&s2`&-X?K{J`Xg-iSCbPmojD=Y%<8J6PsmZ!)V!2-=^4;)_5bIq><;@Z z#x6*%Jut?6^mICu7SI)iezQSgz%XsJ*^s}&OBushDm_4^dORe@&T6Kci#CV!n)I2h zzqlJx$1-$ek^r}fTEDa87+O9p5u*}&Wr}HT;p3NgW>J?V=U^o zxE=UHO_AJE9cXYsm61R!9aGZ%5V+?fOuAa{hzIyK8$DHlfQq}ETYk3#lSWp{AYz0O zUdYm3d*dzR8yBl@VsjV$7b!2{2H^&jMah~mky?tn9yK(8lvebERs<#rCaRZY%T0%P z=c2wrMN~n=z?)qxkTlLigwv5}K?ZCzihQG*F$qArrYHH~?*G~<-?oyNOY9}R4SvNS7Hlxz zmhz(Wcc8r&F~N0jL%~c^bU|p z_NBrjq5Vv zrCRM>SguXZc&+Ak8?~ChiaP#D1b8v!Zcko82){mTZ?genw}m%Uy}=T1GoZI*G(nv| z-ch-q@p@ciZtAT?WaJAHoMEiTpv&MJ48FrS_ZXv2Rx~m`dR>% zCG*8HO|TR@uFqvmmll)1$%GRO77$Ra8+;By=_n1Ya4Wo!w-(NnkCet2#^i5$VP;{v zG*h+~KDBVB2+eQ#%jLPEweWhWT%IY%E6sqF|>;rn`p+LW+3k%+hR-> zZx*p=8ZX+0!dINn{WA0B7|b)^Y}19%F)l+iaY7)TcCHz6IQ$ukG$q-Rf7`qEQWH}nJDX$?ui->Lw)_**iX~^Wbc~tibdzjp zW_n!RBZ*`A5(&GfAn_^4CHKVwd&(_`AV84wo;#;S5ClFXx$hFJ;9>euhRdhdH*mHxZ6x;&nEKi>a+rT$%&KJ(cBV|){hO`y~cwlx*h)u5(!ursHE zc{Ql3;hfr;Q?;O>f(130SHp(dX_(`t8Z4+`Q|&CO+Vx;j1xxA?dRbE8oO+1cmeloN zSp_R{c149tau#Rm>Snt&?cUJw-6#mXBp!y|fj%!}I_rLO+QHhn?lG@M3EFs~WcLTDE+} zPr?tR?N@bks`>FK!h?pm+8hPG*8X9i_vD{NyOXrgM+g-`E!ftnDCi`dS3gy1Pm~6A z{RAf)DqNu6pe<m-)R+`QIB9Bv^=YVC#CV43Z7EIIoV&A#^h+) z)D8}xQ^E7n5?5bPfv19(rSW%E@Llz2UhSObE8lE?FKZcP8-@CR2L)Y*RYrHiQAA4_ zqZo5~>4BdXR!qF1KRWc{J)95ohL>-=n|OY*?BNp*$LXQBHyQP5-=woFXK(o0ABJhD zO{0tMBUb}I^}8c9NW9m*L6oHIKS{M*WT)G5%|4W^z3~1x?5APi?H-mS>Bt=ee>WVQ zygzB49M_|FHJ%Lcn9%F>w&PLQ>v@s9yc5o9D~KWfC{Ddm2(e73MxJ<*!rxCWkZBLq z@zwhc>&N_^v)|;ev6JGjrXIjk>*}_+^@Dl!P}OSboO)1GDGusN|3IY;9>U$esHqX| z!PRaRbqa^i5{V;2!6GVn=p7zdGx`hv9t!~MJ*=o0}KU|cS z$$>u(6I?V+?cu4stvPNtviT$&>}AWP-dD<>k)s-3T|)#P7&sfK^D zot`a^!%nWIK7$@V{N?v5z6n(c{okglYtz>*yMbZ`4r8v23>|(piF7_PA`Dmwo1Jz| zpF$V-IfEZf(ah#)?w!SE3*Us%Rn1MGsz_F-?`HbJt~r|-T@(FmhAVJ=G#DndVT ztkBttXv}!wAC`naR}sF{=mfI+>k_B)_;>>G#nZjn^1JZ$E~A-yFVBTBzuynXsSRY_ z#fx4T!G_w4kPQ&+^Vok6XJG_y5GN+KGic@Ef6Iecnj07ycjL(@5Ok1gA0c71zeP2V zi~tGxL8R4TI?*G~2NYl`adOIg_j;dv+Uxa9XylXNG#kg8p;e9!Z#p|Pr&T8^w>p6# zrq{la6CLmMr;nF>)jGCuq3cfcx#`mt7pgR?$j*T*frDcUx#D5I-6Mdj`5W-2>tPDV z&r>sv5htPN>(Cp=DUd+q4+emB@i;`J)CNEQc* z?;`OKnHska@lpYDw(Ao>Cx!PVCGpGtl+Cl-*}*7V-Vf8RXe{~A*<|TRv@F{E=jUgB zas5}*M}M^ac7Flw`Qfh%>hm}N9!2u><2j|?MVf@pqbY@l5u9l-90Jdu+>Zv4uMdHZ z{4`8D-jzeBbkCm*QuDI;xO4BCkLCl6wG(TKQ1EdLxjv6zIhxVIxr?q9GroBh4NzCS zSOMyI<5f5qQWfh2*S`GPfLer8VNh7e*ZbB2=$EE)k7S34W{fmBR%3Y2@tZHL0ZOSs z`p|tK4@UZ4(Pghwh$;n2rSp5yU_c}&0|J3e$3y7dPfNm0?*G<>b6tLG;Rb{FUUwKJ z0Ac%3-fkQZD(?s7^?JAT1R!_ISRui?GALV*@ba7Q`h!WBU_`<#>8XYFv}d(c>ACOh zyhk0HIFol~##5n0K$(6hknuMPDbJGj$Qu_6DZCK141u~o(4mjgw767qrELK~?FB(Z z*-FtyXUm08VBeV_Tg8hMO3aoES&twp)17q&l;;+|oSjx70H0|aRigKd2r>PykH-X<)$J_PO{z1pPI?)_}{lGZU_63lmYr zXHKt;>FG*dG}rrxsq*13zgO{1E@ESlG(zv4MWs=CBL<`NxsB>a>K@cn52diG+35T` zFvwf zApSb@KL=A9^Xr$?gNADH2Za^PFIDuz8VE9Aqz6s%Xm{!=d4olBfO@*7{t_Ny3gUT% zavS$`8!AP)EEP6)&{LBwd1K5Qaf{*K9^tt)bq&`{BewX`wy<^gVqSMIaR*Y|SyFnr z#`eOz;V}6Zdy9S8G9Si$sp8q^f8ps9eCFH!AOR$(KD0yu(NuCqSC9smKDWb+8J2NlAg)WQUpO5eM%zY7ZqNE$B*0a{>FL5e~ndl$kE&cjt;D*i7WH8vQ^m)NmZ*hJ{Nm?4yc4v602zJ{a z#JiG#FF?RH`=v0U-#ef*&bx>IM%QVv#NXZK0zu!q#9Vs-o!T{owzUQaztB z51Mxw;9bw^omaEk{UeG0*UNA41vs|f{0DZ){+HDB1(4991G4M$8CU){WzP(hM z08w{xF^zP>l3v2Of5g{H)LzXNC`K3vb(}SU?#OCor!Hf7Xhx;Cw2pPwf{k2^&(@Dd z;A{@AD(tK#r)Z$l8uqd$yB4ZfiESWO@TH4}$$f^mJj>1Zx%oY8W&b044krwITF+ zWBT2S@@5-3Ylw4abIE>?tqwxtTV0KkqcSPvG|eGs%awb0&by`*lou^ZThG;nRsN__ z&lRPd(}>|wz9~Apb2kA-3{AXcuzLT`Zt{fdmN70*R=rJm;ms1g>x3G(0sLhC^RkMT85aD!G5(ML#Pz}1N9z)J9XurbWy z!r=&@sL&*Ahej2dox+F^IBxowd4~@bek1h1qg5>3P~=!!w>M0k*R&pxExC@fb+>Dc zqF;vR9t}pPgc@;8|J9A2jp;L$(Nk^rv--3;R)bF3=quc~**6Gwjwd9&G_kjWZ}Jj0t<{Cv+0z@1jYWKyHWpheYqgChTg}$l1q9@^ zS6j`6n#`~E(b#?RiZ}UNoC+UX^?>ruEGWS82Nb10H7g)wfx}N^X^2}0vV)UYHDcB^ z$Wn+vRy1HC!y_DGIv7{6JYii9F)54%2ph8c0Jor}SqV@8Cqrm!uf37Sny9U zX=Ko=|HHZe=S2v#y@@hzJZ$5XoA2m>!6JF_<@)##82~K)09^govHkgd76SNp zp#5F~zBEjV{6H$tfPe&?_b%v62Ill|48UL+Fit$7m7SMlNX((XsW_ZmEP}8oC8uJ| zqgWo4DO%WzQ>E^$Oky#eyPiybG1NPX3Evw{hT}s3d=PTXq6o)MI%CY81*8}Bi`gsq z2oPV{T-%k&Cvs>QY%2LVHb7?OMI>ZDW%USf#8dmgYd8k)HCL-yf@+Q{M-0_u$+7pi zgapuz=dp4O_ZZS_*PaLXC$)rdVs=1EnH_Ll2q8)+uD@3)7_SCzPp~f?lAOZQ72GQ; zO>Va5x3i`lS%SAU!Bl%iFX2bO#|_t9WDN=M8cXJs{#~AEc8w{pW4sjM-kcHsBU>b; z>F;7O%Cem3(sS}zfq4!P6Y#l^Lms@dBI_@4a`xbWk&;*$@&KfPHtu za14#>P{@!zYLVED@Iqe_TzF_BXu@?Ss9De@%NTU)CSgn=B)n;94*eVOq(avtq+Uae zxWXA7Cy#IqZl-jE1x^AYNzP)`W;F1%hgE(B`Bt8hjS^GaejP*=!pjg*@U(*nbj2|A zF$v$x@9uf2glMY5#6S3MugCRUopSkHCTGRBw%ZN;J}j*{BQG^DDkn}T=o>hZPdJz* zV4{F*F@I)(!vuh3O+1J*Gz}O#;Xqa^>-AQxRbQy(^LC{%Ya8F{i?gCJf1UF^ivw~W zGR2f^Fw!DQnz1Vbw(h@!$8ZTo4kws-@b=>)vv6g;yyM9zwj;ANn2X8-j+0{2(^S1P z7yoedn!##Po2<({EM>UiH+CTH-@?PMV>lbHdW;)u6X43^c4i+{JKeHcq#Ytw%mo(j5iqn z1OX;4oa*MWw98DFyUCKf`G|95i+TH(=pknI@my`KJ|lzYE&FIGA4Y;Xe6Ql00Ou4L zoRfM#vWF%ha0{nEG6Iki0G7_H+rUGZ&U>HG>UKWGXcqBWG}ct3C6z9!hZ0*5YPgW- zR}kQ`S)k=8H8a;{<4Z1SHLl!((M#UvY-H8R;GBnJi=KKpi6$}HVk>(RE;O>;k%KEZ zhAMVO%$b3iv1&(HBd{KbRAsOW(BTBLxuFc@m>b-6@?Ti{Aw)~K2bKVAhhMOc#PoJy z0TAc2%D~rt94RmfyA-XH-jnW`!FKN>4Ch0NJX)*^{eH3I(JJPd;6b$-LdG&|b z|Tg8WC4nE8*2Df|L2&l4td zZRLuP87rQ5*>R>D7v?`1ESixZX2F{RNz#@~(v(J5VMv#^XeDbP{bTGOu}`+f zBqT~PLH;e;N^;gBa+dw?GbMC#q7X)9>(Au5or;j;g-b|WoH{zPuZX5 z(DE^H?MftYY|_gOiYq{=K10|ejj_O8>R9|%E7|s{zDC^;+3UBtxz5cy+z_uBitW$2 z$Goopf*X3a_PObEGn;#R510N8--N)Sb($R8xwY29+Ud2{+FEO4ZAB_fRGaRB*~%`c zQOhCfId=UbH>CeErI{@Br`(X%Gxe!^v=M^!Vz&5(i3xAgmNNJjFz7h-vR;4km;VO> C@U^A@ literal 0 HcmV?d00001 diff --git a/python/mxnet/module/__pycache__/sequential_module.cpython-34.pyc b/python/mxnet/module/__pycache__/sequential_module.cpython-34.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c83a17edc6968b412ba313fdaecc3692552cbb9d GIT binary patch literal 15162 zcmeHOOKcp;d9I!rK86pGl9nrRMa!*SYsFbpBTF`5TV89!wH}ez%UvwF-k4rura9dt zTbh|3R`-bHL_&Z_b}qRDxdk}~@FmCOl0$Au4nZyf0(4G64#DP-i;;^FAo;$(`ZYsR zyS9@chYZ<0)zwvhz5eG{HUBmpfIVtarL~cHjz_O z>$U3OqvpD|-}kza@3h|!+u?)E#|co*y> zuR}WucD-oN)Ai5LW3~2K3HO>6Pkh|eDpT57Ri`k71;VG{*C3>YKC z9mB(bF+vQr!l7=CaSXqw)D8+DXH5MR1u&j*g>hn#n;2x3Lv^cse2%wMk~5{;=`?3r za%Pk}o94_&&Kc#-r8%>bGq2pUX%60=Q|@^_Zq6wWuU;X2cR~JNRWAN7DEFdzF{U=p zO6HrxW=b>-fW^lO~D_1dj?estZP*D&#` z_GtD?xxwigj<7rZcI5Zk-f?c2nzHd(3nI_n?)z=m@9qkBb$b2CiTt2z?`o&FXCL^{ z9(WAA_t@{cWGuha^E8S)@e0g{_iZl>h3HP(JM_2FsDld3wY{AkzvY9Q$*o=7JUcw< zYzJ+>Wtyqm>w#${+za|`*WUK*pxZvO`=RIB4mevVqq2*<$Zz}6QC+e&rz7XF*JwN2 zUOU{JcKT7!IPf)w(;}zy!|w{NGOm!!9pT?nPr-yG^+-7RsimGP%i2}UZIo4#03Vwi zh^|SDt4L%N91m`bCP@gab{7}^Ss6Uem8`7AH+o0WUeLYKIqZ7T4WmYGgqgZ&n4Eg= z=oeFOxBW0`1)biVFR{}V+^SaD8eAV`$9kjTcm1f*sK4C~T24E>gG5@lHWiPBeQ=F* zF`i7zB4yeNZAWl23ASK$Q}K}(`wS-pyqziX-e3+fZTuSoq8 z`No18CGUXEOae;NBqo@@UMo942luYJkcCX|Q(z#1)ckcL&6Ma59qnK$J#A!>S=Y0V zNDQ;gwA;7sT~MDgFw4^Yka@jZcFEq zmo`fGI}m@!v4-e7PmEA=$!)(&GbAc!s0+vy8fX}$6P}cgwfb6v7_`IVZg9}G6EdxK z-}S$lJtwB5m;*hOh1nGf%<=e3}uf}x(+Hfh*K#jsk}b^}tN zU?keIZd9q$Jxb@y?~W2lSIcfu2Np;!cY_Gi-GZugw2yI#>f`E0L2ragCkP|F_i8Jm zUQy`Cb~xRn$0{L~KkaJ|a^Sfm6n1pbk&=P-cm1x@wkc!wwiC7X>~z+TyB`G}`^by> zx|=U~!e3cC6vym=%dkQFvD5Fi#IR^n-EOj66Jz86f&=J!radR*6bN4@YM&(d2|ew5 zA#G3-T+zflw%cK$U-u534vgKgj$O|>v)Z*iQ;1YR1k1>kj1Gci-J}n2vnW=#Z)X+R z{GB^@(1m^KEc?55=dg}MeN=fY*G{WbM{~=la`SUb+{#M~Ui-LP;PTh6hOcAc?dYf% zymP-csn0?g#p6z|=XG6ufe&X^vF4f+eNXE^>vK|K8;tL{@wpz=Lnyi$o~1uP-a^oe z{0@@3m|d&ri&E2wMm%n0J}yJ1jbf#F5$zh6F|sRo3aMt*<@xfg^+svFLj77T4PF_Y zB2GjrHH!SEVf-Gh@EUF~&eV=#PoXy-SCoDu$-(0y)ap9y=6bC(xVdW6#Be3MV0v&E zREF@fhB5Os)W-G^I32ua22s~+4d%1l_PS>3ro_HT^_nFk9vPhe3aWg-&e6Q8@`827 z8eBRyq@2vJHnMVIWR<~k5X9wnUwE*>b&(S?DvVKj@vYcGv(~hGj<3VImnNpH~qzVV|T8sS)@wJXO(4 zobgM%$+S-kN0??$;Vg3d7A|qC)YP0cSUNW61u8>RPDU1yp;vR_l*3#c`srm(EL$t>l{0G-HaciNvQV;%@P>suQu^Yy?!$0UBCh6RfpsMm5czBKt z518(Oqv2>8(@DQ8Z&swIxpsF|Yv-t0vpaAvAXL<*e(qb5AfyQKh^KcUvBOlwC53Po zsESCY;tn-L0{qk_jJ#0*M1jT?(n{>}t`{X^iO2grs6($dZakBze8}-sK4kr6(1A;$ zcX8_c8z1PG11qsFZLud7co5bfE7iiGK&0f;b(syf_N5h{U!!L)!Twtgr)IhD3D zk#Z>*70rS2B%6CQ$$V4o>6Cg7#FFo-?8FK0_Wuz5m}2<`U<-`0Htp;~nLPvwU_8sh z+7}r4!;HH#2FSwG+)pQ#V5n$BnK{i|-RFys0_2yB9C_-lx8nf%vYW>2H0$=gLpTpK z4LL^|hPL+rFYWoefJp70_7N0CM?R{t5>J^)nDUk<0mw;R&)4%luqCfqSc0{|3gbYV5f0h{9S?&_ z*oQMgiVJe|D+9(5fes|m%$v!IbrUp$rfr$zh353#!{F2V0K5{Hi|t#`Lt)cBB6 z2*cwE6*d{eKqLs^AhRrBBhv8kA&)x%f-X?bA#4@3NFY+69zMwX65bmfff}aN+xruc zwhDx7Z19bTU7$M*+p~Zg+I!Dy_YkytEugf!w%hLkS_NOZCNioImhR?R*<_fPx9!23 zbe?Snm*j`Ga|Btq>e|a!!#{XvEm?z`$4oodGcgf4_*kH` zQLAdPElVI(^w0|x^;>usS3PW>5Zm@N%|u)gd~w=@g|}Ec;}`vPK7NBY+8h0C+~O() zN#ZOdVS`O_!Uv{_866p`zKtq>ge$y^n>s&lRjdVAq6J`*)8&~eqvs{qr6tQQoe&rl zI=>N*9YAY&t#q%3VK^_bN`HwM09*tI6``zHC47l@W8yNf8sW!v@SlE1EKXTHo>bvG z^mSmUx~ML2Br`t|G*es}dPTzWbbG+;V+_^9Ln(mGEHQG1&0!!6i#K4FECpKybM_2| z>Z}DYoqZ7scej-FDQpr9)wJZoEKRA!XEQ3Qsz)$2GXzxizcOKR7Wsf0v$UHa*p+_b zF#>^#>P=dZpa6GlF-iB(t7m6e|B)eqim-+yrG4w^w0d+#Jww=vLCvxMOu!}rf+(r* z`<9uF%=AWcEXD%BeMxsxW(riW{e;e9kmv*&n2|Z1RZp=?g|-3l&QUYC^EAEB)r7Qs zmi71l1a4tfg@2L4!Q&*!wW*zT&O`gWOBF|TW;Ra*s=|I?80kJxOYhKYVfVv?y|5(= zaS}}NC{>Jz!7`IT6cU&`6kVr?WTrblsi*c!+wMZh3RRvWG|TAaiXA}Lc1k(-Ap9-?okrL5+y-{s+?QEDX)`r|LyfW~_t%6a$RR@3;20!L z2OoCW0M=ds$)^d=S3p8$@{Qop9uUL2{cfMTcF63&M?&DujRC$fKsYu7+a95PnjMl4 zLYUs5n{gQ%K$MXv{(RiVzG_dV=JVz`8#5!2-0r22!5$k=V1j)5xz*qp%Qe&+fey^L z^3Bb3Ey1Fgy_`tq!O%UCX*6BV2ZGQPNMo{*h1Gjsw&jOf9O4>#~GgmfTI4B|Cr! z_mj=s+!+AqF@=ipWa=ixRs-9m-2coAnO#h7URKa_NYJRN2DQ+Dya6^P~Dl-oq0&Y#z@Jv3ER~GF$DO3EC$}asu0dgx!pD^#+?F2C503>F@By zO$f93J1<+<8>{c#Yy4pK-FrV+k1GsR;xYp5*t}f*c;kb{Cl5b*xb{HM_qaUpdnksn zX<|LKdE=2~W72m?{vL1ld3(UyZ{oIjrgK!>LEgNG4Z7{XJ!*JQu-!pL7McxbxIU;u zS1zt_KXG#oHH)@p^E@&}Cri3Fx=P$E_#0f|9o*E?yams$G!DP+qIIQIEmx5qhgV0R zb`~i-1A|y7Ex_NqjO$gTW~}q&S>z8EPF&%5O!nC0a=7cpvu)4JeuIu{i$MxML@3Bt z;sgE;5{VCZ4BBFTS?RawG0pG`&)AxVpCVf@;u{`*U9!g!>nsbjHKrcn7{n=H#d9#} z{(S^2;@sghg}5-2&O?C*5bq*dKuYcg}!3I;!GI;ZeKx*>K74NUvBX7(;4kOU#1^AbE_dHLGVfZL47TrNO&Z!W_9BEESW5u1alkMQeA?un1E4sa~Pe; z*%JMbDcRgPBhgnzAweYd1`B-58#mJm(V=XWX8Wc36TCe~JudsD=Pp7+P)9KK0|xo! zV=_P7TC0re{7}Rx?jqvKsqj?e`nj?jy73ZTco8j9TRZvC0qDTGjncp&-`r9t`; zT0OTc#!wL83&l_k5s?nnhYSr--T;Rdk3L0+d=Lu)ctnYBCI?j+j(YiOwLy!crrjb%9uC{6&kZ@|C{`$hFPyo4*ubOpuzEOF^u`u z!WMG~;`^UEzeS`Q%uctJGq4giO30u~@)tXdg5o>`@a& zN8@3^(O5aglISvL{3Kr^>lpbo>)2I1BpthyI7iQbdkL2Tn;7wou;oUs##Ws4@molr z+R^v&j;__tjEHYsX$8F_8^uBsKO_jXMl6Z;Y4TP$0%}x9pqsORlQi5D=k^)ON*sSY0jD}ziM59$sb%iZGKx^ zTz)u@ly{N!pCWOZtP}kQ;#lE+0JZ@VsTCXpQpQ>WSs;!7yIbEXxTw$)IJ}5Lc!$G~ z5P!%$MO_xEuIY-mOiN3ek1{HC@i;Dn;u=6NlhmV=ab?J zU3;18cub4AdIKXgs&;ZaAs^3it;7Q}bdwV*y2-Mc`!>qT;C{8n1u#hB#nNr)+8Jou z6VCmmO=TR++2l#>k z31c$0PRv|(d_k|^1vCV?)&QI%CXDZY#Azi{s>&4@ufc_5Tx>S%GA8C!mT}2+-az6s zCT2dY4iCj(?@QDKI1=3E{4Hoa$Tq<-NFHW$uyHdy=wHhc$cAP^j!rjt+}=BUsbeHs zKShEb7Q7jAQV=@jIuUocFJSWyQ7)qB253lBPqv11=u*g`6`)E>5C9G`JId{1s5mNE zw^uU+DmiSC_+>l?h~ESp40~LOdnSW39#=+Rdm%Q38yKXi^^H%-;#?BGxp=Zs5UUdZp9N=8rMC zH-~*?Q@l^U4hn}9b^pUGQWD>n0Qnf#%CrP;T(gzhR&VnRxh z4Oq|i3^rzty<%^nhqId8R4+K+#U4$`AqkMFlBp0pbZ z9_9%<2nEx3vV_v=CK7tO!ear*C{N~@)5~H!cw+e<6HzT{tk;or(tsQX*_8P-;f;@6e<@~kM9#ZdXk{>{Xs6MSFFJ?w6Q+o_hj+k80u2K z#hmUVkqGNCi0AMldM8Fwd51&7A1iz(5BuOL0?dXB69&rjFPXg&=urV-nlh0og0Z`L^3#ZnH&#&^BIHXGwlLfFi0n_YTe z9nEYIDR*>i5Yc42BRylwKF?@PuP-Lycuxojw;YVHECJ z8%D{_=pJ5{9P-3@#i$IYfl|D{+F#7p`gNT4Fz{U`09pjojU^?Tm-+!AjshGz!#9_> zv)Ce$N(k!cR>oyKpEz8EgX)hn-x6^We9PlfF+%lcrNJ-~s&9e};>b3?F7Zc4+>UQF z@Y|tevj*vzMq|J4w39azjfNYv@WUaloPHm-c!Iy-BD77YM!(LiNkrkm=A$DM@V&_w z*LWkmE&#Qh9?+a|+2q$mga!1`3HB<{pgWM_O~2=Ds{PJyV@qm|K{c7YabP zdZb2Wp*MMxQw{8CJiQGBE}0ok)nb@9_xmNq&cu^%8ykHG7Tz%BcxlF(Db3^WzW|6O B7)}5H literal 0 HcmV?d00001 diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index f998fbc27d6c..586f1de31858 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -849,9 +849,17 @@ def get_input_grads(self, merge_multi_context=True): """ raise NotImplementedError() - def update(self): + def update(self, storage_type_dict=None): """Updates parameters according to the installed optimizer and the gradients computed - in the previous forward-backward batch. + in the previous forward-backward batch. The storage type of parameters is casted according + to `storage_type_dict`, if provided. + + Parameters + ---------- + storage_type_dict: dict of str to str + Defaults to ``None``. Desired storage types of parameters for parameter update. If the + parameter gradient is not of desired storage type, its storage type will be casted + before the update. Examples -------- diff --git a/python/mxnet/module/base_module.pyc b/python/mxnet/module/base_module.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9f548f4135c039121e301ce9eb1be23a370dc13 GIT binary patch literal 40120 zcmd^odu&{3me(!Yc{uI7ADyIMx6_kxxyPya7265zo=GR^*{~CLrjqXLBonI2Ro8YU z^>FImO6;8&Msy=UAa-YWb`eh@SgnMFb^(C|D}+EG0pcGJLIT7eAR(;~e}gU;NK;s0G1& z9CJafFKFa~`#90_zTkeJJ?{_h_Xjon8weT$L32-Ve@~F(^&Yw28{FTk*L#D;zM#23 zxW8Yo_XUl?pm`v;e?YJI2aPv^=1_2dNUsNj#=)R@D7b$J*9U^H(3!(Q?F~6O64Zv| z6YTn({1J{Qk3;#BN*vv4*Sd|U9M+=6dMk>< zN-Mm2`)0V2xY8+@kcrsB~d4q%M0$0T(%b;g;zoxqiUk%i&4_;v@(UIU5FZybW;1r?#;^}lLHuJ9fjhg z(yB(u04`a^10Qv9h66domvDA*x!sH|b{D#>qSSMB z07{hfynyGNlTsxvvG$`lk2&?h8;5fPeO0zPh#$+y6#W~VU_KrRx~&fctN$o?o(rx& z{CmOrp5QtD4Fr$(2AyvNPmTvI%o|D%;PH2Y^}WIKy}=_OE2cIV40R^#b>D{T{tedy zsp~(~7d)8=*7uRvp7#YwA6rLtXt{>k`-A8Ec)2(boLb)>Jns)W|1bdPp6}<~Kbm8& zl0CsA^hNIEaDA_~4|v4^3Jk9;N7cuQN|i<%SY{c zOVIU;N~0Uy=ycj0Dg3O`X#vS&Da4CBQAwbpFchfB_@sQYwi5A5ZYRr~_A_}CBhV`3 z85_nOX|Gyo)#^0>FqZx_<0YX5IlKlz977YOD)FEs;MOTP#}U7cL(n&rJD3~F?Zxqk z{Czvu-+vC*hm~LmK+=#94EUK)0g2>D0Ez)UfCSDEllKYHfJF8LkAbQySz3dw+|FL!dFIxI$a@+94On%(dioyKM$CB_&Vi~O=;Co6PT-)3N4Ce#DA1EP?4`C?Jj>FjaR3N{U{#Q*Qcb#* zM)4io*zNh3?w4eMGpIf$SpO{yR7!-|NGS4|B#jhdAnzc8=$* zL7BpV;7pkE{jN?~y87 z;DEw4(M5lkSl1scQgjf!`4+~oQt8A|D05Lsk`6H?X@|rHuHKX$02jy-wNPk9d8W>{ z08EINuM0glDhp9#K2N$R%7O4Iet?6NyW8pN8aEHNBrkkgX~a=pO1;NY#005CF{2pr z=8(H-@Z){l*cBuof|djXB>@3xL#9_zU@CVUNHCQ<)ptB6yvWO0RlF&Gh}56p7ymcl zY6NKxiyw}CL4*m(1z$nH1TW|lX%6exdGjyA>V^UX!B-G1_wy$2>6V&1E!15fZFIpW-e3$0lrx6fKVj_YVfqAtPO)JY-K^9um z8T*?8t24I4r={@Opms*eycN{W%E{YO;4LZqj-0 zg;L2m-c;}s!T=E-oeRE#SpPv#Q`n9O@IMgXe@zDdHNH!roezq4@*i^2e*@Eh7h))- zt+3jt#4%Reuwo=`BZ^)%myjD*I#CR1m7-*|-GtOp5uujnAc+35oyDvCq}ky2d~d^K8)s@)_a@nS))3H2^is*mMj1%0HL1B&jh0O_L;_|Nqi zk}|~#>3h`BYVSp1rQQs|5t^09kqF$-DdI4y#B03UrQ93IasWAMjqg(^uM!&1mZMhq zEDBo@)(I^u59`g9sDoiYjkM>`QSkIxyYpBjqh_~})K?mj)mp-QR9Gs6ani0XL)@xY zv4X8Yf&vHOOCuYBj6z`?Jw+Q0yhOR-ic0tw!kevnQn$CSgtww*yMs=vV1)s+t0JeZ zMagIwx23lR`lbv7xf*m4s8s2hKzRk%h47Pl3+v~V@QyB>EGf%zJQ}vU3FaFo1bqtT z6gVk_K5n(2wMHSeNvAl9g$NLI-5M2eh4MOq?0eR;SGGsD{p_k}Dt-JUVM3hGi;ZDfY*|8#XB-1KVRvC2;~0t0<*43k2J_&sLIp6MX@c zhoDHijjvmVt<)p@{!M0LSZmqo^#=zph%eh~p|U)gW0u0YIcqTp_$ zV;%In4)LlRCmdyE0ho`1U=h=Yu?_5e#(vgYs1r*sbzy+E4qjrWLa{x}>okQ6 zKt;ST;;_FgOPs9TsVdjSgow}66C1JF&w%RxfeLi&lB&oI0Ymf2ub z!!XJV?RKM#*#R8M7-F(C&Y0qp(&<8Xagm&pTn@nDe5Y80t2YdnD^IaZWGGtUM{59@ z#haz1Q^8!eme6>nDKX!WL>5H(VQp)VY!Ct=a3!Iq%DfV70*Qw@5=79)VpXd%2M z%naCu#cr|N&~*)qn3{A*E1d9l448u?0YrsgY_RwO1_8c*QE_`Kq!ZFq=@w&ALl)~Q z=ShTEjkunSlnYqGh<#YLaIX%NQ5qKk#!agQu?vfERY;}=0x;Lt6+#NtjSo*8?_8lsUENspxg!bH*KTc@&j1k>KU$IYe#C?W2Zy0qlj1{O~gjR z02gLCd0K6^63`;$v_>03oijfgvt!a;8Nrh7nR^34XH_CDeQ>%{V^-lrtvHN|z1>xFP87TW!BacZ zj;%C6s1p4`uu9TF1#%VUpU6u>E0^c5UnPo_%a9~jRzNj{@a{6+58V^s?la&Z7N)0| zgm%kyHVm)b{+z|Z@Wlv_7BxVWX|79hROy4&pq5b9Sg@C;^(v~(fBW|5@kkbh(Qh}s z&+3rG0EtgtsV|<%O><&z+|F{Bo-Vj};D+oYYD4&34R14|Vv7B;G*MT@de5hB2zu8r@?DsJV0+Bx>UkK<5Cx zm1WaKt~)B7<+0*aT5G(@Gtv1J1>Hw90@J>C21tgmXp-TADD^;#dcl6s}K{cEG}v z>cG8#VqCNEZl?h${Ig6jIzTx?>0-z^;#`z&5cjY z`kJtAk7sJS3Mt{KD6x6GX)EbOjV`M}@e=jYxpTa8&eyY8op@o7*p60)*3hSkOrLOX zddgRZ!V_C}Y9ie!*7Jf+O>WmISJ`59GSj8G@iE^a;8nJwJMG1!Sy}x|EJpy;q#s+X zk5aP*gxAsFq;N}^23InzzDmEJ{avHw@;f-d`vf;{2_|+uf&V~H;8f6)0eiB?p6sQ0 zn{UBNs%2n@hW}pt+4JI1%G13gPMb!1X({RyX{(oO)vd*#5BzufZbCb}E<0YQOXxTf z^37ywL5Ep=`llU`<5!lC@FzK)NF^~c0<^@ayLJL-Zhk#HY z?e&L$Blvz_us#^9AD|lzUIM!8ejGd>%}OwHC;_(unxj)jCv%9kIP_zvO9OgzaQ^~ z&#u^of}CA3%Vn-{;!Ge)3fmQ^{m@MUbTPOzhNumshu1)PDw*RD9iJ zmStsg;ax}ySi>6OM~KW=xpGk>VG8y3#l$>`84pvhn&6nia?)&kTrN@di6gJWR-n4n zE+p*>_>U6P1=>JT4c~UQbk|=OR-}tHyXPPfUJ*y?O;|Wwc^0^#VgQ8BRlW4I5+rHc=m#HO{y`hwH3CDu~rFuhVy?h;l1;J3x)Q>o|dW`4<-y3PC1n8H@ z%*YMUEvuYo-n}LUEqIUHVhDDSu!U2V(&d(*9-4DAb!10D=QPj^7PgJ%Uf9o!%~`8T z;W!!x@1g;{0%aqm`ddVRxQE+*gfQ#Fda@bT>AsAqpq^1tC^ti!u(Gju@x-lW0R=tB zXdu2ZcH(+hO#9YeoUR9AD~+LTAUh_OlxItY+D#w_5FkYW87Y;uTl9Bu?Io>t83h-L zlq|x#XBob$`ktYTh`gZ-KM71PtyCvOE;>Rq{TaJ^6;I^5P<+e;h zhQNzgS@B1#g^~cIY=rrgTxG>MKA}JDCP0oam0-=YHTl#$8$$>hfpeY-B+blLvZm|#V`IR zI8bK56mOPe5D2zGj#>SYh%yp4cKzX(B9jbAh#6v+HH@fCIRVlJWs9fORXhSIPNCwxjiuqGN3rC}uD=W%5zy5hKpWhWQ>YD)42iY5IVA)1#OmCkv6O${MeCTmy2CeXc-q@siE|6#>K1+SbbwhqBc=Plu%hFmS9g zOiaBo!eoZ&nf-PvvNop+vh%;P8S53@xFS0+6Qr3n*-Z{7Uqc#tEkDA)>!3vNW5%Er z9X~F9is0(Acjp$)AV#A9O?X=nnqB!>&sFZ2Ye=O&hf@ z>iNb(seH#EQA&IhvuzoI!BV7rG>TA zWo&uq((r~OEL}KGKBPVXYNa=b?n1i(0(=C_9Ohz%7{ijmxQdpD$kLkWGRwHyp$N#c zCbR_v+DJTeA=II$%g}gCWN(8RkHnRh%kbdAB9a=Ds}*CD_!*9oS{}8W3SH{hu3-$k zquBU?VM&b&26USa90Y-}xEVHvx+4&qF1-+}*tf%t0W}>jgti36I&B8j=HlUK zn0hvWORRP5l{z4>mcdFN+wg^js9J#r=l}su9b-d@0Ks~~7p(8WAb~E2-aDT5VZg9o z>zX(Oy82zkSXPYoXl6A9rMrN-ZoY9#1=a-7C?!rtf|8tP7&_Ua>}U|5JyaDRELfq* z86H}RKk&|sM~tSTFo;PiRAMOu4_KM%gy5sE`}hCi+gr=VA`2z^6&BwN`R`OMMzl#3 zAvg^pok(M8w(zQ|wp-aTl$Tl99frMlN#D#HwjZpc(HD-^b)cBqGseys@zp)I@}&uW zV7!REW&byi+$*f+ws%4qsz{N}>!{#mBl?t`M6J;pUNTcP+;o7{LjJf|?|_`f$DDKr z=4bUkN*1p3R@FAmcsYD7T^vpfYjkq%4-VVQV4`7NtF#zMBc}iLvWJ3rVL6D(QkrqXJ-i*>BcPb=g`4e^7 z3zfyYY~XX-G~sT*M(a7mLJEhe*C%#G;@pQy4GMgN!sZ07gAY~MG@7!An=^e!p+65m zLq39CRwuFB<`8}#L5tpxy8=VlS_Fm+&2Gq@c&R2JYDBv6tVCUN=?5o%>)ZeN7r**u zmENIq_)&-#`Spu_$gkK&dxlac9R%WKfz(AefjC(pU#X`6;v7$4nTOj8x-x9#w7?JK zo@2wUyx;l0yaQGG1Vudh6<7z%T_QRe!?sBirwMbJlwP;sz+ilrZk8zF~H9_w$46G3bT>j1^kVEs%0%zA!WsyrL4 z!Y*(oh)4ULpAFXE;wa6yze^ZAI?hV%piczP-x44Y0w?9k+B=!r_y9dPE2ZwW-V6qk zH-pCqg3e!t{UHYkqxDuU_yv|b&7PvdM`t*Y6vY3fw2V4T%a|9Piud{wYzjFXJU2meE2nSHUX;uemIog6nB(>>dU_V_4h@!d4H?XVJb4CJKZlp0%a&7J^8P&_?&?% z!fkr-2jF3M8FisbhY=OU_9lBuQcKtGbH9xkzr}S0LE*abSh%%~E0m1WJ@Tujvo~$^ zBIH-DQQm}N?f4AK;hKi^k$YU}5~8z$x*Qj@l)Eg`5#5`hnw2;dY<=trBTQ0}w9+|a zD!;-=e~KYDQ0|r(zcclwy%4l>iw%BQV=q&E!J=G=yU^{V`w+-$^?0RGfu5ChHh2gT zdY@rLu;U?|Rss;YLYQ44+B<}gc2c%LzA{{mhQs`|4&J%ZN)Xbu290ERkuf5}&5Bs# zjZwa1{zqVL>2b+--3?Fd@)^rBCd+z@(vq3Ix51M}E!LWsog&N|m_Ugu=X{yitFSDs zK?N?GXfmAn=+CSdO#Jgc!WPOTN6tT@?7mEw-J5({!gOEQ<4TCc5vg%PYzVj|b!`Jbd-O{slEm+-q)nS5JQ59pgwtT%@&gX|Sf&ma^ ztS~k{G79C-pt1AeW~>Z^7-!IT8*&x<6kel1*g*(qW+0g=Q#zH2vfYC1hDizt5KEgh=4YzrLrA#TQjmI)+{;5ty8e1tywd@A zY2Pv-Hm;-b9b&{W;sj*J*%f$t-Bo5}XcKlG_l(nPAcsly4-`D&Dew`>qRNWeM59EO zUbqBQ8oFE0SCeAZo<(er>CWlCacS|}qY&wh{gN=7`_))CUR2c4n~CiW=}Czif>G^? z>}BE}4Me~*v%m_oKZJSkY>7QaJZ6mjOf`K6IVWy%$O0$q2+aVKXrfdi) zvuW(v;+6)6r-~!1jvy1HqCv0~qEmdwJEAK&Q=7Q#E&0@t3vj!rPu>jp!Ty;vyN32Gret;gRyY*3UeUGVr9^rwXzi2z&?X9r+Sy2 z3UHFLA36emFgIy|M=W@PAixo8Z5B(YS{OFBn5o|)jfZ(nrk3-(#YDh@T#ZVdh!KY` z0Tq~)QK`&@)ByXp*?=fWr8}#UIfMI)Eb_@lACE=aGQ0W`?QRY_!E9&`+@MuFM1(aA z*Rs1Q1EbVrZ9y5niwn_Gy;af}k8MLyeAftW8|b1DA_X{b;jU*`H@ZY)6;1o%<*a0S zmtrEyIG9v|o><+=;NnhJKyo1K!~~$sI~)PtC>L?;EmQyeH^%+&Xy({zaP!V8EhWSq z;lf^0y2|UA*}zqVXzi$c&(yqz=nlZiwJHwcn{-@K6vQa4CP440-Um6!vf;*@si*;d zP2%?e<=z6vwk?~^ybQS%#vA#MQPWmVqYx=|3 z^Inmf^QSZvf~8@2qGD~Z2s*KAZkQsVUqYPYo;I*teFT*@U%6Tf=|r!CfMCU^w#t_ zYcymU4~gQUWIHK-h=Rq7IM^_0$H&Cj=hy)Pgcug%$Do&P7-A*MtPWDZvnTO18$n+p zcpQu+631N`h>Lez0(I2L$iN>784?p}OvBp^>Y@+M*B+{-eBpdNcYYBzx6G3|olV%0 zGk3l)xmf&ztmP&SF|7%{f-f8yXf5-3iBlS8@o(WER#&oSffAfhOk5}0 zcy0HbzhgcSt^Z7pDXq@-AIlxVljFH}29V4Q&kw`k9`;x15U@XD zc}zz{A}Ii2Dn~%P$;-$f#jDPbr8y$NR5Z-VWr+|e9E-Nv0}Jz-(eOCloo^>M{l+i} zaFmCFj=KTOB3bR(K8yfAB%a*y$jEeijf`@S^)^9a7CAOEuHrA_&D|YD+A||Ce(x0q z^mCa3Wl{-1CxlneE=80cI7(7UX!d_X0R5?qR93y=R~X%4V5^_59gI#EgDr;Vgd>=0 zQ^WQ#4=^cG8PnaX@brceEF_TGAnO6m?S4eDn*diNYk;fMniM5tfWAQOXmf^uw7Ba+ z&Y71KqG1>{R}!xRivt;9ArgK<&6R(Y%zgssC{oF&o1}d^rYXkEz0@TA)CrAj&bUA| z8N($v%$>82U&A}!j<6xZt6`i=Zjhxmw$etJ6D>~^%-bv6l`M<|N&u)%W1ds5ie24d z4+yc=#%n%ZnDNVNQiGJzvRym^?4w>>OpCm*j=JWEM^}YCu#YaE1N$tdLX*G2ch}?z zLN{%k16Ii{w{^HWUrn2WAS320<+Z-Sl=XvZhsZ`;XbN8PNyY!?f545PCT}aWf{$!3 zr>9o2Nj3Zy^Ba$id`!neb-YJ86j~i08}kL-u<*j9B!Nv=i#Ul)m6N@y99x~9^2O2> zPHd_$T~hnVytA1mQ^Qyv4S%ri^|GHFNBAS4n>j}+MDZaee6yVWCTJf$ST2;*DJTClHy9-$lqaees|meHhPk^s53Jo%vB zO}+@M9cil<)AqPjRi2V0Gm5j#mVK47`N#l96am*PY0?_OgnzgUI+c{U-RrHp*Ey$7 zzINxfvr__i-ZH5A3z!}=eZdpMWIWC;sq6*(3+58?Pd|Gd3#AO?yzzdKju2(b8Ixm* zqj>$em-#N!$+9n7)`sNg$#C0SFnvW^ z610Abph%LG-7T7q`4TTTSZH>Q{C|kbvCYBk@jv-iu8Ip`03cwM9=?YYWdL7lxD+C$ zK=jC{jQiM(jvzL~XCaD9v#ldp0$2o}A>hqD!{Z$6GzjTKU>~^!Ld(EIO#W-`gt-GZ zrdTQ9`by!_tpMag(3Q4@~Dxp?2M(4F2kW&~|%V+*gg1i+oLL*`G%i)H@9 z%{7kCqz_>aMWZA1$%gOX7t!hgtkpG~h>8{4py;U_SvU-6@wSDK|rF5$8;bT;GSFS2pYh}efY z{hqX?m|&|Xa2h3;ve9l-Aze(`VO09FJV z23`TdfO0@WkUo$26iFf&D5+7dKU|bRN(7S1=2qnHps;{veW@L12(;Do91*v3G0Jp* z@X?dJMBw5(CWx8*!}0(;LLUHHlr3^@V^Gx&}5)nh9sKg2TgT+cD zPKTwez4&MOBKZ)a3hRk}LE?wFOzoxobGR)2c^-fl#*08=lh;Jj$98XgGABN7a@fEfV)*HQ z#LHq_(%28s zrheoI)=S;{lYX5_7^`qgp4+a59_nVb@e)qCH;rvf@WO6lllz(Gu3~7WZ)V@1gocJr zj!x-KI}_e}z@(RUmkwyOx~Pe5(hYLlr=&?+E)a)JMqCHciEp1@7OlP+xa?Ir=fV%(6INKMc{CEXdqUd z5Y%ELx$wmYScQ>7hlFS-{OhJz6=r|LZMyI6q$Xu&h;Qoi>E7)`^xAiJr`<)+l{m$>*>aJp+)MEy+b{b$4?$}_znfU?DoJ(y_!Lbc~)Ead>}zCG2cifE77R@cBIR4pWaVmGY7t40p)@SPH)nJ zJzK_?7`SBwM!B+myBsF?qKu-9k~O#=vt>6H3&0(h8WLgqUc7lV4H?xj6=K?h=GPvY zS7b3XODBTg5!|WSQ>n|>Bp&E7^^)=TteUeyOk=Gv+i0Pi^ez*Y8el>ND=>F?4jRC@2Y+9fZHoB^Uusoh;#VK#>U1kAt9Tk(3 zjD^g##SggJ0C7l&>CY>?Hm$SYFD7Bna$W-gCM z7si0#;iba(#KfhUnWu_O-hx6<@P5kLi~L7X`x+!dO-OqvA2>SB3f$28S}C?a~XMJ+5L)Ph`C zc~8#o;IlJjS;KWiY1*V9r@sE{YRxLhA@Yi)6>Lq^;^$ijkgjrF)Y*&~e^NO5u%<%0UGsFL4nIURhc`5)bJWPd zSl|&(kzti~Nz6*gj&NRuSF#fjM)}_aHU1R0MAW3K8Ulmmt`FoDV;pqnVoQ)(QcuLJ zpq7M?_Cn7YUB%9J*fXWnPB|oM&U~iGRsXZWAqiLczM&h}|3^Hshx}|gWH0+zh)6Q7 z=MHA6weQtbQ_CvVpOdPWlfXZldF9W9So%^@DYMyaBb4qM`<2ZgGzh1_Jmt9YyF2E_ zb#UX)<@36!KqPoEHhyl@Y)Aq_b~}#5jCjU>yc5564TGe@+s2v&?KTdH8z0657m5X%4|cdHCu;T3FHI3 zJKlG>WrmG#gCXe!iZ&l9FkYyE%gm_l>4Ah=KZ~e~3~ynHg04FYJ$Y<=jCwW>e3l2R5)E2pB%e1H8x{wn<`AqT$&!6 z8ke%D9;*7}?D)*o1XXpmH&vLP7@r)U)#@{7fJc|n^wiX4)0K(B)Wj5QWz*(Hc)4(S zYJ7YgKa`k~>c=N0XQ!vcN|5c_4p^dWliNgayCR}8uH>#`if49g-2Qj4e&3+67bYg~ zLPYO^?R4Pb#V2*J3dX>iP3i58YUgE>uwH5O%v$FR~6y(*{H*!0sunCO5mlP&QdS4-1zOjjs5nf6dJ`uDZPq%Zj z{4x?0egbp6dLJ}Hlh_$t%Ab_SO%0!15zL4C)?g|H1YwA2`GtgzT9Y6vtwE_vg}h&R zb&B`zY{e9StHS{JX@%1~mPrH~@Jp>D%wuvLegLeARhv`S%_g`==+O>?C@bo5BW%I$ zk!BlI&#qw%n?9(LsY)R5n`Nu<4AwIlq|<7QX|qLCC?K8O&FMEXL3RH;suI8F?!O4 zB<{W`Yze@N5Llj-hXG@UWkG;HcDPbQHWBj3;TBSy4QEE0w<*^!%jS(#Z#j`Y|GX(b;PFi#C+!0ECnVFOs14@8#s;f^D`|`X{#7%=N>7 z_9kHlZz~BmdFE2ZNPd@HAm+ zy&Jcy{B*TT;tJH;B&*->iTKPXCt(O9wwoUhDjk1U}}*h7pQ6Ig;9^+lNRka7q)1Nq4g+nji49geXY*>j48 zYvb?`>a}?#(ZnT_0N3uOFp+da|H4hte%)s#X=O4wT3zyXXMiB|Ybtj6qe!X}l=eob*FbtVc2`6~EvaF}t|H z1EVC1zl4LdVrN=X(-eQiXCV)=wEYUtxOGRj;Qf$ie}RX;%ENE)@S8j^GPB6QLJ2}D zGHOD@`B4xFz>9{Ni80VQ&d(hYk+z*^7+*hd(`h zZZJ3W^Lr&3Q9rgn;3{|cFCN}=ct3uFhxZR2Kb$*!_VDS!V}oahP9HurxNq={p)*6L zhTa_thu#`I#W#nBk&5VS{vyMhQjP+91Iyva^&WdxflNqb09l+RjXKrb#w6 zFL(Dyq9bj9l@KffWD(?;TP{IP$t4JK&;9{P4tboCobm?*IVa!uRrO;e>Oq?dut8{b zO+TuBRbBOae!pr@{O5tWfBO6He%3bbZvwwRz>)q0MaP&`l%DC-OxH83II-t7vsyEq z3DcQ0-3im1G^!3Sl9vbt&n9o)Zu)`aT{T#=iqgeJ^S4QnyQ5N?%J{)vL-Kgcqyu-hHnN&Viao5iR-mu~mg2Hn)p2L4*)kJ6~)ccNj`?|`m9 z=vzu_>Zy4v+HU#7WN&5-}{92gd=3w3Lgju*W7ymc@ zR^l3l-0)Cj`%tt8{Vckhl{Y@=$LoWn7qdo#b+b7X$XC6*-*)a21 zhSeFa3RaX$f|Z}sotLAB)pV0XkouAbC%NoNeovX*rOFDnmbWV#1C}Wn+@SL;s zn(4f(`>&heM&mWk{~`qH=4K2g6yN%rQFkZ?>_lPLkJi`Yb_{mvZu@uQY}3CTCfi`f z^WCVwflDoax->WE7w&f6m;4+^S)+MKz5}ROnBb4Z{le=JKaM(^Dxa_Y`fR zC2iY3)6;#Gwjc$wXP^E;PoZZE%F`yBG57Efz1K{JPA84$ZlTKwnu5wL^fUQLEfAl| zZ=ol?>V6K0Cv|ggQYQz@J-k2e95k`f$vnrzJJ7H8Uqqw$@N5d!y*aWWy>abvO)@+IzB_%tEyUS@BB~($@l4Ou% zY}|ZYZCC=+}jFP zd!@E{dz2(itH85Eq^BFDg3Au@ri3wVcH$)4mhYlbrcrlYen9>r9OvUN8*t_W97Z$r z$dC@>{^PGE|EQF6EuPiV@+i1?ZW36tqN^M1M@4!GXBRdHz39SdZPd?37uH5`w{yYj z=nK8OSg{NC&GE#!U~VdwhT8#W_biU|S5TOVQ?(=BtT$U*@LsY1)sA}luZFvNVt?E8 zeR$CQk(6^Zka9Almou#ubljPg#N%X&CpY9YLEw$K4%aH_!;p@$xEp6M9?ter-z1Tg z4E1tRb%K*P4Y+*4DHboGkfRlx;TeuSnJ39hQv+(U<7=Rs&XBsB z7)=N{n};Uk*KvZ`*fJ(LsuPcz?*?y8=z20&X`Z=XGxw$WZj{sWHrg*2ieDb2OskZ4 z5`8s_6YJMzrFVi!qT~kOWuf_agJ)#FKHjrQ615+JqMezES#Pd( z+IxIvE~CnBW(EcUb2mkIw=zE$KxL=-VKgxqWpKMwIuPTtBOxtjN6Nw0wQ^R@lpaxC zH65*T>VntNb3mclyO|@UPjEawbGe|s$ERFUJ9Ftg4Jyqg_z5<30#e1QX&c1{u+3tV zbTb@Ax%){IA6Wme@Ku}(T3nYC+G>#H7JZR8he#@n6Hp=%{xzt;EuYaKh}k40Qd+j( z7a)fL)zt6cNZG%cnYZk<+eunN15T1us(r;5GE1B(y2aKlM{v*`MSgoQ>LMVG;FDX| z(~rrfQ66D*2Jq_Wvh^bfxQ3W6786gB@Q=rvAUt7%wZRBMv;6NQgx!WG+lUzz<9QqL zwe$Dk=`wzWt6PaNNY{}GZ-?+E*SgWUXRJ4khqX#CZoz3K-~+h@-~SC4d>s_9;Lr%t z3@>oWcN7fTC#Pp#+)OQ^lRL#s<)T3YJ%K~|gN4()M|9$u^rXpZxgyC8Q!_uYS{Jqn zz6l!gKJB+O&EJ5S=9JD9ahx)kaRD!R>RB46)d-Uf=XwG38Qpc~=WIwv68u}10uoXa zD>x357-Aw!rxy-~l2=t<#8;%gj#CINeUQtb06Qi+5 zDAj5nBQLPUi~=Q`pay6+IOOy&*bae0ya7C-SQ^Bu#PvpZfhZt+Ef%KGZy%3?BwYIf zCpll#DtQLMF~at&c)|u(a~9bq!;l&1PFn5TXgQu}o%0>zlv=K2zr7_(7_-@nF&SXf zOVHDU?%05@!vG){TD0>7Wi1oI7zb+H&l~CmQj3hzUX<*b01N7NVN=GaI+zHrzRsrr zZBiH+Hww$H<*UOYv?B$3PaUB8QC%3d^W_Nq=r({USf#L=07KZW3=Zf?F|2EML&Clm z0IRNTQRCfr-}T?{m-~KXt8xVq1_+jM5M0Sj4AlfJo@< za1i&ihVC|pNrV7x(eF2-;h?=K7^%dzNFMc;2zS_-tEjDacJA|?lAB6G@{!7EOQcxJ z_V%(+f{IH8g6G_$3gud@q#r`mN}1Vd`XzB?lXnztFA2G^uOd9rhnx-&Q7owk2xt3N z7i-$JVE8P0DTbaaK}B8X`D6ygoiL%tkt%5BMX5N;Sqm6E2UM#zP&va(MmoVo7T;s> zYbf?kjeZ~0{s;%`9R3!jz0=+SZ?bmCt9$$Kd(zv7IA!B^kNw35|NiyE-)d8n^Pp=Z z?TYx#>DA;IQJFe^7^KmR!g`*g^@OekMC6=tt4C!hsdw=dP!6;&45;)gXix_O_@L$l zpk!63(lj?3oXPYjFxuA(?5(OjE22~`dLcVID9Q?}L9R=pC0toD-;O$O3~#mm>XLu$ zE+V5blbbHV(S`rG8RCdHHUaVulX#FIrZ0v^dCxWhAGRu^a-cJxmGbfpMuNsb5P-g0 z!?Gvid6wEsba}X;jy-6AnpKu2(3&>S04%eP0ND4&wb?_D%>?q1) z>40kDk#I2}LLyH}3PGyDVm#6cPo#%dRzG}O?2G;XHI9_{h&g)5oAegElTr-uv*fXHXR81l=Wpu*CNi^TC$9V@DqWcO(QW>KhG6JHVIHikCmBIUDVRr0gnl9&VB6x zTC_7Tpk60W&ld0GBj2k|ozF!C51AEsVVhy*-x>E((oH!#dM|ZcQU~6KL{#dh1KScD zx9xWagIj>bZXsKuxkU1#zk%E^**5=+?ndpAavR0%XT>3ThQ)b@h!o2t)T7S|qU;a8 zTWg0!L4q4oZ>NFDrHt1acB}O{V!^i;8H-;05>n@XgcnQh#0=2maZk3=MvZ{7-?lB^ z=IO;R7k%EpSl-v^#Vh9v(?^9%T>4YO^mllz3ce_jF2NuA{EUz+ z-;}{oK!kel?D=nk+CSj10ML<>UK?SCBHThyi3~ndP;ESebd*;@t{aGJZxAK0p#oH1 zVN5>iBg2dIhul${cze?YFOVpk5gSJBcs+&-Rr92(G#t8T?(u zx6msBbSRQJk9{txViRyoc>$go;{;nT=OERq@K90JPRS+zkyI)rVA+YFpTb}u3t7${ z6`d%$VV3+=1*&&oBzlZ$n3lM!O^;~ql{Hc~d&gLp!8QL;FxLK)^w8kd!h?yF(|XhH(vn8SAMdOr&A_*TaVxzB3Hs&g8q{golFwZ2J69b^x!|6+w~#vo)jBIPJ343X?>G1B27v=op!e{!uIL|i zV4LJ}Zeg#5z??1Q6yY@9-V_^xzZ(XTdpV3640frRCQ z$q5hj(5wl^ScBmC0vq%4xVw3o7qFREJ)WF!*;5K0ri5?aps$_&0mSzT4Fx<43SPLw zmZ4r00!sjI*d+Z>%FQME{xgffyew?bPu#6iR9ewx-~FdM9MCUiq5a(iPaCaN2pkHF45SF#S7W*+u_s*y-R04bJvQ-R{}3*JpyD;5?|W3DK;WGjj+VMD*4bc!;%koUD@bV zylVOibkeeevLxzo%?OZCyO3t%+zJ3!=&dnU;EO%N4CG)gYdRVZu~h;YYr+cw4imh} z!e{Xs3Z)#Yj4AYItOEfpIZHH@`j+urB8@j$D7VQTv}M6zGB!Vcu+m2%d)f~1s61Je zhEb8S#sNL6;5dsDERL~QKq07yvtgq+C7hXnw@G;-FAa7oj9i(!fH72XmBlp{KVY%U z;tGor2zs5DZ?O0d3R!-@Z`e1wclQ2$5dJS5DJRdoJnJn?PJ0KC&PNeSknfitrBx+9LO80`I_@c-+Aw+<{kMEjSQ*Yr%oj!(PLP?=xF4_kjFull^oF z*aRooi@T3|Qk;p`?kbr6Qx7IaqXZ60*aE|Zc>Ui_vu*ep)EO-7Mm+$oLDYpMJq zj4On>gNVRlZK61qzk4)j;_?(zM@XCTgpflxi}9sCVz{o!@U34xCR4+L5U>V{ebZJ~5fN3O_&38Ro5x1cht zPlH<}T=o_E`#6b!Aj1`a1i$stc%ZY-r;V!OUrT&;4srdvRKw5C{ZZT27%GXaw6Kcjfx zE*{X39yU_iKykPJS~;CNE@YhAAmrIf%|^}CX~D#2q5?9JWViP z5#m&RB?VAB^;|GFq!r8j0AE_j=xc$gPp=@qPw^aEVXN)%y&6Nz^eGT62e#%YRFM8JaC`5j z;Y(2aLrw!d6@yHDk&-J!U>Ogb2%Of5XA*l-GjNNcl@Wq5wu0M}yGRV9c9@jk2dPhn z;92N0h@tyXi%`iRv6$ij4CriJ#RSr6qH!$16)Q4Pr_fl=dP2N!$F%c~hw$xTz69It zZd67i8NG12PazV0+4eBP7YXfbMOy%aUN)#|&Il5JC=)Ls=irm43E zhC4c$L3J!sk={#qvUk$k0<}LW6f0Axx$G8-l|Jn=6)P*`v8y#QZqviR=0aRmj4ngM zWhu}6;)?Pt9=AT#3zO*(ef?!A)35F2wg{r+*BC#`Inw^^Jb?sf zDY8%SEs2kh+mpD>{}DrOy&dk_4FQ5Uzvo+7Zg$pm-CXiHs1Ma=VLue%rLg9?m2vxd ziVC#9H+M!@A4683;sG;!v`&^AP8kwtpKN8If*x~$pjSGE|$zQ<% zbOx;sf8xht+~ju?&a2B5yVc6S=ddU(_WltaQd)ASk6ZLP+>7|s>vat7?}}HPz?VNX zQ;cAE=oM)fA9=;Q-_ad?68C@r=~1B`xK<=AIP_KXX^CR6n*svu4FDg(R{;wH62%_$ zd|N=36(r%p#y>(^eS+*F1Ehr2zQPO~EA9q2@w<1+AA#CGa&F{+*XDrHH_+aGL9I_B@#Qz6k)}&_Q z4j42B3{Qd0{I0G%^VDXJ+@!&*H!R^9Sem z&mNqA6CdxKt}o0VoMsa7a{bJ-H{sP^o;_SYSf8rT)EDXtv-SCDn=Xw(DFBLRg5N-q z;c)rlTReM<1?NP1W(~VJ-P!~c4}D#)+=C_z>?g2lx>on= np.prod(arg_shape): - # nice, we can directly re-use this data blob - assert arg_arr.dtype == arg_type - arg_arr = arg_arr.reshape(arg_shape) - else: - logger.warning(('bucketing: data "%s" has a shape %s' % (name, arg_shape)) + - (', which is larger than already allocated ') + - ('shape %s' % (arg_arr.shape,)) + - ('. Need to re-allocate. Consider putting ') + - ('default_bucket_key to') + - (' be the bucket taking the largest input for better ') + - ('memory sharing.')) - arg_arr = nd.zeros(arg_shape, context, dtype=arg_type) - - # replace existing shared array because the new one is bigger - shared_data_arrays[name] = arg_arr - else: - arg_arr = nd.zeros(arg_shape, context, dtype=arg_type) - shared_data_arrays[name] = arg_arr - - return arg_arr - - # create or borrow arguments and gradients - for j in range(len(self.arg_names)): - name = self.arg_names[j] - if name in self.param_names: # model parameters - if shared_exec is None: - arg_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j]) - if self.grad_req[name] != 'null': - grad_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j]) - grad_arrays[name] = grad_arr - else: - arg_arr = shared_exec.arg_dict[name] - assert arg_arr.shape == arg_shapes[j] - assert arg_arr.dtype == arg_types[j] - if self.grad_req[name] != 'null': - grad_arrays[name] = shared_exec.grad_dict[name] - else: # data, label, or states - arg_arr = _get_or_reshape(name, shared_data_arrays, arg_shapes[j], arg_types[j], - context, self.logger) - - # data might also need grad if inputs_need_grad is True - if self.grad_req[name] != 'null': - grad_arrays[name] = _get_or_reshape('grad of ' + name, shared_data_arrays, - arg_shapes[j], arg_types[j], context, - self.logger) - - arg_arrays.append(arg_arr) - - # create or borrow aux variables - if shared_exec is None: - aux_arrays = [nd.zeros(s, context, dtype=t) for s, t in zip(aux_shapes, aux_types)] - else: - for j, arr in enumerate(shared_exec.aux_arrays): - assert aux_shapes[j] == arr.shape - assert aux_types[j] == arr.dtype - aux_arrays = shared_exec.aux_arrays[:] - - executor = self.symbol.bind(ctx=context, args=arg_arrays, - args_grad=grad_arrays, aux_states=aux_arrays, - grad_req=self.grad_req, shared_exec=shared_exec) - # Get the total bytes allocated for this executor + executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req, + type_dict=input_types, param_names=self.param_names, + shared_exec=shared_exec, + shared_data_arrays=shared_data_arrays, **input_shapes) self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1]) return executor diff --git a/python/mxnet/module/executor_group.pyc b/python/mxnet/module/executor_group.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17e1ac998aab55aa9ecf89665c56b8a59333fe18 GIT binary patch literal 22701 zcmd^nZEPLadFGjWDT<^hilRh`vSn+eRQ6ROQMKzft>ZYBO*?fPGww)@64~tK@ZKSL zsmpsW@644*S8R&PNW4vVH^pw!0%`xGUntOa7uZD#wCIlli(M?RKZ5L_j#W8%*-XlBzDYn7bqnjo|!r4Jzwwl`_5GV_V~mz-@g2XhVwsF{Cx?(=^vnI zI(G+U;F=ZJ4%{77@W1NX6}MS+cd9P%&qmyxBd$5(nnzrF)NLMhcgFAx*Q2@~cXuZ2 z`lxFkbDNXy&ZJ$Bx%P3lIpyw5;dI8ff@^zn6g9xklI^U#(P_s0 ze!OEh#?WFR&fKgr7s6>SKTiDkGS0imImfZ?(c6!A7-96@Cz0fJ`4jHJ`l-|iq=tWB~Z@4nOOWfS<@N?V;Yi#v@3 zo8uhXa#o8AH;9Kgb|PtSC+iA2Cn)IIpw~`ROV{I`uVryRQ?oSd_9CG_N6;MWoLG%W zsuoS4wFW|;wIoY5mCbmg+gC$DZr0O&BO6DTX5BZ7xH3(jkGPB`Cd4K-NUOQ+yW)N{ z&Bk0!vs-WUySCnD9gpfj?iuA`(yT0DtOs)@B2AynhePz&UH?L#e>)p=v&FxN- zUEl0B2kqo~PUdwH{yds+QS~-Xu)dabl78G4JNYP~S1tuJ!E`VioDWV{P6d;d@!(Qr zG?)yYsLTY92dAo!qdr}ku8M@Y9-iSV_)X8EF#6p9ndT%+u6eV@gG!%3A-v&U8m?#) zfJi*Lbbe9oq6=sjokme3d~MH4oHf>?Sw1+!2lzdVjo@M1MsP~c;40-zUFG{&-IM62 z)wSc5r2RdXr$wou))FE4Sros2{^+zCJFKq{1E0>HRBI)|}g|xb1KHW5gMP ze=aQjs<@?iFxaAJgJ1Wo7JL`1iOlL2nf3ZO|~-bg@`A=m{ZOd4kEaR$zG zFG<60C)^ygvleKdVv(ZDPNv&Qu7+vX+Ggv?PT1~t?@omHX}9hs>`rNk{%DzZo!4`* zu9lbu@w3_56kCi))S~b*FSBfkBrU+ z**V;$#D6me8)+x)Tp#%~)sY7K@P~7K8NccOLXkl40n(tXxaO=&fETmG7?u_H?h&b1 zRp~&Z)U)12%cE)uS$o92TXA>BsciB7glisE<(O-ZsWK@z(5|4u9nZC^DVIRgLK_SL z?F#DLwC)^NGpJm+kN;4!0DMM|PUz7o-Jez+s@F*kftE9>oN>)ls+@Jr)2ht5=w^)^ zD!mSLxN(54wGTUo)UjrA52{gWNOo%-s+_dB^`zZP`eCCTr)fAyAw?l$Tbsb>O=@`J zr9Hpz-e?hHv!U`v>LnyO=?!yjJUThZ4{_7qjmBHfTfhNWxNE=?pek zur@mi(@u7@3k6VYF?*zGYps(si^<@2v$eVk-55LX=QAIEh8jUXG^k)N%@)W!3D=YO zUJLhgpf+!wS6wQ|BTK!x{^(R|9UKHW6?K*?p_G{Ps`ikl=HXy-2zEYP>|$I#=u&|P z{uDC@!gTvknlc@HT8tzp7y@e@>L?D%N6Vz`*MSJjVXxo4*J>JgjV9O=e9p`KqW4wSs&s7C*e<4{xt09nnj(t0njcR}0`=)Mm2T?eC~3NVrMEG8IQrPtr-d-6hI5ZL%`-(t0OJ zpkSiI`(M2WATw z%#%wdz;4BYbx&;R+XDa61>VzwU z#@!^t;6wXg@~c)hbYyYCCW|r5d~bBI>a^Fz*%S3>b(NgS2Eg*MBw++tP?N~&x1;R5 zuV*ocXYnfo_Q=+C9x2&Gn#kUJ1ic5Hei!@F45P(GA_w}E^MOLNmaV5(3_U0+Ftd#3 z*N9VaH#jGDax9~yBJZ`mw;t{88X4^K;b&ASX6d(^a&nn%uocUmURv8YxFJ$DslpQm z=@tNZnkMJC5g2uq{leye)E(voHy^%cD48ZhbbVFWU!ZLPGhaX>ZZ`&PQiWJ0(i}reQY6@vVnW(}OAP{Izy> z1xworYnap(!s8?oDkYGW11T;wiutR+l1Ao1$tcoP8dFfAPT`l<^D9)RJ_TJLt{OOp z>hcf4#C(VfE@5y4@(0-^{1`AjaUWI%9-)F>uscwJvj_iC8Kdun75H`FZb1clr1g@^ zhIim4ac%?Jq;q%2T>qD4w&;Eb107edDYco(+x$`9X3DY6QJ4N1HCi}aa7n9&KBcGK z?v%?;xCbbw<)8YWtarwFiOR1y89@Jz9xpn4F#NddKNqn7kpuf*2oC5!62Q0`8ukQ- zMZ1lY0{@9%_k?CK>9QHk;;#ePHC#B{UtbBhcA9>Dw~JZfKdv#8QP=;2z+X=4@;3sk zu!H@z{W`_QFKqSI=!9DQ7XiZ>{zHI9Xo?O5S2W#!eYM7Ux<9xYuTJB#}%H{6*E(b0583Z9)HXief3p!Z%<1~CQc`ORvJ&yz@ ztMm+>XZQMf*Euho;weqK-IEUMq2)YRwa4hg$`kC&%9A;AUdZo+dF5hWxs+Evl2?8x zuUyV6Pvw=T^GYqRT*)gR%_|?vD_8T%H7y=S|9nor*Zsscu45&a-R=wtVdIB=y9YDw z+{QDs`>-4M&n5n)miTOr201W))<2@2&ym_dF!YVmb?@LVpcT4^ZGs^^f`S|_x{gZv zEQ*5DZVuqm1XHBlZ{ot(?p4#i#U07C=H&sb*W2+9EcSZ6_M-$$V}D5OWK*~%qpRn9 zFO~BTSEgbDzw-uz?N+-L_jl@WZf1#ZH`w;=IC+!G*%+$1F{F#!!pP_=1*a`kMeHQ} zrS!WGsblg*m`Ke|;&~me(zsox##mq3;mEuzNjNM)wP!S{v_{R*rWVnM6jf2!wo;<< z3anqV%Xc>7UPN>?Mwi)`mp9Th=~MT5wcm$*!eimynlFV*9`s;qB--lF#qB|2?dXi~ z${9@&Y%cwWY9)KOp0?gev`%w3OR{i*sYb1PX{I%&TF)>T?xNW>9xiX#p$^@!9q>1J z>{{p0Ok$j!zvPBm_+z*cJ84`q<8*q%_LSD61~Mx&PkV0HoR%IN3a0WbS%dtJmO}LN zP9=?Db&pp(1^Bb;wr_IA^AzBx>T9oj;~XI83`4hoH;!xC723)gMYawxP1>u`$8iRsKaM>lZeu>4=7 zt2+F0h*j0UhkFEFOI{kwFoQ(aOtZDHwaboAoC1=|oGgT@7M zMd$98+j>J<4LXdva}1n8%;E8<^pijLsuk`~Apa$&0r3YRX3fr=Wt>VJ8Bw`2XHB#L zq}Rws9ax1mWeKKjy4wVYy^Z<#(9?g%nimjF@_q=T?ozq*T#;HXEFNze(U#0P;@0B|LT z5w2m}s;SerFmn3ea057xpRu{%GGbSg!DwXy^fptu3@SQbX^?Gjs&t==YJLZFMCt&R z0hN}1QurQ_OmG)i%LmdM7nYvVGmCrRE>ylMqPruCog%;iL?$UcrP9oZBRS|%fWG2( zgZ!FW+9|2@MR)ccsBgkSX`NsYxVS-ZP#;I6@$1fE%7gY|=bOy)oMu>`Yzqg(>HJ0DS=C9`f3w1aKZ;fhJQ}~aW3p15p6TotmcSit zac1}&a~m^kjolH<3l&0kRP|`2E^jSWT2>%(@B*$s2LlsE2qnqr(%YNKh zON`Ml`WQthhFAxsca)%zG;gW2(7d$KNL+<7|FAd+nIXbj+E2ffzs0R|9)+t`gU?s0 z!B0SnS1WHp&R2tT!6zVx`S(R!hpbm_=}CjjBWK~SI16kwin!E_Q|Bg8!-{a*)0)|g z5Q0J_Vh^wkF)$2C=|fn{JkKxF_z9Yzwk74W|R z`iHuYHV~)Q1|e3oIqEv2#7Gjyi0l7vHG}U1=_~Ykz?42t1s#M0dZD!br%q-(R(J&V z{kqy?M40I)K_&pMwjkDrJ{e~nf?NEfep2IOy)!u#oCFcVC4pXGA*aNfJ|@CGJv8|& zX9zQow*G=z&q^i13_xoRR$f6ie}O}JneOh=4{m|3RzIZQ6y${X@V=-Q1^S293WgQ< z{Tjotl8a{x-_xIrsWR|`wtGDO@Qz3|=-#gndbD+*X)M;5e_>QEnJyxh_+ICO&$IXf z3kJ-MB&nNQGKl42SR5z%3~wJr;fI$-f2%)8qF34VDL$szP?zElSFe=+A^l!tO|p${ zv)Dp0MD(Q{u)?QMVO<17x0wv-TlJPk8Q#!(>yQ{hw2T4gK(OgYQE?Mzg7NAMbotrJ z`CtxO#N(Aw==IYod#xEwUdl&J|7S~z#p&>Z>%HRVel+2NtIOHk8IQVEC> zx8O^>nNtD%f$-f(v7ku#1gQxQP3(q|GXaa1p_e_Z1eQy~DWp3nB99#fjvuCWJG|}t z@pw^(W?2vAX%cNrRg|b%=H&o|(lmxtMTr<5_+gzTG3X_pV34qUT);OK1Gnm+49sI&@K2#;2WZ?LF1p}F8h=>gl6+F1|liF9+*LCPtXbRVL+q5 zd8NX8alaL>AdlveF?`w9)chlCUYjbN3Z$0E^# zc;F_c4KYZ-1!DYXTx>Uove!heN`87H*M>)!+yl&bS5kffG7!Q(`4rQ(k(1Ke1&LB z-WrA+-V@-^?+s%q%#-XfATh9MaT8W_RE|7iTEqj0O2dV1 zOpv|9+RwANfkLLv3UVOs%70?19*49hF)})6jb*;YMbWG4rY;2`G!mADRzjOZlJ3i# z6%Y=*2dvCWpc*0}-Zp?5GOuUIs~=K0ye(2^PjIk)or<7`V)nzBpg??kK{^f{rVykS zMkP)M!4`7qONIPTG4Gpv4Io^WiYH$cSk17bE_c$20WmiXp^W=+zsVdm#g6wTsgZ%a zKYmn>2!6BxF8DIji3=V;K%Yg1U7jbyY=rL0hD`1}&yK#lw6)o4$4IN99li<%QPO)W zwU`Uw@6pr=`WAG9b@VA5fGaWbKRWQw!9ZcE7-8qRJNFZ?7Z@gcGZ4Za+;>C65$tgC zgL2&vpkg);RgrrPm4wUm!p{%KgO-cL1LemQI}jz=px@cYBd=X6S*Uyl`69jh@FH+5 zdV_^*3)vg;2rTm6B^DoHA)`ae0c{F@ru%e)@`NIy?qO;I(R%>cBiT%$;-=0(k)MQG zf35;e0Jmkp7liZsxIvPEqY>U7P>%QE(J=;t6+%`y&eYb=Qhi*54L#zvd*mkU2lpPV zaKNxjS=$RsZ@RZ(bRY4roRFlH=}c1{nXRx2@LtJ#>JOzrH(vmS(bcFoPrw7L1K((& z9`Q3Utw(Ylm5Jw{h7{$&dsj2WEw6mg$6mHDpRZK&E9hY^TJn60)7)Y0=U8x?%939*{qN5K|oHUUV%iD_e~#XG$2^VMqGvHHtkZ+_ln8 zE{}Xmj4q#G!NbTCi`*;o@FDd|>TdB0KrTSS8{x7>^_cIVE4^W?csDArDdSOXMqHLVk6rXQyyxs4hZ4jvtG}z`s~ux zkoTW0@2mRkl3?EOCc5od0W`53db~3Ukz1J4<(Vv+i$jvF4q`ReS;<1L@g8}mwTn^9 z7y>_F2O{Z?R~y0bms1vP)8G^v_Wo z7WL7*EC>F#Bt8^T z)waHZ_kDi>wi^{p&Nz*+oSPtjvnsg9w!h%j9pv~HiE_+Z3!n1M8kSmIs3Su{IKj~F$^#|Ec>9grG=yd~cJ0|d5 ze20_B6B^5-us8h$phmnjuf1sG0xuN_zKohZCAFBgZ;6TQq9sM=*_smMFzJY1TkTQ$ zA4x~k0Pl`XA}qsT%p7$2dEB1{OuMf%Fyl@D5QUsPc4r!SJ3T6kq?a32k^4*q8|GfNu;X&qt%cL()vM& zJ{~!E0sp*T^;2EEo>|nTfC|dclbmT#t^Gb#P@_PmS6HT&R8FQ>$m#L}l4dmdDb321 z)%1Um6vuGiojd_e@|2O_IIb!Bq)CpWi6_C?JyM5~?x@V`kIG*X&;s86uXLGEqJcWi zwJ7@-0&HRFPZe~rTtqldWOKvP-u|i=f$W1Xb}*ObftuQ3g^NfD4M&f$&n(a<|?XaboPufok(Ab$+Qm{yGO zqTDNx$1Uc=m-8VYc&QX2gPk1!XV7UUX(=t24{!_avClp*PlHvZ_xu>W^l=iJocrW8 zOuQaLEOGhfBIVwvJJ32p2lNUTwZSg+ z5x#eb_HrTUl}9|Vttd8=y)j-f7Oj5TfZVtkUjo3J*2T^ios=ZhPp%23-f+Vy6sh*6 zk>YshxX#a8BG&RC3;Rzlv0}>>YxHATCY1!Rj7|2lXoZZ;pW>x_bF}gznwXKg(bJ`T zjEDj_Vm^mr7IN3-lcvHyAresHh|-??+tM1?pl2l z?TfR`2@8>xhqi>5(y0tK^~L6>d3CuA(A z2yl=i{9PHbrFRx=<2Muw$rS;yjraDl4a?P76TO2~fSgBs$QC&n=AuQK)VJu&gBy=* z1<-IG%mq%`g1;a*5>$=vfY*SBcoqYYd<2#PZ~&np{Pk_*_4{Rj6M>7!$BR!6mG_ln z3Et!%Ddmd{2@DRSFlhmSWpFmHum3&v=0^UWAAMS&BLpm{4CA9Ydt)xk2{bu;^a1ZP zHxZjhtJ5P|;&$!db}Q6U0f_1}V)OH+{)-q|@R` zbCv1JNmwI~2TB9wK1|_PYS0O67nlORGs1uZzT$?L9zn5S6W~&aDKTOIO#}Rc7Zw=+ z8{q}=Ji!@^{-K+|<_8cs2*8)1p_u(I6Q8UDY#e@qcQF{49$5{zLDAzk$sEuFY>F*l zyjaRT*5Of+!ULU>>N9ApvxhUpzz73^*H39-If7(?c{p z?(i`?z{Z#C&fx;L@TC>zyL`NY^m4Sjsjq0I?_h^s&r_HO8PbX}d{iJ$ya4l%V3noC zFvX)Ta>@ImgCM*EvP4)GA45?xVrbvLq7!|8gtkE}#(65~hYEB`Qx|lCB`wXCNf|f@ z9Mr!+e2*a-H1W9#Se6C*P3Qh1FB^P{-$Ot?<*vQJ+%Iz47hfXgakk@S`US=D|MwGr zxtzFVd%S{KTNDrp6_#?&3=HuOpARc$@RPWi@{A^0{|6r<)c2JZYow~7oI)u$Y-yRO zud()V6b7xt1AA|b<3w~IDjEr#&^(&S%3uu+hQ1hxK88~=E>n(fv7^!yS6KU97Mxnd zp~bPxGP0PX2Dckb8cC&m<=A^j=n*kY`5gC_(Ox%l0Zx$Pc!z8hf?k-s2M606W2E;~Cfhfr1a9gw^G9!}$fTB7prjm>$CmBz;Rn z&g0k=-v=;+LQ20+PojdOg(X=woYb94n{VQ~oMw)2sKdg|6tI4pxXDGh34S9C??oav z>AHCdID7uSg5Q*G#qB4BM0{Td5p+pH3P+#ZuqOdGMDy4AxkuQ(_^?!qpE~tgYEEK! zoehIq>bn#KeIByiwj8<`k3|3It-t7YNDG(I45t zE1j#3*8DUhE=i;Hdb8W8*Tv^d!?ACfM5}lht+UudA%68*+)k5dnfG8cxCq4PqA#&{ zkHtS@!KCZxt1Nz%#lK`h79d+I`fb+8P9mOL^lwm<`=Q#0(xRey0F;;_tkI7#kACL! zci}C8Klxa zwSK(L1R2ls?)10;o-r%2P*^}+OwU_gdncWko})@#pTSM~_A;Y<-4_YtBP{4zlcH+5 lUs}jv?u_u@1v~i{Ne?o4N5RFx!GSIMR~^4JegyyK{x2drOEUlf literal 0 HcmV?d00001 diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index fef5c507d7e8..a0eb19dafccc 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -454,8 +454,12 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', if self._params_dirty: self._sync_params_from_devices() + name2idx = {} + for idx, name in enumerate(self._exec_group.param_names): + name2idx[name] = idx + (kvstore, update_on_kvstore) = \ - _create_kvstore(kvstore, len(self._context), self._arg_params) + _create_kvstore(kvstore, len(self._context), self._arg_params, name2idx=name2idx) batch_size = self._exec_group.batch_size if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type: @@ -558,7 +562,7 @@ def backward(self, out_grads=None): assert self.binded and self.params_initialized self._exec_group.backward(out_grads=out_grads) - def update(self): + def update(self, storage_type_dict=None): """Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. @@ -572,7 +576,9 @@ def update(self): if self._update_on_kvstore: _update_params_on_kvstore(self._exec_group.param_arrays, self._exec_group.grad_arrays, - self._kvstore) + self._kvstore, + stype_dict=storage_type_dict, + param_names=self._param_names) else: _update_params(self._exec_group.param_arrays, self._exec_group.grad_arrays, diff --git a/python/mxnet/module/module.pyc b/python/mxnet/module/module.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a997f2c431f8982710f53144140cd59c46753c3 GIT binary patch literal 26694 zcmeHwTWlRinqHkli6SXd)cs=1a!c0b$SYH_c6KM@$ezfO?N!#MJuQ2@=2&}pcutc% z)Va{zeI(LrJ;c_|PJ-+P$xe2HV1Q(khYWJVU~<9aVdo*p20;P@335d)4*~KRAPIs1 zdCKPdzQ6i%C`y(@4su(`r&L{CT~%H6*Z+V2rP?F^ZEE7b|I;%cHJ$%6f`9)Ie#yT> z(sFJNX~ngwu2XSy$XLGW=BlnW;#xagXT)`PxVfFK!t0&79(8lO?0Qt!V{UF7*SlP6 z%ylMQcek6{ZTH7rXVP`2+}xDw?s0Q_TzA^dO}p-1H@DYnOsK{_H@DBOce~Dh*FE6o z4!G_?H+Rt5ophZ;u6x+c9d_L#ZtjTd9(8j^ac{~+hutDZy~o{m?hEHWnmfkVxmla$ zTi!u(HJE>=-x_qH`Jmlh?L^(Emj>x_6wJTgNTTf4)%R`$3+ut%xUst0?kxq4VE*QM zccI^z4;sBz(CdwB2og8pUlsWaY<@yN*uY^L=F9 zzZLfftF;}toAPDr%{XeLQT_Iv1Y?WR11xLz+G)FiVHM?jSUyw{h!Gto`QP;di1UXPMyjii=NAStFf-H%Ia z0V7)-r1f5-8zt#p7Go?)ROf8gTX;SlMWL-u6FL21`>*4d{0OrXfqJpG{)F<+D*wDXJ+6G9{Bz2mRQ`nWr&J%EKCiL8;97y&KkZs4mBA{UQs%5{y`T&# zpH`;k!ke`hxpsdUgkkF%w3DC_EHsjKGw5oIVqZ2=Z9?40HsDNfBW2aaL9fZpg^jkZ z4V-{@CIb9vB(|a7jH1??{rGMpZq2jBsBPBf4 zC%Swo|D(%9Tp)NQFm|9a5vfQQ<(+nt2K~h#Nn_pQfUZZ2jX@{Hu=>G#?Huc$J8$fw z5v1#@?PjB60}KX9)C$@?@TSI66f`@HBxx_=zOfND)@dw69os37*Im+IOuLP>x*mM2 ze)EBE+JAp(33K;K@U}8}-JzH4;2!wq{ETjUMu%sv8OsCn%qPkhZQjm&cWtJ*I;fqW z_f6gH$G7X^JeufYzI%i2LiE6t2lMl@VAk{V7ci+U%_F$^z!+4(#@LpPaJZ9(nQS$WLV=j&;$WZ}e1-Dj4$uSS! z8)a=4>AHOHcGV@{cImkL9DSl8G&kWs(^%@e-DemJ+8uZC4_!LROElLTaZAqKtGdqs zZtU(q*4>fPU5tAyyZfJXcdT?5BiWVRov5(oU8TFb9OjKa5 zZS30FtYR`)jpC5#M>wS}Y*eocqO_AX_DiGJj6kKpY3Z79jptNW-}5wW2(lQ}>;O!% z$*V~c#bnuU#Bo11JT}a^)L`ADiP|7L3%PxC^?p2v^cHV6I@pG6lNIWXcxgbynt-w} z$pUO1H3>|WUey;L${!u9wL9%byzck3-n8Brtoe$ZDZ^Wa!qh=Y*jSI6RsRHQXjJ_{ z3iOtUC)Dkt+$aW}K}&VQK0VYV7xDRes~xB7HmV%=+FWNCRYJo_tjupp-b64X? zG6L;nZ%gh}Uy1=)CNcX7GjjV=xr z2EBA}abeKzv@Wi$r_24`#qJtt^`fC3|If_odUzNQy@X%#HWD{-th&E4SsBOgSmk78 zygFVzQPJP(fr|cCCo6j@HPo2KkAJ7D`qxq`Q@+U^M5Mk0zhA^J0aB+}0FwaM0qg&G z#NFFLz7J*&Sz#gqjOW)~|BgXm~sPnXZs%8a-Oz~-`QS?KRkA? zLtyL4U_86r49aUL4}tT%SAOkbM9;fk0^$X{V0;v+-`L#z5US_hVa`~ZmJP_h+|m=E zdp`C6r~UF3-?WFvLiy}4e1!~;eE&@H(wR$NY0Vfp3&n)wib5bEId3?_c3OT#5GC4a$ z1nZy(`U7s_O;&agMf1AqDi9d)UD|U>k4TAtegPR$dc>~rNDp#*)I-YTdk~{H!7q}a zIf!E#1))|g#Ze-yy(}8Iu>A8u5(UyvVS!j^G;hzCP%W9nU@Qc~(CXBq9*kwa<~$h| znN!dj#56CaaRUljZ)yAF&ik-vE<(U$vIiX&?||+L2KX?R&K9z%usR!s!KUmL8}|Y3 z{$au0_0~CodhX8YAUA3E22p7)Y?^~0+dw5Lm=$SH*;Yx8f{eI=WkJrCPn}6l4b`_j zm^_!)-!Pu+q1hVAG<+V3$eLppa@31kufp-3H>y&*x7gQKRHeLSFMO=Uud&(dOx|Gf z4R$~jsV0T;bg6Oba;uwC)^CiyM4R76f3`oT4p#P7rtllpzfIyv{%FtD{$%e4*y^@F zfuy0WnktPwXqlbN+>{kK;S)Dt+>T+L%MEp$R_rGl7hsFQGU@$h{4AehC!BA)89s|_ z`14G@%H$6qk?A?)x0WjUF}l z^LTjs27d*u{Q$qmAN<9K3?A??@OmhL_Jgy2VLJydg|V#Z{B~cJY=^*YEh}Y{MIe;$Ka=u1FykJeh!o1pO$7m^Iq*92& z571jk=BJ&tZEGpcSW=;XgdP&gdN(#*d5|i>uF9zl|RKv;p9-(!AFK!4(>EM z5JFAqW3Yv0_%CQL=F`Y0NN%>NgH|7)fp%ExB?j%23&g^0l3`vAT!57p`tUu_Q#jn5 z2LM40wtvxX6W+)R)*7@e`qShCpbElf8h08n3omq{^ABI%?d6%W_Q>~O1o}Jq`tIIr zeUIbzrZm3-V=4`KWUa{}60NfBrJv$5%QIS{eTtTp=9eW{x&2P|mL865Tf_{FNJCBk zBNRNiCq_^zA@dqCHfJ$RvVy}vHEA_t7=XAlDyIou_?49m<&HX_WIwAvMmw6-48I9j z$fxxLY-PSy817(R$c=XA=Z8r5s<;e!F8nJpImB&jh^PVcLD##_^Zo-9lt2Qr*RRso zZTqdvqbO-@oj51PsL1rqEuF0Qnwi0nRvb8o@3fmyQm9bd$IuALNF9ds{w1yqB#a#e zOqi-H{mQ|={?>o^^#0eIge`;=96A>>IRz5(3we&($7FnteBT@(P-|M)1C>ez+R9B< z8*#}MLQGiuzmRnJHRpb=f?t3(G^BDDd@QWqF8~a0k0`%!^))8HNLvE;Dj^AC;^L0FxPEPFEH7GTY~n?be&m)%u`D~{NAv) zmPMPG=L&BZ9hXIK!%X-ZV>aXtE9!%|Gj}lgIaz$<{e|qS*db__^gpyUZr}m%>P&`E zE@yJ3QS6p=?wx22=6yw)J1y0d8!>mO?RZoAW=uddl&ZDY(bXwq9=NP}X-Zrjju z7O>T!bvpgK^==!!Z79E+D2$l6`FYyu=(x4*=S4Er#0F$!_OQe?#4X8$!nRns7&bu7 z(a^k61-CVZcHF@ueQ5DnW7bnKg^;?~- z083<)%z|qaq`X;WyVCCQ5RqJZo^~uV3fgs71g-I zJnTBO^}}rQ$sPB}x7~Np&>RJu>5jDa029L=^BpLd)-oyyCG$DhYVc5RYPbj$y@{W1 zm8+Y}Q3swq#bV{}QOvq_Qo=VmlVn&5mb3&0pQtr!0GX<%R8efM^|UXnAHtK;?nd(G zS&$WHh0;tWHNDK%YS#6<`gyvEr~QB zUfsOUR|^1CG$dvaLeRLd&x8ynTtzaQyo0o$@+T0%#h`{IaQMq;(L#R7OT&bT%$j5g z;WCpp6RM(Nm&r*cWSuo~4`Y*?4AwXtYCN65P+upYq9ioJxI?30UsIJGRnOF_2S&y# z;;?sPaQo&M>QMQoU=cI>9vBFen)rZcW?vM#XvM{sRElUXnTRR#DX0kuz=y!(N-&!6xX0Eo}*bIITHjkA0|m8PDaMF!C$>;(Z(ND=*n6roQ@)ys<1n-#=B z;brg|V8VJUs&#s+6~yO)fh$IfPaSh#jPF2(0am!&?^d2+rj9fpZNK6t0y+}L^QHH`8?kPg7B3*JS^_`U@TBV+{wo}bs#jP{)I zF`x?I%3@{tfJGXZVQX7DP=nJoo*(`J%Z3S@}(%L$GMes6L+9u zpOQLNaRqdJX{S<^Fyx_L84A{BBpoq`X^vP_v%t%ci(={5T}+BW6-^_WbSYgcGrgmA-@rxuPKL``4-Y4%i|ZsM0v^Ft-D)?0^UQA9vCyJx)SOrSEzGxZ@l(h!W4Ef?zUsJ;axQK(;DlaEyh~oEQH@=B6k}d z4GIa{x2+Q^`xyD4Ok}ISDhVU8Y zNIv0dCTEzuh(tb+I;18Hdz7O)#(Cdn%(@YbO4FX6%PBPQ!iK1GsEB^Mv@ zdv}dgPQy4eT{&8rtd3TvVMmxm;02>-nev)Zx40fTUOinMseG;a{D{)(p^B_=oLOoI zT;znCmN*F6%(IN8VTmDAz|swPS>7TZtUw*P513IQoeKbxMHI)gc$YBJq8gI&91wI1 ziFOx8HA+mo|NkNInXO5!)Bb++_T*O7?uVG6bKlI2iUP_>dW!@Yeg}R0c@p0kGHznJ zLPFE7?DS#r)rL%#TG}3+$c%q=oDzy;tfGE5jUS0C`PbmLFusDQ;D%(T;&L}6FX4u~ zVkW68x~(%jARFRvf^gP>Jvh`&9DgzO5Ml)}-fJ&XEV0Hobvit9{nlYgno;_-faWdW z`^qP`;ZKsfhb_+fQ>@I)9;sb$m5o%O1Kxv#${U>4Q5fiY|Jg?Z!udqgG%i8ncs9GF z3k;Y0+}GeRqwauy$K4?)IS9UlxPu~yy%7Iyg`>tOpio`n;Tfog8-~x9?!E2SZji5) zhpJb*?EWY=sGci)db{Y;+t3WvC^q%Uo0Sh!=$v#AIvCQajE(5<&Tm)fx>Ly&j5s~w zpl)KQ(8@R~IEsVY7%BqiD^&>XUYNZLYCgl)N&bBW%z?ot(9WSd;!qHl@<=s3PKB}s zEldZY>=t7#kKl~z5u8oMVN|*+(JP)rXY~`DJ!YQg)G!lJ{bJ@6Z^eTubU7SaMIJ{~ zQO@N)&9kb^qs0Tt;7kv6ZMrabs&!2BZ&bN3XLb98rlhxDFAG&qhqSDdw8k0;Lt}gA zrR75P|L0X+zUT`%pMF&5wa8QuDSVO}pl}|E1N<`&VR@XZ? zB@lSPL-gBf^l)GcCoY1ah96LIxEPO8>gMOYb~ALSGdn_Xm#|+(TxlRf?l{6^=mkTg z&Zi9WO!^QbZ(t5D=zzL*;l2~qloi5+Fsf*!sG0e+TLnm;nVFdtET{GRa=W!=UJLks z#j0^2tke6Ipvy{t8XO5C`_|cw@Ej4EB!01xgOCPChVVKebaVqruj0P!pBnWu^cw~O$X?EvxrmU4+BfL)Q224WB*DU$Xc{sIbgG|KQL=H#-1 zL>SkY6MtY|GiV{Aq&JY#hPavphhDw%sI#!;W`pEaJfA7%jGp3G52TL1Jag&dLL+T1 z*Ao<7L{xK~XAf}TJE*}acn2n-Gf=vMgrNNfkq?>yDIPn*f9dAGq4ARN3 zXRo5y5Wark%v$YdP<3OJ%jYa(r{NkCxzIjFPX5FQSE7SXai%6+_6>1YSy0;L$qdr7u99GH)h&=?oo%s8kOpS1Gw=Q_H@G zz%TUVrqr@HZYO4reiVWL3C-sphPb=jW~gG+fIobGK6`u~hjbdZ5y(h8q6Gy5!p!5N z2%SZ~cW(18+rnqfZzq+ID`te0tABpF3iVIskzy{48Ts*Vm9Bn({tu9@NIq`Q(Sj{J z89jk^KvG}elMU4EH1^~U=;?JFdE(-4D?#xb)zeWF@PP0{f_2YgOG{rF@kNNs23z5= z*dbbb%R2+`VJEG(1?^((xLG1g?HoVEKZ(fcKf1B3Eli%A3<+pSoCZ*OC0O3pm?j4 z^;glx=}jpsk9YYCsNy-5IVqtNS_W$9NlIj>p+Fd@1KGlw(9&BzS_Lt!P#0N7B zG6V^-95JLL)k^ObeNMybl^*b3S!kI%{MRV>nG9&6))Re;BZ{xF0Mvhmv3{TrV3p!w zfcTs63$&y$WDX(^{Y#ob7Is*8lgv3-T4t|RAizPV%^8?R3V$hE9MBY4Ux|Z09V<4p zRMI8ITM}YFi{K(+6lC?dO81EfErL2EnwqT%eir=ea}*jAqvSJ4Dm=G~mt<^0zJ}TO zaRqPZ74SRF+tgh`$Q?FV`)K~>NT2M?jhz%=qFt3X z)+rB~aAtwS;4kNg1mz^JN6OcnANsRyKR+^bgh75D%8ZQ*{t!N?P*)&(Oh6akh-%jv zdW5gSLHkR3n_oD%va@_6uXxahll zBSKKVPBuYBOclXP<^*lhyVs=TP=vNt817#fKy+$hCb){#w0<)yi+!dFp>g{5*|4p` z8#eTEHpaYtD8Kr|SVCq6NrI8BLx}Wm8_lSDpmOlD&TyNV`?oa3Uv}Uu3scxc&ThDs zuP-O%T}irN&*2lx;H2xM{!6chPWC+({{os(5{fu|b6-cWYuS{nttZ|D)vAo9F8hgvD!d3xUQq{_o4B=h@w`+xJ^<|pRp4J1`4tR~<2p|+-AK!Gdh zx-BV3bUkrji|P%=Z}Z;%o9H3R_x8b)+a>*fl5J@nW(Q*DmxUlFat>b*H}~u&r2BuD zoG)JI|2R2pE0ELO?~7pw4=o9+9*)n-@GG?k0t;8E>F=Wv2&&cd4Q48_p(P{oXz6t1{LTL3@tkBx6&HewYQ8dxRx|6vP;LJ(!e?Uz7<} zD6fOM*(!$n+h`MxFHcdsOB6-Q5=GJTVHgf1vYEB?IwQDf$o5`?jLx*RA6xH@kA|{$Wrj-|-#2LcaWp?xhKB z1I6YQNFVR#LC8J=*(?|jL{)uSX^-!(ZX3=6tn=`ZuTWaU+@rBW}kw}?-Be?G8On}m*aj9&ch*tkwm{&>}WZ> z_TC3RW;HY8WsYga6;qc4VJ*UFm`^SXX>gkN#vIe+?820W+|Wofqf`UmA-)D|hdIgC z$iR5oloj&EMZNf{RdJqK1NIheep%Ez&R}^U&F44r!&t+XZXr5ZfNo(DO96wYcu7}9 zjZkL4L>qf_ITC2MeV%jjSPX%$G1uf+vB07 zMIDNJ@GFwC;?jri!>NaG@Kj!jL6u`v=gWLgpVDv!)+$?TC_97zo2#ckIkFaaz zt7%_E;Mx_d_C2)wHcpg1YSp}j#3nW0HQTdpK2p+P!6wTVKWf?D#NchADBQx|L86u6 z4lm)cpP+oJ3&pPf0T*fvnQWo*Wx@^HYcC+P=`zuaY+I~gnXH`0T6MuTVeTdKB9_Q( zPYAh~-Sh)3R1Ebk<__eRJzKo8{Q4-}>RbO^bn>tG*0TG+g)YXWD_Z{SiUMrt6Pza9 z5|$4|QRSZn(Q_K@qbi-p^YZ)E^`$re1?sv#9f3<)+UxaJzge$`?;;-_pU43PO7R!}pnd6Nx&aI}2Z@UN1$Gf5-}XC~EjEUW%)aGnX(q z$)w8!fs9V!rr}>=?h7VAU_!r8__J0^@Js$MlF=iKp+GS8rIE_?*~!D>yQX*Q@9yb6)4Rv_;Lg{k&x|tA`dgEGQS;{b z^yEJLJ&M0aCXY<-8J`{>8{a!VU84sf{QFEa^~=bSZ`uc)wGyF(+&z|ogk1RNkSP2k zqivWo@P+A+3nrKgTVB+$&>bmVlwhIUDPznnGSRe?S<113|0M>T13DxV5-aI5;UNi6 k;r`n(fg+46*2ilcmX3yw(pREjmAyz$A++Rh^?3Dv0RxQVCIA2c literal 0 HcmV?d00001 diff --git a/python/mxnet/module/python_module.py b/python/mxnet/module/python_module.py index f46ea280aaff..82dcb06aa020 100644 --- a/python/mxnet/module/python_module.py +++ b/python/mxnet/module/python_module.py @@ -110,7 +110,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non """ pass - def update(self): + def update(self, storage_type_dict=None): """Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. Currently we do nothing here. Subclass should override this method if contains parameters. diff --git a/python/mxnet/module/python_module.pyc b/python/mxnet/module/python_module.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4a32e38e9b599a296aad63bb7af22fb3d9c0810 GIT binary patch literal 14763 zcmc&*U5^~cb*-8GUM{sXKO`~;lvW^TXXE8kiV+x*6(>tlqG($yqDfis3d&5+Ozm!y zJ>8@39+G2`FOgn@z>h|N{DA=Z3qfA;7$kX2p8N;mhX8qX&bha`x@WnRmgWkfIK}>` zuCIIVx%Zwb&Hc~W)yMzw-JcJP`!k1s{}R6;J4`s^?+a_t5JzUJ0;k-%OW;AE^ z=1gn2V1|n(nKz>av$vr4OD0(~qb0MqjQeF{=gdAnTrpo6^QAFA-&--mRbyA==^6iY zm7Uz}tg-JdiaUAsFdo{V$VN6ejMCvW7$i|q*do}^@?cWfJSeju9*q+_vS}GaK_NH6 zARCP*WmLvl8jP~xB(bY;8r+$dhj?J4B2K2;tMPB~XE$!4rw)s96~*dlMm8il;_d3xd;l+2(R!LBGOk&Krx>Ys(7Y zc2REUs@o;Gov&`Oj;Op)-LA;(Vs(qfTjKTJ8DrNB7X3?fcY%NHEBwpVhk(}Q;i~k0 z&J1Po_r4?5&Q~?@6Ny=aE`7N1;_*3O%+=FQ$E6W;B-$~pcDE4YJo zPaoPerYv;XRbWAR7?ri0ieMC_(;(Z&b6Zv1xpTV+qGC0`A3GYC(_nv+4k*E5dsUv^ ziSlS+NQ#?e&6KChrI5R%@7H_-R%Ksl~9gQ5{bu*t%_}em1V3KWfzfF~VG7s}?waPj3q-9Fnv`zj4Q!elp z>h!c_QhbBA?dt8v^0qB++kA_fXtBG5!2y~!4!|Jta zh_QAXvMt$1x4F3TErq0Y<=f&g8rwqa>c{TMD~+%9QHSG|HVT{U%Lp6Yb2!c3i5<<( zLk_%ypBBXraCPkv@ON#}pQPpFT7MEJ!)s#!vTLKq809rp9rvuOciZD>$c4OuUqSmc ztt;*Imi*UVY2iO92ZX)qCwf85f{^Z09(+-}jbg{!liC--R}3h(pHJevGN2fCEy%eo zpN51}z4{KX@H>TBokuHM_!ZO^Gk1a+TtmI%W;~sgM668s(aveEc8G zjK4z>Yz3lho547W^CBn!WAXkpP7egeL^N>4G#Kh7X#`0$`G{P&ZM#HmDN_>uas-+_y17qPoLXh2>j*hBagm zK?L+T1USPlt8$-FsG#2kl(#JGXzXJT$~7x&5al))XC*?NI7$*krr9`#^LFtC!6h1$ zlPGC?%7}I7r4G{^W;2m0VpD!Mv=8Hf-45^ zovU@U9;}O;s7|3MUtP&XPkuPPUgwLaTfrAkH+#J;7y+S$v(Z>=7!S%Lswqy6&rcmn z=r|5+*_C+|A9>NO9DH%{A8-Bo7hnCwfX_XmmXY_WRJR|xIp}QfW_-)Su!&}EMBj7FP|XV-%=?Qm@EB&GuWW|_+gYxYyqQ# zT2fPM>E(f2rJ0roUe5bmL&c-2F*TZ6!U)y|Rs!a49HjIm_2H_n)}sh4k~(Tnl;k!- z7GA$jpo{&eC;KPabI;5J++c4SsR?8#8twr6 zgrHOo^v@>3a6F3g;njXLc)*LGAC-f{?cl~F&(Q)r&Lg;rj0Q?P3J}0N>odICe3S2J z%&1l!CsD46N|a(P$}Eo#Y`2__ZCBtQ4_`j8m%MKx%lkQdVW(zc034U3mH!IPU&Vzc zC_wuM%O(dw0oX^{i4<0I&iiw8<;Y!IwL`cHp}t2u6Y&ZD4+fLbB$0coZ-gzfWgZU# zxZ;_p#Ufe+r2uslSv2^idbv*OuPY4p!a11_;2^u;KRCR10#4>bx$(_NMAu|Lc5m*G zq4E_7>G!7NnD{XTI7zs!Y9-$`qtD$AJ!;7CGh{RszJWrlho0H3O^0HsjlD5KP2iXRLoMiux^0>5vWjtHKkEysQ&D)T^Doy^Td9q;E_z$vN zQ5q--xQ!O6$0ti7Wp8&oT-)N`v542> zbf_^@F&*`@MAAw)F|ccsK?&Dy2NR$ILa2L8x9~T$16vdf4f+-~gG4*EUJVCGfWYVA z`N+Yb61B{?1Opw5psXL2_GAXq$Ub2pKAFTQ5wnvP+Vn%1wFC^<83VK@-ySRNEVNP6 zB$^r1sdVs+(P2rtRsMLKx4!AZQyH*^7TdumhsXrTgd=kBqG&{FeMoO6Gp!Lz^XD!B zsY}=NjeWi!B?aP{*-mSx<~<->n=(YgW8$&IdFt`%?q{VSSq*i^(VqxYgX?HVvz_?h z5Lvc$0c`MyxwR+$>Hi4HS%EM)OJ~H4*i`k&aiQ!K>x`Hvx3V(1)V&VQ zvHCMJ7bTx-J7q$OfFNPXphLZO)$6iB!dUg8z2hkwlCdnl?wcw~^2 zWX8f|_o31p?420%_>`g>23uUOzrnvE9rOCR*16Vv`?8SrZ?v|AvR`XmYQ53A)*i6$ z>UV?RFg#KY3l6=Gd;b}a@6mjZa7i9I9lR=;1YLKlCvYZ92RrbJ>cT#UABaWCK+&CJ zQx^J$kFd>0v~ccJ*S^=2{HNDzI?!Ivx7xtU=Vg|hYIy<0ZdTr|wtMYRl;^bRr|>HJ z`6v8b4kCeNk|Ago1uPDRq8N!)b?)X!A^Zj#2P*QI<)eOubmq3^g&>&WkL$b`XjY|( zvmDe+diM<7tNiv%7rY`y&)pKBc!I>$1^ct^H)`7Cx%x z_nbDcq8qg~_y)d)`tkjDIvX=`0KC4n{np!?TbqbcsP`4BGQmS?uF*x4yV{x=)cpyP zJXk^WFpG!5c$&oUTn@MB5Qb5TjY+2GGYvlzT2Q=LfM+-%KX45yfVwGs4Z>gGEDsrIKluWtMy8DH9U3W*ooQ|x-_5JlIC&*Yo`Y# zA=CFDuQxQ-VQ)I@hl2JNXwpQA0exD?7x3bTdNttgs-#H&!C*yKQ9drMYe7 zt_z}Y%jtK!!))NLxf_jzmkm|<5D5bEmvS0%Gcx3kUc;Zf##`>YhQtENodrQe;^80g z*&8gri$W-L#_{2w@DYJ0dGByw$?7JFRU-)x3a#xat2$% zudXevTv~*mS$k_`X|W|jR-gYk^Q z&I9Rz6&!8hL!4H@F&CU`K|+kvD!9f;AUt1|Yn%*_BSJjFf(m*_1&0Ilkjhz}H6b5S z{19{ce?4nLy-|Ggmbi;Ye)mU1B}Jf$Myi5&eme|$pGq6w3pQMwvhkS%QORQYvdQE$ zBT$DSl7wT^aDyMt58Xz@2_cK?6^A!hT;5D6#{op#C*c% z55Z@1cM-=xSb}eu7X|eiz#p9a-r)^HfjI;lSOF z)ShP&C$$&I&k7;SW{Y1Uta$&JZHV9l@GM-c#-0%A6GXEQtQ`jtR#B*!0rIrW!$-Y@ zf?T|H@Lcfd5b`-R$rKOG`mHEU5{0#W;jL`(Z?fqCyuJzOxz2qpKOSK^{3dgXgKiQV zlN!d^-Z+W!04GIj>aV$(9T!RKg0Hhkn=^%gr7KY7j zs|Bb|!5+<-bM3X(T6?kOH$A<&{yttqT>0cZIPYm3_`*GnEKw8AM1J*zq)l=(gJ5cI60aJb0 z!Mdu`q4G!?qyB;(;+Qd>j(&uHh}JJtG3b`e%K2kd;!p7cD#0x%sKjNe1i%Tt*DFS` z(a;lkb&yjvVr`-I>(wsy%RmP-M;t06V;68S!l~j`_XrHD3;Je2<(Aue=hiMwyHZ@> zBV@^xE;@k6;XeD|K`k5Nkr4l$aAr0wA+jW+i`8pGMq(|Dfv)ICQA@KAZR(I zpKhXZ3|+;fUPX0^x6lNra1YT>>Y5`{T9Y`lDFUGEnA8AAuY{HZtposU#`K7Y{J@-- za)S|Dp_lUSNkQTOk1$zwA4Ci2q^RbGnYuVSk#A()%I0BqR&n(f9Q}x+_EG};g4DO`zRVlLO|p!PRBmNduZeoiDzGAVbu@4M{$bZTyFSHJXDaQ z79x&1ibnLfeluW{f$jv>!%INX>&@ToUCmUE=;R!}#V4mYC3jp6^rsBu4yQy@F>M)u zOpg(&-f>DZ^vRG+C zw8HCX@4xUXE}~exKzeh1ePwa&!rI!}`pWsWGaZtm^=Rd2Qy(n0x+)RHaf|*-B2S${{(VDyLK`$K><*zWy~rV0$fcilAGs zU%&T#ufM;4JXQa5dh+hv;s!lGw!6}8CtGe}cec>S*S$e^9q*!z^tHB=Xwy%&{J0tX z4}MmI3V=5(+>&wJKIsJ;QQQsKVR|Rgy5?%%teVW?asO62=#w8pxQ0u=iupcpf=|L7 z7ywRy1Yo14e0+UXJweYH(Si@g$NCcrSmDQ`zb=ns+2f=L|p=+mhJrr zeJ*L;o*}Q@9fV2H5B;Y}J2`3Nr`fc7QQ}{@VjkIbZ4iWB(AyNo>GlVSn*>qM-i+P; zmc1J!TVNls)ppSH$TLB=@5dPOz%TI`-iN-{vIJpng#KQzj%~P@#5cCTu@Q6vFe{m{ zhnsKfVRt0jcO;J>9Fg2gK8zOdv&3$kWen~J@y zs!AI*GA(<+pQbUck_o;|>I4CfTDkXzdQ?#l0cnzKjjs5X!aleH9xLkiD{AK!`w3`@ zfYut^(rUs@h0*4wA16o9ymkHda;q`T4f0VBfO3Z)fF&`t74(9n)iMu8#I=BEb^XNE zNezQmqn@7-g26{iZ@GP6CsXY4+{Dc~xW4w2gulHYPAJ+sapOdj^dkSz4Q`T{SF%ED z+aGG#xh$pqaKqs~@u5oXHE>gM({H17aVzTj7YFNuUNX42PF}m%A0}H-?_zfkTfAuY ze^KYEtYvzd{h`CJJ&j9K!zgRss$2C6E^Ex1wS=PlR9+w8(h+Vz?S1?Qj)f@J-UT9m z<=%PqsHz^~KVDYU_L$PA6h0bC2%Rs3CVyn9om1)#-`8ZmP|yMaNUbqaf(?+}G$#1C z)*w~%e*k;So=4?sHV=&|EG{~?S;*SQZtQ|aeq4ykh5RRH?-a%ryU?-uJ4)AW2v5)7 zEv@Bxo)H7S-rjZNO>JMXfA9kkYB!3#Vl}>B4nx^o`qIYeK^KyV?Z=Qx_kj^N@{%3& zsQN`CjBEkD0s{+UbV@PVSZ5H&00z0;?nS#jI|b7T2cEABOmHI?>~=|F+gGrf7j&=~ zl+o^%SR`2_#fQJw_dC!Z(a270n=ou6^_iz;{&2yXkoT}R}%#WQ-wk%812V_uW2IM;EFcz7>2sEuj_km7nbR% zja|q$^X|Lv+OONoJ=-@!2@9Y=Owu#ijZ*5#PT=Oyp()s&v!O+(KJ1^_-MuCn&HYxi zu{g3?I=+cj8W?`*xo5Br?~v4 zK~?CSz)heSxsj&P38Lj8_J^mU1tPHbI@UB=`s{3V+Iqb*Tce(?S30x_yr?MOd1&Z0nATLgUc!WBI}9jV8+H#81a^h~5y6OIH-Z>CXnhE= z0VA7_ZU&g7jE$g_q}GeW!#OO%K*q zdBpSvUM4AzjQ|Aa@qIBxl$U*wrsn$mF)S(~ODD&9nKw}y^ks!f&4<1aE`1Gav~X!| zKuyjRJW6R5N#%wPq|KV_NMFWCA&MAZgTb=Ke=$@s2FiZmh69EmX{6Am!;AQL)U=T# zMj^l{fZ8yo9NNOT3rjZCNMIC2ChPYJWpX{58f4(T)E2 zzoPocYMC^K2zEvYR1mRHi)M%4KpdI-Fc!2Rj))_}s8n0NK!JP%t9*rPR3K?XD4`;d zFq^4B#>b^$DScWc6>3db;>Suo${Q86_cIw&L-cbwjehPQ;h--o>QPNSLL6hM`(x^1 zjWMakz!eV|frp(p5eh*SSJVbmPE-lHmWg}n;zLzt(ag6{#SMfTE#uU{UmEPC?ZUUB zYFrjIOqwAepr-`S8{+G`aDuyVrznH@>ZwZr?{sx#IgZ_-(G8-+gTmbul`A2dA8$eh zZci1S7xo_*0l9fW%CbC{`HN2 z`Rd8l4tGmqLJGzMw9yhOYZ$t##X_AP4x)8HOAN) zi9T?aS0i&FzlZP}iBIq#GM;2Ub;R2Fk-B5jo{uElCSOm`MP{R}_BQF>az(H_^DdtX zM^9!;jAuL!Hpf~3i&!1ZJ;8o6?ct<)jHD~y)wzQM-tGLEx`W&f#=}bT6HVSVhQ7Rs z^o#Mub-{&Y!}oD67&+7AQGOBbFpFcr#0b_4poPsT*aHK|T{$@N;flZEBHd@VjqPhU z?dyB+b4Fw%6#{z;`LnHH6FFYH9|wpnli>&~Ny@cLJ8BpFUU22N+yy2gL+8dr8yTg< zKa{;N7HxF{t@)W(V5g&Sh6?hcgQ|zrnd&7Jljo%tB9pkpQO|f0M*P-oz zxI!?t$CYj~mV%TRY`R)(vmxO(qmG`5^Brc#m=P`m<5*?2JSF zYY;feeZGmcZ}yNfV=g4mQb5^n`C%V90KbFG^rr0%`bZ^%%RFNan)(bTt`xmFYRvm9 zwx$lG?{7192QJ~p4Iv-rJ-cyUFG{{WjC~iGW=$0@Y3TisOT)vP%fzd38lI$EvQnYcIuQ+UR* zE0UMlg}U@R6C8`{BrdUq{{(*syo}H_e#6*8{1Q88!dLoH5Q{Z1dwf+yqpG$il)g%o zf#Z9aj&cIB(dfLxPDqLohSa{nl8L3!9E}16~G~i$9-&x%!_l{@C+-K=_eoc<{5D0eNm1jHbqqrxT-=|P}m2VB6XfayNQ zx*vQBRQ!NtRQX`A%KRg!!} zg;gi<0NH4N&+nkhWKsr*`WX17x#D4_rV*LuHj3G0RCAzhyX6eEF&jm=T5_nrnKMyJ z7r?e;$+8A^xnw8Mg8rz3eS6r4whWQV#fIB589W|0`*HLj@O&u|mF!PBj64kR0#8LvXok1>KnxMygS*z#!UJHdz?|J4} z*32yO$fBIdlBTc(X(wP!4v3(cFo$Dc3K4#(LPMXTpnxP2$Tw1F@R9g!qA1HkDLT7w z+K`T5!G~{5fde(z&}YepmXF{g4M!mAuL`+}OU!*l-4w?_&DSt4bN zi>Jwx0{tu@xyEdc>_!qQ=DUiTaz;zAD6m%m$z3Y?OM*tT<${2S!;$n&ONF4`Ejj2g zvf>EONOlN3^Il$^G7z%tP>Oz3TKiEkYm1C!5+@II)ro&RkO%N< zWM_`aS(AXWPR2*J+vyBY!IyFy_6k)2!Y_#VSy-|ckY~p)Y;2)72*ZUDm<=zD2C!Z_ zW1kb7g3~CX1sefAaWjn8MFmhxED|n1T4HS2V<`#R4Sv^>#2zJZVmOym=SV2KWCxo) zgcV>3@>XMURmQ9dnVyB@0G3pVXSQ^Wxs4^qu_ZPE)oZLgn^`zo>8>H@`?%X>3~WJ92HgO{hXij6`esxT%0_i?$E1p zPVx2@Z{kxivnD%|a{le*53aX9T)uYw!!>b~7`aGbk32y!YWb7fAGPk>bZ)NRkUD~l z_$ug|DZq=vG*N?oh;guTWH|4#?<#LU=WUs{Yq&|tsXH90`qR@xHFQ1lhAsaAib&Kh z+8h;d-r-EiKQItr{Wh6lV>o)0fO|mSesr5{(^<-#UGZ;Q;zMe^y%)!;Z5NWLk1D4IfA3@%1yDJ!qg_{~+( zktwH01eebg&6Yo9kt8EPQzIYEqB<jO98bmE{=(E%U)C<%Cb3@wp+RPYl(OecA*#Q0+9_=fu!OeTNdLk2>mHwJfDInD%R>_L@pd}vYeJ`;anP` zAIwEii4kwfQ8o-X3SBds>KX)Z^NgncoVioN9OJIoKxF^{oC(o>m6jM7q9KTq9q}!r zv{z)e;{p%UpuVQF5ba+HrFBg0yv6@_kn3F&qUnDGn%<9MNRd~Lfqt`?$DEA0Vz1zs z%r%~p$;CI&#xZ!F*g^i=Y5J6c2lqNai9w{Loz0h1wQ_<7Nz}HSi)ogCNJWh%7gDq)#X(K?@S5Afpf5?fo#t?(u<{`bMlI{ z$(OTJv*H**Wsn>@zYUZ#4@1zJEbxic&&mZChD$iym%n zjyv?|p&KYJ@dXJ)I;h$?m?j|&gOI#_%rivK5+(M-w(#^75|K!U2_aGf{2jpD;PIPO z>)3Kk(M2t36mQ>2og=3Vq{<2Q7XO67-+QjdB&s{h1>G#UIpz!~RN@8Z{NdL?yt>SD zF@tYX4e4j(Y(P&hOFDESOzB7tfM+rD6ZIZG6^B5Z_Hv$sN)Mr=4jWH+;_obWM_sYh zZE~_BHBv+SWWQjG187p&driHu87dC@67x*JhN=g>IEBlD0VS_o zj>RPs!#>S+Qoz|~oW&V+?m>b~#YZH0i;uVtDV!yANML#d1YTw?!VEm$!=-AV(j&k}#JioU zty0cGTrWptlh&{#$dVo_?D7WQzBp+n^kE^*=PZd6vlNm@^G+enPoi@#h*O>(2$qdOD z<=JPRv@&n_hhRtm#z6;v)sP$iR3Gs4UFI--C&SOh_=iPY=8DZt-I9l#=u$4?U$i{W zR22o4wkc>ybkj7YrV@9o9+-Lf1<+yAw&JnTZE{v{lO&)yyOnOy7ze6kvb`H%n@PaUFpAYcC}8H(v;1W=XNB(Et(f_^eDVJil1#z^Fcs`t)H4>{;W=5S|O z2y17)%o(p?cAiS?Mo7Itset@&rHUF24nEMWG_sXTl!Pa#CywHbOf4K=urm)m^bQ*@ zA;A7+?)xH`oVCoc49#N=VQrW0&Rx}4FnDn0jT~aL&s+KIo^nFDQB3d{rh~xm6jMzvGrr# zr1o%=t= (3, 0): from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke @@ -50,7 +51,6 @@ np.uint8 : 3, np.int32 : 4 } - _DTYPE_MX_TO_NP = { 0 : np.float32, 1 : np.float64, @@ -58,7 +58,18 @@ 3 : np.uint8, 4 : np.int32 } -# pylint: enable= no-member +_STORAGE_TYPE_ID_TO_STR = { + -1 : 'undefined', + 0 : 'default_storage', + 1 : 'row_sparse', + 2 : 'csr', +} +_STORAGE_TYPE_STR_TO_ID = { + 'undefined' : -1, + 'default_storage' : 0, + 'row_sparse' : 1, + 'csr' : 2, +} def _new_empty_handle(): """Returns a new empty handle. @@ -102,6 +113,11 @@ def waitall(): """ check_call(_LIB.MXNDArrayWaitAll()) +def _storage_type(handle): + storage_type = ctypes.c_int(0) + check_call(_LIB.MXNDArrayGetStorageType(handle, ctypes.byref(storage_type))) + return _STORAGE_TYPE_ID_TO_STR[storage_type.value] + class NDArray(NDArrayBase): """An array object representing a multidimensional, homogeneous array of fixed-size items. @@ -115,6 +131,9 @@ def __repr__(self): return '<%s %s @%s>' % (self.__class__.__name__, shape_info, self.context) + def __reduce__(self): + return (NDArray, (None,), self.__getstate__()) + def __add__(self, other): """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """ return add(self, other) @@ -625,7 +644,6 @@ def wait_to_read(self): """ check_call(_LIB.MXNDArrayWaitToRead(self.handle)) - @property def ndim(self): """Returns the number of dimensions of this array @@ -660,6 +678,7 @@ def shape(self): self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) return tuple(pdata[:ndim.value]) + @property def size(self): """Number of elements in the array. @@ -721,6 +740,10 @@ def dtype(self): self.handle, ctypes.byref(mx_dtype))) return _DTYPE_MX_TO_NP[mx_dtype.value] + @property + def storage_type(self): + return _storage_type(self.handle) + @property # pylint: disable= invalid-name, undefined-variable def T(self): @@ -926,6 +949,13 @@ def backward(self, out_grad=None): 1, c_array(NDArrayHandle, [self.handle]), c_array(NDArrayHandle, ograd_handles))) + def to_csr(self): + # pylint: disable=undefined-variable + return cast_storage(self, storage_type='csr') + + def to_rsp(self): + # pylint: disable=undefined-variable + return cast_storage(self, storage_type='row_sparse') def onehot_encode(indices, out): """One-hot encoding indices into matrix out. @@ -999,7 +1029,6 @@ def zeros(shape, ctx=None, dtype=mx_real_t, **kwargs): # pylint: disable= unused-argument if ctx is None: ctx = Context.default_ctx - # pylint: disable= no-member, protected-access return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype) # pylint: enable= no-member, protected-access @@ -2380,37 +2409,5 @@ def %s(%s): ndarray_function.__module__ = 'mxnet.ndarray' return ndarray_function - -# pylint: enable=too-many-locals, invalid-name -def _init_ndarray_module(ndarray_class, root_namespace): - """List and add all the ndarray functions to current module.""" - _set_ndarray_class(ndarray_class) - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - check_call(_LIB.MXListAllOpNames(ctypes.byref(size), - ctypes.byref(plist))) - op_names = [] - for i in range(size.value): - op_names.append(py_str(plist[i])) - - module_obj = _sys.modules["%s.ndarray" % root_namespace] - module_internal = _sys.modules["%s._ndarray_internal" % root_namespace] - module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace] - for name in op_names: - hdl = OpHandle() - check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) - function = _make_ndarray_function(hdl, name) - if function.__name__.startswith('_contrib_'): - function.__name__ = function.__name__[9:] - function.__module__ = 'mxnet.contrib.ndarray' - setattr(module_contrib, function.__name__, function) - elif function.__name__.startswith('_'): - setattr(module_internal, function.__name__, function) - else: - setattr(module_obj, function.__name__, function) - -_init_ndarray_module(NDArray, "mxnet") - # from .base import add_fileline_to_docstring # add_fileline_to_docstring(__name__) diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py new file mode 100644 index 000000000000..63fbfd0e5510 --- /dev/null +++ b/python/mxnet/sparse_ndarray.py @@ -0,0 +1,641 @@ +# coding: utf-8 +"""SparseNDArray API of mxnet.""" +from __future__ import absolute_import +from __future__ import division +try: + from __builtin__ import slice as py_slice +except ImportError: + from builtins import slice as py_slice + +import ctypes +import warnings + +import os as _os +import sys as _sys + +# import operator +import numpy as np +from .base import _LIB, numeric_types +from .base import c_array, py_str, mx_real_t, c_str +from .base import mx_uint, NDArrayHandle, check_call, OpHandle +from .context import Context +from . import _ndarray_internal as _internal +from . import ndarray +from .ndarray import _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .ndarray import _STORAGE_TYPE_STR_TO_ID +from .ndarray import NDArray, _storage_type, _make_ndarray_function + +# Use different verison of SymbolBase +# When possible, use cython to speedup part of computation. +# pylint: disable=unused-import +try: + if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: + #TODO remove some import? + from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke + elif _sys.version_info >= (3, 0): + from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke + else: + from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke +except ImportError: + if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: + raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") + from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke + +# pylint: enable=unused-import +_STORAGE_AUX_TYPES = { + 'row_sparse': [np.int32], + 'csr': [np.int32, np.int32] +} + +def _new_alloc_handle(storage_type, shape, ctx, delay_alloc, dtype, aux_types, aux_shapes=None): + """Return a new handle with specified storage type, shape, dtype and context. + + Empty handle is only used to hold results + + Returns + ------- + handle + A new empty ndarray handle + """ + hdl = NDArrayHandle() + aux_type_ids = [int(_DTYPE_NP_TO_MX[np.dtype(aux_t).type]) for aux_t in aux_types] + aux_shapes = [(0,) for aux_t in aux_types] if aux_shapes is None else aux_shapes + aux_shape_lens = [len(aux_shape) for aux_shape in aux_shapes] + aux_shapes = sum(aux_shapes, ()) + num_aux = mx_uint(len(aux_types)) + check_call(_LIB.MXNDArrayCreateSparseEx( + ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[storage_type])), + c_array(mx_uint, shape), + mx_uint(len(shape)), + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + ctypes.c_int(int(delay_alloc)), + ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), + num_aux, + c_array(ctypes.c_int, aux_type_ids), + c_array(mx_uint, aux_shape_lens), + c_array(mx_uint, aux_shapes), + ctypes.byref(hdl))) + return hdl + +class SparseNDArray(NDArray): + """An array object representing a multidimensional, homogeneous array of +fixed-size items, stored in sparse format. + + """ + + def __reduce__(self): + raise Exception('Not implemented for SparseND yet!') + # return SparseNDArray, (None,), self.__getstate__() + + def __add__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __iadd__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __radd__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __isub__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __rsub__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __imul__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __rmul__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __rdiv__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __idiv__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __truediv__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __rtruediv__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __itruediv__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __pow__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __rpow__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __getstate__(self): + raise Exception('Not implemented for SparseND yet!') + + def __setstate__(self, state): + raise Exception('Not implemented for SparseND yet!') + + def __setitem__(self, key, value): + """x.__setitem__(i, y) <=> x[i]=y + + Set self[key] to value. Only slice [:] is supported. + + Parameters + ---------- + key : slice + The indexing key. + value : NDArray or numpy.ndarray + The value to set. + + Examples + -------- + >>> src = mx.sparse_nd.row_sparse(data, indices, (3,3)) + >>> src.asnumpy() + array([[ 1., 0., 2.], + [ 0., 0., 0.], + [ 4., 5., 6.]], dtype=float32) + >>> # assign SparseNDArray with same storage type + >>> x = mx.sparse_nd.zeros('row_sparse', (3,3)) + >>> x[:] = src + >>> x.asnumpy() + array([[ 1., 0., 2.], + [ 0., 0., 0.], + [ 4., 5., 6.]], dtype=float32) + >>> # assign NDArray to SparseNDArray + >>> x[:] = mx.nd.ones((3,3)) + >>> x.asnumpy() + array([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]], dtype=float32) + """ + if not self.writable: + raise ValueError('Failed to assign to a readonly NDArray') + if isinstance(key, py_slice): + if key.step is not None or key.start is not None or key.stop is not None: + raise ValueError('Assignment with slicing not supported in SparseNDArray.') + if isinstance(value, NDArray): + # avoid copying to itself + if value.handle is not self.handle: + value.copyto(self) + elif isinstance(value, numeric_types): + raise Exception("Assigning numeric types to SparseNDArray not supported yet.") + elif isinstance(value, (np.ndarray, np.generic)): + # TODO(haibin) Implement _sync_copyfrom for sparse ndarray to avoid an extra copy + warnings.warn('Assigning non-NDArray object to SparseNDArray is not efficient', + RuntimeWarning) + tmp = ndarray.array(value) + tmp.copyto(self) + else: + raise TypeError('type %s not supported' % str(type(value))) + else: + assert(isinstance(key, (int, tuple))) + raise Exception('SparseNDArray only supports [:] for assignment') + + def __getitem__(self, key): + """x.__getitem__(i) <=> x[i] + + Returns a sliced view of this array. + + Parameters + ---------- + key : int or slice + Indexing key. + + Examples + -------- + >>> x[:] = mx.nd.arange(0,6).reshape((2,3)) + >>> x.asnumpy() + array([[ 0., 1., 2.], + [ 3., 4., 5.]], dtype=float32) + >>> x[1:2].asnumpy() + array([[ 3., 4., 5.]], dtype=float32) + """ + stype = self.storage_type + if stype != 'csr': + raise Exception("__getitem__ for " + str(stype) + " not implemented yet") + if isinstance(key, int): + raise Exception("Not implemented yet") + if isinstance(key, py_slice): + if key.step is not None: + raise ValueError('NDArray only supports continuous slicing on axis 0') + if key.start is not None or key.stop is not None: + return self._slice(key.start, key.stop) + else: + return self + if isinstance(key, tuple): + raise ValueError('Multi-dimension indexing is not supported') + + def _sync_copyfrom(self, source_array): + raise Exception('Not implemented for SparseND yet!') + + def _slice(self, start, stop): + """Returns a read-only SparseNDArray slice that shares memory with current one. + To create a writable slice, please use ``mx.nd.slice`` instead. + + Parameters + ---------- + start : int + Starting index of slice. + stop : int + Finishing index of slice. + + Example + ---------- + >>> indptr = np.array([0, 2, 3, 6]) + >>> indices = np.array([0, 2, 2, 0, 1, 2]) + >>> data = np.array([1, 2, 3, 4, 5, 6]) + >>> a = mx.sparse_nd.csr(data, indptr, indices, (3, 3)) + >>> a.asnumpy() + array([[1, 0, 2], + [0, 0, 3], + [4, 5, 6]]) + + >>> a[1:2].asnumpy() + array([[0, 0, 3]]) + + """ + stype = self.storage_type + assert(stype == 'csr'), "_slice for " + str(stype) + " not implemented yet" + warnings.warn('slicing SparseNDArray is not efficient', RuntimeWarning) + shape = list(self.shape) + shape[0] = stop - start + handle = _new_alloc_handle(self.storage_type, tuple(shape), self.context, + True, self.dtype, self.aux_types) + start = mx_uint(start) if start else mx_uint(0) + stop = mx_uint(stop) if stop else mx_uint(self.shape[0]) + + check_call(_LIB.MXNDArraySliceEx(self.handle, start, stop, handle)) + ret = SparseNDArray(handle=handle, writable=False) + return ret + + def _at(self, idx): + raise Exception('at operator for SparseND is not supported.') + + def reshape(self, shape): + raise Exception('Not implemented for SparseND yet!') + + def broadcast_to(self, shape): + raise Exception('Not implemented for SparseND yet!') + + def _aux_type(self, i): + """Data-type of the array’s ith aux data. + + Returns + ------- + numpy.dtype + This NDArray's data type. + """ + aux_type = ctypes.c_int() + check_call(_LIB.MXNDArrayGetAuxType(self.handle, i, ctypes.byref(aux_type))) + return _DTYPE_MX_TO_NP[aux_type.value] + + @property + def _values(self): + """The values array of the SparseNDArray. This is a read-only view of the values array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's values array. + """ + return self._data() + + @property + def _indices(self): + """The indices array of the SparseNDArray. This is a read-only view of the indices array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indices array. + """ + stype = self.storage_type + if stype == 'row_sparse': + return self._aux_data(0) + elif stype == 'csr': + return self._aux_data(1) + raise Exception("unknown storage type " + stype) + + @property + def _indptr(self): + """The indptr array of the SparseNDArray with `csr` storage type. + This is a read-only view of the indptr array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indptr array. + """ + stype = self.storage_type + if stype == 'csr': + return self._aux_data(0) + raise Exception("unknown storage type " + stype) + + @property + def _num_aux(self): + ''' The number of aux data used to help store the sparse ndarray. + ''' + return len(_STORAGE_AUX_TYPES[self.storage_type]) + + @property + # pylint: disable= invalid-name, undefined-variable + def T(self): + raise Exception('Transpose is not supported for SparseNDArray.') + + @property + def aux_types(self): + """The data types of the aux data for the SparseNDArray. + """ + aux_types = [] + num_aux = self._num_aux + for i in range(num_aux): + aux_types.append(self._aux_type(i)) + return aux_types + + def asnumpy(self): + """Return a dense ``numpy.ndarray`` object with value copied from this array + + """ + return self.to_dense().asnumpy() + + def astype(self, dtype): + raise Exception('Not implemented for SparseND yet!') + + def copyto(self, other): + """Copies the value of this array to another array. + + If ``other`` is a ``NDArray`` object, then ``other.shape`` and + ``self.shape`` should be the same. This function copies the value from + ``self`` to ``other``. + + If ``other`` is a context, a new ``NDArray`` will be first created on + the target context, and the value of ``self`` is copied. + + Parameters + ---------- + other : NDArray or Context + The destination array or context. + + Returns + ------- + NDArray + The copied array. If ``other`` is an ``NDArray``, then the return value + and ``other`` will point to the same ``NDArray``. + """ + if isinstance(other, NDArray): + if other.handle is self.handle: + warnings.warn('You are attempting to copy an array to itself', RuntimeWarning) + return + return _internal._copyto(self, out=other) + elif isinstance(other, Context): + hret = SparseNDArray(_new_alloc_handle(self.storage_type, self.shape, other, + True, self.dtype, self.aux_types)) + return _internal._copyto(self, out=hret) + else: + raise TypeError('copyto does not support type ' + str(type(other))) + + def to_dense(self): + return to_dense(self) + + def _aux_data(self, i, writable=False): + """ Get an NDArray referencing the ith aux data array associated with the SparseNDArray. + """ + self.wait_to_read() + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayGetAuxNDArray(self.handle, i, ctypes.byref(hdl))) + return NDArray(hdl, writable) + + def _data(self, writable=False): + """ Get an NDArray referencing the value array associated with the SparseNDArray. + """ + self.wait_to_read() + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayGetDataNDArray(self.handle, ctypes.byref(hdl))) + return NDArray(hdl, writable) + +def _prepare_src_array(src, dtype, default_dtype): + if isinstance(src, NDArray): + dtype = src.dtype if dtype is None else dtype + else: + dtype = default_dtype if dtype is None else dtype + if not isinstance(src, np.ndarray): + try: + src = np.array(src, dtype=dtype) + except: + raise TypeError('values must be array like object') + return src, dtype + +def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, indices_type=None): + """Creates a 2D array with compressed sparse row format. + + A SparseNDArray with `csr` storage represents a NDArray as three separate arrays: `values`, + `indptr` and `indices`. It uses the standard CSR representation where the column indices for + row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored + in values[indptr[i]:indptr[i+1]]. + + Parameters + ---------- + values: array_like + An object exposing the array interface, with shape [nnz], where D0 is the number of + non-zero entries. + indptr: array_like + An object exposing the array interface, with shape [D0 + 1]. The first element in indptr + should always be zero. + indices: array_like + An object exposing the array interface, with shape [nnz]. + ctx : Context, optional + Device context (default is the current default context). + dtype : str or numpy.dtype, optional + The data type of the output array. The default dtype is ``values.dtype`` + if `values` is an `NDArray`, `float32` otherwise. + indptr_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indptr.dtype`` + if `indptr` is an `NDArray`, `int32` otherwise. + indices_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indices.dtype`` + if `indicies` is an `NDArray`, `int32` otherwise. + + Returns + ------- + SparseNDArray + An `SparseNDArray` with the `csr` storage representation. + """ + storage_type = 'csr' + # context + if ctx is None: + ctx = Context.default_ctx + # prepare src array and types + values, dtype = _prepare_src_array(values, dtype, mx_real_t) + indptr, indptr_type = _prepare_src_array(indptr, indptr_type, + _STORAGE_AUX_TYPES[storage_type][0]) + indices, indices_type = _prepare_src_array(indices, indices_type, + _STORAGE_AUX_TYPES[storage_type][1]) + # verify types + assert('int' in str(indptr_type) or 'long' in str(indptr_type)) + assert('int' in str(indices_type) or 'long' in str(indices_type)) + # verify shapes + aux_shapes = [indptr.shape, indices.shape] + assert(values.ndim == 1) + assert(indptr.ndim == 1) + assert(indices.ndim == 1) + assert(len(shape) == 2) + result = SparseNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype, + [indptr_type, indices_type], aux_shapes)) + # assign indptr, indices and values + values_ref = result._data(True) + indptr_ref = result._aux_data(0, True) + indices_ref = result._aux_data(1, True) + values_ref[:] = values + indptr_ref[:] = indptr + indices_ref[:] = indices + return result + +def row_sparse(values, indices, shape, ctx=None, dtype=None, indices_type=None): + """Creates a row sparse array with a set of tensor slices at given indices. + + A SparseNDArray with `row_sparse` storage is typically used to represent a subset of a larger + NDArray with `default_storage` of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. The values + in indices are the indices in the first dimension of the slices that have been extracted from + the larger NDArray. The indices are expected to be sorted in ascending order. + + The corresponding NDArray ``dense`` with `default_storage` represented by a ``rsp`` + SparseNDArray with `row_sparse` storage has + + ``dense[rsp.indices[i], :, :, :, ...] = rsp.values[i, :, :, :, ...]`` + + `row_sparse` SparseNDArray is used principally in the definition of gradients for operations + that have sparse gradients (e.g. SparseEmbedding). + + Parameters + ---------- + values: array_like + An object exposing the array interface, with shape [D0, D1, .. Dn], where D0 is + the number of rows with non-zeros entries. + indices: array_like + An object exposing the array interface, with shape [D0]. + ctx : Context, optional + Device context (default is the current default context). + dtype : str or numpy.dtype, optional + The data type of the output array. The default dtype is ``values.dtype`` + if `values` is an `NDArray`, `float32` otherwise. + indices_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indices.dtype`` + if `indicies` is an `NDArray`, `int32` otherwise. + + Returns + ------- + SparseNDArray + An `SparseNDArray` with the `row_sparse` storage representation. + """ + storage_type = 'row_sparse' + # context + if ctx is None: + ctx = Context.default_ctx + # prepare src array and types + values, dtype = _prepare_src_array(values, dtype, mx_real_t) + indices, indices_type = _prepare_src_array(indices, indices_type, + _STORAGE_AUX_TYPES[storage_type][0]) + # verify types + assert('int' in str(indices_type) or 'long' in str(indices_type)) + # verify shapes + assert(values.ndim == len(shape)) + assert(indices.ndim == 1) + result = SparseNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype, + [indices_type], [indices.shape])) + # assign indices and values + values_ref = result._data(True) + indices_ref = result._aux_data(0, True) + values_ref[:] = values + indices_ref[:] = indices + return result + +def to_dense(source): + """ Return a dense array representation of this SparseNDArray. + + Returns + ------- + SparseNDArray + The dense array with default storage + """ + return ndarray.cast_storage(source, storage_type='default_storage') + +def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None): + """Return a new array of given shape and type, filled with zeros. + + Parameters + ---------- + shape : int or tuple of int + The shape of the empty array + storage_type: string + The storage type of the empty array, such as 'row_sparse', 'csr', etc + ctx : Context, optional + An optional device context (default is the current default context) + dtype : str or numpy.dtype, optional + An optional value type (default is `float32`) + aux_types: list of numpy.dtype, optional + An optional type for the aux data for SparseNDArray (default values depends + on the storage type) + + Returns + ------- + SparseNDArray + A created array + Examples + -------- + >>> mx.sparse_nd.zeros('csr', (1,2), mx.gpu(0)) + + >>> mx.sparse_nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy() + array([[ 0., 0.]], dtype=float16) + """ + if ctx is None: + ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype + if aux_types is None: + if storage_type == 'row_sparse' or storage_type == 'csr': + aux_types = _STORAGE_AUX_TYPES[storage_type] + else: + raise Exception("unknown storage type") + assert(len(aux_types) == len(_STORAGE_AUX_TYPES[storage_type])) + out = SparseNDArray(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types)) + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out) + +def _ndarray_cls(handle): + stype = _storage_type(handle) + # TODO(haibin) in the long run, we want to have CSRNDArray and RowSparseNDArray which + # inherit from SparseNDArray + return NDArray(handle) if stype == 'default_storage' else SparseNDArray(handle) + +# pylint: enable=too-many-locals, invalid-name +def _init_ndarray_module(ndarray_class, root_namespace): + """List and add all the ndarray functions to current module.""" + _set_ndarray_class(ndarray_class) + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + + check_call(_LIB.MXListAllOpNames(ctypes.byref(size), + ctypes.byref(plist))) + op_names = [] + for i in range(size.value): + op_names.append(py_str(plist[i])) + + module_obj = _sys.modules["%s.ndarray" % root_namespace] + module_internal = _sys.modules["%s._ndarray_internal" % root_namespace] + module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace] + for name in op_names: + hdl = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) + function = _make_ndarray_function(hdl, name) + if function.__name__.startswith('_contrib_'): + function.__name__ = function.__name__[9:] + function.__module__ = 'mxnet.contrib.ndarray' + setattr(module_contrib, function.__name__, function) + elif function.__name__.startswith('_'): + setattr(module_internal, function.__name__, function) + else: + setattr(module_obj, function.__name__, function) + +_init_ndarray_module(_ndarray_cls, "mxnet") diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 4632f7d71b17..c8c45f4060f2 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -13,11 +13,13 @@ import numpy as _numpy from .base import _LIB, numeric_types -from .base import c_array, c_str, mx_uint, py_str, string_types, mx_real_t +from .base import c_array, c_str, mx_uint, py_str, string_types from .base import NDArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call, MXNetError from .context import Context, cpu -from .ndarray import NDArray, zeros as _nd_zeros, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .ndarray import _STORAGE_TYPE_ID_TO_STR, _STORAGE_TYPE_STR_TO_ID +from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .sparse_ndarray import _ndarray_cls from .executor import Executor from . import _symbol_internal as _internal from .attribute import AttrScope @@ -520,7 +522,7 @@ def list_attr(self, recursive=False): pairs = ctypes.POINTER(ctypes.c_char_p)() f_handle = _LIB.MXSymbolListAttrShallow check_call(f_handle(self.handle, ctypes.byref(size), ctypes.byref(pairs))) - return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)} + return {py_str(pairs[i * 2]): py_str(pairs[i * 2 + 1]) for i in range(size.value)} def attr_dict(self): """Recursively gets all attributes from the symbol and its children. @@ -546,8 +548,8 @@ def attr_dict(self): check_call(f_handle(self.handle, ctypes.byref(size), ctypes.byref(pairs))) ret = {} for i in range(size.value): - name, key = py_str(pairs[i*2]).split('$') - val = py_str(pairs[i*2+1]) + name, key = py_str(pairs[i * 2]).split('$') + val = py_str(pairs[i * 2 + 1]) if name not in ret: ret[name] = {} ret[name][key] = val @@ -715,6 +717,89 @@ def list_auxiliary_states(self): self.handle, ctypes.byref(size), ctypes.byref(sarr))) return [py_str(sarr[i]) for i in range(size.value)] + def infer_storage_type(self, *args, **kwargs): + """Infer the storage type of outputs and arguments of given known types of arguments. + + User can either pass in the known types in positional way or keyword argument way. + Tuple of Nones is returned if there is not enough information passed in. + An error will be raised if there is inconsistency found in the known types passed in. + + Parameters + ---------- + *args : + Provide type of arguments in a positional way. + Unknown type can be marked as None + + **kwargs : + Provide keyword arguments of known types. + + Returns + ------- + arg_storage_types : list of numpy.dtype or None + List of types of arguments. + The order is in the same order as list_arguments() + out_storage_types : list of numpy.dtype or None + List of types of outputs. + The order is in the same order as list_outputs() + aux_storage_types : list of numpy.dtype or None + List of types of outputs. + The order is in the same order as list_auxiliary_states() + """ + # pylint: disable=too-many-locals + if len(args) != 0 and len(kwargs) != 0: + raise ValueError('Can only specify known argument \ + types either by positional or kwargs way.') + sdata = [] + if len(args) != 0: + keys = None + for s in args: + if s is not None: + if s not in _STORAGE_TYPE_STR_TO_ID or not isinstance(s, basestring): + raise TypeError('Argument need to be one of '+str(_STORAGE_TYPE_STR_TO_ID)) + sdata.append(_STORAGE_TYPE_STR_TO_ID[s]) + else: + sdata.append(_STORAGE_TYPE_STR_TO_ID['undefined']) + else: + keys = [] + for k, v in kwargs.items(): + if v in _STORAGE_TYPE_STR_TO_ID: + keys.append(c_str(k)) + sdata.append(_STORAGE_TYPE_STR_TO_ID[v]) + arg_storage_type_size = mx_uint() + arg_storage_type_data = ctypes.POINTER(ctypes.c_int)() + out_storage_type_size = mx_uint() + out_storage_type_data = ctypes.POINTER(ctypes.c_int)() + aux_storage_type_size = mx_uint() + aux_storage_type_data = ctypes.POINTER(ctypes.c_int)() + complete = ctypes.c_int() + check_call(_LIB.MXSymbolInferStorageType( + self.handle, + mx_uint(len(sdata)), + c_array(ctypes.c_char_p, keys), + c_array(ctypes.c_int, sdata), + ctypes.byref(arg_storage_type_size), + ctypes.byref(arg_storage_type_data), + ctypes.byref(out_storage_type_size), + ctypes.byref(out_storage_type_data), + ctypes.byref(aux_storage_type_size), + ctypes.byref(aux_storage_type_data), + ctypes.byref(complete))) + if complete.value != 0: + arg_storage_types = [ + _STORAGE_TYPE_ID_TO_STR[arg_storage_type_data[i]] \ + for i in range(arg_storage_type_size.value)] + out_storage_types = [ + _STORAGE_TYPE_ID_TO_STR[out_storage_type_data[i]] \ + for i in range(out_storage_type_size.value)] + aux_storage_types = [ + _STORAGE_TYPE_ID_TO_STR[aux_storage_type_data[i]] \ + for i in range(aux_storage_type_size.value)] + return (arg_storage_types, out_storage_types, aux_storage_types) + else: + return (None, None, None) + # pylint: enable=too-many-locals + + def infer_type(self, *args, **kwargs): """Infers the type of all arguments and all outputs, given the known types for some arguments. @@ -770,7 +855,7 @@ def infer_type(self, *args, **kwargs): if s is not None: s = _numpy.dtype(s).type if s not in _DTYPE_NP_TO_MX: - raise TypeError('Argument need to be one of '+str(_DTYPE_NP_TO_MX)) + raise TypeError('Argument need to be one of ' + str(_DTYPE_NP_TO_MX)) sdata.append(_DTYPE_NP_TO_MX[s]) else: sdata.append(-1) @@ -879,7 +964,7 @@ def infer_shape(self, *args, **kwargs): if len(unknowns) >= 10: unknowns.append('...') break - unknowns.append('%s: %s'%(name, str(shape))) + unknowns.append('%s: %s' % (name, str(shape))) warnings.warn( "Cannot decide shape for the following arguments " + "(0s in shape means unknown dimensions). " + @@ -1006,7 +1091,7 @@ def _infer_shape_impl(self, partial, *args, **kwargs): return (arg_shapes, out_shapes, aux_shapes) else: return (None, None, None) - # pylint: enable=too-many-locals + # pylint: enable=too-many-locals def debug_str(self): """Gets a debug string of symbol. @@ -1154,12 +1239,11 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing): raise TypeError('Only accept list of NDArrays or dict of str to NDArray') return c_array(NDArrayHandle, arg_handles), arg_arrays - def simple_bind(self, ctx, - grad_req='write', - type_dict=None, - group2ctx=None, - **kwargs): - """Binds current symbol to get an executor, allocate all the arguments needed. + def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=None, + group2ctx=None, shared_arg_names=None, shared_exec=None, + shared_buffer=None, **kwargs): + """Bind current symbol to get an executor, allocate all the arguments needed. + Allows specifying data types. This function simplifies the binding procedure. You need to specify only input data shapes. Before binding the executor, the function allocates arguments and auxiliary states @@ -1169,7 +1253,7 @@ def simple_bind(self, ctx, ---------- >>> x = mx.sym.Variable('x') >>> y = mx.sym.FullyConnected(x, num_hidden=4) - >>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[]) + >>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null') >>> exe.forward() [] >>> exe.outputs[0].asnumpy() @@ -1199,9 +1283,26 @@ def simple_bind(self, ctx, type_dict : Dict of str->numpy.dtype Input type dictionary, name->dtype + storage_type_dict : Dict of str->str + Input storage type dictionary, name->storage_type + group2ctx : Dict of string to mx.Context The dict mapping the `ctx_group` attribute to the context assignment. + shared_arg_names : List of string + The argument names whose `NDArray` of shared_exec can be reused for initializing + the current executor. + + shared_exec : Executor + The executor whose arg_arrays, arg_arrays, grad_arrays, and aux_arrays can be + reused for initializing the current executor. + + shared_buffer : Dict of string to `NDArray` + The dict mapping argument names to the `NDArray` that can be reused for initializing + the current executor. This buffer will be checked for reuse if one argument name + of the current executor is not found in `shared_arg_names`. The `NDArray`s are + expected have default storage type. + kwargs : Dict of str->shape Input shape dictionary, name->shape @@ -1210,47 +1311,187 @@ def simple_bind(self, ctx, executor : mxnet.Executor The generated executor """ - # pylint: disable=too-many-locals - if type_dict is None: - attrs = self.attr_dict() - type_dict = {k: mx_real_t for k in self.list_arguments() - if k not in attrs or '__dtype__' not in attrs[k]} - arg_shapes, _, aux_shapes = self.infer_shape(**kwargs) - arg_types, _, aux_types = self.infer_type(**type_dict) - - if arg_shapes is None or arg_types is None: - raise ValueError("Input node is not complete") - + # data types + num_provided_arg_types = 0 + provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names + provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types + if type_dict is not None: + provided_arg_type_names = [] + provided_arg_type_data = [] + for k, v in type_dict.items(): + v = _numpy.dtype(v).type + if v in _DTYPE_NP_TO_MX: + provided_arg_type_names.append(c_str(k)) + provided_arg_type_data.append(ctypes.c_int(_DTYPE_NP_TO_MX[v])) + num_provided_arg_types = mx_uint(len(provided_arg_type_names)) + provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names) + provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data) + + # storage types + num_provided_arg_stypes = 0 + # provided storage type argument names + provided_arg_stype_names = ctypes.POINTER(ctypes.c_char_p)() + provided_arg_stype_data = ctypes.POINTER(mx_uint)() # provided storage types + if storage_type_dict is not None: + provided_arg_stype_names = [] + provided_arg_stype_data = [] + for k, v in storage_type_dict.items(): + if v in _STORAGE_TYPE_STR_TO_ID: + provided_arg_stype_names.append(c_str(k)) + provided_arg_stype_data.append(ctypes.c_int(_STORAGE_TYPE_STR_TO_ID[v])) + num_provided_arg_stypes = mx_uint(len(provided_arg_stype_names)) + provided_arg_stype_names = c_array(ctypes.c_char_p, provided_arg_stype_names) + provided_arg_stype_data = c_array(ctypes.c_int, provided_arg_stype_data) + + provided_arg_shape_data = [] # shape data + # argument shape index in sdata, + # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg + provided_arg_shape_idx = [0] + provided_arg_shape_names = [] # provided argument names + for k, v in kwargs.items(): + # if k not in listed_arguments and k not in listed_aux_states: + # raise ValueError('arg name %s is not valid', k) + if isinstance(v, tuple): + provided_arg_shape_names.append(c_str(k)) + provided_arg_shape_data.extend(v) + provided_arg_shape_idx.append(len(provided_arg_shape_data)) + + provided_req_type_list_len = 0 + provided_grad_req_types = ctypes.POINTER(ctypes.c_char_p)() + provided_grad_req_names = ctypes.POINTER(ctypes.c_char_p)() + if grad_req is not None: + if isinstance(grad_req, string_types): + # use provided_req_type_list_len = 0 to indicate this situation + provided_req_type_list_len = 0 + provided_grad_req_types = [c_str(grad_req)] + elif isinstance(grad_req, list): + if len(grad_req) == 0: + raise RuntimeError('grad_req in simple_bind cannot be an empty list') + provided_grad_req_types = [c_str(item) for item in grad_req] + provided_req_type_list_len = len(provided_grad_req_types) + elif isinstance(grad_req, dict): + if len(grad_req) == 0: + raise RuntimeError('grad_req in simple_bind cannot be an empty dict') + provided_grad_req_names = [] + provided_grad_req_types = [] + for k, v in grad_req.items(): + provided_grad_req_names.append(c_str(k)) + provided_grad_req_types.append(c_str(v)) + provided_grad_req_names = c_array(ctypes.c_char_p, provided_grad_req_names) + provided_req_type_list_len = len(provided_grad_req_types) + provided_grad_req_types = c_array(ctypes.c_char_p, provided_grad_req_types) + + num_ctx_map_keys = mx_uint(0) + ctx_map_keys = ctypes.POINTER(ctypes.c_char_p)() + ctx_map_dev_types = ctypes.POINTER(ctypes.c_int)() + ctx_map_dev_ids = ctypes.POINTER(ctypes.c_int)() if group2ctx is not None: - attr_dict = self.attr_dict() - arg_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) \ - if name in attr_dict and '__ctx_group__' in attr_dict[name] \ - else ctx for name in self.list_arguments()] - aux_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx) \ - if name in attr_dict and '__ctx_group__' in attr_dict[name] \ - else ctx for name in self.list_auxiliary_states()] - else: - arg_ctx = [ctx] * len(arg_shapes) - aux_ctx = [ctx] * len(aux_shapes) - - # alloc space - arg_ndarrays = [ - _nd_zeros(shape, dev, dtype=dtype) - for dtype, dev, shape in zip(arg_types, arg_ctx, arg_shapes)] - if grad_req != 'null': - grad_ndarrays = {} - for name, shape, dev, dtype in zip( - self.list_arguments(), arg_shapes, arg_ctx, arg_types): - if not isinstance(grad_req, dict) or grad_req[name] != 'null': - grad_ndarrays[name] = _nd_zeros(shape, dev, dtype=dtype) + ctx_map_keys = [] + ctx_map_dev_types = [] + ctx_map_dev_ids = [] + for key, val in group2ctx.items(): + ctx_map_keys.append(c_str(key)) + ctx_map_dev_types.append(ctypes.c_int(val.device_typeid)) + ctx_map_dev_ids.append(ctypes.c_int(val.device_id)) + num_ctx_map_keys = mx_uint(len(ctx_map_keys)) + ctx_map_keys = c_array(ctypes.c_char_p, ctx_map_keys) + ctx_map_dev_types = c_array(ctypes.c_int, ctx_map_dev_types) + ctx_map_dev_ids = c_array(ctypes.c_int, ctx_map_dev_ids) + + # prepare param names + shared_arg_name_list = [] + if shared_arg_names is not None: + if not isinstance(shared_arg_names, list): + raise ValueError('shared_arg_names in simple_bind must be a list or None') + shared_arg_name_list = [c_str(name) for name in shared_arg_names] + + # prepare shared_buffer + if shared_buffer is None: + shared_buffer_len = mx_uint() + shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() + shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() else: - grad_ndarrays = None - - aux_ndarrays = [_nd_zeros(shape, dev, dtype=dtype) - for shape, dev, dtype in zip(aux_shapes, aux_ctx, aux_types)] - executor = self.bind(ctx, arg_ndarrays, - grad_ndarrays, grad_req, aux_ndarrays, - group2ctx=group2ctx) + if not isinstance(shared_buffer, dict): + raise ValueError('shared_buffer in simple_bind must be dict or None') + shared_buffer_names = [] + shared_buffer_handles = [] + for k, v in shared_buffer.items(): + assert(v.storage_type == 'default_storage'), \ + "shared_buffer is expected to only contain NDArrays with default storage" + shared_buffer_names.append(c_str(k)) + shared_buffer_handles.append(v.handle) + shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names) + shared_buffer_len = mx_uint(len(shared_buffer_handles)) + shared_buffer_handles = c_array(NDArrayHandle, shared_buffer_handles) + + # prepare shared_exec_handle + shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle() + + # prepare current executor handle + exe_handle = ExecutorHandle() + + # prepare current executor's in_args, arg_grads, and aux_states + num_in_args = ctypes.c_uint() + in_arg_handles = ctypes.POINTER(NDArrayHandle)() + arg_grad_handles = ctypes.POINTER(NDArrayHandle)() + num_aux_states = ctypes.c_uint() + aux_state_handles = ctypes.POINTER(NDArrayHandle)() + + check_call(_LIB.MXExecutorSimpleBind(self.handle, + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_array(ctypes.c_char_p, provided_arg_shape_names), + c_array(mx_uint, provided_arg_shape_data), + c_array(mx_uint, provided_arg_shape_idx), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, + mx_uint(len(shared_arg_name_list)), + c_array(ctypes.c_char_p, shared_arg_name_list), + ctypes.byref(shared_buffer_len), + ctypes.byref(shared_buffer_names), + ctypes.byref(shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) + + # update shared_buffer + if shared_buffer is not None: + updated_shared_buffer = [NDArray(NDArrayHandle(shared_buffer_handles[i])) + for i in range(shared_buffer_len.value)] + updated_shared_buffer_names = [py_str(shared_buffer_names[i]) + for i in range(shared_buffer_len.value)] + for k, v in zip(updated_shared_buffer_names, updated_shared_buffer): + shared_buffer[k] = v + + # create in_args, arg_grads, and aux_states for the current executor + arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i])) \ + for i in range(num_in_args.value)] + grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i])) + if arg_grad_handles[i] is not None + else None for i in range(num_in_args.value)] + aux_arrays = [_ndarray_cls(NDArrayHandle(aux_state_handles[i])) + for i in range(num_aux_states.value)] + + executor = Executor(exe_handle, self, ctx, grad_req, group2ctx) + executor.arg_arrays = arg_arrays + executor.grad_arrays = grad_arrays + executor.aux_arrays = aux_arrays return executor def bind(self, ctx, args, args_grad=None, grad_req='write', @@ -1435,6 +1676,7 @@ def grad(self, wrt): c_wrt, ctypes.byref(handle))) return Symbol(handle) + # pylint: enable= no-member def eval(self, ctx=cpu(), **kwargs): @@ -1494,8 +1736,8 @@ def reshape(self, shape): """ return reshape(self, shape=shape) - -def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, init=None, **kwargs): +def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, + init=None, storage_type=None, **kwargs): """Creates a symbolic variable with specified name. Example usage: @@ -1549,6 +1791,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, ini if not isinstance(init, string_types): init = init.dumps() attr['__init__'] = init + if storage_type is not None: + attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[storage_type]) for k, v in kwargs.items(): if k.startswith('__') and k.endswith('__'): attr[k] = str(v) @@ -1559,9 +1803,11 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, ini ret._set_attr(**attr) return ret + # for back compatibility Variable = var + def Group(symbols): """Creates a symbol that contains a collection of other symbols, grouped together. @@ -1654,6 +1900,7 @@ def load_json(json_str): # Initialize the atomic symbol in startups _init_symbol_module(Symbol, "mxnet") + # pylint: disable=no-member # pylint: disable=redefined-builtin def pow(base, exp): diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 6b836f5d5d84..6969ad730510 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -8,8 +8,10 @@ import os import errno import logging +import scipy as sp import numpy as np import numpy.testing as npt +import numpy.random as rnd import mxnet as mx from .context import Context from .ndarray import array @@ -63,6 +65,39 @@ def random_arrays(*shapes): return arrays[0] return arrays +# TODO(haibin) also include types in arguments +def rand_sparse_ndarray(shape, storage_type, density=None): + """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """ + density = rnd.rand() if density is None else density + if storage_type == 'row_sparse': + # TODO(haibin) support high dim sparse ndarray + assert(len(shape) < 3) + prod = np.prod(shape) + num_cols = int(prod / shape[0]) + # sample index + idx_sample = rnd.rand(shape[0]) + indices = np.argwhere(idx_sample < density).flatten() + if indices.shape[0] == 0: + result = mx.sparse_nd.zeros('row_sparse', shape) + return result, (np.array([]), np.array([], dtype='int32')) + # generate random values + val = rnd.rand(indices.shape[0], num_cols) + arr = mx.sparse_nd.row_sparse(val, indices, shape, indices_type=np.int32) + return arr, (val, indices) + elif storage_type == 'csr': + assert(len(shape) == 2) + csr = sp.sparse.rand(shape[0], shape[1], density=density, format='csr') + result = mx.sparse_nd.csr(csr.data, csr.indptr, csr.indices, shape) + return result, (csr.indptr, csr.indices, csr.data) + else: + assert(False), "unknown storage type" + +def rand_ndarray(shape, storage_type, density=None): + if storage_type == 'default_storage': + arr = mx.nd.array(random_arrays(shape)) + else: + arr, _ = rand_sparse_ndarray(shape, storage_type, density=density) + return arr def np_reduce(dat, axis, keepdims, numpy_reduce_func): """Compatible reduce for old version of NumPy. @@ -295,7 +330,8 @@ def _parse_location(sym, location, ctx): % (str(set(sym.list_arguments())), str(set(location.keys())))) else: location = {k: v for k, v in zip(sym.list_arguments(), location)} - location = {k: mx.nd.array(v, ctx=ctx) for k, v in location.items()} + location = {k: mx.nd.array(v, ctx=ctx) if isinstance(v, np.ndarray) \ + else v for k, v in location.items()} return location @@ -586,8 +622,8 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, g[:] = 0 executor.forward(is_train=False) - outputs = [x.asnumpy() for x in executor.outputs] + outputs = [x.asnumpy() for x in executor.outputs] for output_name, expect, output in zip(sym.list_outputs(), expected, outputs): assert_almost_equal(expect, output, rtol, atol, ("EXPECTED_%s"%output_name, "FORWARD_%s"%output_name)) @@ -655,14 +691,29 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= if isinstance(expected, (list, tuple)): expected = {k:v for k, v in zip(sym.list_arguments(), expected)} args_grad_npy = {k:_rng.normal(size=v.shape) for k, v in expected.items()} - args_grad_data = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()} + # args_grad_data should be casted to storage type if hinted + # TODO(haibin) this is a temporary solution for testing. remove later + attrs = sym.attr_dict() + args_grad_data = {} + for k, v in args_grad_npy.items(): + attr = attrs.get(k, {}) + grad_stype = attr.get('grad_stype_hint', None) + nd = mx.nd.array(v, ctx=ctx) + if grad_stype is not None: + out = mx.nd.cast_storage(nd, storage_type=grad_stype) + args_grad_data[k] = out + else: + args_grad_data[k] = nd + if isinstance(grad_req, str): grad_req = {k:grad_req for k in sym.list_arguments()} elif isinstance(grad_req, (list, tuple)): grad_req = {k:v for k, v in zip(sym.list_arguments(), grad_req)} - executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states) + executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, + aux_states=aux_states, grad_req=grad_req) executor.forward(is_train=True) + if isinstance(out_grads, (tuple, list)): out_grads = [mx.nd.array(v, ctx=ctx) for v in out_grads] elif isinstance(out_grads, (dict)): diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ae7af5bad129..ccddc03a8e29 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -154,6 +154,39 @@ int MXNDArrayCreateEx(const mx_uint *shape, API_END(); } +int MXNDArrayCreateSparseEx(int storage_type, + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out) { + API_BEGIN(); + std::vector aux_types; + std::vector aux_shapes; + auto shape_start = aux_shape; + for (size_t i = 0; i < num_aux; i++) { + // types + aux_types.push_back(aux_type[i]); + // shapes + aux_shapes.emplace_back(shape_start, shape_start + aux_ndims[i]); + shape_start += aux_ndims[i]; + } + *out = new NDArray( + NDArrayStorageType(storage_type), + TShape(shape, shape + ndim), + Context::Create(static_cast(dev_type), dev_id), + delay_alloc != 0, + dtype, aux_types, aux_shapes); + API_END(); +} + + int MXNDArrayLoadFromRawBytes(const void *buf, size_t size, NDArrayHandle *out) { @@ -287,6 +320,16 @@ int MXNDArraySlice(NDArrayHandle handle, API_END_HANDLE_ERROR(delete ptr); } +int MXNDArraySliceEx(NDArrayHandle handle, + mx_uint slice_begin, + mx_uint slice_end, + NDArrayHandle out) { + NDArray *ptr = static_cast(out); + API_BEGIN(); + static_cast(handle)->SliceEx(slice_begin, slice_end, ptr); + API_END(); +} + int MXNDArrayAt(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out) { @@ -333,6 +376,18 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle, API_END_HANDLE_ERROR(delete ptr); } +int MXNDArrayGetStorageType(NDArrayHandle handle, + int *out_storage_type) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + if (!arr->is_none()) { + *out_storage_type = arr->storage_type(); + } else { + *out_storage_type = kUndefinedStorage; + } + API_END(); +} + int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata) { @@ -378,6 +433,32 @@ int MXNDArrayGetDType(NDArrayHandle handle, API_END(); } +int MXNDArrayGetAuxType(NDArrayHandle handle, + mx_uint i, + int *out_type) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out_type = arr->aux_type(i); + API_END(); +} + +int MXNDArrayGetAuxNDArray(NDArrayHandle handle, + mx_uint i, + NDArrayHandle *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out = new NDArray(arr->aux_ndarray(i)); + API_END(); +} + +int MXNDArrayGetDataNDArray(NDArrayHandle handle, + NDArrayHandle *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out = new NDArray(arr->data_ndarray()); + API_END(); +} + int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_type, int *out_dev_id) { diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index e2e739ae62a4..27bce311f980 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -58,6 +58,8 @@ struct MXAPIThreadLocalEntry { std::vector arg_shapes, out_shapes, aux_shapes; /*! \brief result holder for returning type flags */ std::vector arg_types, out_types, aux_types; + /*! \brief result holder for returning storage types */ + std::vector arg_storage_types, out_storage_types, aux_storage_types; /*! \brief result holder for returning shape dimensions */ std::vector arg_shape_ndim, out_shape_ndim, aux_shape_ndim; /*! \brief result holder for returning shape pointer */ diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index ce765acd77bf..aae7fe5e3c9f 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -154,6 +154,332 @@ int MXExecutorBindEX(SymbolHandle symbol_handle, API_END_HANDLE_ERROR(delete exec); } +/*! + * \brief + * \param symbol_handle symbol handle + * \param dev_type default device type + * \param dev_id default device id + * \param num_g2c_keys number of group2ctx keys + * \param g2c_keys key list of group2ctx + * \param g2c_dev_types device type list of group2ctx + * \param g2c_dev_ids id list of group2ctx + * \param provided_grad_req_list_len grad_req length provided by users in front-end + * \param provided_grad_req_names grad_req names provided by users in front-end + * \param provided_grad_req_types req types provided by users in front-end + * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes + * \param provided_arg_shape_names name list of provided shapes + * \param provided_arg_shape_data provided shape data + * \param provided_arg_shape_idx provided shape data index + * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes + * \param provided_arg_dtype_names argument name list of provided dtypes + * \param provided_arg_dtypes data of provided dtypes + * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types + * \param provided_arg_stype_names argument name list of provided storage types + * \param provided_arg_stypes data of provided storage types + * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec + * \param shared_arg_name_list parameter name list passed from _bind_ith_exec + * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec + * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec + * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec + * \param num_in_args number of input arguments of this sym + * \param in_args list_arguments associated with the current executor + * \param arg_grads list of gradients of in_args associated with the current executor + * \param num_aux_states number of aux states of this sym + * \param aux_states list_auxiliary_states associated with the current executor + * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec + * \param out the handle of the executor to be created + */ +int MXExecutorSimpleBind(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const mx_uint* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const mx_uint num_shared_arg_names, + const char** shared_arg_name_list, + mx_uint* shared_buffer_len, + const char*** shared_buffer_name_list, + NDArrayHandle** shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + nnvm::Symbol *sym = static_cast(symbol_handle); + + // get in_arg names + std::vector in_arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs); + std::vector aux_state_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates); + + // attr_dict for setting up type_dict and arg/aux ctx + std::unordered_map> attr_dict; + if (nullptr == provided_arg_dtypes || nullptr == g2c_keys) { + std::vector> attrs = + sym->ListAttrsRecursive(); + attr_dict.reserve(attrs.size()); + for (const auto& tp : attrs) { + attr_dict[std::get<0>(tp)][std::get<1>(tp)] = std::get<2>(tp); + } + } + + // setup arg_dtype_map + std::unordered_map arg_dtype_map; + if (nullptr == provided_arg_dtypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__dtype__")) { + arg_dtype_map[arg_name] = mshadow::kFloat32; + } + } + } else { // use user input type_dict + // create dtype map for in_args and aux_states + arg_dtype_map.reserve(num_provided_arg_dtypes); + for (mx_uint i = 0; i < num_provided_arg_dtypes; ++i) { + arg_dtype_map[provided_arg_dtype_names[i]] = provided_arg_dtypes[i]; + } + } + + // setup arg_stype_map + std::unordered_map arg_stype_map; + if (nullptr == provided_arg_stypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__storage_type__")) { + arg_stype_map[arg_name] = kDefaultStorage; + } + } + } else { // use user input type_dict + // create stype map for in_args and aux_states + arg_stype_map.reserve(num_provided_arg_stypes); + for (mx_uint i = 0; i < num_provided_arg_stypes; ++i) { + arg_stype_map[provided_arg_stype_names[i]] = provided_arg_stypes[i]; + } + } + + // create default ctx + Context ctx = Context::Create(static_cast(dev_type), dev_id); + // create ctx map + std::map ctx_map; + std::vector in_arg_ctx_vec(in_arg_names.size(), ctx); + std::vector aux_state_ctx_vec(aux_state_names.size(), ctx); + if (nullptr != g2c_keys) { // use user input group2ctx dict + for (mx_uint i = 0; i < num_g2c_keys; ++i) { + ctx_map[g2c_keys[i]] = Context::Create( + static_cast(g2c_dev_types[i]), g2c_dev_ids[i]); + } + + // initialize in_arg_ctx_vec using group2ctx if there are any + for (size_t i = 0; i < in_arg_ctx_vec.size(); ++i) { + const auto it1 = attr_dict.find(in_arg_names[i]); + if (it1 != attr_dict.end()) { + const auto it2 = it1->second.find("__ctx_group__"); + if (it2 != it1->second.end()) { + const auto it3 = ctx_map.find(it2->second); + if (it3 != ctx_map.end()) { + in_arg_ctx_vec[i] = it3->second; + } + } + } + } + + // initialize aux_state_ctx_vec using group2ctx if there are any + for (size_t i = 0; i < aux_state_ctx_vec.size(); ++i) { + const auto it1 = attr_dict.find(aux_state_names[i]); + if (it1 != attr_dict.end()) { + const auto it2 = it1->second.find("__ctx_group__"); + if (it2 != it1->second.end()) { + const auto it3 = ctx_map.find(it2->second); + if (it3 != ctx_map.end()) { + aux_state_ctx_vec[i] = it3->second; + } + } + } + } + } + + // create provided_grad_req_map + const std::map req_map = + {{"null", kNullOp}, {"write", kWriteTo}, {"add", kAddTo}}; + std::unordered_map provided_grad_req_map; + std::string grad_req_type; + if (0 == provided_grad_req_list_len + && nullptr == provided_grad_req_names + && nullptr != provided_grad_req_types) { // string, grad_req='write' + CHECK_EQ(req_map.count(provided_grad_req_types[0]), 1U) + << "grad_req=" << provided_grad_req_types[0] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + grad_req_type = "string"; + } else if (provided_grad_req_list_len > 0 + && nullptr == provided_grad_req_names + && nullptr != provided_grad_req_types) { // list, grad_req=['null', 'write'] + grad_req_type = "list"; + CHECK_EQ(provided_grad_req_list_len, in_arg_names.size()) + << "The length of grad_req list does not match the number of input arguments in simple_bind, " + "expected " << in_arg_names.size() << ", provided " << provided_grad_req_list_len; + } else if (provided_grad_req_list_len > 0 + && nullptr != provided_grad_req_names + && nullptr != provided_grad_req_types) { // dict, grad_req=['lhs': 'null', 'rhs': 'write'] + grad_req_type = "dict"; + provided_grad_req_map.reserve(provided_grad_req_list_len); + for (mx_uint i = 0; i < provided_grad_req_list_len; ++i) { + CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) + << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + provided_grad_req_map[provided_grad_req_names[i]] = provided_grad_req_types[i]; + } + } else { // grad_req is None + grad_req_type = "none"; + } + + // initialize arg_grad_ctx_vec and grad_req_type_vec + std::vector arg_grad_ctx_vec(in_arg_names.size(), ctx); + std::vector grad_req_type_vec(in_arg_names.size(), kNullOp); + if ("none" != grad_req_type) { + for (size_t i = 0; i < in_arg_names.size(); ++i) { + OpReqType cur_req = kNullOp; + if ("string" == grad_req_type) { + cur_req = req_map.at(provided_grad_req_types[0]); + } else if ("list" == grad_req_type) { + CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) + << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' are supported"; + cur_req = req_map.at(provided_grad_req_types[i]); + } else if ("dict" == grad_req_type) { + const auto it = provided_grad_req_map.find(in_arg_names[i]); + if (it != provided_grad_req_map.end()) { + cur_req = req_map.at(it->second); + } + } + if (kNullOp != cur_req) { + arg_grad_ctx_vec[i] = in_arg_ctx_vec[i]; + grad_req_type_vec[i] = static_cast(cur_req); + } + } + } + + // create shape map for in_args and aux_states + std::unordered_map arg_shape_map(num_provided_arg_shapes); + for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) { + auto p = arg_shape_map.emplace(provided_arg_shape_names[i], + TShape(provided_arg_shape_data+provided_arg_shape_idx[i], + provided_arg_shape_data+provided_arg_shape_idx[i+1])); + CHECK(p.second) << "Duplicate shapes are provided for argument " + << provided_arg_shape_names[i] << " in simple_bind"; + } + + // create para name set for sharing data array memory + std::unordered_set shared_arg_name_set(num_shared_arg_names); + for (mx_uint i = 0; i < num_shared_arg_names; ++i) { + shared_arg_name_set.insert(shared_arg_name_list[i]); + } + + // create shared_buffer_map + std::unordered_map shared_buffer_map; + std::vector shared_exec_in_args; + std::vector shared_exec_arg_grads; + std::vector shared_exec_aux_states; + bool use_shared_buffer = (nullptr != *shared_buffer_handle_list); + if (use_shared_buffer) { + // create shared_buffer_map + shared_buffer_map.reserve(*shared_buffer_len); + NDArray*** shared_buffer_ptrs = + reinterpret_cast(shared_buffer_handle_list); + for (mx_uint i = 0; i < *shared_buffer_len; ++i) { + shared_buffer_map[*shared_buffer_name_list[i]] = *(*shared_buffer_ptrs)[i]; + } + } + + // create temporary place holders for the initialized NDArrays + // to be passed back to front end + std::vector in_arg_vec; + std::vector arg_grad_vec; + std::vector aux_state_vec; + + *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, + aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_type_vec, shared_arg_name_set, &in_arg_vec, + &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, + reinterpret_cast(shared_exec_handle)); + + // copy ndarray ptrs to ret->handles so that front end + // can access them + ret->ret_handles.clear(); + ret->ret_handles.reserve(in_arg_vec.size()+arg_grad_vec.size()+aux_state_vec.size() + +shared_buffer_map.size()); + size_t nd_idx = 0; + for (const auto& nd : in_arg_vec) { + if (nd.is_none()) { + LOG(FATAL) << "Input argument NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(nd)); + } + if (in_arg_vec.size() > 0) { + *num_in_args = in_arg_vec.size(); + *in_args = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + for (const auto& nd : arg_grad_vec) { + if (nd.is_none()) { + ret->ret_handles.push_back(nullptr); + } else { + ret->ret_handles.push_back(new NDArray(nd)); + } + } + if (arg_grad_vec.size() > 0) { + *arg_grads = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + for (const auto& nd : aux_state_vec) { + if (nd.is_none()) { + LOG(FATAL) << "Auxiliary argument NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(nd)); + } + if (aux_state_vec.size() > 0) { + *num_aux_states = aux_state_vec.size(); + *aux_states = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + if (use_shared_buffer) { + ret->ret_vec_charp.clear(); + ret->ret_vec_charp.reserve(shared_buffer_map.size()); + for (const auto kv : shared_buffer_map) { + if (kv.second.is_none()) { + LOG(FATAL) << "Shared data NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(kv.second)); + ret->ret_vec_charp.push_back(kv.first.c_str()); + } + *shared_buffer_len = shared_buffer_map.size(); + *shared_buffer_handle_list = &(ret->ret_handles[nd_idx]); + *shared_buffer_name_list = &(ret->ret_vec_charp[0]); + } + + API_END(); +} + int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, void* callback_handle) { diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index c633e8609cd4..9db999406a0d 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2016 by Contributors - * \file c_api_symbolic.cc + * \file c_api_ndarray.cc * \brief C API of mxnet */ @@ -16,6 +16,8 @@ #include "../common/utils.h" #include "../ndarray/autograd.h" +#define IMPERATIVE_EXEC_DEBUG 0 + using namespace mxnet; using mxnet::autograd::AutogradRuntime; @@ -121,16 +123,18 @@ void SetContext(Context* p_ctx, ctx = Context::CPU(); } } - +// Set the shape, dtype and storage type void SetShapeType(const nnvm::Op* op, const nnvm::NodeAttrs& attrs, const Context& ctx, const std::vector& ndinputs, const int& infered_num_outputs, - std::vector* p_ndoutputs) { + std::vector* p_ndoutputs, + int* dispatch_stype) { std::vector& ndoutputs = *p_ndoutputs; static auto& infershape = nnvm::Op::GetAttr("FInferShape"); static auto& infertype = nnvm::Op::GetAttr("FInferType"); + static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); // infer shape std::vector& in_shapes = ret->arg_shapes; @@ -166,9 +170,41 @@ void SetShapeType(const nnvm::Op* op, CHECK(infertype[op](attrs, &in_types, &out_types)); CHECK_EQ(out_types.size(), static_cast(infered_num_outputs)); + // infer storage type + auto& in_storage_types = ret->arg_storage_types; + auto& out_storage_types = ret->out_storage_types; + in_storage_types.clear(); + out_storage_types.clear(); + + for (auto& i : ndinputs) { + in_storage_types.push_back(i.storage_type()); + } + for (auto& i : ndoutputs) { + out_storage_types.push_back(i.storage_type()); + } + if (inferstorage.count(op)) { + CHECK(inferstorage[op](attrs, &in_storage_types, &out_storage_types)); + CHECK_EQ(out_storage_types.size(), static_cast(infered_num_outputs)); + } else { +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "FInferStorageType not present."; +#endif + } + + bool contains_non_default = common::ContainsNonDefaultStorage(in_storage_types); + contains_non_default |= common::ContainsNonDefaultStorage(out_storage_types); + int kNonDefaultStorage = -2; + *dispatch_stype = contains_non_default ? kNonDefaultStorage : kDefaultStorage; + for (int i = 0; i < infered_num_outputs; ++i) { + NDArrayStorageType storage_type = static_cast(out_storage_types[i]); if (ndoutputs[i].is_none()) { - ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); + // If failed to infer the storage type, assume the output storage is dense + if (storage_type == kDefaultStorage || out_storage_types[i] == kUndefinedStorage) { + ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); + } else { + ndoutputs[i] = NDArray(storage_type, out_shapes[i], ctx, true, out_types[i]); + } } else { CHECK_EQ(ndoutputs[i].shape(), out_shapes[i]) << i << "th output has invalid shape. " @@ -215,23 +251,20 @@ void SetDependency(std::vector *p_read_vars, } CHECK_LE(ntmp, 1) << "Only support 1 temp space request"; } - - for (auto& i : ndinputs) { - read_vars.push_back(i.var()); - } - for (auto& i : ndoutputs) { - write_vars.push_back(i.var()); - } + for (auto& i : ndinputs) read_vars.emplace_back(i.var()); + for (auto& i : ndoutputs) write_vars.emplace_back(i.var()); if (mutate.count(op)) { auxidx = mutate[op](attrs); std::sort(auxidx.begin(), auxidx.end()); - for (auto & i : auxidx) { - write_vars.push_back(ndinputs[i].var()); + for (auto& i : auxidx) { + auto var = ndinputs[i].var(); + write_vars.push_back(var); } } Engine::Get()->DeduplicateVarHandle(&read_vars, &write_vars); } + void PushFCompute(const FCompute& fn, const nnvm::Op* op, const nnvm::NodeAttrs& attrs, @@ -247,15 +280,21 @@ void PushFCompute(const FCompute& fn, RunContext rctx, engine::CallbackOnComplete on_complete) { std::vector input_blobs, output_blobs; - for (auto& i : ndinputs) { - input_blobs.push_back(i.data()); - } - for (auto& i : ndoutputs) { - output_blobs.push_back(i.data()); - } + std::vector tmps; OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; + if (ctx.dev_mask() == gpu::kDevMask) { +#if MXNET_USE_CUDA + common::GetInputBlobs(ndinputs, &input_blobs, &tmps, opctx); + common::GetOutputBlobs(ndoutputs, &output_blobs); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + common::GetInputBlobs(ndinputs, &input_blobs, &tmps, opctx); + common::GetOutputBlobs(ndoutputs, &output_blobs); + } std::vector req(output_blobs.size(), kWriteTo); fn(attrs, opctx, input_blobs, req, output_blobs); if (ctx.dev_mask() == gpu::kDevMask) { @@ -266,6 +305,33 @@ void PushFCompute(const FCompute& fn, 0, PROFILER_MESSAGE(op->name.c_str())); } +void PushFComputeEx(const FComputeEx& fn, + const nnvm::Op* op, + const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& read_vars, + const std::vector& write_vars, + const std::vector& requested, + const std::vector& ndinputs, + const std::vector& ndoutputs) { + Engine::Get()->PushAsync( + [ctx, attrs, fn, ndinputs, ndoutputs, requested]( + RunContext rctx, + engine::CallbackOnComplete on_complete) { + std::vector input_blobs, output_blobs; + OpContext opctx{false, rctx, + engine::CallbackOnComplete(), + requested}; + std::vector req(ndoutputs.size(), kWriteTo); + fn(attrs, opctx, ndinputs, req, ndoutputs); + if (ctx.dev_mask() == gpu::kDevMask) { + rctx.get_stream()->Wait(); + } + on_complete(); + }, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, PROFILER_MESSAGE(op->name.c_str())); +} + void PushOperator(std::shared_ptr opr, const nnvm::Op* op, const nnvm::NodeAttrs& attrs, @@ -329,8 +395,6 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, int num_params, const char **param_keys, const char **param_vals) { - static auto& fcpu = nnvm::Op::GetAttr("FCompute"); - static auto& fgpu = nnvm::Op::GetAttr("FCompute"); static auto& ndfunc = nnvm::Op::GetAttr("FNDArrayFunction"); static auto& createop = nnvm::Op::GetAttr("FCreateLayerOp"); const nnvm::Op* op = static_cast(creator); @@ -344,20 +408,23 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, int infered_num_outputs; int num_visible_outputs; - SetNumOutputs(op, attrs, num_inputs, - &infered_num_outputs, &num_visible_outputs); + SetNumOutputs(op, attrs, num_inputs, &infered_num_outputs, &num_visible_outputs); std::vector ndinputs, ndoutputs; SetNDInputsOutputs(op, &ndinputs, &ndoutputs, num_inputs, inputs, - num_outputs, infered_num_outputs, num_visible_outputs, outarray); + num_outputs, infered_num_outputs, num_visible_outputs, outarray); if (ndfunc.count(op)) { ndfunc[op](attrs, ndinputs, &ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "NDArray function executed."; +#endif } else { // TODO(piiswrong): infer ctx Context ctx; + int storage_type; SetContext(&ctx, attrs, num_inputs, ndinputs, infered_num_outputs, ndoutputs); - SetShapeType(op, attrs, ctx, ndinputs, infered_num_outputs, &ndoutputs); + SetShapeType(op, attrs, ctx, ndinputs, infered_num_outputs, &ndoutputs, &storage_type); std::vector read_vars, write_vars; std::vector requested; @@ -365,20 +432,24 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, SetDependency(&read_vars, &write_vars, &requested, &auxidx, op, attrs, ctx, ndinputs, ndoutputs); - FCompute fn; - if (ctx.dev_mask() == cpu::kDevMask && fcpu.count(op)) { - fn = fcpu[op]; - } else if (ctx.dev_mask() == gpu::kDevMask && fgpu.count(op)) { - fn = fgpu[op]; - } - - if (fn) { + FCompute fn = common::GetFCompute(op, ctx); + FComputeEx fcomp_ex = common::GetFComputeEx(op, ctx, storage_type); + if (fcomp_ex) { + PushFComputeEx(fcomp_ex, op, attrs, ctx, read_vars, write_vars, requested, + ndinputs, ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "FComputeEx executed."; +#endif + } else if (fn) { if (AutogradRuntime::Get()->IsTraining()) { AutogradRuntime::Get()->RecordImperativeFCompute(op, attrs, &ndinputs, &ndoutputs); } PushFCompute(fn, op, attrs, ctx, read_vars, write_vars, requested, ndinputs, ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "FCompute executed."; +#endif } else if (createop.count(op)) { std::shared_ptr opr( createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types)); @@ -388,11 +459,14 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, } PushOperator(opr, op, attrs, ctx, read_vars, write_vars, requested, auxidx, ndinputs, ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "CreateOp executed."; +#endif } else { LOG(FATAL) << "Operator " << op->name << " cannot be run; requires at least one of" - << " FCompute, NDArrayFunction, FCreateOperator be registered"; + << " FCompute, FComputeEx NDArrayFunction, FCreateOperator be registered"; } } diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index f7281c999e6a..b6e1c30e7dd8 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -363,7 +363,6 @@ int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) { API_END(); } - namespace mxnet { template @@ -497,6 +496,58 @@ int MXSymbolInferShapePartial(SymbolHandle sym, &succ); } +// TODO(haibin) refactor with infer_type +int MXSymbolInferStorageType(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int *arg_storage_type_data, + mx_uint *in_storage_type_size, + const int **in_storage_type_data, + mx_uint *out_storage_type_size, + const int **out_storage_type_data, + mx_uint *aux_storage_type_size, + const int **aux_storage_type_data, + int *complete) { + nnvm::Symbol *s = static_cast(sym); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + nnvm::Graph g = Symbol2Graph(*s); + nnvm::StorageTypeVector arg_storage_types(g.indexed_graph().input_nodes().size(), + kUndefinedStorage); + if (keys == nullptr && num_args != 0) { + std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); + CHECK_LE(num_args, read_only_args.size()); + for (mx_uint i = 0; i < num_args; ++i) { + arg_storage_types[read_only_args[i]] = arg_storage_type_data[i]; + } + } else { + std::unordered_map kwargs; + for (mx_uint i = 0; i < num_args; ++i) { + kwargs[keys[i]] = arg_storage_type_data[i]; + } + mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_storage_types, "InferStorageType"); + } + + g = nnvm::pass::InferStorageType(std::move(g), arg_storage_types, "__storage_type__"); + // copy back + CopyAttr(g.indexed_graph(), g.GetAttr("storage_type"), + &(ret->arg_storage_types), &(ret->out_storage_types), &(ret->aux_storage_types)); + + *in_storage_type_size = static_cast(ret->arg_storage_types.size()); + *in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types); + *out_storage_type_size = static_cast(ret->out_storage_types.size()); + *out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types); + *in_storage_type_size = static_cast(ret->arg_storage_types.size()); + *in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types); + *out_storage_type_size = static_cast(ret->out_storage_types.size()); + *out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types); + *aux_storage_type_size = static_cast(ret->aux_storage_types.size()); + *aux_storage_type_data = dmlc::BeginPtr(ret->aux_storage_types); + *complete = (g.GetAttr("storage_type_num_unknown_nodes") == 0); + API_END(); +} + + int MXSymbolInferType(SymbolHandle sym, mx_uint num_args, const char** keys, diff --git a/src/common/utils.h b/src/common/utils.h index 789b4d14b9f2..1687a0909839 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -18,11 +18,106 @@ #include #include +#include +#include +#include namespace mxnet { +// forward declaration +namespace op { +template +void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +} + namespace common { #if DMLC_USE_CXX11 +/* + * \brief Get input TBlobs from NDArrays, potentially performing cast_storage op and store + * temporary NDArrays in temps. If storage_fallback is false, + * MXNET_EXEC_STORAGE_FALLBACK env var determines whether storage type fallback is allowed. + */ +template +inline void GetInputBlobs(const std::vector& nds, + std::vector *blobs, + std::vector *temps, + const OpContext& ctx, + bool storage_fallback = false) { + if (storage_fallback == false) { + storage_fallback = dmlc::GetEnv("MXNET_EXEC_STORAGE_FALLBACK", true); + } + for (auto& nd : nds) { + if (nd.storage_type() != kDefaultStorage) { + if (storage_fallback == false) { + LOG(FATAL) << "Storage type conversion detected during execution. " + << "You are probably executing an operator which " + << "doesn't support NDArray inputs with non-default storage."; + } + NDArray temp(nd.shape(), nd.ctx(), false); + op::CastStorageComputeImpl(ctx.get_stream(), nd, temp); + temps->push_back(temp); + blobs->push_back(temp.data()); + } else { + blobs->push_back(nd.data()); + } + } +} + +template +inline void GetOutputBlobs(const std::vector& nds, + std::vector *blobs) { + for (auto& nd : nds) { + blobs->push_back(nd.data()); + } +} + +// Check if any storage type is not default storage +inline bool ContainsNonDefaultStorage(const nnvm::StorageTypeVector& vstorage) { + for (auto& i : vstorage) { + if (i != kUndefinedStorage && i != kDefaultStorage) return true; + } + return false; +} + +inline bool ContainsDefaultStorage(const std::vector& ndarrays) { + for (auto &nd : ndarrays) { + if (nd.storage_type() == kDefaultStorage) { + return true; + } + } + return false; +} + +inline FCompute GetFCompute(const Op* op, Context ctx) { + static auto& fcompute_cpu = nnvm::Op::GetAttr("FCompute"); + static auto& fcompute_gpu = nnvm::Op::GetAttr("FCompute"); + if (ctx.dev_mask() == cpu::kDevMask) { + return fcompute_cpu.get(op, nullptr); + } else if (ctx.dev_mask() == gpu::kDevMask) { + return fcompute_gpu.get(op, nullptr); + } + LOG(FATAL) << "Unknown device mask"; + return nullptr; +} + +inline FComputeEx GetFComputeEx(const Op* op, Context ctx, int stype) { + static auto& fcpu = nnvm::Op::GetAttr(FCOMP_EX_CPU); + static auto& fgpu = nnvm::Op::GetAttr(FCOMP_EX_GPU); + if (stype == kDefaultStorage) return nullptr; + if (ctx.dev_mask() == cpu::kDevMask) { + return fcpu.get(op, nullptr); + } else if (ctx.dev_mask() == gpu::kDevMask) { + return fgpu.get(op, nullptr); + } + LOG(FATAL) << "Unknown device mask"; + return nullptr; +} + + // heuristic to dermine number of threads per GPU inline int GetNumThreadPerGPU() { // This is resource efficient option. diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 16b55adc15e8..27839760f7ea 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -8,11 +8,15 @@ #include #include #include "./exec_pass.h" +#include "../common/utils.h" #if MXNET_USE_MKL2017 == 1 #include #include "../operator/mkl/mkl_memory-inl.h" #include "../operator/mkl/mkl_util-inl.h" #endif + +#define EXEC_ATTACH_OP_DEBUG 0 + namespace mxnet { namespace op { @@ -24,8 +28,28 @@ namespace exec { // forward executor class ForwardOpExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; + + // TODO(haibin) ForwardOp is stateful. If any input ndarray has non-default storage, + // we need to cast it to default storage and setup the tblobs again. For example, + // if any of the input ndarray chagnes, the updated value won't be reflected in the temporary + // ndarray with default storage. This is not efficient and should be improved later. + in_data_.clear(); out_data_.clear(); aux_data_.clear(); tmps_.clear(); + if (is_gpu) { +#if MXNET_USE_CUDA + common::GetInputBlobs(in_array_, &in_data_, &tmps_, op_ctx); + common::GetInputBlobs(aux_array_, &aux_data_, &tmps_, op_ctx); + common::GetOutputBlobs(out_array, &out_data_); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + common::GetInputBlobs(in_array_, &in_data_, &tmps_, op_ctx); + common::GetInputBlobs(aux_array_, &aux_data_, &tmps_, op_ctx); + common::GetOutputBlobs(out_array, &out_data_); + } + op_->Forward(op_ctx, in_data_, req, out_data_, aux_data_); #if MKL_EXPERIMENTAL == 1 mkl_tblobs_prv_to_cpu(in_data_); @@ -35,18 +59,14 @@ class ForwardOpExecutor : public OpExecutor { } void Setup() override { - in_data_.clear(); aux_data_.clear(); + // We need to tell whether in NDArray is input or aux for (size_t i = 0; i < in_array.size(); ++i) { if (!std::binary_search(aux_index_.begin(), aux_index_.end(), i)) { - in_data_.push_back(in_array[i].data()); + in_array_.emplace_back(in_array[i]); } else { - aux_data_.push_back(in_array[i].data()); + aux_array_.emplace_back(in_array[i]); } } - out_data_.resize(out_array.size()); - std::transform(out_array.begin(), out_array.end(), out_data_.begin(), [](const NDArray& nd) { - return nd.data(); - }); } Operator::ExecType exec_type() const override { return op_->exec_type(); @@ -62,12 +82,13 @@ class ForwardOpExecutor : public OpExecutor { std::shared_ptr op_; std::vector aux_index_; std::vector in_data_, out_data_, aux_data_; + std::vector in_array_, aux_array_, tmps_; }; // backward executor class BackwardOpExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; op_->Backward(op_ctx, out_grad_, in_data_, out_data_, req, in_grad_, aux_data_); @@ -135,23 +156,32 @@ class BackwardOpExecutor : public OpExecutor { // fcompute executor executor class FComputeExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; + // setup blobs + // TODO(haibin) we should avoid repeating this if it's known that all inputs are in + // default-storage. + { + in_data_.clear(); out_data_.clear(), tmp_nds_.clear(); + if (is_gpu) { +#if MXNET_USE_CUDA + common::GetInputBlobs(in_array, &in_data_, &tmp_nds_, op_ctx); + common::GetOutputBlobs(out_array, &out_data_); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + common::GetInputBlobs(in_array, &in_data_, &tmp_nds_, op_ctx); + common::GetOutputBlobs(out_array, &out_data_); + } + } fcompute_(attrs_, op_ctx, in_data_, req, out_data_); #if MKL_EXPERIMENTAL == 1 mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); #endif } - void Setup() override { - in_data_.resize(in_array.size()); - out_data_.resize(out_array.size()); - auto get_blob = [](const NDArray& nd) { - return nd.data(); - }; - std::transform(in_array.begin(), in_array.end(), in_data_.begin(), get_blob); - std::transform(out_array.begin(), out_array.end(), out_data_.begin(), get_blob); - } + void Setup() override {} Operator::ExecType exec_type() const override { return Operator::kSync; } @@ -159,28 +189,41 @@ class FComputeExecutor : public OpExecutor { : fcompute_(fcompute), attrs_(attrs) { } - static FCompute GetFCompute(const Op* op, Context ctx) { - static auto& fcompute_cpu = nnvm::Op::GetAttr("FCompute"); - static auto& fcompute_gpu = nnvm::Op::GetAttr("FCompute"); - if (ctx.dev_mask() == cpu::kDevMask) { - return fcompute_cpu.get(op, nullptr); - } else if (ctx.dev_mask() == gpu::kDevMask) { - return fcompute_gpu.get(op, nullptr); - } else { - LOG(FATAL) << "Unknown device mask"; - return nullptr; - } - } - private: FCompute fcompute_; NodeAttrs attrs_; std::vector in_data_, out_data_; + std::vector tmp_nds_; +}; + +// fcomputend executor +class FComputeExExecutor : public OpExecutor { + public: + void Run(RunContext rctx, bool is_gpu) override { + op_ctx.run_ctx = rctx; + fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + } + void Setup() override { + in_data_ = in_array; + out_data_ = out_array; + } + Operator::ExecType exec_type() const override { + return Operator::kSync; + } + explicit FComputeExExecutor(FComputeEx fcompute, const NodeAttrs& attrs) + : fcompute_(fcompute), attrs_(attrs) { + } + + private: + FComputeEx fcompute_; + NodeAttrs attrs_; + std::vector in_data_, out_data_; }; // pass to attach operator executors Graph AttachOpExecs(Graph g) { using nnvm::DTypeVector; + using nnvm::StorageTypeVector; using nnvm::ShapeVector; using nnvm::FMutateInputs; @@ -193,6 +236,7 @@ Graph AttachOpExecs(Graph g) { const auto& vctx = g.GetAttr("context"); const auto& saved_opr = g.GetAttr< std::unordered_map>>("saved_opr"); + const auto& dispatch_stypes = g.GetAttr("dispatch_stypes"); // get the graph const auto& idx = g.indexed_graph(); @@ -206,7 +250,12 @@ Graph AttachOpExecs(Graph g) { if (fmutate_inputs.count(inode.source->op())) { mutate_index = fmutate_inputs[inode.source->op()](inode.source->attrs); } - FCompute fcompute = FComputeExecutor::GetFCompute(inode.source->op(), vctx[i]); + FCompute fcompute = common::GetFCompute(inode.source->op(), vctx[i]); + FComputeEx fcompute_ex = + common::GetFComputeEx(inode.source->op(), vctx[i], dispatch_stypes[i]); +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "dispatch storage type = " << dispatch_stypes[i]; +#endif if (fcreate_layer_op.count(inode.source->op())) { std::vector ishape; std::vector itype; @@ -222,19 +271,33 @@ Graph AttachOpExecs(Graph g) { inode.source->attrs, vctx[i], ishape, itype)); } ret[i] = std::make_shared(opr, mutate_index); +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "ForwardOp for op " << inode.source->op()->name; +#endif } else if (is_layer_backward.get(inode.source->op(), false)) { CHECK_GE(inode.control_deps.size(), 1); uint32_t fwd_id = inode.control_deps[0]; CHECK(vctx[fwd_id] == vctx[i]); CHECK(ret[fwd_id] != nullptr); + CHECK_EQ(dispatch_stypes[i], kDefaultStorage) + << "BackwardOp doesn't handle non-default storage yet"; ret[i] = std::make_shared( dynamic_cast(ret[fwd_id].get())->op_, mxnet::op::OpPropGetOpProperty(inode.source->attrs), mutate_index); +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "BackwardOp for op " << inode.source->op()->name; +#endif + } else if (fcompute_ex != nullptr) { +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "FComputeEx for op " << inode.source->op()->name; +#endif + ret[i] = std::make_shared(fcompute_ex, inode.source->attrs); } else if (fcompute != nullptr) { +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "FCompute for op " << inode.source->op()->name; +#endif ret[i] = std::make_shared(fcompute, inode.source->attrs); - } else { - LOG(INFO) << "FCompute not registered " << inode.source->op()->name; } } g.attrs["op_execs"] = std::make_shared(ret); diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 8df6a3c5d3bb..20535be320d9 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -19,6 +19,12 @@ namespace exec { /*! \brief reuse graph definition */ using nnvm::Graph; +const int kBadStorageID = -1; +const int kExternalStorageID = -2; +const int kDynamicStorageID = -3; + +const int kNonDefaultStorage = -2; + /*! * \brief executor to execute an operator * This is a graph executor dependent interface @@ -26,7 +32,7 @@ using nnvm::Graph; */ class OpExecutor { public: - /*! \brief input arrays */ + /*! \brief input data arrays, which may be either input or aux */ std::vector in_array; /*! \brief output data arrays */ std::vector out_array; @@ -47,7 +53,7 @@ class OpExecutor { * This function call do not synchronize the stream. * \param rctx The runtime context passed in by environment. */ - virtual void Run(RunContext rctx) = 0; + virtual void Run(RunContext rctx, bool is_gpu) = 0; /*! \return the execution type */ virtual Operator::ExecType exec_type() const = 0; }; diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 6ba0ff96b382..c07e86c49b3f 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -12,6 +12,7 @@ #include "./exec_pass.h" #include "./graph_executor.h" #include "../engine/profiler.h" +#include "../common/utils.h" namespace mxnet { namespace exec { @@ -29,6 +30,30 @@ GraphExecutor::~GraphExecutor() { } } +inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype) { + // NDArray with default storage + if (stype == kDefaultStorage) { + NDArray ret(shape, ctx, false, dtype); + ret = 0; + return ret; + } + // NDArray with non-default storage. Storage allocation is always delayed. + return NDArray(stype, shape, ctx, true, dtype); +} + +inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype, + std::vector *vec) { + // NDArray with default storage + if (stype == kDefaultStorage) { + vec->emplace_back(shape, ctx, false, dtype); + vec->back() = 0; + } else { + // NDArray with non-default storage. Storage allocation is always delayed. + vec->emplace_back(stype, shape, ctx, true, dtype); + } +} void GraphExecutor::Forward(bool is_train) { RunOps(is_train, 0, num_forward_nodes_); } @@ -78,6 +103,18 @@ const std::vector& GraphExecutor::outputs() const { return output_arrays_; } +const std::unordered_map& GraphExecutor::in_arg_map() const { + return in_arg_map_; +} + +const std::unordered_map& GraphExecutor::arg_grad_map() const { + return arg_grad_map_; +} + +const std::unordered_map& GraphExecutor::aux_state_map() const { + return aux_state_map_; +} + nnvm::NodeEntry AttrHint(nnvm::NodeEntry src, nnvm::NodeEntry like) { static const Op* id_like = Op::Get("_identity_with_attr_like_rhs"); nnvm::NodePtr n = nnvm::Node::Create(); @@ -178,10 +215,12 @@ inline ValueType get_node_attr( } } -nnvm::Graph GraphExecutor::InitFullGraph( - nnvm::Symbol symbol, - const std::vector& grad_req_type, - const std::vector& arg_grad_store) { +/*! + * \brief Create the graph for backward pass. + * This is triggered by both simple_bind and bind flows. + */ +nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, + const std::vector& grad_req_types) { using nnvm::NodePtr; using nnvm::NodeEntry; // initial information @@ -191,7 +230,7 @@ nnvm::Graph GraphExecutor::InitFullGraph( nnvm::Graph g; g.outputs = symbol.outputs; bool need_grad = false; - for (OpReqType req : grad_req_type) { + for (OpReqType req : grad_req_types) { if (req != kNullOp) need_grad = true; } if (!need_grad) return g; @@ -202,10 +241,8 @@ nnvm::Graph GraphExecutor::InitFullGraph( } std::vector args = symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs); std::vector xs; - for (size_t i = 0; i < grad_req_type.size(); ++i) { - if (grad_req_type[i] != kNullOp) { - grad_store_.emplace_back( - std::make_pair(grad_req_type[i], arg_grad_store[i])); + for (size_t i = 0; i < grad_req_types.size(); ++i) { + if (grad_req_types[i] != kNullOp) { xs.emplace_back(NodeEntry{args[i], 0, 0}); } } @@ -241,13 +278,16 @@ nnvm::Graph GraphExecutor::InitFullGraph( return g; } -// pass to assign context to the graph +/*! + * \brief Assign context to the graph. + * This is triggered by both simple_bind and bind flows. + */ Graph AssignContext(Graph g, const Context& default_ctx, const std::map& ctx_map, - const std::vector& in_args, - const std::vector >& grad_store, - const std::vector& aux_states, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, size_t num_forward_inputs, size_t num_forward_outputs) { const auto& idx = g.indexed_graph(); @@ -256,56 +296,65 @@ Graph AssignContext(Graph g, if (ctx_map.size() == 0) { g.attrs["context"] = std::make_shared( ContextVector(idx.num_nodes(), default_ctx)); - for (const auto& x : in_args) { - CHECK(x.ctx() == default_ctx) - << "Input array is in " << x.ctx() << " while binding with ctx=" << default_ctx + for (const auto& x : in_arg_ctxes) { + CHECK(x == default_ctx) + << "Input array is in " << x << " while binding with ctx=" << default_ctx << ". All arguments must be in global context (" << default_ctx << ") unless group2ctx is specified for cross-device graph."; } - for (const auto& x : grad_store) { - CHECK(x.second.ctx() == default_ctx) - << "Gradient array is in " << x.second.ctx() << " while binding with ctx=" + for (const auto& x : arg_grad_ctxes) { + CHECK(x == default_ctx) + << "Gradient array is in " << x << " while binding with ctx=" << default_ctx << ". All gradients must be in global context (" << default_ctx << ") unless group2ctx is specified for cross-device graph."; } return g; } + // otherwise, use context assignment. - std::map ctx2id; - std::vector ctx_list; - nnvm::DeviceVector device(idx.num_nodes(), -1); - nnvm::DeviceAssignMap device_map; + std::map ctx2id; // map ctx to device id + std::vector ctx_list; // index is device id + nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id + nnvm::DeviceAssignMap device_map; // map arg name to device id + // loop through the user input ctx_map and + // populate maps and lists for (auto &kv : ctx_map) { - if (ctx2id.count(kv.second) == 0) { - ctx2id[kv.second] = static_cast(ctx_list.size()); - ctx_list.push_back(kv.second); + if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one + ctx2id[kv.second] = static_cast(ctx_list.size()); // assign device id to ctx + ctx_list.push_back(kv.second); // save ctx to the list } + // assign device id to to the arg name with the corresponding ctx device_map[kv.first] = ctx2id.at(kv.second); } + // loop through all the rest of input nodes not specified + // in the ctx_map and populate maps and lists size_t arg_top = 0, aux_top = 0; for (size_t i = 0; i < num_forward_inputs; ++i) { const uint32_t nid = idx.input_nodes().at(i); Context ctx; - if (mutable_nodes.count(nid)) { - CHECK_LT(aux_top, aux_states.size()); - ctx = aux_states[aux_top].ctx(); + if (mutable_nodes.count(nid)) { // aux node is mutable + CHECK_LT(aux_top, aux_state_ctxes.size()); + ctx = aux_state_ctxes[aux_top]; ++aux_top; - } else { - CHECK_LT(arg_top, in_args.size()); - ctx = in_args[arg_top].ctx(); + } else { // regular input node is immutable + CHECK_LT(arg_top, in_arg_ctxes.size()); + ctx = in_arg_ctxes[arg_top]; ++arg_top; } - if (ctx2id.count(ctx) == 0) { - ctx2id[ctx] = static_cast(ctx_list.size()); - ctx_list.push_back(ctx); + if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id + ctx2id[ctx] = static_cast(ctx_list.size()); // assign the current ctx with device id + ctx_list.push_back(ctx); // save the current ctx in the list } - device[nid] = ctx2id.at(ctx); + device[nid] = ctx2id.at(ctx); // assign device id to the current node } + + // loop through backward input nodes and populate maps and lists + // the backward input nodes is the gradient of the loss wrt the output for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) { const uint32_t nid = idx.outputs()[i].node_id; - Context ctx = grad_store[i - num_forward_outputs].second.ctx(); + Context ctx = arg_grad_ctxes[i - num_forward_outputs]; if (ctx2id.count(ctx) == 0) { ctx2id[ctx] = static_cast(ctx_list.size()); ctx_list.push_back(ctx); @@ -317,6 +366,7 @@ Graph AssignContext(Graph g, device[nid] = devid; } } + g.attrs["device"] = std::make_shared(std::move(device)); g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); const auto& assigned_device = g.GetAttr("device"); @@ -333,27 +383,388 @@ Graph AssignContext(Graph g, return g; } +/*! + * \brief GraphExecutor initializer for regular bind flow in which + * input arguments and gradients are provided by users. This initializer + * uses the user provided NDArrays to populate data entries of the graph. + */ void GraphExecutor::Init(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, const std::vector& in_args, const std::vector& arg_grad_store, - const std::vector& grad_req_type, + const std::vector& grad_req_types, const std::vector& aux_states, Executor* shared_exec, const nnvm::NodeEntryMap& feed_dict) { - nnvm::Graph g = InitGraph(symbol, default_ctx, - ctx_map, in_args, arg_grad_store, - grad_req_type, aux_states, feed_dict); + // create in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes + auto get_ctx1 = [](const NDArray& nd) { return nd.ctx(); }; + auto get_ctx2 = [default_ctx](const NDArray& nd) -> Context { + if (nd.is_none()) return default_ctx; + return nd.ctx(); + }; + std::vector in_arg_ctxes(in_args.size()); + std::transform(in_args.begin(), in_args.end(), in_arg_ctxes.begin(), get_ctx1); + std::vector arg_grad_ctxes(arg_grad_store.size()); + std::transform(arg_grad_store.begin(), arg_grad_store.end(), arg_grad_ctxes.begin(), get_ctx2); + std::vector aux_state_ctxes(aux_states.size()); + std::transform(aux_states.begin(), aux_states.end(), aux_state_ctxes.begin(), get_ctx1); + + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, + arg_grad_ctxes, aux_state_ctxes, grad_req_types); + + // create arg_shapes and arg_dtypes for shape and type inferences + const auto& idx = g.indexed_graph(); + auto mutable_nodes = idx.mutable_input_nodes(); + size_t arg_top = 0, aux_top = 0; + data_entry_.resize(idx.num_node_entries()); + nnvm::ShapeVector arg_shapes; + nnvm::DTypeVector arg_dtypes; + nnvm::StorageTypeVector arg_stypes; + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const std::string& arg_name = idx[nid].source->attrs.name; + size_t eid = idx.entry_id(nid, 0); + if (mutable_nodes.count(nid)) { + CHECK_LT(aux_top, aux_states.size()); + data_entry_[eid] = aux_states[aux_top]; + arg_shapes.push_back(aux_states[aux_top].shape()); + arg_dtypes.push_back(aux_states[aux_top].dtype()); + arg_stypes.push_back(aux_states[aux_top].storage_type()); + aux_state_map_.emplace(arg_name, aux_states[aux_top]); + ++aux_top; + } else { + CHECK_LT(arg_top, in_args.size()); + data_entry_[eid] = in_args[arg_top]; + arg_shapes.push_back(in_args[arg_top].shape()); + arg_dtypes.push_back(in_args[arg_top].dtype()); + arg_stypes.push_back(in_args[arg_top].storage_type()); + in_arg_map_.emplace(arg_name, in_args[arg_top]); + if (kNullOp != grad_req_types[arg_top]) { + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_store[arg_top]); + arg_grad_map_.emplace(arg_name, arg_grad_store[arg_top]); + } + ++arg_top; + } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign data entry\t" << eid << " as stype " + << data_entry_[eid].storage_type() << " (input)"; +#endif + } + + // expand arg_shapes and arg_dtypes to contain backward inputs + arg_shapes.resize(idx.input_nodes().size(), TShape()); + arg_dtypes.resize(idx.input_nodes().size(), -1); + arg_stypes.resize(idx.input_nodes().size(), kUndefinedStorage); + // Infer shapes and dtypes + g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + g = nnvm::pass::InferStorageType(g, arg_stypes, "__storage_type__"); + + // Initialize the rest attributes of the graph. + // This function can be called by regular bind + // operation flow as well. + FinishInitGraph(symbol, g, shared_exec, feed_dict); +} + +/*! + * \brief Initialize in_args, arg_grads, and aux_states + * and their data_entry_ of the executor. This function + * is called for regular simple_bind flow, i.e. no + * shared data arrays are provided. + */ +void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec) { + // initialize in_args, arg_grads, and aux_states + // populate grad_store_ + data_entry_.resize(idx.num_node_entries()); + size_t arg_top = 0, aux_top = 0; + auto mutable_nodes = idx.mutable_input_nodes(); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const TShape& inferred_shape = inferred_shapes[eid]; + const int inferred_dtype = inferred_dtypes[eid]; + const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid]; + const std::string& arg_name = idx[nid].source->attrs.name; + if (mutable_nodes.count(nid)) { // aux_states + EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top], + inferred_dtype, aux_state_vec); + data_entry_[eid] = aux_state_vec->back(); + aux_state_map_.emplace(arg_name, aux_state_vec->back()); + ++aux_top; +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign aux entry\t" << eid << "\t as stype " << inferred_stype; +#endif + } else { // in_args + EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], + inferred_dtype, in_arg_vec); + data_entry_[eid] = in_arg_vec->back(); +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign data entry\t" << eid << "\tas stype " << inferred_stype; +#endif + // Get the storage type for grad + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + // Init based on storage type + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], + inferred_dtype, arg_grad_vec); +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign grad entry\t" << grad_eid << "\tas stype " << grad_stype; +#endif + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); + } + in_arg_map_.emplace(arg_name, in_arg_vec->back()); + ++arg_top; + } + } +} + + +/*! + * \brief If the requested ndarray's shape size is less than + * the corresponding shared_data_array's shape size and the + * storage type is default storage, reuse the memory allocation + * in shared_buffer; otherwise, create a zero ndarray. + */ +NDArray ReshapeOrCreate(const std::string& name, + const TShape& dest_arg_shape, + const int dest_arg_dtype, + const NDArrayStorageType dest_arg_stype, + const Context& ctx, + std::unordered_map* shared_buffer) { + if (dest_arg_dtype != kDefaultStorage) { + return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + } + auto it = shared_buffer->find(name); + if (it != shared_buffer->end()) { + if (it->second.shape().Size() >= dest_arg_shape.Size()) { // memory can be reused + CHECK_EQ(it->second.dtype(), dest_arg_dtype) + << "Requested arg array's dtype does not match the reusable ndarray"; + CHECK_EQ(it->second.storage_type(), kDefaultStorage) + << "shared_buffer should only contain NDArrays with default storage type."; + return it->second.Reshape(dest_arg_shape); + } else { + LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape + << ", which is larger than already allocated shape " << it->second.shape() + << ". Need to re-allocate. Consider putting default bucket key to be " + << "the bucket taking the largest input for better memory sharing."; + // the NDArrays in shared_buffer are guaranteed to be of default storage + it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + return it->second; + } // arg_array.shape().Size() >= arg_shape.Size() + } else { + auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + shared_buffer->emplace(name, ret); + return ret; + } // if (it != shared_buffer->end()) +} + +/*! + * \brief Initialize in_args, arg_grads, and aux_states + * and their data_entry_ of the executor using + * shared_buffer from DataParallelExecutorGroup + * and shared_exec if available. + */ +void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec) { + // initialize in_args, arg_grads, and aux_states and populate grad_store_ + data_entry_.resize(idx.num_node_entries()); + size_t arg_top = 0, aux_top = 0; + auto mutable_nodes = idx.mutable_input_nodes(); + const auto& shared_exec_in_args = shared_exec->in_arg_map(); + const auto& shared_exec_arg_grads = shared_exec->arg_grad_map(); + const auto& shared_exec_aux_states = shared_exec->aux_state_map(); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const TShape& inferred_shape = inferred_shapes[eid]; + const int inferred_dtype = inferred_dtypes[eid]; + const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid]; + const std::string& arg_name = idx[nid].source->attrs.name; + // aux_states + if (mutable_nodes.count(nid)) { + if (nullptr != shared_exec && inferred_stype == kDefaultStorage && + shared_exec_aux_states.at(arg_name).storage_type() == kDefaultStorage) { + const NDArray& aux_nd = shared_exec_aux_states.at(arg_name); + CHECK_EQ(inferred_shape, aux_nd.shape()) + << "Inferred shape does not match shared_exec.aux_array's shape." + " Therefore, the allocated memory for shared_exec.aux_array cannot" + " be resued for creating auxilliary NDArray of the argument" + << arg_name << " for the current executor"; + CHECK_EQ(inferred_dtype, aux_nd.dtype()) + << "Inferred dtype does not match shared_exec.aux_array's dtype." + " Therefore, the allocated memory for shared_exec.aux_array cannot" + " be resued for creating auxilliary NDArray of the argument" + << arg_name << " for the current executor"; + aux_state_vec->emplace_back(aux_nd); + } else { + EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top], + inferred_dtype, aux_state_vec); + } // if (has_shared_exec) + data_entry_[eid] = aux_state_vec->back(); + aux_state_map_.emplace(arg_name, aux_state_vec->back()); + ++aux_top; + } else { // in_args and grad for in_args + if (shared_arg_names.count(arg_name)) { // model parameter + // model parameter + if (nullptr != shared_exec && inferred_stype == kDefaultStorage && + shared_exec_in_args.at(arg_name).storage_type() == kDefaultStorage) { + // try to reuse memory from shared_exec + const NDArray& in_arg_nd = shared_exec_in_args.at(arg_name); + CHECK_EQ(inferred_shape, in_arg_nd.shape()) + << "Inferred shape does not match shared_exec.arg_array's shape" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument" + << arg_name << " for the current executor"; + CHECK_EQ(inferred_dtype, in_arg_nd.dtype()) + << "Inferred dtype does not match shared_exec.arg_array's dtype" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument" + << arg_name << " for the current executor"; + in_arg_vec->emplace_back(in_arg_nd); + } else { + // doesn't have shared_exec, or non-default storage + EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], + inferred_dtype, in_arg_vec); + } + // gradient for model parameter + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + if (nullptr != shared_exec && grad_stype == kDefaultStorage && + shared_exec_arg_grads.at(arg_name).storage_type() == kDefaultStorage) { + // try to reuse memory from shared_exec + arg_grad_vec->emplace_back(shared_exec_arg_grads.at(arg_name)); + } else { + EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], + inferred_dtype, arg_grad_vec); + } + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } + } else { // !shared_arg_names.count(arg_name) + // model parameter + in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype, + inferred_stype, in_arg_ctxes[arg_top], + shared_buffer)); + // gradient for model parameter + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape, + inferred_dtype, grad_stype, + arg_grad_ctxes[arg_top], shared_buffer)); + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } // if (kNullOp == grad_req_types[arg_top]) + } // if (shared_arg_names.count(arg_name)) + in_arg_map_.emplace(arg_name, in_arg_vec->back()); + if (!arg_grad_vec->back().is_none()) { + arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); + } + data_entry_[eid] = in_arg_vec->back(); + ++arg_top; + } + } +} + +/*! + * \brief Finish graph initialization after shape and dtype inferences. + * This function is used by both simple_bind and bind flows. + */ +void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, + nnvm::Graph g, + Executor* shared_exec, + const nnvm::NodeEntryMap& feed_dict) { + const auto& idx = g.indexed_graph(); + // dispatch based on stype per operator + const auto& vstorage_type = g.GetAttr("storage_type"); + nnvm::StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage); + for (size_t nid = 0; nid < idx.num_nodes(); nid++) { + const auto& inode = idx[nid]; + auto num_outputs = inode.source->num_outputs(); + auto num_inputs = inode.inputs.size(); + nnvm::StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage); + for (size_t i = 0; i < num_inputs; i++) { + auto e = inode.inputs[i]; + vs[i] = vstorage_type[idx.entry_id(e)]; + CHECK_NE(vs[i], kUndefinedStorage); + } + for (uint32_t i = 0; i < num_outputs; ++i) { + uint32_t eid = idx.entry_id(nid, i); + vs[i + num_inputs] = vstorage_type[eid]; + } + bool contains_non_default = common::ContainsNonDefaultStorage(vs); + dispatch_stypes[nid] = contains_non_default ? kNonDefaultStorage : kDefaultStorage; + } + g.attrs["dispatch_stypes"] = std::make_shared(std::move(dispatch_stypes)); + + // data entries for output gradients + for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { + data_entry_[idx.entry_id(idx.outputs()[j])] = grad_store_[j - num_forward_outputs_].second; + } + + { + // memory allocator + nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID); + for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { + arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID; + } + for (const auto& kv : feed_dict) { + uint32_t eid = idx.entry_id(kv.first); + data_entry_[eid] = kv.second; + arg_storage_id[eid] = kExternalStorageID; + } + for (size_t i = 0; i < idx.num_node_entries(); i++) { + if (vstorage_type[i] != kDefaultStorage) arg_storage_id[i] = kDynamicStorageID; + } + g.attrs["storage"] = std::make_shared(std::move(arg_storage_id)); + g = nnvm::ApplyPass(g, "PlanMemory"); + } + g = DetectInplaceAddTo(g); + g.attrs["saved_opr"] = std::make_shared(std::move(saved_opr_)); g = AttachOpExecs(g); g = AttachOpResources(g); graph_ = std::move(g); + if (shared_exec != nullptr) { this->InitDataEntryMemory(&(dynamic_cast(shared_exec)->data_pool_)); } else { this->InitDataEntryMemory(nullptr); } + { // initialize output arrays auto& idx = graph_.indexed_graph(); @@ -373,22 +784,121 @@ void GraphExecutor::Init(nnvm::Symbol symbol, this->InitOpSegs(); } +/*! + * \brief GraphExecutor initializer for simple bind flow in + * which only certain input shapes and dtypes are provided by users. + * The initializer uses these shapes and dtypes to perform + * shape and dtype inferences, and then create NDArrays + * to populate data entries of the graph. The created NDArrays + * for in_args, arg_grads and aux_states are passed to the + * front end to attach the created executor. + * In front end, if the simple_bind flow is trigger by + * _bind_ith_exec, the shared data arrays of DataParallelExecutorGroup + * and shared executor will be taken into account in creating + * NDArrays for in_args, arg_grads, and aux_states for resuing + * already allocated memory. + */ +void GraphExecutor::Init(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec, + std::unordered_map* shared_buffer, + Executor* shared_exec, + const nnvm::NodeEntryMap& feed_dict) { + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, + aux_state_ctxes, grad_req_types); + // The following code of shape and dtype inferences and argument + // initialization is for simple_bind only. Regular bind operation + // should do this differently. + + // Initialize arg_shapes and arg_dtypes for shape and type inferences. + // It contains all in_args and aux_states' shapes and types in a certain order. + const nnvm::IndexedGraph& idx = g.indexed_graph(); + nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); + nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); + nnvm::DTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const std::string& name = idx[nid].source->attrs.name; + auto it1 = arg_shape_map.find(name); + if (arg_shape_map.end() != it1) { + arg_shapes[i] = it1->second; + } + auto it2 = arg_dtype_map.find(name); + if (arg_dtype_map.end() != it2) { + arg_dtypes[i] = it2->second; + } + auto it3 = arg_stype_map.find(name); + if (arg_stype_map.end() != it3) { + arg_stypes[i] = it3->second; + } + } + // TODO(jun/haibin) check if InferShape is successful, and give warnings instead of segfault later + g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + g = nnvm::pass::InferStorageType(g, arg_stypes, "__storage_type__"); + + // Create in_args, arg_grads, and aux_states using + // the inferred shapes and dtypes. + if (nullptr == shared_buffer) { // regular simple bind + InitArguments(idx, g.GetAttr("shape"), + g.GetAttr("dtype"), + g.GetAttr("storage_type"), + in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); + } else { // simple bind using shared data arrays and shared_exec + InitArguments(idx, g.GetAttr("shape"), + g.GetAttr("dtype"), + g.GetAttr("storage_type"), + in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + grad_req_types, shared_arg_names, shared_exec, + shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); + } + // The above code of shape and dtype inferences and argument + // initialization is for simple_bind only. Regular bind operation + // should do this differently. + + // Initialize the rest attributes of the graph. + // This function can be called by regular bind + // operation flow as well. + FinishInitGraph(symbol, g, shared_exec, feed_dict); +} + +/*! + * \brief This function is triggered by both simple_bind + * and bind flows. + * Setup backward graph, create device and context + * attributes in the graph, and calculate the number + * of forward nodes. + */ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, - const std::vector& in_args, - const std::vector& arg_grad_store, - const std::vector& grad_req_type, - const std::vector& aux_states, - const nnvm::NodeEntryMap& feed_dict) { + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types) { // setup gradient - nnvm::Graph g = InitFullGraph(symbol, grad_req_type, arg_grad_store); + nnvm::Graph g = InitFullGraph(symbol, grad_req_types); + + // create "device" and "context" attrs for the graph g = AssignContext(g, default_ctx, ctx_map, - in_args, - grad_store_, - aux_states, + in_arg_ctxes, + arg_grad_ctxes, + aux_state_ctxes, num_forward_inputs_, num_forward_outputs_); + const auto& idx = g.indexed_graph(); // get number of nodes used in forward pass num_forward_nodes_ = 0; @@ -396,61 +906,13 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, num_forward_nodes_ = std::max( num_forward_nodes_, static_cast(idx.outputs()[i].node_id + 1)); } - // Setup data entry, shape and type. - data_entry_.resize(idx.num_node_entries()); - auto mutable_nodes = idx.mutable_input_nodes(); - nnvm::ShapeVector arg_shapes; - nnvm::DTypeVector arg_types; - size_t arg_top = 0, aux_top = 0; - for (size_t i = 0; i < num_forward_inputs_; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - if (mutable_nodes.count(nid)) { - CHECK_LT(aux_top, aux_states.size()); - data_entry_[idx.entry_id(nid, 0)] = aux_states[aux_top]; - arg_shapes.push_back(aux_states[aux_top].shape()); - arg_types.push_back(aux_states[aux_top].dtype()); - ++aux_top; - } else { - CHECK_LT(arg_top, in_args.size()); - data_entry_[idx.entry_id(nid, 0)] = in_args[arg_top]; - arg_shapes.push_back(in_args[arg_top].shape()); - arg_types.push_back(in_args[arg_top].dtype()); - ++arg_top; - } - } - for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { - data_entry_[idx.entry_id(idx.outputs()[j])] - = grad_store_[j - num_forward_outputs_].second; - } - arg_shapes.resize(idx.input_nodes().size(), TShape()); - arg_types.resize(idx.input_nodes().size(), -1); - // other initializations - g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); - g = nnvm::pass::InferType(g, arg_types, "__dtype__"); - - { - // memory allocator - const int kBadStorageID = -1; - const int kExternalStorageID = -2; - nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID); - for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { - arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID; - } - for (const auto& kv : feed_dict) { - uint32_t eid = idx.entry_id(kv.first); - data_entry_[eid] = kv.second; - arg_storage_id[eid] = kExternalStorageID; - } - g.attrs["storage"] = std::make_shared(std::move(arg_storage_id)); - g = nnvm::ApplyPass(g, "PlanMemory"); - } - g = DetectInplaceAddTo(g); return g; } // initialize the memory of each entries void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { using nnvm::DTypeVector; + using nnvm::StorageTypeVector; using nnvm::ShapeVector; using nnvm::StorageVector; // get the graph @@ -459,20 +921,29 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { const auto& vdtype = graph_.GetAttr("dtype"); const auto& vshape = graph_.GetAttr("shape"); const auto& vstorage = graph_.GetAttr("storage_id"); + const auto& vstorage_type = graph_.GetAttr("storage_type"); const auto& vctx = graph_.GetAttr("context"); CHECK_EQ(idx.num_node_entries(), vshape.size()); CHECK_EQ(idx.num_node_entries(), vdtype.size()); CHECK_EQ(idx.num_node_entries(), vstorage.size()); CHECK_EQ(data_entry_.size(), vshape.size()); std::vector data_context(idx.num_node_entries()); + std::vector data_storage_type(idx.num_node_entries(), kUndefinedStorage); for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t i = 0; i < idx[nid].source->num_outputs(); ++i) { - data_context[idx.entry_id(nid, i)] = vctx[nid]; + auto eid = idx.entry_id(nid, i); + data_context[eid] = vctx[nid]; + CHECK_NE(vstorage_type[nid], kUndefinedStorage); + data_storage_type[eid] = (NDArrayStorageType) vstorage_type[nid]; } } // information about the pool - using PoolEntry = std::pair; + struct PoolEntry { + Context ctx; + size_t bytes; + NDArrayStorageType stype; + }; std::vector pool_info; // assign array to head gradient @@ -480,26 +951,36 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { uint32_t nid = idx.input_nodes().at(i); uint32_t oid = head_grad_map_.at(idx[nid].source); uint32_t eid = idx.entry_id(idx.outputs()[oid]); + NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid]; CHECK_NE(vshape[eid].ndim(), 0U); CHECK_NE(vdtype[eid], -1); - data_entry_[idx.entry_id(nid, 0)] = - NDArray(vshape[eid], data_context[eid], false, vdtype[eid]); + auto data_eid = idx.entry_id(nid, 0); + // initialize based on storage_type + if (stype != kDefaultStorage) { + data_entry_[data_eid] = NDArray(stype, vshape[eid], data_context[eid], true, vdtype[eid]); + } else { + data_entry_[data_eid] = NDArray(vshape[eid], data_context[eid], false, vdtype[eid]); + } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tinit head_g entry\t" << data_eid << "\tas stype " << stype; +#endif } // get maximum bytes in each pool for (size_t i = 0; i < vshape.size(); ++i) { if (!data_entry_[i].is_none()) continue; size_t bytes = vshape[i].Size() * mshadow::mshadow_sizeof(vdtype[i]); int storage_id = vstorage[i]; + // skip pool allocation for kBadStorageID, kExternalStorageID and kDynamicStorageID if (storage_id < 0) continue; size_t sid = static_cast(storage_id); if (sid >= pool_info.size()) { - pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0)}); + pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0), kUndefinedStorage}); } PoolEntry& info = pool_info[sid]; - if (info.second == 0) { - info = PoolEntry{data_context[i], bytes}; + if (info.bytes == 0) { + info = PoolEntry{data_context[i], bytes, data_storage_type[i]}; } else { - info.second = std::max(info.second, bytes); + info.bytes = std::max(info.bytes, bytes); } } // construct the re-use pool, if needed @@ -520,13 +1001,14 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { sorted_pool_index.push_back(i); } auto pool_comparator = [&pool_info](int lhs, int rhs){ - return pool_info[lhs].second > pool_info[rhs].second; + return pool_info[lhs].bytes > pool_info[rhs].bytes; }; std::sort(sorted_pool_index.begin(), sorted_pool_index.end(), pool_comparator); for (size_t i : sorted_pool_index) { - const Context& ctx = pool_info[i].first; - size_t bytes = pool_info[i].second; + const Context& ctx = pool_info[i].ctx; + size_t bytes = pool_info[i].bytes; + NDArrayStorageType storage_type = pool_info[i].stype; bool allocated = false; for (auto it = free_pool.lower_bound(bytes); it != free_pool.end(); ++it) { if (it->second.ctx() == ctx && it->first >= bytes) { @@ -551,15 +1033,22 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { } CHECK_EQ(data_pool_.size(), pool_info.size()); // assign the data entries - for (size_t i = 0; i < data_entry_.size(); ++i) { // avoid pre-allocated arrays if (!data_entry_[i].is_none()) continue; // assign allocated array by storage id int storage_id = vstorage[i]; - CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet"; - const NDArray& src = data_pool_.at(storage_id); - data_entry_[i] = src.AsArray(vshape[i], vdtype[i]); + auto storage_type = (NDArrayStorageType) vstorage_type[i]; + if (storage_type == kDefaultStorage) { + CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet"; + const NDArray& src = data_pool_.at(storage_id); + data_entry_[i] = src.AsArray(vshape[i], vdtype[i]); + } else { + data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i]); + } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tinit data entry\t" << i << "\tas stype " << storage_type; +#endif } } @@ -574,11 +1063,28 @@ void GraphExecutor::InitCachedOps() { const auto& vctx = graph_.GetAttr("context"); const auto& addto_entry = graph_.GetAttr >("addto_entry"); const auto& skip_plus_node = graph_.GetAttr >("skip_plus_node"); + const auto& vstorage_type = graph_.GetAttr("storage_type"); op_nodes_.resize(idx.num_nodes()); // setup the array and requirements. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; +#if EXECUTOR_DEBUG + if (inode.source->is_variable()) { + LOG(INFO) << "node " << nid << " var"; + } else { + LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name; + auto exec = op_execs[nid]; + for (const auto& e : inode.inputs) { + auto eid = idx.entry_id(e); + LOG(INFO) << "\t\tinput " << eid << " stype: " << vstorage_type[eid]; + } + for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { + uint32_t eid = idx.entry_id(nid, index); + LOG(INFO) << "\t\toutput " << eid << " stype: " << vstorage_type[eid]; + } + } +#endif if (inode.source->is_variable()) continue; #if MXNET_USE_PROFILER op_nodes_[nid].opr_name = inode.source->op()->name.c_str(); @@ -655,7 +1161,7 @@ void GraphExecutor::InitCachedOps() { if (is_async) { exec->op_ctx.async_on_complete = on_complete; } - exec->Run(ctx); + exec->Run(ctx, is_gpu); // call on complete only if it is async op if (!is_async) { if (is_gpu) { @@ -800,6 +1306,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; +#endif +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid << " - " << seg_op.topo_end - 1; #endif Engine::Get()->Push(seg_op.opr, seg_op.ctx, 0, profiling); nid = seg_op.topo_end - 1; @@ -812,6 +1321,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { if (op_nodes_[nid].skip_exec_node) continue; opnode.exec->op_ctx.is_train = is_train; if (opnode.exec->exec_type() == Operator::kCrossDeviceCopy) { +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid << " for CrossDeviceCopy"; +#endif CHECK_EQ(inode.inputs.size(), 1U); CHECK_EQ(opnode.exec->in_array.size(), 1U); CHECK_EQ(opnode.exec->out_array.size(), 1U); @@ -821,6 +1333,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; +#endif +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid; #endif Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling); } else { @@ -885,7 +1400,7 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, RunContext ctx, Engine::CallbackOnComplete on_complete) { // Run all opr in the sub-graph for (auto &exec : exec_list) { - exec->Run(ctx); + exec->Run(ctx, is_gpu); } if (is_gpu) { #if MXNET_USE_CUDA @@ -912,6 +1427,32 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, } } // namespace exec +Executor *Executor::SimpleBind(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& group2ctx, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + std::vector* in_args, + std::vector* arg_grads, + std::vector* aux_states, + std::unordered_map* shared_buffer, + Executor* shared_exec) { + auto exec = new exec::GraphExecutor(); + exec->Init(symbol, default_ctx, group2ctx, + in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_types, shared_arg_names, + in_args, arg_grads, aux_states, + shared_buffer, shared_exec); + return exec; +} + Executor *Executor::Bind(nnvm::Symbol symbol, const Context& default_ctx, const std::map& group2ctx, diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index d9c3a3e6aa47..308eddba8b80 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -19,6 +19,8 @@ #include #include "./exec_pass.h" +#define EXECUTOR_DEBUG 0 + namespace mxnet { using NodeOperatorMap = std::unordered_map &head_grads) override; const std::vector& outputs() const override; + const std::unordered_map& in_arg_map() const override; + const std::unordered_map& arg_grad_map() const override; + const std::unordered_map& aux_state_map() const override; void Print(std::ostream &os) const override; // NOLINT(*) void SetMonitorCallback(const MonitorCallback& callback) override; - // initialized the executor + // Initialize the rest of attributes + // after setting up arguments. + void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, + Executor* shared_exec = nullptr, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); + + // initialize executor for bind void Init(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, const std::vector& in_args, const std::vector& arg_grad_store, - const std::vector& grad_req_type, + const std::vector& grad_req_types, const std::vector& aux_states, Executor* shared_exec = nullptr, const nnvm::NodeEntryMap& feed_dict = nnvm::NodeEntryMap()); + // initialize executor for simple bind + void Init(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec, + std::unordered_map* shared_buffer = nullptr, + Executor* shared_exec = nullptr, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); protected: // Information about operational node @@ -94,21 +125,45 @@ class GraphExecutor : public Executor { // list of op executors std::vector exec_list; }; - - // internal initialization of the graph. + // Initialize in_args, arg_grads, and aux_states + void InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec); + // Initialize in_args, arg_grads and aux_states with + // shared_buffer and shared_exec + void InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec); + // internal initialization of the graph for simple bind Graph InitGraph(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, - const std::vector& in_args, - const std::vector& arg_grad_store, - const std::vector& grad_req_type, - const std::vector& aux_states, - const nnvm::NodeEntryMap& feed_dict - = nnvm::NodeEntryMap()); - // initialize the full graph, including gradient. + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types); + // intialize the full graph for simple bind, including gradient Graph InitFullGraph(nnvm::Symbol symbol, - const std::vector& grad_req_type, - const std::vector& arg_grad_store); + const std::vector& grad_req_types); // initialize the cached operator void InitCachedOps(); // initialize the opr segments for bulk exec @@ -136,10 +191,17 @@ class GraphExecutor : public Executor { std::vector op_nodes_; // internal data entry of each node std::vector data_entry_; - // internal data pool of allocated entries + // internal data pool of allocated entries. + // these allocated entries can be used for static memory sharing between executors. std::vector data_pool_; // output arrays std::vector output_arrays_; + // input argument map, key is arg name, value is arg's NDArray + std::unordered_map in_arg_map_; + // arg grad map, key is arg name, value is arg grad NDArray + std::unordered_map arg_grad_map_; + // aux state map, key is aux state name, value is aux state NDArray + std::unordered_map aux_state_map_; // gradient store std::vector > grad_store_; // array to hold head gradient. diff --git a/src/executor/inplace_addto_detect_pass.cc b/src/executor/inplace_addto_detect_pass.cc index 75a2608313aa..1a0bc9cb40a6 100644 --- a/src/executor/inplace_addto_detect_pass.cc +++ b/src/executor/inplace_addto_detect_pass.cc @@ -44,6 +44,8 @@ Graph DetectInplaceAddTo(Graph g) { uint32_t eid_rhs = idx.entry_id(inode.inputs[1]); if (ref_count[eid_rhs] != 1) continue; if (inode.inputs[0].node_id >= inode.inputs[1].node_id) continue; + // TODO(haibin) support inplace addto for Dynamic Storage + if (storage_id[eid_rhs] == kDynamicStorageID) continue; CHECK_NE(storage_id[eid_rhs], sid); storage_id[eid_rhs] = sid; addto_entry[eid_rhs] = 1; diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index c19a82b164c4..f692a5700ba5 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -11,6 +11,7 @@ #include #include #include "./ndarray_function.h" +#include "../operator/tensor/matrix_op-inl.h" #include "./autograd.h" #if MXNET_USE_OPENCV @@ -27,6 +28,7 @@ NDArray NDArray::Reshape(const TShape &shape) const { using namespace autograd; CHECK_GE(shape_.Size(), shape.Size()) << "NDArray.Reshape: target shape size is different from current shape"; + CHECK(storage_type() == kDefaultStorage) << "Not implemented yet"; NDArray ret = *this; ret.shape_ = shape; if (AutogradRuntime::Get()->IsTraining()) { @@ -50,12 +52,14 @@ NDArray NDArray::Reshape(const TShape &shape) const { } } - NDArray NDArray::Slice(index_t begin, index_t end) const { using namespace autograd; + using namespace mshadow; NDArray ret = *this; CHECK(!is_none()) << "NDArray is not initialized"; CHECK_GE(shape_[0], end) << "Slice end index out of range"; + auto stype = storage_type(); + CHECK_EQ(stype, kDefaultStorage); size_t length = shape_.ProdShape(1, shape_.ndim()); ret.offset_ += begin * length; ret.shape_[0] = end - begin; @@ -80,8 +84,69 @@ NDArray NDArray::Slice(index_t begin, index_t end) const { } } +void NDArray::SliceEx(index_t begin, index_t end, NDArray *ret) const { + using namespace autograd; + using namespace mshadow; + CHECK(!is_none()) << "NDArray is not initialized"; + CHECK_GE(shape_[0], end) << "Slice end index out of range"; + auto stype = storage_type(); + CHECK_NE(stype, kDefaultStorage); + if (stype == kCSRStorage) { + using namespace csr; + ret->shape_[0] = end - begin; + NDArray src = *this; + // destination NDArray shares the same variable + ret->ptr_->var = var(); + Engine::Get()->PushSync([src, ret, begin, end](RunContext ctx) { + NDArray dst = *ret; + // create a new chunk for dst NDArray + NDArray::Chunk chunk = *src.ptr_; + // void indptr storage handle + chunk.aux_handles[kIndPtr] = Storage::Handle(); + // shape for indptr is end - begin + 1 + chunk.CheckAndAllocAuxData(kIndPtr, Shape1(end - begin + 1)); + if (src.ctx().dev_mask() == cpu::kDevMask) { + MSHADOW_INT_TYPE_SWITCH(src.aux_type(kIndPtr), IType, { + MSHADOW_TYPE_SWITCH(src.dtype(), DType, { + // create new indptr + const IType* src_indptr = src.aux_data(kIndPtr).dptr(); + IType* dst_indptr = static_cast (chunk.aux_handles[kIndPtr].dptr); + op::SliceCsrIndPtrImpl(begin, end, ctx, src_indptr, dst_indptr); + // advance idx and values pointers (CPU implementation) + // TODO(haibin) refactor for GPU implementation later + IType offset = src_indptr[begin]; + IType* idx = static_cast(chunk.aux_handles[kIdx].dptr); + DType* values = static_cast(chunk.shandle.dptr); + chunk.aux_handles[kIdx].dptr = idx + offset; + chunk.shandle.dptr = values + offset; + // update storage shape and aux shape (CPU implementation) + auto nnz = dst_indptr[end - begin]; + chunk.aux_shapes[kIdx] = Shape1(nnz); + chunk.storage_shape = Shape1(nnz); + chunk.static_data = true; + chunk.skip_delete_var = true; + // update dst chunk + *dst.ptr_ = chunk; + }); + }); + } else { +#if MXNET_USE_CUDA + LOG(FATAL) << "SliceEx CSR not implemented yet"; +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } + }, ctx(), {}, {var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + } else { + LOG(FATAL) << "Slice not yet implemented for storage " << stype; + } + // TODO(haibin) support auto_grad for SliceEx +} NDArray NDArray::At(index_t idx) const { + CHECK(storage_type() == kDefaultStorage) << "Storage type " + << storage_type() << " doesn't support At()"; NDArray ret = this->Slice(idx, idx+1); if (shape_.ndim() > 1) { return ret.Reshape(TShape(shape_.data()+1, shape_.data()+shape_.ndim())); @@ -190,11 +255,11 @@ void BinaryOp(const NDArray &lhs, // redirect everything to mshadow operations switch (lhs.ctx().dev_mask()) { case cpu::kDevMask: { - Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - }, lhs.ctx(), const_vars, {ret.var()}, - FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { + TBlob tmp = ret.data(); + ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); + }, lhs.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -220,6 +285,7 @@ void SetValueOp(const real_t &rhs, NDArray *out) { switch (ret.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { + CHECK(ret.storage_type() == kDefaultStorage); TBlob tmp = ret.data(); ndarray::Eval(rhs, &tmp, ctx); }, ret.ctx(), {}, {ret.var()}, @@ -291,6 +357,7 @@ void ScalarOp(const NDArray &lhs, } } + void CopyFromTo(const NDArray &from, NDArray *to, int priority) { if (from.var() == to->var()) { // skip to copy to itself @@ -305,44 +372,33 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { NDArray ret = *to; int a = from.ctx().dev_mask(); int b = to->ctx().dev_mask(); - std::vector const_vars; if (from.var() != ret.var()) const_vars.push_back(from.var()); if (a == cpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kNormal, priority, PROFILER_MESSAGE("CopyCPU2CPU")); } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, ret.ctx(), const_vars, {ret.var()}, FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("CopyCPU2GPU")); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2CPU")); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, from.dtype() != ret.dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2GPU")); diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h index 28524b73d0dd..aad80fd4360a 100644 --- a/src/ndarray/ndarray_function-inl.h +++ b/src/ndarray/ndarray_function-inl.h @@ -12,27 +12,28 @@ // macro to help specialize evaluation function #ifndef DECL_TERNARY -#define DECL_TERNARY(XPU, OP, FUN) \ - template<> \ - void Eval(const TBlob &lhs, const TBlob &mhs, \ - const TBlob &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, mhs, rhs, ret, ctx); \ +#define DECL_TERNARY(XPU, OP, FUN) \ + template<> \ + void Eval(const TBlob &lhs, const TBlob &mhs, \ + const TBlob &rhs, TBlob *ret, RunContext ctx) { \ + FUN(lhs, mhs, rhs, ret, ctx); \ } #endif #ifndef DECL_BINARY -#define DECL_BINARY(XPU, OP, FUN) \ - template<> \ +#define DECL_BINARY(XPU, OP, FUN) \ + template<> \ void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ + FUN(lhs, rhs, ret, ctx); \ } #endif #ifndef DECL_SCALAR -#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ - template<> \ - void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ +#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ + template<> \ + void Eval(const TBlob &lhs, const real_t &rhs, \ + TBlob *ret, RunContext ctx) { \ + FUN(lhs, rhs, ret, ctx); \ } #endif @@ -44,10 +45,11 @@ namespace mxnet { namespace ndarray { + // true implementation template -inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalBinary_(const TBlob &lhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); CHECK_EQ(ret->type_flag_, lhs.type_flag_) @@ -61,10 +63,9 @@ inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, }); } - template -inline void EvalOneHot_(const TBlob &index, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalOneHot_(const TBlob &index, const TBlob &rhs, + TBlob *ret, RunContext ctx) { LOG(INFO) << "The operator onehot_encode is deprecated; use one_hot instead."; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); @@ -81,8 +82,8 @@ inline void EvalOneHot_(const TBlob &index, const TBlob &rhs, } template -inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); // TODO(eric): support mixed type choose, i.e. int index and float rhs. @@ -98,8 +99,8 @@ inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, } template -inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); ret->get(s) @@ -109,8 +110,8 @@ inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob } template -inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, - TBlob *ret, RunContext ctx) { +void EvalScalar_(const TBlob &lhs, const real_t &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); CHECK_EQ(ret->type_flag_, lhs.type_flag_) @@ -130,7 +131,7 @@ inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, template<> void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max, - TBlob *ret, RunContext ctx) { + TBlob *ret, RunContext ctx) { typedef DEVICE xpu; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); @@ -145,12 +146,11 @@ void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max } template<> -void EvalRandom( - const real_t &a, - const real_t &b, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +void EvalRandom(const real_t &a, + const real_t &b, + const Resource &resource, + TBlob *ret, + RunContext ctx) { typedef DEVICE xpu; mshadow::Stream *s = ctx.get_stream(); switch (ret->type_flag_) { @@ -426,6 +426,7 @@ DECL_SCALAR(DEVICE, Plus, EvalScalar_, true) DECL_SCALAR(DEVICE, Minus, EvalScalar_, true) DECL_SCALAR(DEVICE, Mul, EvalScalar_, true) DECL_SCALAR(DEVICE, Div, EvalScalar_, true) + // for reverse seq DECL_SCALAR(DEVICE, Plus, EvalScalar_, false) DECL_SCALAR(DEVICE, Minus, EvalScalar_, false) diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index def38126d08c..f4315b62a6a8 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -17,6 +17,7 @@ #include #include #include "./operator_common.h" +#include "../common/utils.h" namespace mxnet { namespace op { @@ -53,6 +54,42 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, return true; } +// Only inferring output storage types from input for now +template +inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + auto deduce = [&](std::vector *vec, const char *name, AttrType& result, + bool fallback) { + auto &v = *vec; + for (size_t i = 0; i < vec->size(); ++i) { + if (v[i] == kUndefinedStorage) { + // if input type is unknown, assume it's default storage + CHECK(assign(&v[i], kDefaultStorage)); + } else if (assign(&result, v[i]) == false && fallback) { + result = kDefaultStorage; + } + } + }; + AttrType dattr = kUndefinedStorage; + deduce(in_attrs, "input", dattr, enable_fallback); + if (reverse_infer) { + LOG(FATAL) << "not implemented yet"; + } + auto write = [&](std::vector *vec, const char *name) { + for (size_t i = 0; i < vec->size(); ++i) { + CHECK(assign(&(*vec)[i], dattr)) + << "Incompatible attr in node " << attrs.name << " at " << i << "-th " + << name << ": " << "expected " << dattr << ", got " << (*vec)[i]; + } + }; + if (is_none(dattr)) dattr = kDefaultStorage; + write(out_attrs, "output"); + return true; +} + template inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -73,6 +110,29 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } +template +inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + return ElemwiseStorageAttr( + attrs, in_attrs, out_attrs); +} + +inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; + auto &in = *in_attrs; + auto &out = *out_attrs; + CHECK_NE(in[1], kUndefinedStorage) << "rhs storage type must be known"; + if (in[0] == kUndefinedStorage) in[0] = in[1]; + if (out[0] == kUndefinedStorage) out[0] = in[1]; + return true; +} + // Transfer gradient and input to FGradient function struct ElemwiseGradUseIn { const char *op_name; @@ -105,6 +165,22 @@ struct ElemwiseGradUseNone { } }; +// TODO(haibin) this is a temporary function for debugging purpose. Remove later. +template +void print_info(const mshadow::Tensor& tensor, const std::string& name) { + std::cout << "Tensor " << name << " with shape ("; + int len = 1; + for (int i = 0; i < dim; i++) { + len *= tensor.shape_[i]; + std::cout << tensor.shape_[i] << ","; + if (i == dim - 1) std::cout << ")"; + } + std::cout << std::endl; + for (int j = 0; j < len; j ++) std::cout << tensor.dptr_[j] << " "; + std::cout << std::endl; +} + + } // namespace op } // namespace mxnet diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index a43d092bceb6..6e0bc2ad5ba6 100755 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -11,12 +11,15 @@ #include #include #include +#include +#include #include #include #include #include #include #include "../common/cuda_utils.h" +#include "../common/utils.h" namespace mxnet { namespace op { @@ -315,6 +318,22 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) { attrs->parsed = std::move(param); } +template +void FCompExFallback(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + FCompute fcompute, + const std::string& fname) { + std::vector in_blobs, out_blobs; + std::vector tmps; + common::GetInputBlobs(inputs, &in_blobs, &tmps, ctx, true); + common::GetOutputBlobs(outputs, &out_blobs); + fcompute(attrs, ctx, in_blobs, req, out_blobs); +} + + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 85091c008ab4..83a4a9cfccbb 100755 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -84,6 +84,87 @@ inline void SGDUpdate(const nnvm::NodeAttrs& attrs, }); } +/*! \brief kernel for sparse sgd + */ +template +struct SGDDnsRspKernel { + // DType is the output data type + // IType is row sparse idx type + // i is the ith row in row sparse gradient + template + MSHADOW_XINLINE static void Map(int i, size_t width, DType* out, const DType* weight, + const IType* grad_idx, const DType *grad_val, + const DType clip_gradient, const DType lr, + const DType wd, const DType rescale_grad) { + for (size_t j = 0; j < width; j++) { + uint64_t data_i = grad_idx[i] * width + j; + uint64_t grad_i = i * width + j; + if (clip_gradient >= 0.0f) { + KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - + (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[grad_i], clip_gradient)); + } else { + KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - + (lr * rescale_grad) * grad_val[grad_i]); + } + } + } +}; + +template +inline void SGDUpdateDnsRspImpl(const SGDParam& param, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + Stream* s = ctx.get_stream(); + auto &weight = inputs[0]; + auto &grad = inputs[1]; + auto &out = outputs[0]; + CHECK_EQ(weight.storage_type(), kDefaultStorage); + CHECK_EQ(grad.storage_type(), kRowSparseStorage); + if (!grad.storage_initialized()) return; + + MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + auto weight_data = weight.data().FlatTo2D(s); + auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto grad_val = grad.data().FlatTo2D(s); + auto out_data = out.data().FlatTo2D(s); + auto num_rows = grad.aux_shape(rowsparse::kIdx)[0]; + auto width = weight.shape().ProdShape(1, weight.shape().ndim()); + mxnet_op::Kernel, xpu>::Launch(s, num_rows, width, + out_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_, + static_cast(param.clip_gradient), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); + }); +} + +template +inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + const SGDParam& param = nnvm::get(attrs.parsed); + auto weight_stype = inputs[0].storage_type(); + auto grad_stype = inputs[1].storage_type(); + if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage) { + SGDUpdateDnsRspImpl(param, ctx, inputs, req, outputs); + } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) { + FCompExFallback(attrs, ctx, inputs, req, outputs, SGDUpdate, "SGDUpdate"); + } +} + struct SGDMomParam : public dmlc::Parameter { float lr; float momentum; @@ -153,6 +234,88 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs, }); } +template +struct SGDMomDnsRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, size_t width, DType* out_data, + DType* mom_data, const DType* weight_data, const IType* grad_idx, + const DType* grad_data, const DType param_clip_gradient, const DType param_momentum, + const DType param_lr, const DType param_wd, const DType param_rescale_grad) { + for (size_t j = 0; j < width; j++) { + uint64_t data_i = grad_idx[i] * width + j; + uint64_t grad_i = i * width + j; + if (param_clip_gradient >= 0.0f) { + mom_data[data_i] = param_momentum * mom_data[data_i] + - param_lr * param_wd * weight_data[data_i] + - param_lr * + mshadow_op::clip::Map(param_rescale_grad * grad_data[grad_i], + param_clip_gradient); + } else { + mom_data[data_i] = param_momentum * mom_data[data_i] + - param_lr * param_wd * weight_data[data_i] + - param_lr * param_rescale_grad * grad_data[grad_i]; + } + KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]); + } + } +}; + +template +inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + auto &weight = inputs[0]; + auto &grad = inputs[1]; + auto &mom = inputs[2]; + auto &out = outputs[0]; + if (!grad.storage_initialized()) return; + + MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + auto weight_data = weight.data().FlatTo2D(s); + auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto grad_val = grad.data().FlatTo2D(s); + auto mom_data = mom.data().FlatTo2D(s); + auto out_data = out.data().FlatTo2D(s); + auto num_rows = grad.aux_shape(rowsparse::kIdx)[0]; + auto width = weight.shape().ProdShape(1, weight.shape().ndim()); + Kernel, xpu>::Launch(s, num_rows, width, + out_data.dptr_, mom_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_, + static_cast(param.clip_gradient), static_cast(param.momentum), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); + }); +} + +template +inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const SGDMomParam& param = nnvm::get(attrs.parsed); + auto weight_stype = inputs[0].storage_type(); + auto grad_stype = inputs[1].storage_type(); + auto mom_stype = inputs[2].storage_type(); + + if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage && + mom_stype == kDefaultStorage) { + SGDMomUpdateDnsRspDnsImpl(param, ctx, inputs, req, outputs); + } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage && + mom_stype == kDefaultStorage) { + FCompExFallback(attrs, ctx, inputs, req, outputs, + SGDMomUpdate, "SGDMomUpdate"); + } +} + struct AdamParam : public dmlc::Parameter { float lr; float beta1; diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 9ec6aacaafac..5464d03b215f 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -22,6 +22,9 @@ It updates the weights using:: weight = weight - learning_rate * gradient +If gradients are stored with `row_sparse` storage, +where update is applied only to rows whose gradient has non-zero entries. + )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -29,6 +32,7 @@ It updates the weights using:: .set_attr("FInferShape", ElemwiseShape<2, 1>) .set_attr("FInferType", ElemwiseType<2, 1>) .set_attr("FCompute", SGDUpdate) +.set_attr(FCOMP_EX_CPU, SGDUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_arguments(SGDParam::__FIELDS__()); @@ -52,6 +56,9 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. +If gradients are stored with `row_sparse` storage, +only rows whose gradients contain non-zero entries are updated (for both weight and momentum). + )code" ADD_FILELINE) .set_num_inputs(3) .set_num_outputs(1) @@ -63,12 +70,12 @@ Where the parameter ``momentum`` is the decay rate of momentum estimates at each return std::vector{2}; }) .set_attr("FCompute", SGDMomUpdate) +.set_attr(FCOMP_EX_CPU, SGDMomUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mom", "NDArray-or-Symbol", "Momentum") .add_arguments(SGDMomParam::__FIELDS__()); - NNVM_REGISTER_OP(adam_update) .describe(R"code(Update function for Adam optimizer. Adam is seen as a generalization of AdaGrad. diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 2b2667ec317b..bf0cc570e1f4 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -10,10 +10,12 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(sgd_update) -.set_attr("FCompute", SGDUpdate); +.set_attr("FCompute", SGDUpdate) +.set_attr(FCOMP_EX_GPU, SGDUpdateEx); NNVM_REGISTER_OP(sgd_mom_update) -.set_attr("FCompute", SGDMomUpdate); +.set_attr("FCompute", SGDMomUpdate) +.set_attr(FCOMP_EX_GPU, SGDMomUpdateEx); NNVM_REGISTER_OP(adam_update) .set_attr("FCompute", AdamUpdate); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc index 0d0a1d8b5df0..f6f8f429d99e 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc @@ -105,6 +105,7 @@ Example:: .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); + NNVM_REGISTER_OP(_backward_broadcast_mul) .set_num_inputs(3) .set_num_outputs(2) diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 6062febe2d9e..9317720f127a 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -10,10 +10,10 @@ #include #include #include +#include #include "../mxnet_op.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" -#include "../mxnet_op.h" namespace mxnet { namespace op { @@ -123,6 +123,115 @@ void BinaryBackwardUseNone_(const nnvm::NodeAttrs& attrs, } } +// TODO(haibin) This is a single-thread inefficient implementation +// Binary Compute between two row-sparse ndarray +// This implementation only works on CPU +template +void BinaryComputeRspRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + auto &lhs = inputs[0]; + auto &rhs = inputs[1]; + auto &output = outputs[0]; + + bool init_l = lhs.storage_initialized(); + bool init_r = rhs.storage_initialized(); + // both inputs are zeros + if (!init_l && !init_r) return; + // one of the input is zeros + if (!init_l || !init_r) { + NDArray out(output); + CopyFromToRspImpl(!init_l ? rhs : lhs, &out, ctx.run_ctx); + return; + } + // Memory Estimation: This is (roughly) the number of result rows. We still + // need to subtract the number of common rows + unsigned int num_rows_l = lhs.aux_shape(rowsparse::kIdx).Size(); + unsigned int num_rows_r = rhs.aux_shape(rowsparse::kIdx).Size(); + output.CheckAndAlloc({mshadow::Shape1(num_rows_l + num_rows_r)}); + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_TYPE_SWITCH(lhs.aux_type(rowsparse::kIdx), IType, { + // Indices + auto indices_l = lhs.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto indices_r = rhs.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto indices_out = output.aux_data(rowsparse::kIdx).FlatTo1D(s); + // Data + auto data_l = lhs.data().FlatTo2D(s); + auto data_r = rhs.data().FlatTo2D(s); + auto out = output.data().FlatTo2D(s); + + // TODO(haibin) A more appropriate way: Copy to output, then apply ops + size_t iter_l = 0; + size_t iter_r = 0; + size_t iter_out = 0; + int32_t num_common_rows = 0; + while (iter_l < num_rows_l && iter_r < num_rows_r) { + auto idx_l = indices_l[iter_l]; + auto idx_r = indices_r[iter_r]; + if (idx_l == idx_r) { + // Same row + indices_out[iter_out] = idx_l; + mshadow::Copy(out[iter_out], data_l[iter_l++], s); + out[iter_out] += data_r[iter_r++]; + num_common_rows++; + } else if (idx_l < idx_r) { + // Left only + indices_out[iter_out] = idx_l; + mshadow::Copy(out[iter_out], data_l[iter_l++], s); + } else { + // Right only + indices_out[iter_out] = idx_r; + mshadow::Copy(out[iter_out], data_r[iter_r++], s); + } + iter_out++; + } + // Copying over the rest of the rows + while (iter_l < num_rows_l) { + indices_out[iter_out] = indices_l[iter_l]; + mshadow::Copy(out[iter_out++], data_l[iter_l++], s); + } + while (iter_r < num_rows_r) { + indices_out[iter_out] = indices_r[iter_r]; + mshadow::Copy(out[iter_out++], data_r[iter_r++], s); + } + auto new_shape = output.aux_shape(rowsparse::kIdx); + new_shape[0] -= num_common_rows; + output.SetAuxShape(rowsparse::kIdx, new_shape); + }); + }); +} + +template +void BinaryComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2); + CHECK_EQ(outputs.size(), 1); + if (typeid(OP) == typeid(mshadow::op::plus)) { + // If any input is dense, fallback to FCompute + // TODO(haibin) implement dns + rsp in a separate kernel + if (common::ContainsDefaultStorage(inputs)) { + FCompExFallback(attrs, ctx, inputs, req, outputs, + BinaryCompute, "BinaryCompute"); + return; + } + CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; + CHECK_EQ(inputs[1].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; + BinaryComputeRspRsp(attrs, ctx, inputs, req, outputs); + return; + } else { + LOG(FATAL) << "Not implemented"; + } +} + template void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -134,6 +243,55 @@ void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, }); } +// Only implemented for _backward_add for now +template +void BinaryBackwardUseNoneRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage); + CHECK_EQ(outputs[0].storage_type(), kRowSparseStorage); + CHECK_EQ(outputs[1].storage_type(), kRowSparseStorage); + CHECK(typeid(LOP) == typeid(mshadow_op::identity)); + CHECK(typeid(ROP) == typeid(mshadow_op::identity)); + TShape shape = inputs[0].aux_shape(rowsparse::kIdx); + outputs[0].CheckAndAlloc({shape}); + outputs[1].CheckAndAlloc({shape}); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + MSHADOW_TYPE_SWITCH(outputs[0].aux_type(rowsparse::kIdx), IType, { + auto lgrad_idx = outputs[0].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto rgrad_idx = outputs[1].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto ograd_idx = inputs[0].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto lgrad = outputs[0].data().FlatTo1D(s); + Tensor rgrad = outputs[1].data().FlatTo1D(s); + Tensor ograd = inputs[0].data().FlatTo1D(s); + ASSIGN_DISPATCH(lgrad, req[0], F(ograd)); + ASSIGN_DISPATCH(rgrad, req[1], F(ograd)); + ASSIGN_DISPATCH(lgrad_idx, req[0], F(ograd_idx)); + ASSIGN_DISPATCH(rgrad_idx, req[1], F(ograd_idx)); + }); + }); +} +// Only implemented for _backward_add for now +template +void BinaryBackwardUseNoneEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + auto stype = inputs[0].storage_type(); + CHECK_EQ(stype, kRowSparseStorage) << "Not implemented yet"; + BinaryBackwardUseNoneRsp(attrs, ctx, inputs, req, outputs); + // TODO(haibin) fallback for kDefaultStorage +} + template void BinaryBackwardUseNoneWithHalf2(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -214,7 +372,7 @@ void BinaryBackwardUseInWithHalf2(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}, {1, 0}}; \ }) \ - .add_argument("lhs", "NDArray-or-Symbol", "first input") \ + .add_argument("lhs", "NDArray-or-Symbol", "first input") \ .add_argument("rhs", "NDArray-or-Symbol", "second input") } // namespace op diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index be4c1d88e983..8bf0d2e10c01 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -12,7 +12,9 @@ MXNET_OPERATOR_REGISTER_BINARY(elemwise_add) .add_alias("_add").add_alias("_plus").add_alias("_Plus") .describe("Adds arguments element-wise.") .set_attr("FCompute", BinaryCompute) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_add"}); +.set_attr(FCOMP_EX_CPU, BinaryComputeEx) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_add"}) +.set_attr("FInferStorageType", ElemwiseStorageType<2, 1>); // specialized gradient add function to do add to optimization // this must differ from elemwise_add to prevent add to optimization in forward pass. @@ -28,7 +30,10 @@ NNVM_REGISTER_OP(_backward_add) return std::vector >{{0, 0}, {0, 1}}; }) .set_attr("FCompute", BinaryBackwardUseNone); + mshadow_op::identity>) +.set_attr(FCOMP_EX_CPU, + BinaryBackwardUseNoneEx) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 2>); MXNET_OPERATOR_REGISTER_BINARY(_sub) .add_alias("_minus").add_alias("_Minus") diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index ff432380d6d1..cb30d78e2d8e 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -9,7 +9,8 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(elemwise_add) -.set_attr("FCompute", BinaryComputeWithHalf2); +.set_attr("FCompute", BinaryComputeWithHalf2) +.set_attr(FCOMP_EX_GPU, BinaryComputeEx); NNVM_REGISTER_OP(_grad_add) .set_attr("FCompute", BinaryComputeWithHalf2); @@ -17,7 +18,9 @@ NNVM_REGISTER_OP(_grad_add) NNVM_REGISTER_OP(_backward_add) .set_attr("FCompute", BinaryBackwardUseNoneWithHalf2); + mshadow_op::identity, mshadow_op::identity>) +.set_attr(FCOMP_EX_GPU, + BinaryBackwardUseNoneEx); NNVM_REGISTER_OP(_sub) .set_attr("FCompute", BinaryComputeWithHalf2); diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc index ce29a2fdb308..0220b096ba45 100644 --- a/src/operator/tensor/elemwise_unary_op.cc +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -120,7 +120,9 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) .set_attr("FIgnoreInputs", [](const NodeAttrs& attrs) { return std::vector(1, 1); }) .set_attr("FCompute", IdentityCompute) +.set_attr(FCOMP_EX_CPU, IdentityLikeRhsComputeEx) .set_attr("FInferShape", ElemwiseShape<2, 1>) +.set_attr("FInferStorageType", IdentityAttrLikeRhsStorageType) .set_attr( "FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { @@ -163,6 +165,27 @@ NNVM_REGISTER_OP(_backward_cast) .set_attr("TIsBackward", true) .set_attr("FCompute", CastCompute); +// TODO(haibin) declare backward op for cast storage +// Only support cast to default storage now +// Other types require add infer_storage type pass +DMLC_REGISTER_PARAMETER(CastStorageParam); +NNVM_REGISTER_OP(cast_storage) +.describe(R"code(Casts tensor storage type to the new type. +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", CastStorageInferStorageType) +.set_attr("FCompute", IdentityCompute) +// _backward pass +// .set_attr("FGradient", ElemwiseGradUseNone{"negative"}) +.set_attr(FCOMP_EX_CPU, CastStorageComputeEx) +.add_argument("data", "NDArray-or-Symbol", "The input.") +.add_arguments(CastStorageParam::__FIELDS__()); + + // negative MXNET_OPERATOR_REGISTER_UNARY(negative) .MXNET_DESCRIBE("Negate src") diff --git a/src/operator/tensor/elemwise_unary_op.cu b/src/operator/tensor/elemwise_unary_op.cu index 746b39fe4c8c..2084f5d3f5c4 100644 --- a/src/operator/tensor/elemwise_unary_op.cu +++ b/src/operator/tensor/elemwise_unary_op.cu @@ -35,7 +35,9 @@ NNVM_REGISTER_OP(make_loss) // identity output as first input, but attributes are constrainted to be like rhs NNVM_REGISTER_OP(_identity_with_attr_like_rhs) -.set_attr("FCompute", IdentityCompute); +.set_attr("FCompute", IdentityCompute) +.set_attr(FCOMP_EX_GPU, IdentityLikeRhsComputeEx); + NNVM_REGISTER_OP(Cast) .set_attr("FCompute", CastCompute); @@ -43,6 +45,10 @@ NNVM_REGISTER_OP(Cast) NNVM_REGISTER_OP(_backward_cast) .set_attr("FCompute", CastCompute); +NNVM_REGISTER_OP(cast_storage) +.set_attr("FCompute", IdentityCompute) +.set_attr(FCOMP_EX_GPU, CastStorageComputeEx); + // negative NNVM_REGISTER_OP(negative) .set_attr("FCompute", UnaryCompute); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 97a7e36535f0..ffd153bca797 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -13,15 +13,17 @@ #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "../special_functions-inl.h" +#include "../mxnet_op.h" +#include "./broadcast_reduce-inl.h" namespace mxnet { namespace op { template void UnaryLaunch(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mxnet_op; Stream *s = ctx.get_stream(); @@ -77,6 +79,54 @@ void IdentityCompute(const nnvm::NodeAttrs& attrs, }); } +template +void IdentityComputeRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + auto &input = inputs[0]; + auto &output = outputs[0]; + CHECK_NE(req[0], kNullOp) << "kNullOp in IdentityComputeEx not supported yet"; + CHECK_NE(req[0], kWriteInplace) << "kWriteInplace in IdentityComputeEx not supported yet"; + if (!input.storage_initialized()) return; + TShape shape = input.aux_shape(rowsparse::kIdx); + output.CheckAndAlloc({shape}); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_TYPE_SWITCH(output.aux_type(rowsparse::kIdx), AuxType, { + auto out_d = output.data().FlatTo1D(s); + auto out_aux = output.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto in_aux = input.aux_data(rowsparse::kIdx).FlatTo1D(s); + ASSIGN_DISPATCH(out_d, req[0], + F(input.data().FlatTo1D(s))); + ASSIGN_DISPATCH(out_aux, req[0], F(in_aux)); + }); + }); +} + +template +void IdentityLikeRhsComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 2); + CHECK_EQ(outputs.size(), 1); + Stream *s = ctx.get_stream(); + size_t rhs_idx = 1; + NDArrayStorageType stype = inputs[rhs_idx].storage_type(); + if (stype == kRowSparseStorage) { + IdentityComputeRsp(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not implemented yet"; + } +} + struct CastParam : public dmlc::Parameter { // use int for enumeration int dtype; @@ -154,6 +204,393 @@ struct relu_grad { }; } // namespace kernel_launch_op +struct CastStorageParam : public dmlc::Parameter { + // use int for enumeration + // TODO(haibin) add enum for storage_type. Probably also aux-types + int storage_type; + DMLC_DECLARE_PARAMETER(CastStorageParam) { + DMLC_DECLARE_FIELD(storage_type) + .add_enum("default_storage", kDefaultStorage) + .add_enum("row_sparse", kRowSparseStorage) + .add_enum("csr", kCSRStorage) + .describe("Output storage type."); + } +}; + +/*! + * \brief This is the kernel for initializing row_idx array + * of a RSP matrix. Each thread checks a row of the matrix, + * if non-zero elements are found, mark this row as non-zero + * by row_idx[cur_row_id] = cur_row_id. Otherwise, + * row_idx[cur_row_id] = num_rows. + */ +struct FillRspRowIdx { + template + MSHADOW_XINLINE static void Map(int i, RType* row_idx, const DType* arr, + const int num_rows, const int num_cols) { + row_idx[i] = num_rows; + const int offset = i * num_cols; + for (int j = 0; j < num_cols; ++j) { + if (arr[offset+j] != 0) { + row_idx[i] = i; + break; + } + } + } +}; + +/*! + * \brief Kernel for marking row_idx of a RSP matrix per row + */ +struct MarkRspRowIdx { + // i represents the row index of the matrix data + template + MSHADOW_XINLINE static void Map(int i, RType* row_idx, const DType* data, + const index_t num_cols) { + index_t j = 0; + index_t offset = i * num_cols; + for (; j < num_cols; ++j) { + if (data[offset+j] != 0) { + break; + } + } + if (num_cols == j) { + row_idx[i] = 0; // mark as zero for zero row + } else { + row_idx[i] = 1; // mark as one for non-zero row + } + } +}; + +struct CopyDnsToRsp{ + // i represents the row index of the matrix data + template + MSHADOW_XINLINE static void Map(int i, RType* row_idx, DType* rsp_data, + const DType* dns_data, const int num_rows, const int num_cols) { + int j = 0; + int offset = i * num_cols; + for (; j < num_cols; ++j) { + if (dns_data[offset+j] != 0) { + break; + } + } + if (num_cols == j) { + row_idx[i] = num_rows; + } else { + row_idx[i] = i; + for (j = 0; j < num_cols; ++j) { + rsp_data[offset+j] = dns_data[offset+j]; + } + } + } +}; + +/*! + * \brief + * Given a DNS storage type tensor, create a RSP type sparse tensor + * from it. This would allocate memory for storing the row idx and + * non-zero rows for the rsp and deep-copy non-zero rows of the + * dns to the rsp data blob. + * TODO(junwu): The argument type for the dense ndarray is TBlob instead + * of NDArray since it's convenient to call this function from any + * operator's Forward/Backward functions where dev_id is unknown + * but required to wrap a TBlob object as an NDArray. See the use case + * in DotForwardCsrDnsRsp in matrix_op-inl.h. + * Will revisit this interface in the future. + * TODO(junwu): Add gpu implementation. + */ +inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDArray* rsp) { + CHECK(rsp != nullptr); + CHECK_EQ(rsp->storage_type(), kRowSparseStorage); + CHECK_EQ(dns.shape_, rsp->shape()); + MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type + const index_t num_rows = dns.shape_[0]; + const index_t num_cols = dns.shape_[1]; + rsp->CheckAndAllocAuxData(rowsparse::kIdx, mshadow::Shape1(num_rows)); + TBlob row_idx_blob = rsp->aux_data(rowsparse::kIdx); + RType* row_idx = row_idx_blob.dptr(); + mxnet_op::Kernel::Launch(s, num_rows, row_idx, + dns.dptr(), num_cols); + index_t nnr = 0; + nnr = std::accumulate(row_idx, row_idx+num_rows, nnr); + rsp->SetAuxShape(rowsparse::kIdx, mshadow::Shape1(nnr)); + if (0 == nnr) return; + rsp->CheckAndAllocData(mshadow::Shape2(nnr, num_cols)); + mshadow::Tensor dns_data = dns.FlatTo2D(s); + mshadow::Tensor rsp_data = rsp->data().FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < num_rows; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], dns_data[i], s); + ++idx; + } + } + }); + }); +} + +// TODO(haibin) Use memcopy instead will be much faster than assigning each individual element +struct CastStorageRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, const index_t width, const IType* idx, const DType *data, + DType* dns, const index_t invalid_rid) { + auto rid = idx[i]; + // skip invalid rows + if (rid == invalid_rid) return; + auto dns_offset = rid * width; + auto rsp_offset = i * width; + for (size_t col = 0; col < width; col++) { + dns[dns_offset + col] = data[rsp_offset + col]; + } + } +}; + +/*! + * \brief This function assumes that the meomry for dns has been allocated already + * since the shape is known at binding stage. + */ +template +void CastStorageRspDnsImpl(mshadow::Stream* s, const NDArray& rsp, TBlob* dns) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(rsp.storage_type(), kRowSparseStorage); + MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { + MSHADOW_INT_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { + // assign zeros + mxnet_op::Kernel::Launch(s, dns->Size(), dns->dptr()); + if (rsp.storage_initialized()) { + // copy over row by row + auto in_idx = rsp.aux_data(rowsparse::kIdx).FlatTo1D(s).dptr_; + auto in_data = rsp.data().FlatTo2D(s).dptr_; + auto out_data = dns->FlatTo2D(s).dptr_; + auto num_rows = rsp.aux_shape(rowsparse::kIdx).Size(); + auto rsp_shape = rsp.shape(); + auto invalid_rid = rsp_shape[0]; + auto width = rsp_shape.ProdShape(1, rsp_shape.ndim()); + mxnet_op::Kernel::Launch(s, num_rows, width, in_idx, in_data, + out_data, invalid_rid); + } + }); + }); +} + +/*! + * \brief This is the kernel for initializing the indptr in a csr tensor. + */ +struct FillCsrIndPtr { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param indptr indptr of the csr tensor + * \param dns the dns tensor + * \param num_rows + * \param num_cols + */ + template + MSHADOW_XINLINE static void Map(int i, IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + indptr[i+1] = 0; + const int offset = i * num_cols; + for (int j = 0; j < num_cols; ++j) { + if (dns[offset+j] != 0) { + ++indptr[i+1]; + } + } + } +}; + +/*! + * \brief This is the kernel for initializing the col_idx and value array + * of the csr tensor + */ +struct FillCsrColIdxAndVals { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param val value array of the csr + * \param col_idx column idx array of the csr + * \param indptr indptr array of the csr + * \param dns the dns tensor + * \param num_rows number of rows of the dns + * \param num_cols number of columns of the dns + */ + template + MSHADOW_XINLINE static void Map(int i, DType* val, CType* col_idx, + const IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + const int offset = i * num_cols; + int k = indptr[i]; + for (int j = 0; j < num_cols; ++j) { + if (dns[offset+j] != 0) { + val[k] = dns[offset+j]; + col_idx[k] = j; + ++k; + } + } + } +}; + +/*! + * \brief + * Given a DNS storage type tensor, create a CSR type sparse tensor from it. + * This would allocate memory for storing the indptr, values, and column idx + * of the csr and copy the non-zero values to the value array in the csr. + * TODO(junwu): The argument type for the dense ndarray is TBlob instead + * of NDArray since it's convenient to call this function from any + * operator's Forward/Backward functions where dev_id is unknown + * but required to wrap a TBlob object as an NDArray. See the use case + * in DotForwardCsrDnsRsp in matrix_op-inl.h. + * Will revisit this interface in the future. + */ +template +void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDArray* csr) { + CHECK(csr != nullptr); + CHECK_EQ(csr->storage_type(), kCSRStorage); + CHECK_EQ(dns.shape_.ndim(), 2); + CHECK_EQ(dns.shape_, csr->shape()); + MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type + const index_t num_rows = dns.shape_[0]; + const index_t num_cols = dns.shape_[1]; + csr->CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(num_rows+1)); + IType* indptr = csr->aux_data(csr::kIndPtr).dptr(); + DType* dns_data = dns.dptr(); + mxnet_op::Kernel::Launch(s, num_rows, indptr, + dns_data, num_rows, num_cols); + // single thread to accumulate indptr + // indptr[num_rows] indicates the number of non-zero elements + indptr[0] = 0; + for (index_t i = 0; i < num_rows; ++i) { + indptr[i+1] += indptr[i]; + } + // allocate column idx array and value array + csr->CheckAndAllocAuxData(csr::kIdx, + mshadow::Shape1(static_cast(indptr[num_rows]))); + csr->CheckAndAllocData(mshadow::Shape1(static_cast(indptr[num_rows]))); + // fill col_idx and value arrays of the csr + mxnet_op::Kernel::Launch(s, num_rows, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); + }); + }); + }); +} + +/*! + * \brief This is the kernel for copying csr.data to its corresponding dns tensor. + */ +struct CopyCsrDataToDns { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param dns_data data blob of the dns tensor + * \param col_idx column idx array of the csr + * \param indptr indptr array of the csr + * \param csr_data data blob of the csr tensor + * \param num_cols number of columns of the dns + */ + template + MSHADOW_XINLINE static void Map(int i, DType* dns_data, const CType* col_idx, + const IType* indptr, const DType* csr_data, + const int num_cols) { + const int offset = i * num_cols; + for (auto j = indptr[i]; j < indptr[i+1]; ++j) { + dns_data[offset+col_idx[j]] = csr_data[j]; + } + } +}; + +/*! + * \brief + * Given a CSR storage type tensor, create a DNS type sparse tensor from it. + * This assumes that the memory of dns.data() has been allocated in binding stage. + * TODO(junwu): The argument type for the dense ndarray is TBlob instead + * of NDArray since it's convenient to call this function from any + * operator's Forward/Backward functions where dev_id is unknown + * but required to wrap a TBlob object as an NDArray. See the use case + * in DotForwardCsrDnsRsp in matrix_op-inl.h. + * Will revisit this interface in the future. + */ +template +void CastStorageCsrDnsImpl(mshadow::Stream* s, const NDArray& csr, TBlob* dns) { + CHECK(dns != nullptr); + CHECK_EQ(csr.storage_type(), kCSRStorage); + CHECK_EQ(dns->shape_.ndim(), 2); + CHECK_EQ(dns->shape_, csr.shape()); + MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type + const index_t num_rows = dns->shape_[0]; + const index_t num_cols = dns->shape_[1]; + DType* dns_data = dns->dptr(); + mxnet_op::Kernel::Launch(s, dns->shape_.Size(), dns_data); + if (!csr.storage_initialized()) return; + const IType* indptr = csr.aux_data(csr::kIndPtr).dptr(); + const CType* col_idx = csr.aux_data(csr::kIdx).dptr(); + const DType* csr_data = csr.data().dptr(); + mxnet_op::Kernel::Launch(s, num_rows, dns_data, + col_idx, indptr, csr_data, num_cols); + }); + }); + }); +} + +inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_NE(in_attrs->at(0), kUndefinedStorage) + << "src ndarray's storage type must be specified"; + const CastStorageParam& param = nnvm::get(attrs.parsed); + CHECK_NE(param.storage_type, kUndefinedStorage) + << "dst ndarray's storage type must be specified"; + TYPE_ASSIGN_CHECK(*out_attrs, 0, param.storage_type); + return true; +} + +template +void CastStorageComputeImpl(mshadow::Stream* s, + const NDArray& input, + const NDArray& output) { + using namespace mshadow; + using namespace mshadow::expr; + const auto src_stype = input.storage_type(); + const auto dst_stype = output.storage_type(); + if (src_stype == kRowSparseStorage && dst_stype == kDefaultStorage) { + TBlob ret = output.data(); + CastStorageRspDnsImpl(s, input, &ret); + } else if (src_stype == kDefaultStorage && dst_stype == kRowSparseStorage) { + NDArray ret = output; // get rid of the const qualifer + CastStorageDnsRspImpl(s, input.data(), &ret); + } else if (src_stype == kDefaultStorage && dst_stype == kCSRStorage) { + NDArray ret = output; // get rid of the const qualifer + CastStorageDnsCsrImpl(s, input.data(), &ret); + } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) { + TBlob ret = output.data(); + CastStorageCsrDnsImpl(s, input, &ret); + } else { + LOG(FATAL) << "Not implemented"; + } +} + +template +void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + CastStorageComputeImpl(s, inputs[0], outputs[0]); +} + #define MXNET_OPERATOR_REGISTER_UNARY(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ @@ -168,4 +605,5 @@ struct relu_grad { } // namespace op } // namespace mxnet + #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_H_ diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index f9023054a10f..fed4b4dd229b 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -86,6 +86,40 @@ NNVM_REGISTER_OP(_backward_Embedding) .set_attr("TIsBackward", true) .set_attr("FCompute", EmbeddingOpBackward); +NNVM_REGISTER_OP(SparseEmbedding) +.describe(R"code(Maps integer indices to vector representations (embeddings) with sparse weight update +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "weight"}; + }) +.set_attr("FInferShape", EmbeddingOpShape) +.set_attr("FInferType", EmbeddingOpType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", EmbeddingOpForward) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + return MakeNonlossGradNode("_backward_SparseEmbedding", n, ograds, + {n->inputs[0]}, n->attrs.dict); + }) +.add_argument("data", "NDArray-or-Symbol", "The input array to the embedding operator.") +.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.") +.add_arguments(EmbeddingParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_SparseEmbedding) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", SparseEmbeddingBackwardStorageType) +.set_attr("FComputeEx", SparseEmbeddingOpBackwardEx); +// TODO(haibin) handle dense case +// .set_attr("FCompute", EmbeddingOpBackward); NNVM_REGISTER_OP(take) .describe(R"code(Takes elements from an input array along the given axis. diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 5fd6e81d0b2f..12523e237cf2 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -315,6 +316,133 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, }); } +template +struct EmbeddingBackwardRsp { + template + // each thread i is responsible for target gradient row ids in [segment_start, segment_end) + MSHADOW_XINLINE static void Map(int i, const size_t width, IType* dst_idx, DType* dst_val, + const IType* idx, const size_t num_idx, const DType* src, + const size_t segment_len, const size_t num_rows) { + auto req_type = req; + size_t segment_start = i * segment_len; + size_t segment_end = (i + 1) * segment_len; + for (size_t y = 0; y < num_idx; y++) { + size_t j = idx[y]; + if (j >= num_rows) j = num_rows - 1; + if (j < segment_start || j >= segment_end) continue; + dst_idx[j] = j; + for (size_t k = 0; k < width; k++) { + if (req_type == kWriteTo) req_type = kAddTo; + KERNEL_ASSIGN(dst_val[j * width + k], req_type, src[y * width + k]); + } + } + } +}; + +/* + * for sparse embedding, the storage type for weight gradient is row_sparse. + * we don't care about the storage type for data gradient, since it is not + * differentiable. + */ +inline bool SparseEmbeddingBackwardStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ((*in_attrs)[0], kDefaultStorage); + CHECK_EQ((*in_attrs)[1], kDefaultStorage); + (*out_attrs)[0] = kRowSparseStorage; + (*out_attrs)[1] = kRowSparseStorage; + return true; +} + +template +void SparseEmbeddingOpBackwardDnsDnsRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + if (req[1] == kNullOp) return; + // check storage types + auto idx = inputs[1]; // idx shape (d1, d2 .. dk) + auto grad = inputs[0]; // grad shape (d1, d2, .. dk, out_dim) + auto output = outputs[1]; // weight shape (in_dim, out_dim) + CHECK_EQ(idx.storage_type(), kDefaultStorage); + CHECK_EQ(grad.storage_type(), kDefaultStorage); + CHECK_EQ(output.dtype(), grad.dtype()); + CHECK_EQ(idx.dtype(), output.aux_type(rowsparse::kIdx)) << "Index type doesn't match"; + // CHECK_EQ(req[embedding::kData], kNullOp) + // << "Embedding layer doesn't support calculate data gradient" << req[embedding::kData]; + + const TShape& ishape = idx.shape(); + const TShape& oshape = grad.shape(); + + Stream *s = ctx.get_stream(); + CHECK_EQ(idx.dtype(), output.aux_type(rowsparse::kIdx)) + << "embedding input index and gradient row sparse type doesn't match!"; + // Alloc dense output + unsigned int num_rows = output.shape()[0]; + output.CheckAndAlloc({mshadow::Shape1(num_rows)}); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(idx.dtype(), IType, { + MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, { + // input embedding indice, each idx in [0, input_dim) + auto idx_data = idx.data().FlatTo1D(s); + auto grad_data = grad.data().get_with_shape( + Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); + auto output_idx = output.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto output_val = output.data().FlatTo2D(s); + int num_threads = omp_get_num_threads(); + size_t width = output.shape()[1]; + size_t segment_len = (num_rows + num_threads - 1) / num_threads; + // fill indices with invalid row ids + Kernel::Launch(s, num_rows, output_idx.dptr_, + static_cast(num_rows)); + // fill zeros if needed + if (req_type == kWriteTo) { + Kernel::Launch(s, output_val.shape_.Size(), output_val.dptr_); + } + Kernel, xpu>::Launch(s, num_threads, width, + output_idx.dptr_, + output_val.dptr_, idx_data.dptr_, + ishape.Size(), grad_data.dptr_, + segment_len, num_rows); + }); + }); + }); +} + +// todo replace xpu with cpu +template +void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + // CHECK_EQ(req[embedding::kData], kNullOp) + // << "Embedding layer doesn't support calculate data gradient" << req[0] << " " << req[1]; + // idx shape (d1, d2 .. dk) + auto idx_stype = inputs[1].storage_type(); + // grad shape (d1, d2, .. dk, out_dim) + auto grad_stype = inputs[0].storage_type(); + // weight shape (in_dim, out_dim) + auto output_stype = outputs[1].storage_type(); + if (idx_stype == kDefaultStorage && grad_stype == kDefaultStorage && + output_stype == kRowSparseStorage) { + SparseEmbeddingOpBackwardDnsDnsRsp(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not implemented"; + } +} + namespace take_ { // to avoid name conflict enum TakeOpInputs {kArr, kIdx}; enum TakeOpOutputs {kOut}; diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index 16f71fc7e4e3..a5827330a61f 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -21,6 +21,7 @@ NNVM_REGISTER_OP(_zeros) .set_attr("FInferShape", InitShape) .set_attr("FInferType", InitType) .set_attr("FCompute", FillCompute) +.set_attr(FCOMP_EX_CPU, FillComputeZerosEx) .add_arguments(InitOpParam::__FIELDS__()); NNVM_REGISTER_OP(_ones) diff --git a/src/operator/tensor/init_op.cu b/src/operator/tensor/init_op.cu index a798f26db60d..bcb10f70b3c3 100644 --- a/src/operator/tensor/init_op.cu +++ b/src/operator/tensor/init_op.cu @@ -9,7 +9,8 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_zeros) -.set_attr("FCompute", FillCompute); +.set_attr("FCompute", FillCompute) +.set_attr(FCOMP_EX_GPU, FillComputeZerosEx); NNVM_REGISTER_OP(_ones) .set_attr("FCompute", FillCompute); diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 5ce132d4bebf..ca61f9bba460 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -15,6 +15,8 @@ #include #include #include "../elemwise_op_common.h" +#include "../mxnet_op.h" + namespace mxnet { namespace op { @@ -111,7 +113,6 @@ inline bool InitType(const nnvm::NodeAttrs& attrs, return true; } - template void FillCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -127,6 +128,51 @@ void FillCompute(const nnvm::NodeAttrs& attrs, }); } +// Fill a rsp NDArray with zeros by updating the aux shape. +template +void FillZerosRspImpl(mshadow::Stream *s, NDArray *dst) { + if (!dst->storage_initialized()) return; + // reset the shapes if it's not zeros + auto storage_shape = dst->storage_shape(); + storage_shape[0] = 0; + dst->SetAuxShape(rowsparse::kIdx, TShape(mshadow::Shape1(0))); + dst->SetStorageShape(storage_shape); +} + +// Fill a CSR NDArray with zeros by updating the aux shape. +template +void FillZerosCsrImpl(mshadow::Stream *s, NDArray *dst) { + if (!dst->storage_initialized()) return; + // reset the shapes if it's not zeros + TShape new_shape(mshadow::Shape1(0)); + dst->SetAuxShape(csr::kIndPtr, new_shape); + dst->SetAuxShape(csr::kIdx, new_shape); + dst->SetStorageShape(new_shape); +} + +// This operator never needs to fall back, since there's no input NDArray +template +void FillComputeZerosEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(outputs.size(), 1); + CHECK_EQ(inputs.size(), 0); + auto stype = outputs[0].storage_type(); + if (stype == kRowSparseStorage) { + NDArray nd(outputs[0]); + FillZerosRspImpl(s, &nd); + } else if (stype == kCSRStorage) { + NDArray nd(outputs[0]); + FillZerosCsrImpl(s, &nd); + } else { + LOG(FATAL) << "storage type not implemented."; + } +} template void RangeCompute(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index d7a591944e47..3b54bf240447 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -476,6 +476,164 @@ void DotBackward_(const nnvm::NodeAttrs& attrs, } } +inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + out_attrs->at(0) = kDefaultStorage; + return true; +} + +inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 2U); + out_attrs->at(0) = kDefaultStorage; + out_attrs->at(1) = kDefaultStorage; + return true; +} + +/*! + * \brief Tempalte declaration of dot(csr, dns1) = dns2. + * Whether csr and dns1 are transposed before dot operation + * is determined by trans_csr and trans_dns, respectively. + * For now we only implemented the case when trans_dns = false. + */ +template +struct DotCsrDnsDns; + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + */ +template +struct DotCsrDnsDns { + /*! + * \brief This function represents performing an inner product between a row of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_cols number of columns of output + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols) { + const int irow = i / num_cols; // row id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { + const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs + sum += data_l[j] * data_r[cur_col*num_cols+icol]; + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + */ +template +struct DotCsrDnsDns { + /*! + * \brief This function represents performing an inner product between a column of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_rows_l number of rows of lhs + * \param num_cols number of columns of outputs + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const int num_rows_l, + const int num_cols) { + const int irow = i / num_cols; // col id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (int k = 0; k < num_rows_l; ++k) { + const IType low = indptr_l[k]; + const IType high = indptr_l[k+1]; + if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; + int j = -1, l = low, r = high - 1; + while (l <= r) { + int m = l + (r - l) / 2; + if (col_idx_l[m] == irow) { + j = m; break; + } + if (col_idx_l[m] < irow) { + l = m + 1; + } else { + r = m - 1; + } + } + if (j >= 0) { + sum += data_l[j] * data_r[k*num_cols+icol]; + } + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +template +void DotCsrDnsDnsImpl(const OpContext& ctx, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(rhs.storage_type(), kDefaultStorage); + CHECK_EQ(ret->storage_type(), kDefaultStorage); + + mshadow::Stream *s = ctx.get_stream(); + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob data_r = rhs.data(); + const TBlob data_out = ret->data(); + + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (!lhs.storage_initialized()) return; + if (trans_lhs) { + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], + rhs.shape()[1]); + } else { + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), rhs.shape()[1]); + } + }); + }); + }); + }); +} + +template +void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const DotParam& param = nnvm::get(attrs.parsed); + NDArray ret = outputs[1]; + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0], req[1], !param.transpose_a, &ret); +} + inline bool DotShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -519,6 +677,57 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, return true; } +template +void DotForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; + + NDArray ret = outputs[0]; // get rid of the const qualifier + if (inputs[0].storage_type() == kCSRStorage + && inputs[1].storage_type() == kDefaultStorage + && outputs[0].storage_type() == kDefaultStorage) { + DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else { // TODO(junwu): add fallback + LOG(FATAL) << "Not supported dot operation for lhs.storage_type = " + << inputs[0].storage_type() << ", rhs.storage_type = " << inputs[1].storage_type() + << ", out.storage_type = " << outputs[0].storage_type(); + } +} + +template +void DotBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_EQ(kNullOp, req[0]) + << "sparse dot does not support computing the gradient of the csr/lhs"; + CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; + + // TODO(junwu): check whether this CHECK is reasonable + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; + if (inputs[0].storage_type() == kDefaultStorage // ograd dns format + // dns, csr, dns => *, dns + && inputs[1].storage_type() == kCSRStorage // csr input lhs of the op + && inputs[2].storage_type() == kDefaultStorage // dns input rhs of the op + && outputs[1].storage_type() == kDefaultStorage) { // grad(rhs) dns format + DotBackwardCsrDnsDns(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; + } +} + template void BatchDotForward_(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -786,6 +995,96 @@ void Slice(const nnvm::NodeAttrs& attrs, }); } +// slice the indptr of a csr +struct SliceCsrIndPtr { + template + MSHADOW_XINLINE static void Map(int i, IType* out, const IType* in, const IType* base) { + KERNEL_ASSIGN(out[i], kWriteTo, in[i] - *base); + } +}; + +/* + * a wrapper to launch SliceCsrIndPtr kernel. + * slice [src[begin] .. src[end]) and store in dst[0, end - begin) + */ +template +void SliceCsrIndPtrImpl(const int begin, const int end, RunContext ctx, + const IType* src, IType* dst) { + using namespace mshadow; + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + int indptr_len = end - begin + 1; + Kernel::Launch(s, indptr_len, dst, src + begin, src + begin); +} + +/* + * Slice a CSR NDArray + * Only implemented for CPU + */ +template +void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, + const NDArray &in, OpReqType req, const NDArray &out) { + using namespace mshadow; + using namespace mxnet_op; + using namespace csr; + CHECK((std::is_same::value)) << "Slice for CSR input only implemented for CPU"; + if (req == kNullOp) return; + CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported"; + CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported"; + Stream *s = ctx.get_stream(); + int begin = *param.begin[0]; + int end = *param.end[0]; + int indptr_len = end - begin + 1; + out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len)); + if (!in.storage_initialized()) { + out.SetAuxShape(kIndPtr, Shape1(0)); + return; + } + CHECK_EQ(in.aux_type(kIndPtr), in.aux_type(kIdx)) + << "The type for indptr and indices are different. This is not implemented yet."; + // assume idx indptr share the same type + MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIndPtr), IType, { + MSHADOW_TYPE_SWITCH(in.dtype(), DType, { + auto in_indptr = in.aux_data(kIndPtr).dptr(); + auto out_indptr = out.aux_data(kIndPtr).dptr(); + SliceCsrIndPtrImpl(begin, end, ctx.run_ctx, in_indptr, out_indptr); + + // retrieve nnz (CPU implementation) + int nnz = out_indptr[indptr_len - 1]; + // copy indices and values + out.CheckAndAllocAuxData(kIdx, Shape1(nnz)); + out.CheckAndAllocData(Shape1(nnz)); + auto in_idx = in.aux_data(kIdx).dptr(); + auto out_idx = out.aux_data(kIdx).dptr(); + auto in_data = in.data().dptr(); + auto out_data = out.data().dptr(); + int offset = in_indptr[begin]; + // this is also a CPU-only implementation + memcpy(out_idx, in_idx + offset, nnz * sizeof(IType)); + memcpy(out_data, in_data + offset, nnz * sizeof(DType)); + }); + }); +} + +template +void SliceEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + const SliceParam& param = nnvm::get(attrs.parsed); + auto in_stype = inputs[0].storage_type(); + CHECK_NE(in_stype, kDefaultStorage) + << "SliceEx is not expected to execute for input with default storage type"; + if (in_stype == kCSRStorage) { + SliceCsrImpl(param, ctx, inputs[0], req[0], outputs[0]); + } else { + LOG(FATAL) << "Slice not implemented for storage type" << in_stype; + } +} + inline bool SliceAssignShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 1a9eaf505cb8..c5fb8ad96ac5 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -232,6 +232,9 @@ and ``end=(e_1, e_2, ... e_n)`` indices will result in an array with the shape The resulting array's *k*-th dimension contains elements from the *k*-th dimension of the input array with the open range ``[b_k, e_k)``. +For an input array of non-default storage type(e.g. `csr` or `row_sparse`), it only supports +slicing on the first dimension. + Example:: x = [[ 1., 2., 3., 4.], @@ -245,8 +248,10 @@ Example:: .set_attr_parser(ParamParser) .set_attr("FInferShape", SliceShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_slice"}) .set_attr("FCompute", Slice) +.set_attr(FCOMP_EX_CPU, SliceEx) .add_argument("data", "NDArray-or-Symbol", "Source input") .add_arguments(SliceParam::__FIELDS__()); @@ -370,7 +375,13 @@ NNVM_REGISTER_OP(dot) }) .set_attr("FInferShape", DotShape) .set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferStorageType", DotForwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_dot"}) .add_argument("lhs", "NDArray-or-Symbol", "The first input") .add_argument("rhs", "NDArray-or-Symbol", "The second input") @@ -381,7 +392,13 @@ NNVM_REGISTER_OP(_backward_dot) .set_num_outputs(2) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", DotBackwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx) .add_arguments(DotParam::__FIELDS__()); NNVM_REGISTER_OP(batch_dot) diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index 96c075a7d483..2e1effb9e560 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -40,10 +40,13 @@ NNVM_REGISTER_OP(_backward_slice_axis) .set_attr("FCompute", SliceAxisGrad_); NNVM_REGISTER_OP(dot) -.set_attr("FCompute", DotForward_); +.set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx); NNVM_REGISTER_OP(_backward_dot) -.set_attr("FCompute", DotBackward_); +.set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx); + NNVM_REGISTER_OP(batch_dot) .set_attr("FCompute", BatchDotForward_); diff --git a/tests/ci_build/install/ubuntu_install_python.sh b/tests/ci_build/install/ubuntu_install_python.sh index 0459bb9198c4..6ac615c7ee7f 100755 --- a/tests/ci_build/install/ubuntu_install_python.sh +++ b/tests/ci_build/install/ubuntu_install_python.sh @@ -6,5 +6,5 @@ apt-get update && apt-get install -y python-dev python3-dev # the version of the pip shipped with ubuntu may be too lower, install a recent version here cd /tmp && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python2 get-pip.py -pip2 install nose pylint numpy nose-timer requests -pip3 install nose pylint numpy nose-timer requests +pip2 install nose pylint numpy nose-timer requests scipy +pip3 install nose pylint numpy nose-timer requests scipy diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc index 73dc53060b63..509f50bdef51 100644 --- a/tests/cpp/engine/threaded_engine_test.cc +++ b/tests/cpp/engine/threaded_engine_test.cc @@ -100,7 +100,7 @@ double EvaluateWorloads(const std::vector& workloads, return dmlc::GetTime() - t; } -TEST(Engine, RandSumExpr) { +/*TEST(Engine, RandSumExpr) { std::vector workloads; int num_repeat = 5; const int num_engine = 4; @@ -134,11 +134,11 @@ TEST(Engine, RandSumExpr) { LOG(INFO) << "NaiveEngine\t\t" << t[1] << " sec"; LOG(INFO) << "ThreadedEnginePooled\t" << t[2] << " sec"; LOG(INFO) << "ThreadedEnginePerDevice\t" << t[3] << " sec"; -} +}*/ void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); } -TEST(Engine, basics) { +/*TEST(Engine, basics) { auto&& engine = mxnet::Engine::Get(); auto&& var = engine->NewVariable(); std::vector oprs; @@ -235,4 +235,4 @@ TEST(Engine, basics) { var = nullptr; oprs.clear(); LOG(INFO) << "All pass"; -} +}*/ diff --git a/tests/cpp/ndarray_test.cc b/tests/cpp/ndarray_test.cc new file mode 100644 index 000000000000..f14eb6d51033 --- /dev/null +++ b/tests/cpp/ndarray_test.cc @@ -0,0 +1,245 @@ +#include +/* +#include +#include +#include +#include + +#include +#include +#include "../src/executor/graph_executor.h" +#include "../src/operator/tensor/elemwise_binary_op.h" +#include "../src/operator/tensor/elemwise_unary_op.h" +#include "../src/operator/tensor/indexing_op.h" +#include "../src/operator/optimizer_op-inl.h" +#include "../src/operator/tensor/init_op.h" +#include "test_utils.h" + +using namespace mxnet; +// Conversion Tests +void CastDnsDnsTest() { + Context ctx; + TShape shape({2, 2}); + NDArray nd = DnsND(shape, ctx, {}); + auto nd_copy = Convert(kDefaultStorage, nd); + CheckDataRegion(nd_copy.data(), nd.data()); +} + +void CastRspDnsTest() { + Context ctx; + // Sparse ndarray + TShape shape({2, 2}); + float v1 = RandFloat(); + float v2 = RandFloat(); + NDArray nd = RspND(shape, ctx, {0}, {v1, v2}); + // Dense ndarray + NDArray dense_nd = DnsND(shape, ctx, {v1, v2, 0, 0}); + NDArray converted = Convert(kDefaultStorage, nd); + CheckDataRegion(converted.data(), dense_nd.data()); +} + +// NDArray function tests +void SetValueTest() { + Context ctx = Context::CPU(); + TShape data_shape({2, 2}); + float v = RandFloat(); + NDArray nd0 = DnsND(data_shape, ctx, {v, v, v, v}); + NDArray nd1(data_shape, ctx, false); + nd1 = v; + nd1.WaitToRead(); + CheckDataRegion(nd0.data(), nd1.data()); +} + +// InferStorage +void InferElemwiseStorageTest() { + nnvm::NodeAttrs attrs; + attrs.name = "test_op"; + std::vector in_attrs({kRowSparseStorage, kDefaultStorage}); + std::vector out_attrs({kUndefinedStorage}); + // rsp, default -> default + op::ElemwiseStorageType<2, 1>(attrs, &in_attrs, &out_attrs); + EXPECT_EQ(out_attrs[0], kDefaultStorage); + // default, rsp -> default + in_attrs = {kDefaultStorage, kRowSparseStorage}; + out_attrs = {kUndefinedStorage}; + op::ElemwiseStorageType<2, 1>(attrs, &in_attrs, &out_attrs); + EXPECT_EQ(out_attrs[0], kDefaultStorage); + // rsp, rsp -> rsp + in_attrs = {kRowSparseStorage}; + out_attrs = {kUndefinedStorage, kUndefinedStorage}; + op::ElemwiseStorageType<1, 2>(attrs, &in_attrs, &out_attrs); + EXPECT_EQ(out_attrs[0], kRowSparseStorage); + EXPECT_EQ(out_attrs[1], kRowSparseStorage); +} + +// Optimizer +void SGDDnsRspTest() { + TShape shape({4, 2}); + Context ctx = Context::CPU(); + NDArray weight = DnsND(shape, ctx, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray rsp_grad = RspND(shape, ctx, {0, 3}, {1, 2, 3, 4}); + NDArray output = weight; + float lr = RandFloat(); + float wd = RandFloat(); + float rescale = RandFloat(); + op::SGDParam param; + param.lr = lr; + param.wd = wd; + param.rescale_grad = rescale; + param.clip_gradient = -1.0f; + Engine::Get()->PushSync([weight, rsp_grad, output, param](RunContext ctx) { + std::vector inputs{weight, rsp_grad}, outputs{output}; + std::vector req({kAddTo}); + op::SparseSGDUpdateDnsRspImpl(param, {}, inputs, req, outputs); + }, weight.ctx(), {rsp_grad.var()}, {output.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + auto sgd = [lr, wd, rescale] (TEST_DTYPE weight, TEST_DTYPE grad) { + return (1.f-lr*wd)*weight - (lr*rescale)*grad; + }; + + NDArray expected = DnsND(shape, ctx, + {1 + sgd(1, 1), 2 + sgd(2, 2), 3, 4, 5, 6, + 7 + sgd(7, 3), 8 + sgd(8, 4)}); + output.WaitToRead(); + CheckDataRegion(output.data(), expected.data()); +} + +void CopyFromToRspDnsTest() { + Context ctx; + // Sparse ndarray + TShape shape({2, 2}); + NDArray nd = RspND(shape, ctx, {0}, {1, 1}); + // Dense ndarray + NDArray dns_nd = DnsND(shape, ctx, {}); + CopyFromTo(nd, &dns_nd); + dns_nd.WaitToRead(); + CheckDataRegion(nd.data(), dns_nd.data()); +} + +void CopyFromToRspRspReuseTest() { + Context ctx; + // Sparse ndarray + TShape shape({3, 2}); + NDArray nd = RspND(shape, ctx, {0}, {1,2}); + // Sparse ndarray with enough memory. It's expected to reuse the memory + NDArray dst_nd = RspND(shape, ctx, {0, 1, 2}, {6,6,6,6,6,6}); + nd.WaitToRead(); + CopyFromTo(nd, &dst_nd); + dst_nd.WaitToRead(); + CheckDataRegion(nd.data(), dst_nd.data()); + CHECK_EQ(dst_nd.aux_shape(rowsparse::kIdx)[0], 1); + CHECK_EQ(dst_nd.storage_shape()[0], 1); + CHECK_EQ(dst_nd.storage_shape()[1], 2); +} + + +void CopyFromToRspRspFreeTest() { + Context ctx; + // Sparse ndarray + TShape shape({3, 2}); + NDArray nd = RspND(shape, ctx, {0, 1}, {1,1,1,1}); + // Sparse ndarray with enough memory. It's expected to reuse the memory + NDArray dst_nd = RspND(shape, ctx, {0}, {2,2}); + nd.WaitToRead(); + CopyFromTo(nd, &dst_nd); + dst_nd.WaitToRead(); + CheckDataRegion(nd.data(), dst_nd.data()); +} + +void BinaryAddRspRsp() { + Context ctx = Context::CPU(); + + TShape output_shape({4, 2}); + NDArray input_nd0 = RspND(output_shape, ctx, {0, 1}, {10,10,10,10}); + NDArray input_nd1 = RspND(output_shape, ctx, {0, 2}, {5,5,5,5}); + + NDArray output(kRowSparseStorage, output_shape, ctx); + std::vector const_vars; + const_vars.push_back(input_nd0.var()); + const_vars.push_back(input_nd1.var()); + + Engine::Get()->PushSync([input_nd0, input_nd1, output](RunContext ctx) { + OpContext op_ctx; + std::vector inputs, outputs; + std::vector req; + inputs.push_back(input_nd0); + inputs.push_back(input_nd1); + outputs.push_back(output); + op::BinaryComputeRspRsp({}, op_ctx, inputs, req, outputs); + }, input_nd0.ctx(), const_vars, {output.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + + // Check the data region of output ndarray + NDArray dense_output = DnsND(output_shape, ctx, {15, 15, 10, 10, 5, 5, 0, 0}); + NDArray copy = Convert(kDefaultStorage, output); + CheckDataRegion(dense_output.data(), copy.data()); +} + +void SparseEmbeddingBackwardTest() { + Context ctx = Context::CPU(); + // d1 .. dk + // idx shape : (2, 3) + // input dim 4, output dim 2 + int input_dim = 4; + int output_dim = 2; + TShape idx_shape({2, 3}); + NDArray idx = RspIdxND(idx_shape, ctx, {1, 2, 3, 1, 2, 3}); + TShape grad_shape({2, 3, 2}); + NDArray grad = DnsND(grad_shape, ctx, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2}); + TShape out_shape({4, 2}); + NDArray output = NDArray(kRowSparseStorage, out_shape, ctx); + op::EmbeddingParam param; + param.input_dim = input_dim; + param.output_dim = output_dim; + param.dtype = 0; + + Engine::Get()->PushSync([idx, grad, output, param](RunContext ctx) { + std::vector inputs{grad, idx}, outputs{output, output}; + // this is a hack + std::vector req({kNullOp, kAddTo}); + op::SparseEmbeddingOpBackwardEx({}, {}, inputs, req, outputs); + }, output.ctx(), {grad.var(), idx.var()}, {output.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + + NDArray expected = DnsND(out_shape, ctx, {0,0,0,0,0,0,0,0}); + Engine::Get()->PushSync([idx, grad, expected, param](RunContext ctx) { + std::vector inputs{grad.data(), idx.data()}, outputs{expected.data(), expected.data()}; + std::vector req({kNullOp, kWriteTo}); + op::EmbeddingOpBackward({}, {}, inputs, req, outputs); + }, expected.ctx(), {grad.var(), idx.var()}, {expected.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + NDArray converted = Convert(kDefaultStorage, output); + expected.WaitToRead(); + CheckDataRegion(converted.data(), expected.data()); +} + +TEST(NDArray, binary_add) { + BinaryAddRspRsp(); +} + +TEST(NDArray, conversion) { + CastDnsDnsTest(); + CastRspDnsTest(); +} + +TEST(NDArray, functions) { + SetValueTest(); +} + +TEST(NDArray, optimizer) { + SGDDnsRspTest(); +} + +TEST(NDArray, copy) { + CopyFromToRspDnsTest(); + CopyFromToRspRspReuseTest(); + CopyFromToRspRspFreeTest(); +} + +TEST(NDArray, infer_storage) { + InferElemwiseStorageTest(); +} + +TEST(NDArray, sparse_embedding) { + SparseEmbeddingBackwardTest(); +}*/ diff --git a/tests/cpp/test_utils.h b/tests/cpp/test_utils.h new file mode 100644 index 000000000000..c528539a2cb7 --- /dev/null +++ b/tests/cpp/test_utils.h @@ -0,0 +1,105 @@ +#include +#include +#include +#include +#include +#include +#include +#include +/* +#include "../src/operator/tensor/elemwise_binary_op.h" +#include "../src/operator/tensor/elemwise_unary_op.h" +#include "../src/operator/optimizer_op-inl.h" +#include "../src/operator/tensor/init_op.h" + +using namespace mxnet; +#define TEST_DTYPE float +#define TEST_ITYPE int32_t + +void CheckDataRegion(const TBlob &src, const TBlob &dst) { + auto size = src.shape_.Size() * mshadow::mshadow_sizeof(src.type_flag_); + auto equals = memcmp(src.dptr_, dst.dptr_, size); + EXPECT_EQ(equals, 0); +} + +float RandFloat() { + float v = rand() * 1.0 / RAND_MAX; + return v; +} + +// Get an NDArray with provided indices, prepared for a RowSparse NDArray. +NDArray RspIdxND(const TShape shape, const Context ctx, const std::vector &values) { + NDArray nd(shape, ctx, false, ROW_SPARSE_IDX_TYPE); + size_t num_val = values.size(); + MSHADOW_TYPE_SWITCH(nd.dtype(), DType, { + auto tensor = nd.data().FlatTo1D(); + for (size_t i = 0; i < num_val; i++) { + tensor[i] = values[i]; + } + }); + return nd; +} + +// Get a dense NDArray with provided values. +NDArray DnsND(const TShape shape, const Context ctx, std::vector vs) { + NDArray nd(shape, ctx, false); + size_t num_val = shape.Size(); + // generate random values + while (vs.size() < num_val) { + auto v = RandFloat(); + vs.push_back(v); + } + CHECK_EQ(vs.size(), nd.shape().Size()); + MSHADOW_TYPE_SWITCH(nd.dtype(), DType, { + auto tensor = nd.data().FlatTo1D(); + for (size_t i = 0; i < num_val; i++) { + tensor[i] = vs[i]; + } + }); + return nd; +} + +// Get a RowSparse NDArray with provided indices and values +NDArray RspND(const TShape shape, const Context ctx, const std::vector idx, + std::vector vals) { + CHECK(shape.ndim() <= 2) << "High dimensional row sparse not implemented yet"; + index_t num_rows = idx.size(); + index_t num_cols = vals.size() / idx.size(); + // create index NDArray + NDArray index = RspIdxND(mshadow::Shape1(num_rows), ctx, idx); + CHECK_EQ(vals.size() % idx.size(), 0); + // create value NDArray + NDArray data = DnsND(mshadow::Shape2(num_rows, num_cols), ctx, vals); + // create result nd + NDArray nd(kRowSparseStorage, shape, ctx, false, mshadow::default_type_flag, + {}, {mshadow::Shape1(num_rows)}); + // assign values + NDArray nd_aux = nd.aux_ndarray(0); + NDArray nd_data = nd.data_ndarray(); + CopyFromTo(index, &nd_aux); + CopyFromTo(data, &nd_data); + return nd; +} + +// TODO(haibin) support other types +NDArray Convert(NDArrayStorageType type, NDArray src) { + CHECK_EQ(type, kDefaultStorage); + NDArray converted(src.shape(), src.ctx(), false); + Engine::Get()->PushSync([src, converted](RunContext ctx) { + // TODO provide type in attrs, which is empty now + OpContext op_ctx; + op_ctx.run_ctx = ctx; + if (src.storage_type() == kRowSparseStorage) { + std::vector inputs({src}), outputs({converted}); + op::CastStorageComputeEx({}, op_ctx, inputs, {}, outputs); + } else if (src.storage_type() == kDefaultStorage) { + std::vector inputs({src.data()}), outputs({converted.data()}); + op::IdentityCompute({}, op_ctx, inputs, {kWriteTo}, outputs); + } else { + LOG(FATAL) << "unsupported storage type"; + } + }, src.ctx(), {src.var()}, {converted.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + converted.WaitToRead(); + return converted; +}*/ diff --git a/tests/cpp/unittest.mk b/tests/cpp/unittest.mk index 808b655e9dba..ec7bb55ec983 100644 --- a/tests/cpp/unittest.mk +++ b/tests/cpp/unittest.mk @@ -47,4 +47,4 @@ testclean: -include build/tests/cpp/*.d -include build/tests/cpp/operator/*.d -include build/tests/cpp/storage/*.d --include build/tests/cpp/engine/*.d \ No newline at end of file +-include build/tests/cpp/engine/*.d diff --git a/tests/python/unittest/test_executor.py b/tests/python/unittest/test_executor.py index b190b2898843..c1cc013b81c0 100644 --- a/tests/python/unittest/test_executor.py +++ b/tests/python/unittest/test_executor.py @@ -121,7 +121,7 @@ def test_reshape(): x = mx.sym.Variable('x') y = mx.sym.FullyConnected(x, num_hidden=4) - exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[]) + exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null') exe.arg_arrays[0][:] = 1 exe.arg_arrays[1][:] = mx.nd.ones((4,4)) exe.arg_arrays[2][:] = 0 diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 35598bc55be8..6412aad50866 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -112,6 +112,37 @@ def test_incomplete_infer_concat(): assert arg_shapes['b'] == (2, 5) assert arg_shapes['d'] == (2, 15) +def test_fc_infer_type(): + mx_real_t = mx.base.mx_real_t + data = mx.symbol.Variable('data') + out = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=1000) + + # infer type + data_type = mx_real_t + arg_types, out_types, aux_types = out.infer_type(data=data_type) + arg_type_dict = dict(zip(out.list_arguments(), arg_types)) + assert len(out_types) == 1 + assert out_types[0] == mx_real_t + true_types = { + 'fc1_bias' : mx_real_t, + 'fc1_weight' : mx_real_t } + for k, v in true_types.items(): + assert arg_type_dict[k] == v + +def check_infer_storage(v1, v2, v1_storage, v2_storage, out_chunk): + out = mx.symbol.elemwise_add(v1, v2) + arg_storage_types, out_storage_types, aux_storage_types = out.infer_storage_type(v1=v1_storage, v2=v2_storage) + assert len(out_storage_types) == 1 + assert out_storage_types[0] == out_chunk + +def test_elemwise_add_infer_storage_type(): + v1 = mx.symbol.Variable('v1') + v2 = mx.symbol.Variable('v2') + check_infer_storage(v1, v2, 'default_storage', 'default_storage', 'default_storage') + check_infer_storage(v1, v2, 'default_storage', 'row_sparse', 'default_storage') + check_infer_storage(v1, v2, 'row_sparse', 'default_storage', 'default_storage') + check_infer_storage(v1, v2, 'row_sparse', 'row_sparse', 'row_sparse') + if __name__ == "__main__": test_mlp2_infer_shape() test_mlp2_infer_error() @@ -121,3 +152,4 @@ def test_incomplete_infer_concat(): test_incomplete_infer_slicechannel() test_incomplete_infer_convolution() test_incomplete_infer_concat() + test_elemwise_add_infer_storage_type() diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 5508a37c9567..608cdabe4677 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -1,7 +1,10 @@ import mxnet as mx import mxnet.ndarray as nd +from mxnet.test_utils import * import numpy as np from functools import reduce +import numpy.random as rnd +import scipy def test_module_dtype(): dtype = np.float16 @@ -101,6 +104,7 @@ def dict_equ(a, b): dict_equ(mod.get_params()[0], mod2.get_params()[0]) dict_equ(mod._kvstore._updater.states, mod2._updater.states) + def test_module_reshape(): data = mx.sym.Variable('data') sym = mx.sym.FullyConnected(data, num_hidden=20, name='fc') @@ -254,6 +258,70 @@ def mean_abs(x): break assert(mon_result_counts == [2, 2, 1, 6, 6, 4]) +def test_fm_module(): + def fm_model(k, feature_dim, storage_type='default_storage'): + initializer = mx.initializer.Normal(sigma=0.01) + x = mx.symbol.Variable("data", storage_type=storage_type) + v = mx.symbol.Variable("v", shape=(feature_dim, k), init=initializer) + + w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=initializer) + w1 = mx.symbol.dot(x, w1_weight) + + v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1) + x_s = mx.symbol.square(data=x) + bd = 0.5 * mx.symbol.negative(data=mx.symbol.broadcast_mul(x_s, v_s)) + + w2 = mx.symbol.dot(x, v) + w2_squared = 0.5 * mx.symbol.square(data=w2) + + w_all = mx.symbol.Concat(w1, w2_squared, bd, dim=1) + model = mx.symbol.sum(data=w_all, axis=1, keepdims=True) + y = mx.symbol.Variable("out_label") + model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") + return model + + ctx = default_context() + k = 5 + feature_dim = 20 + model = fm_model(k, feature_dim, 'csr') + + num_batches = 8 + batch_size = 25 + scipy_data = scipy.sparse.rand(num_batches * batch_size, feature_dim, + density=0.5, format='csr') + dns_label = mx.nd.ones((num_batches * batch_size,1)) + csr_data = mx.sparse_nd.csr(scipy_data.data, scipy_data.indptr, scipy_data.indices, + (num_batches * batch_size, feature_dim)) + data = csr_data + + train_iter = mx.io.NDArrayIter(data=data, + label={'out_label':dns_label}, + batch_size=batch_size) + + # create module + mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['out_label']) + # allocate memory by given the input data and lable shapes + mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) + # initialize parameters by uniform random numbers + mod.init_params(initializer=mx.init.Uniform(scale=.1)) + # use Sparse SGD with learning rate 0.1 to train + mod.init_optimizer(optimizer='sgd') + # use accuracy as the metric + metric = mx.metric.create('MSE') + # train 5 epoch, i.e. going over the data iter one pass + # TODO(haibin) test with row_sparse instead + storage_type_dict = {'v' : 'default_storage'} + + for epoch in range(10): + train_iter.reset() + metric.reset() + for batch in train_iter: + mod.forward(batch, is_train=True) # compute predictions + mod.update_metric(metric, batch.label) # accumulate prediction accuracy + mod.backward() # compute gradients + mod.update(storage_type_dict) # update parameters + print('Epoch %d, Training %s' % (epoch, metric.get())) + if __name__ == '__main__': test_module_dtype() test_module_input_grads() @@ -263,3 +331,4 @@ def mean_abs(x): test_module_layout() test_module_switch_bucket() test_monitor() + test_fm_module() diff --git a/tests/python/unittest/test_multi_device_exec.py b/tests/python/unittest/test_multi_device_exec.py index 8956c4edebac..37809bf8a3bc 100644 --- a/tests/python/unittest/test_multi_device_exec.py +++ b/tests/python/unittest/test_multi_device_exec.py @@ -1,4 +1,5 @@ import os +import numpy as np import mxnet as mx def test_ctx_group(): @@ -32,5 +33,35 @@ def test_ctx_group(): else: assert arr.context == group2ctx['stage2'] +def check_ctx_group_sparse(lhs_stype, rhs_stype): + with mx.AttrScope(ctx_group='stage1'): + lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype) + rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + plus = mx.symbol.elemwise_add(lhs, rhs, name='plus') + + set_stage1 = set(plus.list_arguments()) + with mx.AttrScope(ctx_group='stage2'): + softmax = mx.symbol.SoftmaxOutput(data = plus, name = 'softmax') + + set_stage2 = set(softmax.list_arguments()) - set_stage1 + + group2ctx = { + 'stage1' : mx.cpu(1), + 'stage2' : mx.cpu(2) + } + texec = softmax.simple_bind(mx.cpu(0), group2ctx=group2ctx, lhs=(1,200), rhs=(1,200)) + + for arr, name in zip(texec.arg_arrays, softmax.list_arguments()): + if name in set_stage1: + assert arr.context == group2ctx['stage1'] + else: + assert arr.context == group2ctx['stage2'] + +def test_ctx_group_sparse(): + check_ctx_group_sparse('default_storage', 'default_storage') + check_ctx_group_sparse('default_storage', 'row_sparse') + check_ctx_group_sparse('row_sparse', 'row_sparse') + if __name__ == '__main__': test_ctx_group() + test_ctx_group_sparse() diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 7f0a1d2b6301..8d4f4540d0c2 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -321,6 +321,7 @@ def test_dot(): assert_almost_equal(c, C.asnumpy()) + def test_reduce(): sample_num = 200 def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 82c20cdb17df..ced41d62938b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2955,7 +2955,6 @@ def test_where_numeric_gradient(shape, same_shape): test_where_numeric_gradient((5, 7, 9), True) test_where_numeric_gradient((5, 7, 9), False) - def test_new_softmax(): for ndim in range(1, 5): for _ in range(5): diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 11ca7bed1743..ad0793405959 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -30,12 +30,23 @@ def test_lr_wd_mult(): assert not mx.test_utils.almost_equal(args1['fc2_weight'], args2['fc2_weight'], 1e-1) -def compare_optimizer(opt1, opt2, shape): - w1 = mx.random.uniform(shape=shape, ctx=default_context()) - g1 = mx.random.uniform(shape=shape, ctx=default_context()) - - w2 = w1.copyto(default_context()) - g2 = g1.copyto(default_context()) +def compare_optimizer(opt1, opt2, shape, w_stype='default_storage', g_stype='default_storage'): + if w_stype == 'default_storage': + w2 = mx.random.uniform(shape=shape, ctx=default_context()) + w1 = w2.copyto(default_context()) + elif w_stype == 'row_sparse': + w2 = rand_ndarray(shape, w_stype) + w1 = rand_ndarray(shape, w_stype).to_dense() + else: + raise Exception("type not supported yet") + if g_stype == 'default_storage': + g2 = mx.random.uniform(shape=shape, ctx=default_context()) + g1 = g2.copyto(default_context()) + elif g_stype == 'row_sparse': + g2 = rand_ndarray(shape, g_stype) + g1 = g2.copyto(default_context()).to_dense() + else: + raise Exception("type not supported yet") state1 = opt1.create_state(0, w1) state2 = opt2.create_state(0, w2) @@ -130,6 +141,97 @@ def test_sgd(): for kwarg in kwargs: compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape) +class PySparseSGD(mx.optimizer.Optimizer): + """python reference implemenation of sgd""" + def __init__(self, learning_rate=0.01, momentum=0.0, **kwargs): + super(PySparseSGD, self).__init__(learning_rate=learning_rate, **kwargs) + self.momentum = momentum + + def create_state(self, index, weight): + """Create additional optimizer state: momentum + + Parameters + ---------- + weight : NDArray + The weight data + + """ + if self.momentum == 0.0: + return None + else: + return mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) + + def update(self, index, weight, grad, state): + """Update the parameters. + + Parameters + ---------- + index : int + An unique integer key used to index the parameters + + weight : NDArray + weight ndarray + + grad : NDArray + grad ndarray + + state : NDArray or other objects returned by init_state + The auxiliary state used in optimization. + """ + lr = self._get_lr(index) + wd = self._get_wd(index) + self._update_count(index) + num_rows = weight.shape[0] + if self.momentum == 0.0: + # Update on a per row basis, skip all-zero rows + for row in range(num_rows): + grad_row = grad[row].asnumpy() + all_zeros = mx.test_utils.almost_equal(grad_row, np.zeros_like(grad_row)) + if all_zeros: + continue + if self.clip_gradient is not None: + weight[row] = ((1 - lr*wd)*weight[row] - + lr*mx.nd.clip(grad[row]*self.rescale_grad, + -self.clip_gradient, self.clip_gradient)) + else: + weight[row] = (1 - lr*wd)*weight[row] - lr*self.rescale_grad*grad[row] + else: + mom = state + for row in range(num_rows): + grad_row = grad[row].asnumpy() + all_zeros = mx.test_utils.almost_equal(grad_row, np.zeros_like(grad_row)) + if all_zeros: + continue + if self.clip_gradient is not None: + mom[row] = (self.momentum*mom[row] - lr*wd*weight[row] - + lr*mx.nd.clip(grad[row]*self.rescale_grad, -self.clip_gradient, self.clip_gradient)) + weight[row] += mom[row] + else: + mom[row] = self.momentum*mom[row] - lr*wd*weight[row] - lr*self.rescale_grad*grad[row] + weight[row] += mom[row] + +def test_sparse_sgd(): + mx.random.seed(0) + opt1 = PySparseSGD + opt2 = mx.optimizer.SGD + shape = (3, 4) + kwargs = [{}, + {'momentum': 0.9}, + {'clip_gradient': 0.5}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14}, + {'rescale_grad': 0.8}, + {'clip_gradient': 0.5, 'wd': 0.07}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03}, + {'rescale_grad': 0.8, 'wd': 0.05}, + {'clip_gradient': 0.5, 'momentum': 0.9}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'momentum': 0.9}, + {'rescale_grad': 0.8, 'momentum': 0.9}, + {'clip_gradient': 0.5, 'wd': 0.07, 'momentum': 0.9}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03, 'momentum': 0.9}, + {'rescale_grad': 0.8, 'wd': 0.05, 'momentum': 0.9}] + for kwarg in kwargs: + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, w_stype='default_storage', g_stype='row_sparse') + # ADAM class PyAdam(mx.optimizer.Optimizer): @@ -354,3 +456,4 @@ def test_rms(): test_adam() test_rms() test_sgd() + test_sparse_sgd() diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py new file mode 100644 index 000000000000..224a5e008b3b --- /dev/null +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -0,0 +1,273 @@ +import os +import mxnet as mx +import numpy as np +import pickle as pkl +from mxnet.test_utils import * +from numpy.testing import assert_allclose +import numpy.random as rnd + +def assert_fcompex(f, *args, **kwargs): + prev_val = mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", "0", "1") + f(*args, **kwargs) + mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", prev_val) + +def rand_shape_2d(): + return (rnd.randint(1, 10), rnd.randint(1, 10)) + +def sparse_nd_ones(shape, stype): + return mx.nd.cast_storage(mx.nd.ones(shape), storage_type=stype) + +def check_sparse_nd_elemwise_binary(shapes, storage_types, f, g): + # generate inputs + nds = [] + for i, storage_type in enumerate(storage_types): + if storage_type == 'row_sparse': + nd, _ = rand_sparse_ndarray(shapes[i], storage_type) + elif storage_type == 'default_storage': + nd = mx.nd.array(random_arrays(shapes[i]), dtype = np.float32) + else: + assert(False) + nds.append(nd) + # check result + test = f(nds[0], nds[1]) + assert_almost_equal(test.asnumpy(), g(nds[0].asnumpy(), nds[1].asnumpy())) + +def test_sparse_nd_elemwise_add(): + num_repeats = 10 + g = lambda x,y: x + y + op = mx.nd.elemwise_add + for i in range(num_repeats): + shape = [rand_shape_2d()] * 2 + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['default_storage'] * 2, op, g) + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['default_storage', 'row_sparse'], op, g) + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['row_sparse', 'row_sparse'], op, g) + +# Test a operator which doesn't implement FComputeEx +def test_sparse_nd_elementwise_fallback(): + num_repeats = 10 + g = lambda x,y: x + y + op = mx.nd.add_n + for i in range(num_repeats): + shape = [rand_shape_2d()] * 2 + check_sparse_nd_elemwise_binary(shape, ['default_storage'] * 2, op, g) + check_sparse_nd_elemwise_binary(shape, ['default_storage', 'row_sparse'], op, g) + check_sparse_nd_elemwise_binary(shape, ['row_sparse', 'row_sparse'], op, g) + +def test_sparse_nd_zeros(): + def check_sparse_nd_zeros(stype, shape): + zero = mx.nd.zeros(shape) + sparse_zero = mx.sparse_nd.zeros('row_sparse', shape) + assert_almost_equal(sparse_zero.asnumpy(), zero.asnumpy()) + + shape = rand_shape_2d() + check_sparse_nd_zeros('row_sparse', shape) + check_sparse_nd_zeros('csr', shape) + + +def test_sparse_nd_copy(): + def check_sparse_nd_copy(from_stype, to_stype): + shape = rand_shape_2d() + from_nd = rand_ndarray(shape, from_stype) + # copy to ctx + to_ctx = from_nd.copyto(default_context()) + # copy to stype + to_nd = rand_ndarray(shape, to_stype) + to_nd = from_nd.copyto(to_nd) + assert np.sum(np.abs(from_nd.asnumpy() != to_ctx.asnumpy())) == 0.0 + assert np.sum(np.abs(from_nd.asnumpy() != to_nd.asnumpy())) == 0.0 + + check_sparse_nd_copy('row_sparse', 'row_sparse') + check_sparse_nd_copy('row_sparse', 'default_storage') + check_sparse_nd_copy('default_storage', 'row_sparse') + check_sparse_nd_copy('default_storage', 'csr') + +def check_sparse_nd_prop_rsp(): + storage_type = 'row_sparse' + shape = rand_shape_2d() + nd, (v, idx) = rand_sparse_ndarray(shape, storage_type) + assert(nd._num_aux == 1) + assert(nd._indices.dtype == np.int32) + assert(nd.storage_type == 'row_sparse') + assert_almost_equal(nd._indices.asnumpy(), idx) + +def test_sparse_nd_basic(): + def check_rsp_creation(values, indices, shape): + rsp = mx.sparse_nd.row_sparse(values, indices, shape) + dns = mx.nd.zeros(shape) + dns[1] = mx.nd.array(values[0]) + dns[3] = mx.nd.array(values[1]) + assert_almost_equal(rsp.asnumpy(), dns.asnumpy()) + indices = mx.nd.array(indices).asnumpy() + assert_almost_equal(rsp._indices.asnumpy(), indices) + + def check_csr_creation(shape): + csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') + assert_almost_equal(csr._indptr.asnumpy(), indptr) + assert_almost_equal(csr._indices.asnumpy(), indices) + assert_almost_equal(csr._values.asnumpy(), values) + + shape = (4,2) + values = np.random.rand(2,2) + indices = np.array([1,3]) + check_rsp_creation(values, indices, shape) + + values = mx.nd.array(np.random.rand(2,2)) + indices = mx.nd.array([1,3], dtype='int32') + check_rsp_creation(values, indices, shape) + + values = [[0.1, 0.2], [0.3, 0.4]] + indices = [1,3] + check_rsp_creation(values, indices, shape) + + check_csr_creation(shape) + check_sparse_nd_prop_rsp() + +def test_sparse_nd_setitem(): + def check_sparse_nd_setitem(storage_type, shape, dst): + x = mx.sparse_nd.zeros(storage_type, shape) + x[:] = dst + dst_nd = mx.nd.array(dst) if isinstance(dst, (np.ndarray, np.generic)) else dst + assert same(x.asnumpy(), dst_nd.asnumpy()) + + shape = rand_shape_2d() + for stype in ['row_sparse', 'csr']: + # ndarray assignment + check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, 'default_storage')) + check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, stype)) + # numpy assignment + check_sparse_nd_setitem(stype, shape, np.ones(shape)) + +def test_sparse_nd_slice(): + def check_sparse_nd_csr_slice(shape): + storage_type = 'csr' + A, _ = rand_sparse_ndarray(shape, storage_type) + A2 = A.asnumpy() + start = rnd.randint(0, shape[0] - 1) + end = rnd.randint(start + 1, shape[0]) + assert same(A[start:end].asnumpy(), A2[start:end]) + + shape = (rnd.randint(2, 10), rnd.randint(1, 10)) + check_sparse_nd_csr_slice(shape) + +def test_sparse_nd_equal(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = x == y + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 == x + assert (z.asnumpy() == np.ones(shape)).all() + +def test_sparse_nd_not_equal(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = x != y + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 != x + assert (z.asnumpy() == np.zeros(shape)).all() + +def test_sparse_nd_greater(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = x > y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y > 0 + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 > y + assert (z.asnumpy() == np.zeros(shape)).all() + +def test_sparse_nd_greater_equal(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = x >= y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y >= 0 + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 >= y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y >= 1 + assert (z.asnumpy() == np.ones(shape)).all() + +def test_sparse_nd_lesser(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = y < x + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 < y + assert (z.asnumpy() == np.ones(shape)).all() + z = y < 0 + assert (z.asnumpy() == np.zeros(shape)).all() + +def test_sparse_nd_lesser_equal(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = y <= x + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 <= y + assert (z.asnumpy() == np.ones(shape)).all() + z = y <= 0 + assert (z.asnumpy() == np.zeros(shape)).all() + z = 1 <= y + assert (z.asnumpy() == np.ones(shape)).all() + +def test_sparse_nd_binary(): + N = 100 + def check_binary(fn): + for _ in range(N): + ndim = 2 + oshape = np.random.randint(1, 6, size=(ndim,)) + bdim = 2 + lshape = list(oshape) + rshape = list(oshape[ndim-bdim:]) + for i in range(bdim): + sep = np.random.uniform(0, 1) + if sep < 0.33: + lshape[ndim-i-1] = 1 + elif sep < 0.66: + rshape[bdim-i-1] = 1 + lhs = np.random.normal(0, 1, size=lshape) + rhs = np.random.normal(0, 1, size=rshape) + lhs_nd = mx.nd.array(lhs).to_csr() + rhs_nd = mx.nd.array(rhs).to_csr() + assert_allclose(fn(lhs, rhs), + fn(lhs_nd, rhs_nd).asnumpy(), + rtol=1e-4, atol=1e-4) + + #check_binary(lambda x, y: x + y) + check_binary(lambda x, y: x - y) + check_binary(lambda x, y: x * y) + check_binary(lambda x, y: x / y) + check_binary(lambda x, y: x > y) + check_binary(lambda x, y: x < y) + check_binary(lambda x, y: x >= y) + check_binary(lambda x, y: x <= y) + check_binary(lambda x, y: x == y) + +def test_sparse_nd_negate(): + npy = np.random.uniform(-10, 10, rand_shape_2d()) + arr = mx.nd.array(npy).to_csr() + assert_almost_equal(npy, arr.asnumpy()) + assert_almost_equal(-npy, (-arr).asnumpy()) + + # a final check to make sure the negation (-) is not implemented + # as inplace operation, so the contents of arr does not change after + # we compute (-arr) + assert_almost_equal(npy, arr.asnumpy()) + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py new file mode 100644 index 000000000000..978737028c98 --- /dev/null +++ b/tests/python/unittest/test_sparse_operator.py @@ -0,0 +1,198 @@ +# pylint: skip-file +import numpy as np +import mxnet as mx +import scipy as sp +from numpy.testing import assert_allclose +from mxnet.test_utils import * + +def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_grad_stype=None): + lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype) + rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + if lhs_grad_stype is not None: + lhs._set_attr(grad_stype_hint=str(lhs_grad_stype)) + if rhs_grad_stype is not None: + rhs._set_attr(grad_stype_hint=str(rhs_grad_stype)) + + lhs_nd = rand_ndarray(shape, lhs_stype) + rhs_nd = rand_ndarray(shape, rhs_stype) + lhs_np = lhs_nd.asnumpy() + rhs_np = rhs_nd.asnumpy() + + out_np = lhs_np + rhs_np + test = mx.symbol.elemwise_add(lhs, rhs) + location = {'lhs': lhs_nd, 'rhs': rhs_nd} + check_symbolic_forward(test, location, [out_np]) + check_numeric_gradient(test, location) + check_symbolic_backward(test, location, [out_np], [out_np, out_np]) + + +def test_elemwise_add_ex(): + shape = (rnd.randint(1, 10), rnd.randint(1, 10)) + check_elemwise_add_ex('default_storage', 'default_storage', shape) + # TODO(haibin/jun) enable these tests when Dns -> Rsp (compact) is implemented. + #check_elemwise_add_ex('default_storage', 'row_sparse', shape) + #check_elemwise_add_ex('row_sparse', 'default_storage', shape) + #check_elemwise_add_ex('row_sparse', 'row_sparse', shape, + # lhs_grad_stype='row_sparse', rhs_grad_stype='row_sparse') + + +# TODO(haibin) randomize this test +def test_elemwise_add_ex_multiple_stages(): + # prep data + shape = (4, 2) + ds_np = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + sp_np1 = np.array([[5, 10], [0, 0], [0, 0], [0, 0]]) + sp_np2 = np.array([[0, 0], [5, 10], [0, 0], [0, 0]]) + + val1 = mx.nd.array([[5, 10]]); + val2 = mx.nd.array([[5, 10]]); + idx1 = mx.nd.array([0], dtype=np.int32); + idx2 = mx.nd.array([1], dtype=np.int32); + sp_nd1 = mx.sparse_nd.row_sparse(val1, idx1, shape) + sp_nd2 = mx.sparse_nd.row_sparse(val2, idx2, shape) + ds_nd = mx.nd.array(ds_np) + + # sparse + sparse = sparse + sp_data1 = mx.symbol.Variable('sp_data1', storage_type='row_sparse') + sp_data2 = mx.symbol.Variable('sp_data2', storage_type='row_sparse') + ds_data = mx.symbol.Variable('ds_data') + plus = mx.symbol.elemwise_add(sp_data1, sp_data2, name='plus') + # sparse + dense = dense + test = mx.symbol.elemwise_add(plus, ds_data) + check_symbolic_forward(test, {'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, [sp_np1 + sp_np2 + ds_np]) + + arr_grads = [mx.nd.zeros(shape) for i in range(3)] + exec_test = test.bind(default_context(), args={'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, args_grad=arr_grads) + exec_test.forward(is_train=True) + assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + ds_np) + exec_test.backward(out_grads=exec_test.outputs) + assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) + +# TODO(haibin) also add test for backward pass +def test_cast_storage_ex(): + def test_rsp_to_dns(shape): + rsp, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse') + dns_out = mx.nd.cast_storage(rsp, storage_type='default_storage') + dns_expected = np.zeros(shape, dtype=default_dtype()) + if row_idx is not None: + for k, v in enumerate(row_idx): + dns_expected[v, :] = data[k] + assert same(dns_out.asnumpy(), dns_expected) + + def test_dns_to_rsp(shape): + dns_in = rand_ndarray(shape, 'default_storage') + rsp_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), storage_type='row_sparse') + ret = mx.nd.cast_storage(rsp_out, storage_type='default_storage') + assert same(ret.asnumpy(), dns_in.asnumpy()) + + def test_csr_to_dns(shape): + csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') + mx_dns = csr.to_dense() + np_dns = sp.sparse.csr_matrix((values, indices, indptr), shape).todense() + assert_almost_equal(mx_dns.asnumpy(), np_dns) + + def test_dns_to_csr(dns_in): + dns_in = np.array(dns_in) + csr_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), storage_type='csr') + ret = mx.nd.cast_storage(csr_out, storage_type='default_storage') + assert same(ret.asnumpy(), dns_in) + + shape = (rnd.randint(1, 10), rnd.randint(1, 10)) + test_rsp_to_dns(shape) + test_dns_to_rsp(shape) + test_csr_to_dns((4, 4)) + test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]]) + + +# TODO(junwu): The backward of the operator dot cannot be tested for now +# since the backend function CopyFromTo does not support taking two arguments +# of the different storage types. Will add backward test after removing this +# restriction on CopyFromTo(@haibin). Nevertheless, both backward and forward use +# the same impl function of dot(csr, dns) = rsp and it has been tested +# in the forward test cases as the following. +def test_sparse_dot(): + def test_dot_csr_dns(csr_shape, dns_shape, trans_csr): + dns1 = rand_ndarray(csr_shape, 'default_storage') + dns2 = rand_ndarray(dns_shape, 'default_storage') + csr = mx.nd.cast_storage(dns1, storage_type='csr') + out = mx.nd.dot(csr, dns2, transpose_a=trans_csr) + assert out.storage_type == 'default_storage' + out_expected = mx.nd.dot(dns1, dns2, transpose_a=trans_csr) + out_np = out_expected.asnumpy() + backward_trans = not trans_csr + rhs_backward_grad = mx.nd.dot(dns1, out_expected, transpose_a=backward_trans).asnumpy() + assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) + + # test symbolic forward + lhs = mx.symbol.Variable('lhs', storage_type='csr') + rhs = mx.symbol.Variable('rhs', storage_type='default_storage') + # TODO(haibin) since backward op is not fully implemented, here we add a dense zero ndarray + # so that the output gradient is dense. + zeros = mx.symbol.Variable('zero', storage_type='default_storage') + + sym_dot = mx.symbol.dot(lhs, rhs, transpose_a=trans_csr) + test = mx.symbol.elemwise_add(sym_dot, zeros) + location = {'lhs': csr, 'rhs': dns2, 'zero': mx.nd.zeros(out_expected.shape)} + expected = {'rhs': rhs_backward_grad, 'zero': out_np} + # dot(lhs, rhs) + zeros + check_symbolic_forward(test, location, [out_expected.asnumpy()], rtol=1e-3, atol=1e-4) + check_symbolic_backward(test, location, [out_np], expected, + grad_req={'lhs': 'null', 'rhs': 'write', 'zero': 'write'}, + rtol=1e-3, atol=1e-4) + + lhs_shape = (rnd.randint(1, 10), rnd.randint(1, 10)) + test_dot_csr_dns(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), False) + test_dot_csr_dns(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), True) + + +def test_sparse_embedding(): + in_dim = 10 + out_dim = 4 + batch = 24 + + data = mx.sym.Variable("data", dtype=np.int32) + embed = mx.sym.SparseEmbedding(data=data, input_dim=in_dim, output_dim=out_dim, name="embed") + exe_test = embed.simple_bind(default_context(), grad_req={'data': 'null', 'embed_weight': 'write'}, + data=(batch,)) + arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays)) + grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays)) + np_data = np.random.randint(low=0, high=in_dim, size=batch) + np_weight = np.random.uniform(-0.01, 0.01, arg_map["embed_weight"].shape) + np_onehot = np.zeros((batch, in_dim)) + np_onehot[np.arange(batch), np_data] = 1.0 + # forward + arg_map["data"][:] = np_data + arg_map["embed_weight"][:] = np_weight + exe_test.forward(is_train=True) + assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, np_weight)) + # backward + np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape) + grad = mx.nd.zeros(np_grad.shape) + grad[:] = np_grad + exe_test.backward([grad]) + assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad), atol=1e-5) + +def test_sparse_slice(): + def check_csr_slice(shape, slice_input): + storage_type = 'csr' + A, _ = rand_sparse_ndarray(shape, storage_type) + B = A._slice(1, shape[0] - 1) if slice_input else A + np = B.asnumpy() + begin = rnd.randint(0, B.shape[0] - 1) + end = rnd.randint(begin + 1, B.shape[0]) + nd_slice = mx.nd.crop(B, begin=begin, end=end) + assert same(nd_slice.asnumpy(), np[begin:end]), (nd_slice.asnumpy(), np[begin:end]) + + shape = (rnd.randint(7, 15), rnd.randint(1, 10)) + check_csr_slice(shape, True) + check_csr_slice(shape, False) + +if __name__ == '__main__': + test_elemwise_add_ex() + test_elemwise_add_ex_multiple_stages() + test_cast_storage_ex() + test_sparse_dot() + test_sparse_embedding() + test_sparse_slice()