Skip to content

Commit

Permalink
fix (apache#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored and jinhongyii committed Jun 16, 2022
1 parent c723660 commit f7002c5
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 755 deletions.
241 changes: 78 additions & 163 deletions python/tvm/meta_schedule/testing/tir_tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,94 +95,16 @@ def dot_product_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
)


# @T.prim_func
# def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
# A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_a")
# B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_b")
# C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=1, scope="wmma.accumulator")

# with T.block("root"):
# for i, j, k in T.grid(16, 16, 16):
# with T.block("update"):
# vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
# C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], "float32") * T.cast(
# B[vkk, vjj], "float32"
# )


# @T.prim_func
# def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
# A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")
# B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")
# C = T.match_buffer(
# c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
# )

# with T.block("root"):
# T.reads(
# [
# C[0:16, 0:16],
# A[0:16, 0:16],
# B[0:16, 0:16],
# ]
# )
# T.writes(C[0:16, 0:16])
# T.evaluate(
# T.tvm_mma_sync(
# C.data,
# C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
# A.data,
# A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),
# B.data,
# B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16),
# C.data,
# C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
# dtype="handle",
# )
# )

@T.prim_func
def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_a")
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_b")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16,
scope="wmma.accumulator")

with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j, k in T.grid(16, 16, 16):
with T.block(""):
vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], 'float32') * T.cast(B[vjj, vkk], 'float32')


@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16,
scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_a")
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_b")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16,
scope="wmma.accumulator")

with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16])
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.data, A.elem_offset // 256,
B.data, B.elem_offset // 256,
C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), dtype='handle'))


@T.prim_func
def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("load"):
vii, vjj = T.axis.remap("SS", [i, j])
Expand All @@ -193,34 +115,27 @@ def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:
def wmma_load_a_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
A = T.match_buffer(
a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]
)
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0])
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_load_matrix_sync(
C.data,
16,
16,
16,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.access_ptr("r"),
s1,
"row_major",
dtype="handle",
)
)
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_load_matrix_sync(
C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "row_major",
dtype="handle"))


@T.prim_func
def wmma_load_b_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16,
scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_b")

with T.block("root"):
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j in T.grid(16, 16):
with T.block("load"):
vii, vjj = T.axis.remap("SS", [i, j])
Expand All @@ -231,34 +146,60 @@ def wmma_load_b_desc(a: T.handle, c: T.handle) -> None:
def wmma_load_b_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
A = T.match_buffer(
a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]
)
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0])
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")

with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_load_matrix_sync(
C.data,
16,
16,
16,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.access_ptr("r"),
s1,
"row_major",
dtype="handle",
)
)
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_load_matrix_sync(
C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "col_major",
dtype="handle"))


@T.prim_func
def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_a")
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_b")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16,
scope="wmma.accumulator")

with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j, k in T.grid(16, 16, 16):
with T.block(""):
vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], 'float32') * T.cast(B[vjj, vkk], 'float32')


@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_a")
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16,
scope="wmma.matrix_b")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16,
scope="wmma.accumulator")

with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.data, A.elem_offset // 256,
B.data, B.elem_offset // 256,
C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), dtype='handle'))


@T.prim_func
def wmma_fill_desc(c: T.handle) -> None:
C = T.match_buffer(
c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator")

with T.block("root"):
T.reads()
T.writes(C[0 : 16, 0 : 16])
for i, j in T.grid(16, 16):
with T.block("init"):
vii, vjj = T.axis.remap("SS", [i, j])
Expand All @@ -267,32 +208,20 @@ def wmma_fill_desc(c: T.handle) -> None:

@T.prim_func
def wmma_fill_impl(c: T.handle) -> None:
C = T.match_buffer(
c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator")
with T.block("root"):
T.reads([])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_fill_fragment(
C.data,
16,
16,
16,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
T.float32(0),
dtype="handle",
)
)
T.reads()
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), T.float32(0), dtype="handle"))


@T.prim_func
def wmma_store_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
)
A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global")
with T.block("root"):
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j in T.grid(16, 16):
with T.block("store"):
vii, vjj = T.axis.remap("SS", [i, j])
Expand All @@ -303,28 +232,14 @@ def wmma_store_desc(a: T.handle, c: T.handle) -> None:
def wmma_store_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
A = T.match_buffer(
a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(
c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s1, s0]
)
A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator")
C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s1, s0])
with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_store_matrix_sync(
A.data,
16,
16,
16,
A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),
C.access_ptr("w"),
s1,
"row_major",
dtype="handle",
)
)
T.reads(A[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(T.tvm_store_matrix_sync(
A.data, 16, 16, 16, A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), C.access_ptr("w"), s1, "row_major",
dtype="handle"))


# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks
Expand Down
Loading

0 comments on commit f7002c5

Please sign in to comment.