diff --git a/topi/python/topi/cuda/sparse.py b/topi/python/topi/cuda/sparse.py index 037eea4477b7..fb875b749750 100644 --- a/topi/python/topi/cuda/sparse.py +++ b/topi/python/topi/cuda/sparse.py @@ -69,6 +69,11 @@ def _callback(op): y_bsrmm = op.input_tensors[0] assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block" out = s.outputs[0].output(0) + + if op not in s.outputs: + y_reshape = op.output(0) + s[y_reshape].compute_at(s[out], s[out].op.axis[1]) + (_, c) = s[y_bsrmm].op.reduce_axis (m_o, n_o) = s[out].op.axis diff --git a/topi/tests/python/test_topi_sparse.py b/topi/tests/python/test_topi_sparse.py index 3290fc0b0941..748181dc650b 100644 --- a/topi/tests/python/test_topi_sparse.py +++ b/topi/tests/python/test_topi_sparse.py @@ -288,12 +288,13 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): assert s.indptr.shape == (M // BS_R + 1, ) return s -def test_sparse_dense_bsr(): - M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9 +def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu): X_np = np.random.randn(M, K).astype("float32") W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") W_np = W_sp_np.todense() Y_np = X_np.dot(W_np.T) + if use_relu: + Y_np = np.maximum(Y_np, 0.0) W_data = te.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype)) W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype)) @@ -309,6 +310,8 @@ def check_device(device): fcompute, fschedule = topi.testing.dispatch(device, _sparse_dense_implement) with tvm.target.create(device): Y = fcompute(X, W_data, W_indices, W_indptr) + if use_relu: + Y = topi.nn.relu(Y) s = fschedule([Y]) func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) @@ -322,6 +325,11 @@ def check_device(device): for device in ['llvm', 'cuda']: check_device(device) +def test_sparse_dense_bsr(): + M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9 + verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=True) + verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=False) + def test_sparse_dense_bsr_randomized(): for _ in range(20): BS_R = np.random.randint(1, 16)