diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index ca4e8d38b318..58c4bba30cbc 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -371,6 +371,11 @@ struct SparseDenseAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {} }; +/*! \brief Attributes for sparse_transpose operator */ +struct SparseTransposeAttrs : public tvm::AttrsNode { + TVM_DECLARE_ATTRS(SparseTransposeAttrs, "relay.attrs.SparseTransposeAttrs") {} +}; + /*! \brief Attributes for upsampling operator */ struct UpSamplingAttrs : public tvm::AttrsNode { int scale; diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index b50a27bd7267..0c374b82efc1 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -99,6 +99,20 @@ def schedule_sparse_dense(attrs, outputs, target): reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +# sparse_transpose +@reg.register_compute("nn.sparse_transpose") +def compute_sparse_transpose(attrs, inputs, out_type, target): + """Compute definition of sparse_transpose""" + return topi.nn.sparse_transpose(inputs[0], inputs[1], inputs[2]) + +@reg.register_schedule("nn.sparse_transpose") +def schedule_sparse_transpose(attrs, outputs, target): + """Schedule definition of batch_matmul""" + with target: + return topi.generic.schedule_sparse_transpose(outputs) + +reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + # conv2d def _find_conv2d_op(op): """Find the op with conv2d in its tag by traversing.""" diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 46c01be2c23e..4a83ef233c24 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -954,6 +954,33 @@ def sparse_dense(data, weight): """ return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr) +def sparse_transpose(x): + r""" + Computes the fast matrix transpose of x, + where x is a sparse tensor in CSR format (represented as a namedtuple + with fields `data`, `indices`, and `indptr`). + + ** Currently only support Square Matrices ** + + .. math:: + + \mbox{sparse_transpose}(x)[n, n] = (x^T)[n, n] + + Please refer to https://github.com/scipy/scipy/blob/v1.3.0/scipy/sparse/csr.py + for the algorithm implemented in this operator. + + Parameters + ---------- + x : namedtuple. + The sparse weight matrix for the fast matrix transpose. + + Returns + ------- + result : relay.Tuple([tvm.relay.Expr, tvm.relay.Expr, tvm.relay.Expr]) + Tuple of output sparse tensor (same shape and format as input), + i.e. if CSR then output is in ([data, indices, indptr]) form + """ + return TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3) def contrib_conv2d_winograd_without_weight_transform(data, weight, diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 3e8178781317..48a9b11f7651 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -72,26 +72,72 @@ Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weig } TVM_REGISTER_API("relay.op.nn._make.sparse_dense") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeSparseDense, args, rv); - }); +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeSparseDense, args, rv); +}); RELAY_REGISTER_OP("nn.sparse_dense") - .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse. +.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse. - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(units, input_dim)` - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) - .set_attrs_type_key("relay.attrs.SparseDenseAttrs") - .set_num_inputs(4) - .add_argument("data", "nD Tensor", "Input data.") - .add_argument("weight_data", "1D Tensor", "Weight data matrix.") - .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.") - .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.") - .set_support_level(1) - .add_type_rel("SparseDense", SparseDenseRel); +.set_attrs_type_key("relay.attrs.SparseDenseAttrs") +.set_num_inputs(4) +.add_argument("data", "nD Tensor", "Input data.") +.add_argument("weight_data", "1D Tensor", "Weight data matrix.") +.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.") +.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.") +.set_support_level(1) +.add_type_rel("SparseDense", SparseDenseRel); + +// relay.nn.sparse_transpose +TVM_REGISTER_NODE_TYPE(SparseTransposeAttrs); + +bool SparseTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4); + const auto* sparse_data = types[0].as(); + CHECK_EQ(sparse_data->shape.size(), 1); + const auto* sparse_indices = types[1].as(); + CHECK_EQ(sparse_indices->shape.size(), 1); + const auto* sparse_indptr = types[2].as(); + + std::vector output_types; + output_types.push_back(TensorTypeNode::make(sparse_data->shape, sparse_data->dtype)); + output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype)); + output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype)); + + reporter->Assign(types[3], TupleTypeNode::make(Array(output_types))); + return true; +} + +Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) { + auto attrs = make_node(); + static const Op& op = Op::Get("nn.sparse_transpose"); + return CallNode::make(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.sparse_transpose") +.set_body_typed(MakeSparseTranspose); + + +RELAY_REGISTER_OP("nn.sparse_transpose") +.describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix + +- **input**: `(N, N)` +- **out**: `(N, N)`. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.SparseTransposeAttrs") +.set_num_inputs(3) +.add_argument("sparse_data", "1D Tensor", "Sparse data matrix.") +.add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.") +.add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.") +.set_support_level(1) +.add_type_rel("SparseTranspose", SparseTransposeRel); } // namespace relay } // namespace tvm diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 59ee7001bfd2..38b66320b428 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -543,6 +543,23 @@ def schedule_sparse_dense(outs): """ return _default_schedule(outs, False) +@tvm.target.generic_func +def schedule_sparse_transpose(outs): + """Schedule for sparse_transpose + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of sparse_transpose + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + @tvm.target.generic_func def schedule_batch_matmul(outs): target = tvm.target.current_target(allow_none=False) diff --git a/topi/python/topi/nn/sparse.py b/topi/python/topi/nn/sparse.py index 17b30ad464fa..11116b2e6d2c 100644 --- a/topi/python/topi/nn/sparse.py +++ b/topi/python/topi/nn/sparse.py @@ -101,3 +101,106 @@ def _compute_block(i, nb_j, j): (m, num_blocks * bs_r), lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r], tag="sparse_dense_bsrmm") + +@tvm.target.generic_func +def sparse_transpose(sparse_data, sparse_indices, sparse_indptr): + """ + Transpose a square sparse matrix, + `A` is an n-by-n sparse matrix in the CSR format. + ** Currently only support Square Matrices ** + + Parameters + ---------- + sparse_data : tvm.Tensor + 1-D with shape [nonzeros], dtype of 'float32' + + sparse_indices : tvm.Tensor + 1-D with shape [nonzeros], dtype of 'int32' + + sparse_indptr : tvm.Tensor + 1-D with shape [n+1], dtype of 'int32' + + Returns + ------- + out_data : tvm.Tensor + 1-D with shape [nonzeros], dtype of 'float32' + + out_indices : tvm.Tensor + 1-D with shape [nonzeros], dtype of 'int32' + + out_indptr : tvm.Tensor + 1-D with shape [n+1], dtype of 'int32' + """ + assert len(sparse_data.shape) == 1, "error in data dimension" + assert len(sparse_indices.shape) == 1, "error in indices dimension" + assert len(sparse_indptr.shape) == 1, "error in indptr dimension" + + nnz = get_const_tuple(sparse_data.shape)[0] + n = get_const_tuple(sparse_indptr.shape)[0] - 1 + output_shape = [(nnz,), (nnz,), (n+1,)] + + # TODO: Add BSR transpose support + + output_data, output_indices, output_indptr = tvm.extern( + shape=output_shape, + inputs=[sparse_data, sparse_indices, sparse_indptr], + fcompute=lambda ins, outs: + csr_transpose_ir(ins[0], ins[1], ins[2], outs[0], outs[1], outs[2]), + tag="sparse_transpose_csr", + dtype=['float32', 'int32', 'int32'], + name='out') + + return [output_data, output_indices, output_indptr] + +def csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr): + """define ir for csr_transpose""" + irb = tvm.ir_builder.create() + + data_ptr = irb.buffer_ptr(data) + indices_ptr = irb.buffer_ptr(indices) + indptr_ptr = irb.buffer_ptr(indptr) + + out_data_ptr = irb.buffer_ptr(out_data) + out_indices_ptr = irb.buffer_ptr(out_indices) + out_indptr_ptr = irb.buffer_ptr(out_indptr) + + n = get_const_tuple(indptr.shape)[0] - 1 + nnz = get_const_tuple(data.shape)[0] + + with irb.for_range(0, n, for_type="parallel", name='col') as col: + out_indptr_ptr[col] = 0 + + with irb.for_range(0, nnz, for_type="serial", name='nz_idx') as nz_idx: + out_indptr_ptr[indices_ptr[nz_idx]] += 1 + + cumsum = irb.allocate('int32', (1,), name='cumsum', scope='local') + temp = irb.allocate('int32', (1,), name='temp', scope='local') + cumsum[0] = 0 + with irb.for_range(0, n, for_type="serial", name='col') as col: + temp[0] = out_indptr_ptr[col] + out_indptr_ptr[col] = cumsum[0] + cumsum[0] += temp[0] + + out_indptr_ptr[n] = nnz + + with irb.for_range(0, n, for_type="serial", name='row') as row: + offset = indptr_ptr[row] + diff = indptr_ptr[row+1] - indptr_ptr[row] + with irb.for_range(0, diff, for_type="serial", name='idx') as idx: + real_idx = offset + idx + col = indices_ptr[real_idx] + dest = out_indptr_ptr[col] + + out_indices_ptr[dest] = row + out_data_ptr[dest] = data_ptr[real_idx] + out_indptr_ptr[col] += 1 + + last = irb.allocate('int32', (1,), name='last', scope='local') + temp2 = irb.allocate('int32', (1,), name='temp2', scope='local') + last[0] = 0 + with irb.for_range(0, n, for_type="serial", name="col") as col: + temp2[0] = out_indptr_ptr[col] + out_indptr_ptr[col] = last[0] + last[0] = temp2[0] + + return irb.get() diff --git a/topi/tests/python/test_topi_sparse.py b/topi/tests/python/test_topi_sparse.py index 49324b74a3f7..1b40b130c75b 100644 --- a/topi/tests/python/test_topi_sparse.py +++ b/topi/tests/python/test_topi_sparse.py @@ -23,6 +23,7 @@ import tvm.contrib.sparse as tvmsp from collections import namedtuple import time +import scipy.sparse as sp def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True): nr, nc, n = tvm.var("nr"), tvm.var("nc"), tvm.var("n") @@ -217,7 +218,6 @@ def test_dense(): def test_sparse_dense_csr(): - import scipy.sparse as sp M, N, K, density = 1, 17, 47, 0.2 X_np = np.random.randn(M, K).astype("float32") W_sp_np = sp.random(N, K, density=density, format='csr', dtype="float32") @@ -235,9 +235,34 @@ def test_sparse_dense_csr(): func(tvm.ndarray.array(X_np), tvm.ndarray.array(W_sp_np.data), tvm.ndarray.array(W_sp_np.indices), tvm.ndarray.array(W_sp_np.indptr), Y_tvm) tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) +def test_sparse_transpose_csr(): + N, density = 1023, 0.3 + + X_sp = sp.random(N, N, density=density, format='csr', dtype='float32') + + X_sp_T = X_sp.transpose() + X_np_T = X_sp_T.todense() + + X_data = tvm.placeholder(shape=X_sp.data.shape, dtype=str(X_sp.data.dtype)) + X_indices = tvm.placeholder(shape=X_sp.indices.shape, dtype=str(X_sp.indices.dtype)) + X_indptr = tvm.placeholder(shape=X_sp.indptr.shape, dtype=str(X_sp.indptr.dtype)) + + X_T_data, X_T_indices, X_T_indptr = topi.nn.sparse_transpose(X_data, X_indices, X_indptr) + s = tvm.create_schedule([X_T_data.op, X_T_indices.op, X_T_indptr.op]) + func = tvm.build(s, [X_data, X_indices, X_indptr, X_T_data, X_T_indices, X_T_indptr]) + + + X_T_data_tvm = tvm.ndarray.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype)) + X_T_indices_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype)) + X_T_indptr_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype)) + + func(tvm.ndarray.array(X_sp.data), tvm.ndarray.array(X_sp.indices), tvm.ndarray.array(X_sp.indptr), + X_T_data_tvm, X_T_indices_tvm, X_T_indptr_tvm) + + X_T_out = sp.csr_matrix((X_T_data_tvm.asnumpy(), X_T_indices_tvm.asnumpy(), X_T_indptr_tvm.asnumpy()), shape=(N,N)).todense() + tvm.testing.assert_allclose(X_np_T, X_T_out, atol=1e-4, rtol=1e-4) def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): - import scipy.sparse as sp import itertools Y = np.zeros((M, N), dtype=dtype) assert M % BS_R == 0 @@ -318,3 +343,4 @@ def test_sparse_dense(): test_csrmm() test_dense() test_sparse_dense() + test_sparse_transpose_csr()