Skip to content

Commit

Permalink
require ampere in test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2022
1 parent 9fad599 commit b0b3a40
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
2 changes: 1 addition & 1 deletion python/tvm/testing/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def mma_schedule(
mma_store_intrin,
shared_scope="shared",
):
"""Create a tensorized schedule for 4k GEMM with MMA intrinsics."""
"""Create a tensorized schedule for GEMM with MMA intrinsics."""
ir_module = tvm.IRModule({"main": workload})
sch = tvm.tir.Schedule(ir_module)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1034,10 +1034,16 @@ def test_error_missing_annotation():
_check_error(simple_compute_missing_annotation)


@tvm.testing.requires_cuda
def test_three_stage_gemm():
N = K = M = 4096
i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1]

def is_ampere_or_newer():
arch = tvm.contrib.nvcc.get_target_compute_version()
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
return major >= 8

def index_map(i, j):
return (
i // 16,
Expand Down Expand Up @@ -1071,17 +1077,18 @@ def index_map(i, j):
sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3])
sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2])

f = tvm.build(sch.mod["main"], target="cuda")

dev = tvm.device("cuda", 0)
a_np = np.random.uniform(size=(N, K)).astype("float16")
b_np = np.random.uniform(size=(K, M)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev)
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
if is_ampere_or_newer():
f = tvm.build(sch.mod["main"], target="cuda")

dev = tvm.device("cuda", 0)
a_np = np.random.uniform(size=(N, K)).astype("float16")
b_np = np.random.uniform(size=(K, M)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev)
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)


if __name__ == "__main__":
Expand Down

0 comments on commit b0b3a40

Please sign in to comment.