From b09e72b54b028ebe8896afc605baac8abde4419e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Mar 2023 03:00:30 -0500 Subject: [PATCH] [TIR] Legalize dtype of constants in IndexMap (#14385) Previously, the legalization was only handled by propagating the dtype of the indices to the transformed indices. As a result, output indices whose value did not depend on the input index would be left with the incorrect dtype. --- .../primitive/layout_transformation.cc | 18 ++++++++-- .../test_tir_schedule_transform_layout.py | 36 +++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 2235af730214..bb2abc559d2c 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1095,8 +1095,17 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& Array initial_indices; Map var_map; + std::optional index_dtype = std::nullopt; for (size_t i = 0; i < args.size(); ++i) { + if (index_dtype.has_value()) { + ICHECK_EQ(*index_dtype, args[i]->dtype) + << "Buffer index " << args[i] << " has dtype " << args[i]->dtype + << ", but previous index for the same buffer access used index type " << *index_dtype; + } else { + index_dtype = args[i]->dtype; + } + if (args[i]->dtype != initial_indices_orig[i].dtype()) { auto new_idx = Var(initial_indices_orig[i]->name_hint, args[i]->dtype); initial_indices.push_back(new_idx); @@ -1108,8 +1117,13 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& if (!var_map.empty()) { auto final_indices = index_map->final_indices.Map([&](PrimExpr index) { - return SubstituteWithDataTypeLegalization(index, - [&](const Var& var) { return var_map.Get(var); }); + if (auto* ptr = index.as()) { + ICHECK(index_dtype.has_value()); + return tir::make_const(*index_dtype, ptr->value); + } else { + return SubstituteWithDataTypeLegalization(index, + [&](const Var& var) { return var_map.Get(var); }); + } }); Optional opt_inverse_index_map = Downcast>(index_map->inverse_index_map); diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index c9a8f70ef7b3..8de11d8bd519 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -1049,5 +1049,41 @@ def func(A: T.Buffer(T.int64(58), "int32")): ) +def test_index_map_dtype_legalize_with_constant(): + """Legalization of inverse containing a constant output + + The index map `lambda i,j: [i, j//8, j % 8]` has an inverse `lambda i,j,k: [i, 8*j+k]`. + """ + + @T.prim_func + def func(A: T.Buffer(T.int64(16), "int32")): + for i in T.grid(T.int64(16)): + with T.block("block"): + vi = T.axis.remap("S", [i]) + A[vi] = 0 + + sch = tir.Schedule(func) + + # Triggering the error requires an IndexMap that introduces padding + func = lambda i: [ + # And a constant to be one of the output indices. + tir.const(0, i.dtype), + (i + 1) // 8, + (i + 1) % 8, + ] + + # Previously, the legalization was only handled by propagating the + # dtype of the indices to the transformed indices. As a result, + # output indices whose value did not depend on the input index + # would be left with the incorrect dtype. + + # Prior to the bugfix, this resulted in the following error is + # raised from the IterVar constructor. + # + # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) : + # The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32) + sch.transform_layout(block="block", buffer="A", index_map=func, pad_value=0) + + if __name__ == "__main__": tvm.testing.main()