diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index c7c0daa30e2fc..756f3724bd994 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -26,15 +26,12 @@ class Module: @T.prim_func def tvm_test_cpacked( - A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + A: T.Buffer[(1,), "float32"], + B: T.Buffer[(1,), "float32"], + C: T.Buffer[(1,), "float32"], + device_context: T.Buffer[(1,), "float32"], ) -> T.handle: - A_0 = T.match_buffer(A, (1,), dtype="float32") - A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") - B_0 = T.match_buffer(B, (1,), dtype="float32") - B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32") - C_0 = T.match_buffer(C, (1,), dtype="float32") - C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32") - T.evaluate(C) + T.evaluate(C.data) @T.prim_func def tir_packed_call() -> None: @@ -59,15 +56,12 @@ def tir_packed_call() -> None: class Expected: @T.prim_func def tvm_test_cpacked( - A: T.handle, B: T.handle, C: T.handle, device_context: T.handle + A: T.Buffer[(1,), "float32"], + B: T.Buffer[(1,), "float32"], + C: T.Buffer[(1,), "float32"], + device_context: T.handle, ) -> T.handle: - A_0 = T.match_buffer(A, (1,), dtype="float32") - A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32") - B_0 = T.match_buffer(B, (1,), dtype="float32") - B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32") - C_0 = T.match_buffer(C, (1,), dtype="float32") - C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32") - T.evaluate(C) + T.evaluate(C.data) @T.prim_func def tir_packed_call() -> None: diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index ce8675f575ee5..98eb24e85f745 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -74,11 +74,8 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_4, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_5, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(T_subtract_1, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): @@ -89,13 +86,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_65, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_66, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_67, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(T_cast_21, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): @@ -115,9 +108,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(placeholder_29, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) - T.preflattened_buffer(T_cast_7, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): @@ -164,13 +155,9 @@ def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr[T.uint8], slow_memory @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.Ptr[T.uint8], slow_memory_7_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8") - T.preflattened_buffer(placeholder_29, [802816], dtype="uint8") T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16") - T.preflattened_buffer(T_cast_7, [177], dtype="int16") fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(fast_memory_6_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(slow_memory_7_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.buffer_decl([200704], dtype="uint8") with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): @@ -185,15 +172,10 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.Ptr[T.uint8], slow_memory_3_var: T.Ptr[T.uint8]) -> None: placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8") - T.preflattened_buffer(placeholder_4, [150528], dtype="uint8") placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16") - T.preflattened_buffer(placeholder_5, [1], dtype="int16") T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16") - T.preflattened_buffer(T_subtract_1, [452], dtype="int16") fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(fast_memory_2_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(slow_memory_3_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") - placeholder_5[0] @@ -201,17 +183,11 @@ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.Ptr[T.uint8], slow_memory_5_var: T.Ptr[T.uint8]) -> None: placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16") - T.preflattened_buffer(placeholder_65, [150528], dtype="int16") placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16") - T.preflattened_buffer(placeholder_66, [9408], dtype="int16") placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32") - T.preflattened_buffer(placeholder_67, [64], dtype="int32") T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8") - T.preflattened_buffer(T_cast_21, [289], dtype="uint8") fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(fast_memory_4_buffer_var, [200704], dtype="uint8", strides=[1], elem_offset=0, align=16) slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(slow_memory_5_buffer_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.buffer_decl([157323], "int16") with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): @@ -275,11 +251,8 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") - T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") - T.preflattened_buffer(T_cast_1, [215], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -289,13 +262,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") - T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") - T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") - T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") - T.preflattened_buffer(T_cast_5, [215], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): @@ -314,13 +283,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") - T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") - T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") - T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") - T.preflattened_buffer(T_add_1, [407], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): @@ -340,15 +305,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") - T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") - T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") - T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") - T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") - T.preflattened_buffer(T_cast_7, [407], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): @@ -385,13 +345,9 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") - T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") - T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") - T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") - T.preflattened_buffer(T_cast_3, [215], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): @@ -413,13 +369,9 @@ class ResnetStructurePlanned: @T.prim_func def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.Ptr[T.uint8]) -> None: placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") - T.preflattened_buffer(placeholder_2, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") - T.preflattened_buffer(placeholder_3, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16") - T.preflattened_buffer(T_cast_1, [215], dtype="int16") global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_1_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16") @@ -427,17 +379,11 @@ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(p @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr[T.uint8]) -> None: placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") - T.preflattened_buffer(placeholder_29, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") - T.preflattened_buffer(placeholder_27, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") - T.preflattened_buffer(placeholder_26, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") - T.preflattened_buffer(placeholder_28, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [407], dtype="uint8") - T.preflattened_buffer(T_cast_7, [407], dtype="uint8") global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_5_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_3_let = T.buffer_decl([360000], 'int16') with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): @@ -457,15 +403,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.Ptr[T.uint8]) -> None: placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") - T.preflattened_buffer(placeholder_19, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") - T.preflattened_buffer(placeholder_20, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") - T.preflattened_buffer(placeholder_21, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [407], dtype="int32") - T.preflattened_buffer(T_add_1, [407], dtype="int32") global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_4_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_2_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): @@ -485,15 +426,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.Ptr[T.uint8]) -> None: placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") - T.preflattened_buffer(placeholder_7, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") - T.preflattened_buffer(placeholder_8, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") - T.preflattened_buffer(placeholder_9, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [215], dtype="int16") - T.preflattened_buffer(T_cast_3, [215], dtype="int16") global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_2_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_let = T.buffer_decl([360000], "int16") with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): @@ -512,15 +448,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.Ptr[T.uint8]) -> None: placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") - T.preflattened_buffer(placeholder_13, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") - T.preflattened_buffer(placeholder_14, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") - T.preflattened_buffer(placeholder_15, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [215], dtype="int16") - T.preflattened_buffer(T_cast_5, [215], dtype="int16") global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) - T.preflattened_buffer(global_workspace_3_buffer_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_1_let = T.buffer_decl([379456], "int16") with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 0610559a05d89..73be9d8cdc58f 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -636,27 +636,5 @@ def test_non_integer_typed_block_iter(): check_error(non_integer_typed_block_iter, 3) -def preflattened_buffer_map_align_nonint(foo: T.handle): - foo_1 = T.match_buffer(foo, [1]) - T.preflattened_buffer( - foo_1, [1], align="bar" - ) # check_error: align: want int or IntImm, got 'bar' - - -def test_preflattened_buffer_map_align(): - check_error(preflattened_buffer_map_align_nonint, 3) - - -def preflattened_buffer_map_offset_factor_nonint(foo: T.handle): - foo_1 = T.match_buffer(foo, [1]) - T.preflattened_buffer( - foo_1, [1], offset_factor="bar" - ) # check_error: offset_factor: want int or IntImm, got 'bar' - - -def test_preflattened_buffer_map_offset_factor(): - check_error(preflattened_buffer_map_offset_factor_nonint, 3) - - if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 4a2482c11d226..26a6f4530bda2 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -181,23 +181,6 @@ def test_dynamic_shape_gemm(): assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip) -@T.prim_func -def preflattened_buffer_map(A: T.handle, B: T.handle): - A_1 = T.match_buffer(A, [1]) - T.preflattened_buffer(A_1, [1], align=T.int32(1), offset_factor=T.int64(2)) - B_1 = T.match_buffer(B, [1]) - T.preflattened_buffer(B_1, [1]) - B_1[0] = A_1[0] - - -def test_preflattened_buffer_map(): - A_var = [ - k for k, _ in preflattened_buffer_map.preflattened_buffer_map.items() if k.name == "A" - ][0] - assert preflattened_buffer_map.preflattened_buffer_map[A_var].data_alignment == 1 - assert preflattened_buffer_map.preflattened_buffer_map[A_var].offset_factor == 2 - - @T.prim_func def match_buffer_int64(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (T.int64(128), T.int64(128)), dtype="float32")