Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] New TIR syntax printer failed to handle dynamic shape. #9953

Closed
yzh119 opened this issue Jan 17, 2022 · 1 comment
Closed

[Bug] New TIR syntax printer failed to handle dynamic shape. #9953

yzh119 opened this issue Jan 17, 2022 · 1 comment

Comments

@yzh119
Copy link
Member

yzh119 commented Jan 17, 2022

The current TIR syntax printer (introduced in #9680 ) fails when there are dynamic shapes in the script:

@T.prim_func
def f(a: T.handle, b: T.handle, c: T.handle):
    N = T.var("int32")
    M = T.var("int32")
    K = T.var("int32")
    A = T.match_buffer(a, (N, K), "float32")
    B = T.match_buffer(b, (K, M), "float32")
    C = T.match_buffer(c, (N, M), "float32")
    for i, j, k in T.grid(N, M, K):
        with T.block("gemm"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

print(f.script())

Expected behavior

The output script should be the same as input.

Actual behavior

The M, N, K are used before declaration.

# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(N, K), "float32"], B: T.Buffer[(K, M), "float32"], C: T.Buffer[(N, M), "float32"]) -> None:
    K = T.var("int32")
    M = T.var("int32")
    N = T.var("int32")
    # body
    # with T.block("root")
    for i, j, k in T.grid(N, M, K):
        with T.block("gemm"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
            T.writes(C[vi, vj])
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

The same case if I pass tensor shape as parameters:

@T.prim_func
def f(a: T.handle, b: T.handle, c: T.handle, N: T.int32, M: T.int32, K: T.int32):
    A = T.match_buffer(a, (N, K), "float32")
    B = T.match_buffer(b, (K, M), "float32")
    C = T.match_buffer(c, (N, M), "float32")
    for i, j, k in T.grid(N, M, K):
        with T.block("gemm"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
@junrushao
Copy link
Member

@yzh119 The bug is introduced in the printer because it didn't check whether the shape consists only of constant values.

CC @shingjan It's a printer bug we should fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants