From cffb161fe1f399bda1fed427087ba5b57e803b77 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Mon, 4 Dec 2017 17:18:27 +0000 Subject: [PATCH 01/11] Add operator for dot(dns, csr) = csr --- benchmark/python/sparse/dot.py | 53 ++++-- src/operator/tensor/dot-inl.h | 163 ++++++++++++++++++ tests/python/unittest/test_sparse_operator.py | 27 +++ 3 files changed, 227 insertions(+), 16 deletions(-) diff --git a/benchmark/python/sparse/dot.py b/benchmark/python/sparse/dot.py index 164e50aef051..bf4e4feaa106 100644 --- a/benchmark/python/sparse/dot.py +++ b/benchmark/python/sparse/dot.py @@ -275,7 +275,10 @@ def bench_dot(lhs_shape, rhs_shape, lhs_stype, rhs_stype, # Create matrix instances lhs_nd = rand_ndarray(lhs_shape, lhs_stype, density=lhs_den, distribution=distribution) # only uniform distribution supported for rhs - rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform") + if rhs_stype == 'csr': + rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution=distribution) + else: + rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform") lhs_dns = None rhs_dns = None dense_cost = None @@ -337,27 +340,41 @@ def print_benchmark_info(lhs, rhs, lhs_trans, fw): def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", rhs_density=1, distribution="uniform"): - if lhs != "csr": - raise ValueError("Value other than csr for lhs not supported") + if rhs_density > 1 or rhs_density < 0: - raise ValueError("rhs_density has to be between 0 and 1") + raise ValueError("Value other than csr for lhs not supported") print_benchmark_info(lhs, rhs, lhs_trans, fw) + if rhs == "csr": + lhs_stype = "default" + rhs_stype = "csr" + assert (lhs_stype == 'default'), "Only dot(default, csr) supported" + # Arrange dimensions according to use case. For below csr will have num_rows << num_cols + feature_dim_list = data_dict['batch_size'] + batch_size_list = data_dict['m'] + output_dim_list = data_dict['feature_dim'] + density_list = data_dict['density'] + default_output_index = data_dict['default_index']['feature_dim'] + default_density_index = data_dict['default_index']['density'] + default_feature_index = data_dict['default_index']['batch_size'] + default_batch_size_index = data_dict['default_index']['output_dim'] + num_repeat = data_dict['num_repeat'] - lhs_stype = "csr" - rhs_stype = "row_sparse" if rhs == "rsp" else "default" + else: + lhs_stype = "csr" + rhs_stype = "row_sparse" if rhs == "rsp" else "default" - feature_dim_list = data_dict['feature_dim'] - output_dim_list = data_dict['m'] - batch_size_list = data_dict['batch_size'] - density_list = data_dict['density'] + feature_dim_list = data_dict['feature_dim'] + output_dim_list = data_dict['m'] + batch_size_list = data_dict['batch_size'] + density_list = data_dict['density'] - default_output_index = data_dict['default_index']['output_dim'] - default_batch_size_index = data_dict['default_index']['batch_size'] - default_feature_index = data_dict['default_index']['feature_dim'] - default_density_index = data_dict['default_index']['density'] - num_repeat = data_dict['num_repeat'] + default_output_index = data_dict['default_index']['output_dim'] + default_batch_size_index = data_dict['default_index']['batch_size'] + default_feature_index = data_dict['default_index']['feature_dim'] + default_density_index = data_dict['default_index']['density'] + num_repeat = data_dict['num_repeat'] for output_dim in output_dim_list: if lhs_trans: @@ -403,7 +420,7 @@ def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", r feature_dim_list[default_feature_index]), (output_row_dim, output_dim_list[default_output_index]), - lhs_stype, rhs_stype, density, rhs_density, lhs_trans, ctx, + lhs_stype, rhs_stype, density, density, lhs_trans, ctx, num_repeat=num_repeat, fw=fw, distribution=distribution) check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(ARGS.num_omp_threads))) @@ -423,6 +440,10 @@ def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", r rhs="rsp", lhs_trans=False, fw="mxnet", rhs_density=0.05, distribution=distribution) + run_benchmark(context, lhs="default", + rhs="csr", lhs_trans=False, + fw="mxnet", rhs_density=0.001, + distribution=distribution) if not ARGS.gpu: run_benchmark(context, lhs="csr", rhs="default", lhs_trans=False, diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 2432703291f9..47455bcac044 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -231,6 +231,12 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); } + if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && + !param.transpose_a && !param.transpose_b) { + // dns, csr -> csr + dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } if (!dispatched) { dispatch_fallback(out_attrs, dispatch_mode); LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs); @@ -527,6 +533,69 @@ struct DotCsrTransRspRspByRowBlocks { } }; +/*! + * \brief CPU Kernel of PopulateCsrForNNC + * Parallelization by individual rows + */ +struct PopulateCsrForNNC { + /*! + * \brief + * \param i the i-th thread + * \param nnc_idx all non zero column indexes + * \param indptr_out indptr array for output + * \param col_idx_out column indices for output + * \param nnc number of non zero columns in the output + * \param num_rows_l number of rows in lhs + */ + template + MSHADOW_CINLINE static void Map(int i, const CType* nnc_idx, + IType* indptr_out, CType* col_idx_out, + const nnvm::dim_t nnc, + const nnvm::dim_t num_rows_l) { + const CType start_idx = i * nnc; + nnvm::dim_t cur = 0; + indptr_out[i] = start_idx; + if (i == static_cast(num_rows_l - 1)) indptr_out[i + 1] = indptr_out[i] + nnc; + for (IType idx = start_idx; idx < (start_idx + nnc); idx++) { + col_idx_out[idx] = nnc_idx[cur++]; + } + } +}; + +/*! + * \brief CPU Impl of dot(dns, csr) = csr + */ +struct DotDnsCsrCsrByRowBlocks { + template + MSHADOW_CINLINE static void Map( + int i, DType* out, const DType* data_l, const IType* indptr_r, + const CType* col_idx_r, const DType* data_r, const nnvm::dim_t seg_len, + const IType num_rows_r, const IType num_rows_l, + const nnvm::dim_t num_cols, const nnvm::dim_t nnc, + const CType* prefix_sum) { + using nnvm::dim_t; + const dim_t seg_start = i * seg_len; + if (seg_start >= num_rows_l) return; + const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l); + + for (dim_t j = seg_start; j < seg_end; j++) { + for (dim_t k = 0; k < num_rows_r; k++) { + const dim_t working_idx = j * num_rows_r + k; + const DType val = data_l[working_idx]; + if (indptr_r[k] == indptr_r[k + 1]) continue; + const dim_t row_start = j * nnc; + for (dim_t cur = indptr_r[k]; cur < indptr_r[k + 1]; cur++) { + dim_t cur_col_idx_r = col_idx_r[cur]; + const dim_t out_idx = row_start + prefix_sum[cur_col_idx_r] - 1; + out[out_idx] += val * data_r[cur]; + } + } + } + } +}; + + + /*! * \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2 */ @@ -811,6 +880,95 @@ inline void DotCsrRspRspImpl(const OpContext& ctx, }); } +/* + * \brief CPU Impl of dot(dns, csr) = csr + */ +inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, + const TBlob& lhs, const NDArray& rhs, + const OpReqType req, NDArray* ret) { + if (kNullOp == req) return; + CHECK_EQ(rhs.storage_type(), kCSRStorage); + if (!rhs.storage_initialized()) return; + + using namespace mshadow; + using namespace mshadow::expr; + using nnvm::dim_t; + + /*Initialize data structures*/ + mshadow::Stream* s = ctx.get_stream(); + const NDArray& out = *ret; + const TBlob data_l = lhs; + const TBlob data_r = rhs.data(); + const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); + const TBlob col_idx_r = rhs.aux_data(csr::kIdx); + + MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type + /* Allocate workspace */ + CType num_cols_out = out.shape()[1]; + CType rhs_data_size = static_cast(col_idx_r.shape_.Size()); + size_t workspace_size = 2 * num_cols_out * sizeof(CType); + Tensor workspace = + ctx.requested[0].get_space_typed( + Shape1(workspace_size), s); + CType* col_flg = reinterpret_cast(workspace.dptr_); + + CType* prefix_sum = col_flg; + CType* nnc_idx = prefix_sum + num_cols_out; + + /* Set the column flags for nnz columns */ + mxnet_op::Kernel::Launch(s, num_cols_out, + col_flg); + mxnet_op::Kernel::Launch( + s, rhs_data_size, col_flg, col_idx_r.dptr()); + + /* 1. Calculate prefix sum from col flgs + * 2. Storage all non zero column indexes in nnc_idx + */ + CType cur = 0; + prefix_sum[0] = col_flg[0]; + if (prefix_sum[0]) nnc_idx[cur++] = 0; + for (CType i = 1; i < num_cols_out; i++) { + prefix_sum[i] = prefix_sum[i - 1] + col_flg[i]; + if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i; + } + + /* Allocate aux data for out */ + IType num_rows_l = lhs.shape_[0]; + dim_t nnc = prefix_sum[num_cols_out - 1]; + dim_t nnz = nnc * num_rows_l; + out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1)); + out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz)); + out.CheckAndAllocData(Shape1(nnz)); + + /* Set csr indptr and index according to nnc_idx*/ + IType* indptr_out = out.aux_data(csr::kIndPtr).dptr(); + CType* col_idx_out = out.aux_data(csr::kIdx).dptr(); + DType* data_out = out.data().dptr(); + mxnet_op::Kernel::Launch( + s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l); + mxnet_op::Kernel::Launch(s, nnz, data_out); + + if (nnc == 0) { + return; + } + + dim_t num_threads = mxnet_op::get_num_threads(num_rows_l); + dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads; + + IType num_rows_r = rhs.shape()[0]; + mxnet_op::Kernel::Launch( + s, num_threads, data_out, data_l.dptr(), + indptr_r.dptr(), col_idx_r.dptr(), + data_r.dptr(), seg_len, num_rows_r, num_rows_l, num_cols_out, + nnc, prefix_sum); + + }); + }); + }); +} + inline bool DotShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -886,6 +1044,11 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, && out_stype == kRowSparseStorage && !param.transpose_b) { NDArray ret = outputs[0]; DotCsrRspRspImpl(ctx, xpu(), inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && + out_stype == kCSRStorage && + !(param.transpose_a || param.transpose_b)) { + NDArray ret = outputs[0]; + DotDnsCsrCsrImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret); } else { LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs); } diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index a08b6187bc7d..6712323b4a7b 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1223,6 +1223,31 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_de grad_req={'lhs': 'null', 'rhs': 'write'}, rtol=1e-3, atol=1e-4) + def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=False, trans_rhs=False): + lhs_nd = rand_ndarray(lhs_shape, stype='default', density=lhs_density) + rhs_nd = rand_ndarray(rhs_shape, stype='csr', density=rhs_density) + rhs_dns = rhs_nd.tostype('default') + + out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs) + out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs) + out_np = out_dns.asnumpy() + assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) + + # test symbolic forward + lhs = mx.symbol.Variable('lhs', stype='default') + rhs = mx.symbol.Variable('rhs', stype='csr') + out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs) + location = {'lhs': lhs_nd, 'rhs': rhs_nd} + check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4) + + # test symbolic backward + backward_trans = not trans_lhs + rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy() + expected = {'rhs': rhs_backward_grad} + check_symbolic_backward(out, location, [out_np], expected, + grad_req={'lhs': 'null', 'rhs': 'write'}, + rtol=1e-3, atol=1e-4) + def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): """Test for nnr_out = 0. Before the fix, the test would fail.""" lhs = mx.nd.zeros(lhs_shape) @@ -1248,10 +1273,12 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True, lhs_d, rhs_d) # (vector kernel) test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d) # (scalar kernel) + test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(500, 1000)), lhs_d, lhs_d) for rhs_d in density: test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d) test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d) + test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40) test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40) From 73a9f3dd89b7deb9d6213fad366be2fe8e6abd29 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Mon, 4 Dec 2017 17:24:35 +0000 Subject: [PATCH 02/11] Fix whitespace --- src/operator/tensor/dot-inl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 47455bcac044..bdc05a681eb7 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -963,7 +963,6 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, indptr_r.dptr(), col_idx_r.dptr(), data_r.dptr(), seg_len, num_rows_r, num_rows_l, num_cols_out, nnc, prefix_sum); - }); }); }); From 362af57dc418724528c1e550052c1154c5f9e96b Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Mon, 4 Dec 2017 17:31:00 +0000 Subject: [PATCH 03/11] Add comments --- src/operator/tensor/dot-inl.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index bdc05a681eb7..0df8db5ae688 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -566,6 +566,15 @@ struct PopulateCsrForNNC { * \brief CPU Impl of dot(dns, csr) = csr */ struct DotDnsCsrCsrByRowBlocks { + /*! + * \brief + * \param i the i-th thread + * \param num_rows_r number of rows in rhs + * \param num_rows_l number of rows in lhs + * \param num_cols number of columns in output + * \param nnc number of non zero columns + */ + template MSHADOW_CINLINE static void Map( int i, DType* out, const DType* data_l, const IType* indptr_r, From a1b3db0e0c9335e1a1d16674d0c803bcc1edc1a6 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 5 Dec 2017 05:35:31 +0000 Subject: [PATCH 04/11] Add comments and fix error message --- benchmark/python/sparse/dot.py | 2 +- src/operator/tensor/dot-inl.h | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmark/python/sparse/dot.py b/benchmark/python/sparse/dot.py index bf4e4feaa106..5cfd540c04be 100644 --- a/benchmark/python/sparse/dot.py +++ b/benchmark/python/sparse/dot.py @@ -342,7 +342,7 @@ def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", r distribution="uniform"): if rhs_density > 1 or rhs_density < 0: - raise ValueError("Value other than csr for lhs not supported") + raise ValueError("rhs_density has to be between 0 and 1") print_benchmark_info(lhs, rhs, lhs_trans, fw) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 0df8db5ae688..d69f2acb679b 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -536,6 +536,8 @@ struct DotCsrTransRspRspByRowBlocks { /*! * \brief CPU Kernel of PopulateCsrForNNC * Parallelization by individual rows + * Populates the indptr and indices array + * based on number of non zero columns */ struct PopulateCsrForNNC { /*! From 3fad0200156220a85453dfc6de466412aa813588 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Sat, 9 Dec 2017 19:58:54 +0000 Subject: [PATCH 05/11] Fixes for dot dns csr --- src/operator/tensor/dot-inl.h | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index d69f2acb679b..a4e16136d7c2 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -234,8 +234,11 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && !param.transpose_a && !param.transpose_b) { // dns, csr -> csr + const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; + const DispatchMode dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback + : DispatchMode::kFComputeEx; dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, - DispatchMode::kFComputeEx); + dispatch_ex); } if (!dispatched) { dispatch_fallback(out_attrs, dispatch_mode); @@ -898,9 +901,10 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, const TBlob& lhs, const NDArray& rhs, const OpReqType req, NDArray* ret) { if (kNullOp == req) return; - CHECK_EQ(rhs.storage_type(), kCSRStorage); - if (!rhs.storage_initialized()) return; + CHECK_EQ(req, kWriteTo); + + CHECK_EQ(rhs.storage_type(), kCSRStorage); using namespace mshadow; using namespace mshadow::expr; using nnvm::dim_t; @@ -912,6 +916,11 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, const TBlob data_r = rhs.data(); const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); const TBlob col_idx_r = rhs.aux_data(csr::kIdx); + const dim_t out_data_size = lhs.shape_[0] * rhs.shape()[1]; + if (!rhs.storage_initialized()) { + FillZerosCsrImpl(s, *ret); + } + MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type @@ -965,8 +974,8 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, return; } - dim_t num_threads = mxnet_op::get_num_threads(num_rows_l); - dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads; + const dim_t num_threads = mxnet_op::get_num_threads(num_rows_l); + const dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads; IType num_rows_r = rhs.shape()[0]; mxnet_op::Kernel::Launch( From 2521627c954905848991b409955266ad56160841 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Sun, 10 Dec 2017 04:33:44 +0000 Subject: [PATCH 06/11] Fixes --- src/operator/tensor/dot-inl.h | 41 +++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index a4e16136d7c2..d8b7022b75ae 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -202,22 +202,25 @@ void DotBackward_(const nnvm::NodeAttrs& attrs, inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); const DotParam& param = nnvm::get(attrs.parsed); - // csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp + // csr has many zero columns, so the result of dot(csr.T, matrix) should be + // rsp const auto& lhs_stype = in_attrs->at(0); const auto& rhs_stype = in_attrs->at(1); auto& out_stype = out_attrs->at(0); bool dispatched = false; bool only_lhs_transpose = param.transpose_a && !param.transpose_b; - bool rhs_rsp_or_dns = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage; - if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) { + bool rhs_rsp_or_dns = + rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage; + if (!dispatched && lhs_stype == kDefaultStorage && + rhs_stype == kDefaultStorage) { // dns, dns -> dns - dispatched = storage_type_assign(&out_stype, kDefaultStorage, - dispatch_mode, DispatchMode::kFCompute); + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFCompute); } if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && (rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage)) { @@ -228,17 +231,16 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns && !param.transpose_a && !param.transpose_b) { // csr, rsp/dns -> dns - dispatched = storage_type_assign(&out_stype, kDefaultStorage, - dispatch_mode, DispatchMode::kFComputeEx); + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFComputeEx); } if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && !param.transpose_a && !param.transpose_b) { // dns, csr -> csr - const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; - const DispatchMode dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback - : DispatchMode::kFComputeEx; - dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, - dispatch_ex); + if (dev_mask == mshadow::cpu::kDevMask) { + dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } } if (!dispatched) { dispatch_fallback(out_attrs, dispatch_mode); @@ -897,14 +899,15 @@ inline void DotCsrRspRspImpl(const OpContext& ctx, /* * \brief CPU Impl of dot(dns, csr) = csr */ -inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, +template +inline void DotDnsCsrCsrImpl(const OpContext& ctx, const TBlob& lhs, const NDArray& rhs, const OpReqType req, NDArray* ret) { if (kNullOp == req) return; CHECK_EQ(req, kWriteTo); - CHECK_EQ(rhs.storage_type(), kCSRStorage); + using namespace mshadow; using namespace mshadow::expr; using nnvm::dim_t; @@ -918,10 +921,10 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, const TBlob col_idx_r = rhs.aux_data(csr::kIdx); const dim_t out_data_size = lhs.shape_[0] * rhs.shape()[1]; if (!rhs.storage_initialized()) { - FillZerosCsrImpl(s, *ret); + FillZerosCsrImpl(s, *ret); + return; } - MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type @@ -1067,7 +1070,7 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, out_stype == kCSRStorage && !(param.transpose_a || param.transpose_b)) { NDArray ret = outputs[0]; - DotDnsCsrCsrImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret); + DotDnsCsrCsrImpl(ctx, inputs[0].data(), inputs[1], req[0], &ret); } else { LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs); } From 7077c53645e8561b1880796f871f570d9253e6dc Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Sun, 10 Dec 2017 05:36:19 +0000 Subject: [PATCH 07/11] Remove non required statements --- src/operator/tensor/dot-inl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index d8b7022b75ae..9eab0c2e9441 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -919,7 +919,6 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const TBlob data_r = rhs.data(); const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); const TBlob col_idx_r = rhs.aux_data(csr::kIdx); - const dim_t out_data_size = lhs.shape_[0] * rhs.shape()[1]; if (!rhs.storage_initialized()) { FillZerosCsrImpl(s, *ret); return; From fb20f4518e15f45c6495a83efd52044c9b9cbaa3 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Wed, 13 Dec 2017 01:20:12 +0000 Subject: [PATCH 08/11] Add fallback for GPU --- src/operator/tensor/dot-inl.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 9eab0c2e9441..eb7ba5b44366 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -237,10 +237,11 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && !param.transpose_a && !param.transpose_b) { // dns, csr -> csr - if (dev_mask == mshadow::cpu::kDevMask) { - dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, - DispatchMode::kFComputeEx); - } + const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; + const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback + : DispatchMode::kFComputeEx; + dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, + dispatch_ex); } if (!dispatched) { dispatch_fallback(out_attrs, dispatch_mode); From 44789d288bfcb9d8225a0d304083e6761c104350 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Wed, 13 Dec 2017 01:35:14 +0000 Subject: [PATCH 09/11] Remove unused if --- include/mxnet/ndarray.h | 5 ++++- src/operator/tensor/dot-inl.h | 4 ---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 8398b7bf7291..b65f1acf9f20 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -305,7 +305,10 @@ class NDArray { bool fresh_out_grad() const; /*! \return updated grad state in entry_ */ void set_fresh_out_grad(bool state) const; - // returns true if a sparse ndarray's aux_data and storage are initialized + /*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized + * Returns false if the indices array shape is inconsistent + * or the indices array is empty(nnz = 0) for csr/row_sparse + */ inline bool storage_initialized() const { if (is_none()) return false; auto stype = storage_type(); diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index eb7ba5b44366..01b1387c57bb 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -973,10 +973,6 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l); mxnet_op::Kernel::Launch(s, nnz, data_out); - if (nnc == 0) { - return; - } - const dim_t num_threads = mxnet_op::get_num_threads(num_rows_l); const dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads; From d6f13a38470aade7e19608277ab6c303fc2762ee Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 4 Jan 2018 00:21:36 +0000 Subject: [PATCH 10/11] Fix comments and casting --- include/mxnet/ndarray.h | 4 ++-- src/operator/tensor/dot-inl.h | 6 ++++-- tests/python/unittest/test_sparse_operator.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index b65f1acf9f20..a18d2daec8c3 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -306,8 +306,8 @@ class NDArray { /*! \return updated grad state in entry_ */ void set_fresh_out_grad(bool state) const; /*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized - * Returns false if the indices array shape is inconsistent - * or the indices array is empty(nnz = 0) for csr/row_sparse + * Throws an exception if the indices array shape is inconsistent + * Returns false if the indices array is empty(nnz = 0) for csr/row_sparse */ inline bool storage_initialized() const { if (is_none()) return false; diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 01b1387c57bb..244f34e911ac 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -245,6 +245,8 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, } if (!dispatched) { dispatch_fallback(out_attrs, dispatch_mode); + } + if (static_cast(*dispatch_mode) == DispatchMode::kFComputeFallback) { LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs); } return true; @@ -563,7 +565,7 @@ struct PopulateCsrForNNC { const CType start_idx = i * nnc; nnvm::dim_t cur = 0; indptr_out[i] = start_idx; - if (i == static_cast(num_rows_l - 1)) indptr_out[i + 1] = indptr_out[i] + nnc; + if (static_cast(i) == (num_rows_l - 1)) indptr_out[i + 1] = indptr_out[i] + nnc; for (IType idx = start_idx; idx < (start_idx + nnc); idx++) { col_idx_out[idx] = nnc_idx[cur++]; } @@ -913,7 +915,7 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, using namespace mshadow::expr; using nnvm::dim_t; - /*Initialize data structures*/ + /* Initialize data structures */ mshadow::Stream* s = ctx.get_stream(); const NDArray& out = *ret; const TBlob data_l = lhs; diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index c6b5831f2e1b..134cb260436e 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1273,7 +1273,7 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True, lhs_d, rhs_d) # (vector kernel) test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d) # (scalar kernel) - test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(500, 1000)), lhs_d, lhs_d) + test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(50, 200)), lhs_d, lhs_d) for rhs_d in density: test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d) test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d) From 9a81b789c81c4043cb3400c2e5743bb98a1eca5b Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 4 Jan 2018 00:37:19 +0000 Subject: [PATCH 11/11] Add operator to the documentation --- src/operator/tensor/dot.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index a7fa2c7933a5..834b559b86f6 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -56,6 +56,7 @@ The storage type of ``dot`` output depends on storage types of inputs and transp - dot(csr, default) = default - dot(csr.T, default) = row_sparse - dot(csr, row_sparse) = default +- dot(default, csr) = csr - otherwise, ``dot`` generates output with default storage )doc" ADD_FILELINE)