Skip to content

Commit

Permalink
[SPARSE] Improve sparse performance on ROCM (apache#7935)
Browse files Browse the repository at this point in the history
* [SPARSE] Improve sparse performance on ROCM

The current sparse dense gpu kernel uses warp level storage to handling
caching of data. Warp level storage uses shuffle intrinsics, which are
slow on rocm (because they actually read and write to shared memory).
Rocm does provide intrinsics to do the correct memory management, but
they are not available through tvm. Instead this PR switches to using
shared memory on rocm devices. Performance is about 2x faster.

* default to shared mem

* formatting

* formatting
  • Loading branch information
tkonolige authored and trevor-m committed May 11, 2021
1 parent 77ed2f5 commit 231975c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def sparse_reshape_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@sparse_dense_padded_strategy.register(["cuda", "gpu"])
@sparse_dense_padded_strategy.register(["cuda", "gpu", "rocm"])
def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):
"""sparse dense cuda strategy"""
strategy = _op.OpStrategy()
Expand Down
55 changes: 44 additions & 11 deletions python/tvm/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,24 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr):
"""

def gen_ir(data, w_data, w_indices, w_indptr, out):
# pylint: disable=invalid-name
# pylint: disable=invalid-name, simplifiable-if-statement
# TODO(tkonolige): use tensorcores for block multiply
# TODO(tkonolige): use vectorize on loads
# TODO(tkonolige): seperate implementation if M is small
# TODO(tkonolige): seperate implementation for large block sizes
ib = tvm.tir.ir_builder.create()

if tvm.target.Target.current(allow_none=False).kind.name == "cuda":
use_warp_storage = True
else:
# TVMs warp shuffle intrinsics are slow on ROCM because they use
# LDS (shared memory) to do the shuffling. Instead, we could use
# ROCM's support for accessing neighboring threads memory, but we
# those intrinsics aren't accessible from TVM. For now, we just use
# shared memory. We also default to shared memory on platforms
# where we do not know how warp storage performs.
use_warp_storage = False

warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
m = data.shape[1]
nb = w_indptr.shape[0] - 1
Expand Down Expand Up @@ -218,11 +229,19 @@ def gen_ir(data, w_data, w_indices, w_indptr, out):

# thread local storage for bs_m x bs_n block
block = ib.allocate(data.dtype, (bs_m, bs_n), name="block", scope="local")
indices = ib.allocate(w_indices.dtype, (rowlength_bi,), name="indices", scope="warp")
data_cache = ib.allocate(data.dtype, (mi, bs_m, bs_k), name="data_cache", scope="local")
w_data_cache = ib.allocate(
w_data.dtype, (rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="warp"
)
if use_warp_storage:
indices = ib.allocate(w_indices.dtype, (rowlength_bi,), name="indices", scope="warp")
w_data_cache = ib.allocate(
w_data.dtype, (rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="warp"
)
else:
indices = ib.allocate(
w_indices.dtype, (ni, rowlength_bi), name="indices", scope="shared"
)
w_data_cache = ib.allocate(
w_data.dtype, (ni, rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="shared"
)

# zero block
with ib.for_range(0, bs_m, name="x", kind="unroll") as x:
Expand All @@ -232,7 +251,10 @@ def gen_ir(data, w_data, w_indices, w_indptr, out):
with ib.for_range(0, rowlength_bo, name="bb") as bb:
elem_idx = bb * rowlength_bi + tx
# Cache indices. Guaranteed to be multiple of warp_size.
indices[elem_idx] = w_indices_ptr[row_start + elem_idx]
if use_warp_storage:
indices[tx] = w_indices_ptr[row_start + elem_idx]
else:
indices[warp, tx] = w_indices_ptr[row_start + elem_idx]
# cache dense matrix
# each thread has a row
# TODO: ideally we could vectorize this
Expand All @@ -242,18 +264,29 @@ def gen_ir(data, w_data, w_indices, w_indptr, out):
# This memory acces should be out of bounds when
# m_index >= mb (which occurs when the dense matrix
# rows % 32 != 0), but it seems to work just fine...
data_cache[bi, x, z] = data_ptr[indices[bi] * bs_k + z, m_index * bs_m + x]
if use_warp_storage:
ind = indices[bi]
else:
ind = indices[warp, bi]
data_cache[bi, x, z] = data_ptr[ind * bs_k + z, m_index * bs_m + x]
# cache w_data
elem_idx = bb * rowlength_bi + tx
with ib.for_range(0, bs_n, name="y", kind="unroll") as y:
with ib.for_range(0, bs_k, name="z", kind="unroll") as z:
w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z]
if use_warp_storage:
w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z]
else:
w_data_cache[warp, tx, y, z] = w_data_ptr[row_start + elem_idx, y, z]
with ib.for_range(0, mi, name="i") as i:
# thread local block matmul
with ib.for_range(0, bs_m, name="x", kind="unroll") as x:
with ib.for_range(0, bs_n, name="y", kind="unroll") as y:
with ib.for_range(0, bs_k, name="z", kind="unroll") as z:
block[x, y] += data_cache[i, x, z] * w_data_cache[i, y, z]
if use_warp_storage:
w = w_data_cache[i, y, z]
else:
w = w_data_cache[warp, i, y, z]
block[x, y] += data_cache[i, x, z] * w
# store results
with ib.for_range(0, bs_m, name="x", kind="unroll") as x:
with ib.for_range(0, bs_n, name="y", kind="unroll") as y:
Expand Down Expand Up @@ -391,11 +424,11 @@ def pad_sparse_matrix(matrix, blocksize):
return sp.bsr_matrix((data, indices, indptr), matrix.shape)


@nn.sparse_dense_alter_layout.register(["cuda", "gpu"])
@nn.sparse_dense_alter_layout.register(["cuda", "gpu", "rocm"])
def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
"""With cuda, we modify use alter_op_layout to swap the default
sparse_dense implementation for one that operates on a padded matrix. We
also padd the matrix.
also pad the matrix.
"""
# TODO(ANSHUMAN87): Handle for sparse_lhs case too
if (
Expand Down
15 changes: 7 additions & 8 deletions tests/python/topi/python/test_topi_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,8 @@ def check_device(device):
check_device(device)


@tvm.testing.requires_cuda
def test_sparse_dense_padded_cuda():
@tvm.testing.parametrize_targets("cuda", "rocm")
def test_sparse_dense_padded_gpu(target, dev):
M = 128
N = 1280
K = 128
Expand All @@ -483,8 +483,7 @@ def test_sparse_dense_padded_cuda():
shape=W_sp_np_padded.indptr.shape, dtype=str(W_sp_np_padded.indptr.dtype)
)
X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
with tvm.target.Target("cuda"):
dev = tvm.device("gpu")
with tvm.target.Target(target):
Y = topi.cuda.sparse_dense_padded(X, W_data, W_indices, W_indptr)
s = topi.cuda.schedule_sparse_dense_padded([Y])
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Expand All @@ -499,9 +498,9 @@ def test_sparse_dense_padded_cuda():
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)


@tvm.testing.requires_cuda
def test_sparse_dense_padded_alter_op():
with tvm.target.Target("cuda"):
@tvm.testing.parametrize_targets("cuda", "rocm")
def test_sparse_dense_padded_alter_op(target, dev):
with tvm.target.Target(target):
M = 128
N = 16
K = 128
Expand All @@ -523,7 +522,7 @@ def test_sparse_dense_padded_alter_op():

# build with cuda and AlterOpLayout to ensure that sparse_dense_padded is in action
with tvm.transform.PassContext(opt_level=3, required_pass="AlterOpLayout"):
x = relay.build(tvm.IRModule.from_expr(f), target=tvm.target.Target("cuda"))
x = relay.build(tvm.IRModule.from_expr(f), target=target)


def test_sparse_add_csr():
Expand Down

0 comments on commit 231975c

Please sign in to comment.