diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index e5aa8aa106207..a6775ae7bd200 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -787,7 +787,7 @@ def sparse_reshape_strategy_cuda(attrs, inputs, out_type, target): return strategy -@sparse_dense_padded_strategy.register(["cuda", "gpu"]) +@sparse_dense_padded_strategy.register(["cuda", "gpu", "rocm"]) def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target): """sparse dense cuda strategy""" strategy = _op.OpStrategy() diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index f68b31ec30efb..1e846ebf53111 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -163,13 +163,24 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr): """ def gen_ir(data, w_data, w_indices, w_indptr, out): - # pylint: disable=invalid-name + # pylint: disable=invalid-name, simplifiable-if-statement # TODO(tkonolige): use tensorcores for block multiply # TODO(tkonolige): use vectorize on loads # TODO(tkonolige): seperate implementation if M is small # TODO(tkonolige): seperate implementation for large block sizes ib = tvm.tir.ir_builder.create() + if tvm.target.Target.current(allow_none=False).kind.name == "cuda": + use_warp_storage = True + else: + # TVMs warp shuffle intrinsics are slow on ROCM because they use + # LDS (shared memory) to do the shuffling. Instead, we could use + # ROCM's support for accessing neighboring threads memory, but we + # those intrinsics aren't accessible from TVM. For now, we just use + # shared memory. We also default to shared memory on platforms + # where we do not know how warp storage performs. + use_warp_storage = False + warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) m = data.shape[1] nb = w_indptr.shape[0] - 1 @@ -218,11 +229,19 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): # thread local storage for bs_m x bs_n block block = ib.allocate(data.dtype, (bs_m, bs_n), name="block", scope="local") - indices = ib.allocate(w_indices.dtype, (rowlength_bi,), name="indices", scope="warp") data_cache = ib.allocate(data.dtype, (mi, bs_m, bs_k), name="data_cache", scope="local") - w_data_cache = ib.allocate( - w_data.dtype, (rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="warp" - ) + if use_warp_storage: + indices = ib.allocate(w_indices.dtype, (rowlength_bi,), name="indices", scope="warp") + w_data_cache = ib.allocate( + w_data.dtype, (rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="warp" + ) + else: + indices = ib.allocate( + w_indices.dtype, (ni, rowlength_bi), name="indices", scope="shared" + ) + w_data_cache = ib.allocate( + w_data.dtype, (ni, rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="shared" + ) # zero block with ib.for_range(0, bs_m, name="x", kind="unroll") as x: @@ -232,7 +251,10 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): with ib.for_range(0, rowlength_bo, name="bb") as bb: elem_idx = bb * rowlength_bi + tx # Cache indices. Guaranteed to be multiple of warp_size. - indices[elem_idx] = w_indices_ptr[row_start + elem_idx] + if use_warp_storage: + indices[tx] = w_indices_ptr[row_start + elem_idx] + else: + indices[warp, tx] = w_indices_ptr[row_start + elem_idx] # cache dense matrix # each thread has a row # TODO: ideally we could vectorize this @@ -242,18 +264,29 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): # This memory acces should be out of bounds when # m_index >= mb (which occurs when the dense matrix # rows % 32 != 0), but it seems to work just fine... - data_cache[bi, x, z] = data_ptr[indices[bi] * bs_k + z, m_index * bs_m + x] + if use_warp_storage: + ind = indices[bi] + else: + ind = indices[warp, bi] + data_cache[bi, x, z] = data_ptr[ind * bs_k + z, m_index * bs_m + x] # cache w_data elem_idx = bb * rowlength_bi + tx with ib.for_range(0, bs_n, name="y", kind="unroll") as y: with ib.for_range(0, bs_k, name="z", kind="unroll") as z: - w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] + if use_warp_storage: + w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] + else: + w_data_cache[warp, tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] with ib.for_range(0, mi, name="i") as i: # thread local block matmul with ib.for_range(0, bs_m, name="x", kind="unroll") as x: with ib.for_range(0, bs_n, name="y", kind="unroll") as y: with ib.for_range(0, bs_k, name="z", kind="unroll") as z: - block[x, y] += data_cache[i, x, z] * w_data_cache[i, y, z] + if use_warp_storage: + w = w_data_cache[i, y, z] + else: + w = w_data_cache[warp, i, y, z] + block[x, y] += data_cache[i, x, z] * w # store results with ib.for_range(0, bs_m, name="x", kind="unroll") as x: with ib.for_range(0, bs_n, name="y", kind="unroll") as y: @@ -391,11 +424,11 @@ def pad_sparse_matrix(matrix, blocksize): return sp.bsr_matrix((data, indices, indptr), matrix.shape) -@nn.sparse_dense_alter_layout.register(["cuda", "gpu"]) +@nn.sparse_dense_alter_layout.register(["cuda", "gpu", "rocm"]) def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type): """With cuda, we modify use alter_op_layout to swap the default sparse_dense implementation for one that operates on a padded matrix. We - also padd the matrix. + also pad the matrix. """ # TODO(ANSHUMAN87): Handle for sparse_lhs case too if ( diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index 500384b23f2a5..98a3ec86180c5 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -463,8 +463,8 @@ def check_device(device): check_device(device) -@tvm.testing.requires_cuda -def test_sparse_dense_padded_cuda(): +@tvm.testing.parametrize_targets("cuda", "rocm") +def test_sparse_dense_padded_gpu(target, dev): M = 128 N = 1280 K = 128 @@ -483,8 +483,7 @@ def test_sparse_dense_padded_cuda(): shape=W_sp_np_padded.indptr.shape, dtype=str(W_sp_np_padded.indptr.dtype) ) X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) - with tvm.target.Target("cuda"): - dev = tvm.device("gpu") + with tvm.target.Target(target): Y = topi.cuda.sparse_dense_padded(X, W_data, W_indices, W_indptr) s = topi.cuda.schedule_sparse_dense_padded([Y]) func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) @@ -499,9 +498,9 @@ def test_sparse_dense_padded_cuda(): tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5) -@tvm.testing.requires_cuda -def test_sparse_dense_padded_alter_op(): - with tvm.target.Target("cuda"): +@tvm.testing.parametrize_targets("cuda", "rocm") +def test_sparse_dense_padded_alter_op(target, dev): + with tvm.target.Target(target): M = 128 N = 16 K = 128 @@ -523,7 +522,7 @@ def test_sparse_dense_padded_alter_op(): # build with cuda and AlterOpLayout to ensure that sparse_dense_padded is in action with tvm.transform.PassContext(opt_level=3, required_pass="AlterOpLayout"): - x = relay.build(tvm.IRModule.from_expr(f), target=tvm.target.Target("cuda")) + x = relay.build(tvm.IRModule.from_expr(f), target=target) def test_sparse_add_csr():