From 231975c27b73438f430dcef81e1531768a3fe000 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 4 May 2021 12:25:06 -0700 Subject: [PATCH] [SPARSE] Improve sparse performance on ROCM (#7935) * [SPARSE] Improve sparse performance on ROCM The current sparse dense gpu kernel uses warp level storage to handling caching of data. Warp level storage uses shuffle intrinsics, which are slow on rocm (because they actually read and write to shared memory). Rocm does provide intrinsics to do the correct memory management, but they are not available through tvm. Instead this PR switches to using shared memory on rocm devices. Performance is about 2x faster. * default to shared mem * formatting * formatting --- python/tvm/relay/op/strategy/cuda.py | 2 +- python/tvm/topi/cuda/sparse.py | 55 ++++++++++++++++---- tests/python/topi/python/test_topi_sparse.py | 15 +++--- 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index e5aa8aa10620..a6775ae7bd20 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -787,7 +787,7 @@ def sparse_reshape_strategy_cuda(attrs, inputs, out_type, target): return strategy -@sparse_dense_padded_strategy.register(["cuda", "gpu"]) +@sparse_dense_padded_strategy.register(["cuda", "gpu", "rocm"]) def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target): """sparse dense cuda strategy""" strategy = _op.OpStrategy() diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index f68b31ec30ef..1e846ebf5311 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -163,13 +163,24 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr): """ def gen_ir(data, w_data, w_indices, w_indptr, out): - # pylint: disable=invalid-name + # pylint: disable=invalid-name, simplifiable-if-statement # TODO(tkonolige): use tensorcores for block multiply # TODO(tkonolige): use vectorize on loads # TODO(tkonolige): seperate implementation if M is small # TODO(tkonolige): seperate implementation for large block sizes ib = tvm.tir.ir_builder.create() + if tvm.target.Target.current(allow_none=False).kind.name == "cuda": + use_warp_storage = True + else: + # TVMs warp shuffle intrinsics are slow on ROCM because they use + # LDS (shared memory) to do the shuffling. Instead, we could use + # ROCM's support for accessing neighboring threads memory, but we + # those intrinsics aren't accessible from TVM. For now, we just use + # shared memory. We also default to shared memory on platforms + # where we do not know how warp storage performs. + use_warp_storage = False + warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) m = data.shape[1] nb = w_indptr.shape[0] - 1 @@ -218,11 +229,19 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): # thread local storage for bs_m x bs_n block block = ib.allocate(data.dtype, (bs_m, bs_n), name="block", scope="local") - indices = ib.allocate(w_indices.dtype, (rowlength_bi,), name="indices", scope="warp") data_cache = ib.allocate(data.dtype, (mi, bs_m, bs_k), name="data_cache", scope="local") - w_data_cache = ib.allocate( - w_data.dtype, (rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="warp" - ) + if use_warp_storage: + indices = ib.allocate(w_indices.dtype, (rowlength_bi,), name="indices", scope="warp") + w_data_cache = ib.allocate( + w_data.dtype, (rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="warp" + ) + else: + indices = ib.allocate( + w_indices.dtype, (ni, rowlength_bi), name="indices", scope="shared" + ) + w_data_cache = ib.allocate( + w_data.dtype, (ni, rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="shared" + ) # zero block with ib.for_range(0, bs_m, name="x", kind="unroll") as x: @@ -232,7 +251,10 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): with ib.for_range(0, rowlength_bo, name="bb") as bb: elem_idx = bb * rowlength_bi + tx # Cache indices. Guaranteed to be multiple of warp_size. - indices[elem_idx] = w_indices_ptr[row_start + elem_idx] + if use_warp_storage: + indices[tx] = w_indices_ptr[row_start + elem_idx] + else: + indices[warp, tx] = w_indices_ptr[row_start + elem_idx] # cache dense matrix # each thread has a row # TODO: ideally we could vectorize this @@ -242,18 +264,29 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): # This memory acces should be out of bounds when # m_index >= mb (which occurs when the dense matrix # rows % 32 != 0), but it seems to work just fine... - data_cache[bi, x, z] = data_ptr[indices[bi] * bs_k + z, m_index * bs_m + x] + if use_warp_storage: + ind = indices[bi] + else: + ind = indices[warp, bi] + data_cache[bi, x, z] = data_ptr[ind * bs_k + z, m_index * bs_m + x] # cache w_data elem_idx = bb * rowlength_bi + tx with ib.for_range(0, bs_n, name="y", kind="unroll") as y: with ib.for_range(0, bs_k, name="z", kind="unroll") as z: - w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] + if use_warp_storage: + w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] + else: + w_data_cache[warp, tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] with ib.for_range(0, mi, name="i") as i: # thread local block matmul with ib.for_range(0, bs_m, name="x", kind="unroll") as x: with ib.for_range(0, bs_n, name="y", kind="unroll") as y: with ib.for_range(0, bs_k, name="z", kind="unroll") as z: - block[x, y] += data_cache[i, x, z] * w_data_cache[i, y, z] + if use_warp_storage: + w = w_data_cache[i, y, z] + else: + w = w_data_cache[warp, i, y, z] + block[x, y] += data_cache[i, x, z] * w # store results with ib.for_range(0, bs_m, name="x", kind="unroll") as x: with ib.for_range(0, bs_n, name="y", kind="unroll") as y: @@ -391,11 +424,11 @@ def pad_sparse_matrix(matrix, blocksize): return sp.bsr_matrix((data, indices, indptr), matrix.shape) -@nn.sparse_dense_alter_layout.register(["cuda", "gpu"]) +@nn.sparse_dense_alter_layout.register(["cuda", "gpu", "rocm"]) def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type): """With cuda, we modify use alter_op_layout to swap the default sparse_dense implementation for one that operates on a padded matrix. We - also padd the matrix. + also pad the matrix. """ # TODO(ANSHUMAN87): Handle for sparse_lhs case too if ( diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index 500384b23f2a..98a3ec86180c 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -463,8 +463,8 @@ def check_device(device): check_device(device) -@tvm.testing.requires_cuda -def test_sparse_dense_padded_cuda(): +@tvm.testing.parametrize_targets("cuda", "rocm") +def test_sparse_dense_padded_gpu(target, dev): M = 128 N = 1280 K = 128 @@ -483,8 +483,7 @@ def test_sparse_dense_padded_cuda(): shape=W_sp_np_padded.indptr.shape, dtype=str(W_sp_np_padded.indptr.dtype) ) X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) - with tvm.target.Target("cuda"): - dev = tvm.device("gpu") + with tvm.target.Target(target): Y = topi.cuda.sparse_dense_padded(X, W_data, W_indices, W_indptr) s = topi.cuda.schedule_sparse_dense_padded([Y]) func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) @@ -499,9 +498,9 @@ def test_sparse_dense_padded_cuda(): tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5) -@tvm.testing.requires_cuda -def test_sparse_dense_padded_alter_op(): - with tvm.target.Target("cuda"): +@tvm.testing.parametrize_targets("cuda", "rocm") +def test_sparse_dense_padded_alter_op(target, dev): + with tvm.target.Target(target): M = 128 N = 16 K = 128 @@ -523,7 +522,7 @@ def test_sparse_dense_padded_alter_op(): # build with cuda and AlterOpLayout to ensure that sparse_dense_padded is in action with tvm.transform.PassContext(opt_level=3, required_pass="AlterOpLayout"): - x = relay.build(tvm.IRModule.from_expr(f), target=tvm.target.Target("cuda")) + x = relay.build(tvm.IRModule.from_expr(f), target=target) def test_sparse_add_csr():