Skip to content

Commit

Permalink
[topi] fix sparse dense schedule on cuda (apache#5803)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceruleangu authored and Trevor Morris committed Jun 15, 2020
1 parent f250700 commit 33fcf79
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
5 changes: 5 additions & 0 deletions topi/python/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def _callback(op):
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
out = s.outputs[0].output(0)

if op not in s.outputs:
y_reshape = op.output(0)
s[y_reshape].compute_at(s[out], s[out].op.axis[1])

(_, c) = s[y_bsrmm].op.reduce_axis

(m_o, n_o) = s[out].op.axis
Expand Down
12 changes: 10 additions & 2 deletions topi/tests/python/test_topi_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,13 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype):
assert s.indptr.shape == (M // BS_R + 1, )
return s

def test_sparse_dense_bsr():
M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9
def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu):
X_np = np.random.randn(M, K).astype("float32")
W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32")
W_np = W_sp_np.todense()
Y_np = X_np.dot(W_np.T)
if use_relu:
Y_np = np.maximum(Y_np, 0.0)

W_data = te.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype))
W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
Expand All @@ -309,6 +310,8 @@ def check_device(device):
fcompute, fschedule = topi.testing.dispatch(device, _sparse_dense_implement)
with tvm.target.create(device):
Y = fcompute(X, W_data, W_indices, W_indptr)
if use_relu:
Y = topi.nn.relu(Y)
s = fschedule([Y])
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx)
Expand All @@ -322,6 +325,11 @@ def check_device(device):
for device in ['llvm', 'cuda']:
check_device(device)

def test_sparse_dense_bsr():
M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9
verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=True)
verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=False)

def test_sparse_dense_bsr_randomized():
for _ in range(20):
BS_R = np.random.randint(1, 16)
Expand Down

0 comments on commit 33fcf79

Please sign in to comment.