Skip to content

Commit

Permalink
[TIR] Legalize dtype of constants in IndexMap (#14385)
Browse files Browse the repository at this point in the history
Previously, the legalization was only handled by propagating the dtype
of the indices to the transformed indices.  As a result, output
indices whose value did not depend on the input index would be left
with the incorrect dtype.
  • Loading branch information
Lunderberg authored Mar 24, 2023
1 parent ad6fbec commit b09e72b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,17 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>&

Array<Var> initial_indices;
Map<Var, PrimExpr> var_map;
std::optional<DataType> index_dtype = std::nullopt;

for (size_t i = 0; i < args.size(); ++i) {
if (index_dtype.has_value()) {
ICHECK_EQ(*index_dtype, args[i]->dtype)
<< "Buffer index " << args[i] << " has dtype " << args[i]->dtype
<< ", but previous index for the same buffer access used index type " << *index_dtype;
} else {
index_dtype = args[i]->dtype;
}

if (args[i]->dtype != initial_indices_orig[i].dtype()) {
auto new_idx = Var(initial_indices_orig[i]->name_hint, args[i]->dtype);
initial_indices.push_back(new_idx);
Expand All @@ -1108,8 +1117,13 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>&

if (!var_map.empty()) {
auto final_indices = index_map->final_indices.Map([&](PrimExpr index) {
return SubstituteWithDataTypeLegalization(index,
[&](const Var& var) { return var_map.Get(var); });
if (auto* ptr = index.as<IntImmNode>()) {
ICHECK(index_dtype.has_value());
return tir::make_const(*index_dtype, ptr->value);
} else {
return SubstituteWithDataTypeLegalization(index,
[&](const Var& var) { return var_map.Get(var); });
}
});
Optional<IndexMap> opt_inverse_index_map =
Downcast<Optional<IndexMap>>(index_map->inverse_index_map);
Expand Down
36 changes: 36 additions & 0 deletions tests/python/unittest/test_tir_schedule_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,5 +1049,41 @@ def func(A: T.Buffer(T.int64(58), "int32")):
)


def test_index_map_dtype_legalize_with_constant():
"""Legalization of inverse containing a constant output
The index map `lambda i,j: [i, j//8, j % 8]` has an inverse `lambda i,j,k: [i, 8*j+k]`.
"""

@T.prim_func
def func(A: T.Buffer(T.int64(16), "int32")):
for i in T.grid(T.int64(16)):
with T.block("block"):
vi = T.axis.remap("S", [i])
A[vi] = 0

sch = tir.Schedule(func)

# Triggering the error requires an IndexMap that introduces padding
func = lambda i: [
# And a constant to be one of the output indices.
tir.const(0, i.dtype),
(i + 1) // 8,
(i + 1) % 8,
]

# Previously, the legalization was only handled by propagating the
# dtype of the indices to the transformed indices. As a result,
# output indices whose value did not depend on the input index
# would be left with the incorrect dtype.

# Prior to the bugfix, this resulted in the following error is
# raised from the IterVar constructor.
#
# TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) :
# The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32)
sch.transform_layout(block="block", buffer="A", index_map=func, pad_value=0)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit b09e72b

Please sign in to comment.