We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The current TIR syntax printer (introduced in #9680 ) fails when there are dynamic shapes in the script:
@T.prim_func def f(a: T.handle, b: T.handle, c: T.handle): N = T.var("int32") M = T.var("int32") K = T.var("int32") A = T.match_buffer(a, (N, K), "float32") B = T.match_buffer(b, (K, M), "float32") C = T.match_buffer(c, (N, M), "float32") for i, j, k in T.grid(N, M, K): with T.block("gemm"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0. C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] print(f.script())
The output script should be the same as input.
The M, N, K are used before declaration.
M, N, K
# from tvm.script import tir as T @T.prim_func def func(A: T.Buffer[(N, K), "float32"], B: T.Buffer[(K, M), "float32"], C: T.Buffer[(N, M), "float32"]) -> None: K = T.var("int32") M = T.var("int32") N = T.var("int32") # body # with T.block("root") for i, j, k in T.grid(N, M, K): with T.block("gemm"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads(C[vi, vj], A[vi, vk], B[vk, vj]) T.writes(C[vi, vj]) with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
The same case if I pass tensor shape as parameters:
@T.prim_func def f(a: T.handle, b: T.handle, c: T.handle, N: T.int32, M: T.int32, K: T.int32): A = T.match_buffer(a, (N, K), "float32") B = T.match_buffer(b, (K, M), "float32") C = T.match_buffer(c, (N, M), "float32") for i, j, k in T.grid(N, M, K): with T.block("gemm"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0. C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
The text was updated successfully, but these errors were encountered:
@yzh119 The bug is introduced in the printer because it didn't check whether the shape consists only of constant values.
CC @shingjan It's a printer bug we should fix.
Sorry, something went wrong.
No branches or pull requests
The current TIR syntax printer (introduced in #9680 ) fails when there are dynamic shapes in the script:
Expected behavior
The output script should be the same as input.
Actual behavior
The
M, N, K
are used before declaration.The same case if I pass tensor shape as parameters:
The text was updated successfully, but these errors were encountered: