Skip to content

Commit

Permalink
Removing more usage of preflattened from python files
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Apr 19, 2022
1 parent f1579c7 commit 7062789
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 124 deletions.
26 changes: 10 additions & 16 deletions tests/python/unittest/test_aot_legalize_packed_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@
class Module:
@T.prim_func
def tvm_test_cpacked(
A: T.handle, B: T.handle, C: T.handle, device_context: T.handle
A: T.Buffer[(1,), "float32"],
B: T.Buffer[(1,), "float32"],
C: T.Buffer[(1,), "float32"],
device_context: T.Buffer[(1,), "float32"],
) -> T.handle:
A_0 = T.match_buffer(A, (1,), dtype="float32")
A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32")
B_0 = T.match_buffer(B, (1,), dtype="float32")
B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32")
C_0 = T.match_buffer(C, (1,), dtype="float32")
C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32")
T.evaluate(C)
T.evaluate(C.data)

@T.prim_func
def tir_packed_call() -> None:
Expand All @@ -59,15 +56,12 @@ def tir_packed_call() -> None:
class Expected:
@T.prim_func
def tvm_test_cpacked(
A: T.handle, B: T.handle, C: T.handle, device_context: T.handle
A: T.Buffer[(1,), "float32"],
B: T.Buffer[(1,), "float32"],
C: T.Buffer[(1,), "float32"],
device_context: T.handle,
) -> T.handle:
A_0 = T.match_buffer(A, (1,), dtype="float32")
A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32")
B_0 = T.match_buffer(B, (1,), dtype="float32")
B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32")
C_0 = T.match_buffer(C, (1,), dtype="float32")
C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32")
T.evaluate(C)
T.evaluate(C.data)

@T.prim_func
def tir_packed_call() -> None:
Expand Down
Loading

0 comments on commit 7062789

Please sign in to comment.