Skip to content

Commit

Permalink
[TIR] Support affine expressions as indices in reverse compute inline
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed May 13, 2022
1 parent be65732 commit 5186f7c
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 58 deletions.
172 changes: 114 additions & 58 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ namespace tir {
static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of
'A[i, j, k, ...] = f(i, j, k, ...)',
where the indices on the left are distinct atomic variables,
and there should not no variables other than the index variables)";
and there should be no variables other than the index variables)";

static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of
`B[...] = g(i, j, k, A[i, j, k, ...] ...)`,
`B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
where A is the only buffer the block consumes, whose indices are distinct atomic variables,
and there should not no variables other than the index variables)";
and there should be no variables other than the index variables), and f is a bijective affine
mapping)";

class HasInitBlock : public ScheduleError {
public:
Expand Down Expand Up @@ -257,57 +258,6 @@ class BaseInliner : public StmtExprMutator {
return std::move(tgt_block);
}

/*!
* \brief Check if the indices are atomic distinct variables and the access is n-dimensional.
* If so, set `self->idx_vars_` properly.
* \param indices The indices to be extracted
* \param expected_ndim The expected ndim of the access
* \return A boolean flag indicating if the check is successful
*/
bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int expected_ndim) {
int n = indices.size();
if (n != expected_ndim) {
// Failure: dimension mismatch
return false;
}
std::vector<const VarNode*> result;
result.reserve(n);
for (const PrimExpr& i : indices) {
if (const auto* var = i.as<VarNode>()) {
result.push_back(var);
} else {
// Failure: indexing expression is not a variable
return false;
}
}
using DistinctSet = std::unordered_set<const VarNode*>;
int n_distinct = DistinctSet(result.begin(), result.end()).size();
if (n != n_distinct) {
// Failure: indexing variables are not distinct
return false;
}
if (idx_vars_.empty()) {
idx_vars_ = std::move(result);
} else if (!support::ArrayWithSameContent(idx_vars_, result)) {
// Failure: indexing variables are not consitent in different BufferLoads
return false;
}
return true;
}

/*!
* \brief Set the mapping of index substitution `self->idx_sub_`
* \param indices The expressions that the corresponding index variables are replaced to
*/
void SetIndexSubstitution(const Array<PrimExpr>& indices) {
ICHECK_EQ(indices.size(), idx_vars_.size());
int n = idx_vars_.size();
idx_sub_.reserve(n);
for (int i = 0; i < n; ++i) {
idx_sub_[idx_vars_[i]] = indices[i];
}
}

/*!
* \brief Count the number of undefined variables that are not used
* as buffer objects.
Expand Down Expand Up @@ -490,6 +440,57 @@ class ComputeInliner : public BaseInliner {
SetIndexSubstitution(load->indices);
return Substitute(inlined_store_->value, idx_sub_);
}

/*!
* \brief Check if the indices are atomic distinct variables and the access is n-dimensional.
* If so, set `self->idx_vars_` properly.
* \param indices The indices to be extracted
* \param expected_ndim The expected ndim of the access
* \return A boolean flag indicating if the check is successful
*/
bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int expected_ndim) {
int n = indices.size();
if (n != expected_ndim) {
// Failure: dimension mismatch
return false;
}
std::vector<const VarNode*> result;
result.reserve(n);
for (const PrimExpr& i : indices) {
if (const auto* var = i.as<VarNode>()) {
result.push_back(var);
} else {
// Failure: indexing expression is not a variable
return false;
}
}
using DistinctSet = std::unordered_set<const VarNode*>;
int n_distinct = DistinctSet(result.begin(), result.end()).size();
if (n != n_distinct) {
// Failure: indexing variables are not distinct
return false;
}
if (idx_vars_.empty()) {
idx_vars_ = std::move(result);
} else if (!support::ArrayWithSameContent(idx_vars_, result)) {
// Failure: indexing variables are not consitent in different BufferLoads
return false;
}
return true;
}

/*!
* \brief Set the mapping of index substitution `self->idx_sub_`
* \param indices The expressions that the corresponding index variables are replaced to
*/
void SetIndexSubstitution(const Array<PrimExpr>& indices) {
ICHECK_EQ(indices.size(), idx_vars_.size());
int n = idx_vars_.size();
idx_sub_.reserve(n);
for (int i = 0; i < n; ++i) {
idx_sub_[idx_vars_[i]] = indices[i];
}
}
};

/*!
Expand Down Expand Up @@ -534,13 +535,33 @@ class ReverseComputeInliner : public BaseInliner {
// Failure: no BufferLoad from the `inlined_buffer_`
return false;
}
int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));

// Collect block iter domains and update the substition map
Map<Var, Range> consumer_iter_doms;
for (const auto& iter_var : consumer_block->iter_vars) {
consumer_iter_doms.Set(iter_var->var, iter_var->dom);
// Set default mapping for unit iters
if (is_const_int(iter_var->dom->extent, 1) && is_const_int(iter_var->dom->min)) {
idx_sub_[iter_var->var.get()] = iter_var->dom->min;
}
}

for (const BufferLoadNode* load : loads) {
if (!UpdateAndCheckIndexVars(load->indices, n_vars)) {
// Failure: incorrect of inconsistent index vars
if (!UpdateAndCheckIndexExprs(load->indices)) {
return false;
}
}

buffer_load_iter_map_ = arith::DetectIterMap(
/*indices=*/buffer_load_indices_,
/*input_iters=*/consumer_iter_doms,
/*predicate=*/true,
/*require_bijective=*/true,
/*analyzer=*/&analyzer);
if (buffer_load_iter_map_.empty()) {
// Failure: indices of BufferLoad are not bijective affine
return false;
}
return true;
}

Expand All @@ -556,8 +577,20 @@ class ReverseComputeInliner : public BaseInliner {
return ReplaceInlinedBuffer(std::move(store));
}

/*!
* \brief Apply the inverse of `buffer_load_iter_map_` to producer indices. Update `idx_sub_` with
* the result. It will be later used to transform the BufferStore indices of the producer.
* \param producer_indices The BufferStore indices of the producer.
*/
void CreateInverseMapping(const Array<PrimExpr> producer_indices) {
auto inverse_iter_map = arith::InverseAffineIterMap(buffer_load_iter_map_, producer_indices);
for (const auto& pair : inverse_iter_map) {
idx_sub_[pair.first.get()] = pair.second;
}
}

Stmt ReplaceInlinedBuffer(BufferStore producer) {
SetIndexSubstitution(producer->indices);
CreateInverseMapping(producer->indices);
producer_rhs_ = producer->value;
return Substituter(this)(GetRef<BufferStore>(inlined_store_));
}
Expand Down Expand Up @@ -588,8 +621,31 @@ class ReverseComputeInliner : public BaseInliner {
return std::move(extractor.result);
}

/*!
* \brief Update `buffer_load_indices_` with the given indices. If `buffer_load_indices_` is
* already non-empty, check it is consistent with the given indices.
* \param indices The indices
* \param expected_ndim The expected ndim of the access
* \return A boolean flag indicating if the check is successful
*/
bool UpdateAndCheckIndexExprs(const Array<PrimExpr>& indices) {
if (buffer_load_indices_.empty()) {
buffer_load_indices_ = indices;
} else if (!std::equal(buffer_load_indices_.begin(), buffer_load_indices_.end(), indices.begin(), indices.end(), ExprDeepEqual())) {
// Failure: indices are not consistent in different BufferLoads
return false;
}
return true;
}

/*! \brief The RHS value of the producer's BufferStore statement */
PrimExpr producer_rhs_{nullptr};
/*! \brief The indices of the consumer's BufferLoad */
Array<PrimExpr> buffer_load_indices_;
/*! \brief The IterMap representing the indices of the consumer's BufferLoad */
Array<arith::IterSumExpr> buffer_load_iter_map_{nullptr};
/*! \brief The arithmetic analyzer */
arith::Analyzer analyzer;
};

void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
Expand Down
140 changes: 140 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,112 @@ def elementwise_multi_reverse_loads_inlined(a: T.handle, c: T.handle) -> None:
C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0


@T.prim_func
def elementwise_reverse_affine_load(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 32, 8, 8), "float32"]
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j, k, l in T.grid(8, 32, 8, 8):
with T.block("C"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
C[vi, vj, vk, vl] = B[
((((vi * 32) + vj) * 8 + vk) * 8 + vl) // 128,
((((vi * 32) + vj) * 8 + vk) * 8 + vl) % 128,
]


@T.prim_func
def elementwise_reverse_affine_load_inlined(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 32, 8, 8), "float32"]
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
C[
(vj + vi * 128) // 2048,
(vj + vi * 128) // 64 % 32,
((vj + vi * 128) // 8) % 8,
(vj + vi * 128) % 8,
] = (
A[vi, vj] * 2.0
)


@T.prim_func
def elementwise_reverse_affine_load_unit_iter(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(8, 16, 1), "float32"],
D: T.Buffer[(1, 8, 16, 128), "float32"],
) -> None:
C = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] * 2.0
for i, j, k in T.grid(8, 16, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
D[0, vi, vj, vk] = C[vi * 16 + vj, vk] + B[vi, vj, 0]


@T.prim_func
def elementwise_reverse_affine_load_unit_iter_inlined(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(8, 16, 1), "float32"],
D: T.Buffer[(1, 8, 16, 128), "float32"],
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
D[0, vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + B[vi // 16, vi % 16, 0]


@T.prim_func
def elementwise_multi_reverse_affine_load(
A: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(8, 16, 128), "float32"],
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j, k in T.grid(8, 16, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[vi * 16 + vj, vk] + B[vi * 16 + vj, vk]


@T.prim_func
def elementwise_multi_reverse_affine_load_inlined(
A: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(8, 16, 128), "float32"],
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + A[vi, vj] * 2.0


@T.prim_func
def elementwise_reverse_non_affine_load(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 16, 128), "float32"]
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j, k in T.grid(8, 16, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[vi * 16 + vj, vi * 16 + vj]


@T.prim_func
def opaque_access_load(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
Expand Down Expand Up @@ -520,6 +626,40 @@ def test_reverse_compute_multi_reverse_loads():
verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_loads)


def test_reverse_compute_inline_affine_load():
sch = tir.Schedule(elementwise_reverse_affine_load, debug_mask="all")
block_c = sch.get_block("C")
sch.reverse_compute_inline(block_c)
tvm.ir.assert_structural_equal(elementwise_reverse_affine_load_inlined, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load)


def test_reverse_compute_inline_multi_affine_load():
sch = tir.Schedule(elementwise_multi_reverse_affine_load, debug_mask="all")
block_c = sch.get_block("C")
sch.reverse_compute_inline(block_c)
tvm.ir.assert_structural_equal(elementwise_multi_reverse_affine_load_inlined, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_affine_load)


def test_reverse_compute_inline_affine_load_unit_iter():
sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter, debug_mask="all")
block_c = sch.get_block("C")
sch.reverse_compute_inline(block_c)
print(sch.mod.script())
tvm.ir.assert_structural_equal(
elementwise_reverse_affine_load_unit_iter_inlined, sch.mod["main"]
)
verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter)


def test_reverse_compute_fail_non_affine_load():
sch = tir.Schedule(elementwise_reverse_non_affine_load, debug_mask="all")
block_c = sch.get_block("C")
with pytest.raises(tvm.tir.ScheduleError):
sch.reverse_compute_inline(block_c)


def test_reverse_compute_fail_multi_reverse_loads():
sch = tir.Schedule(elementwise_multi_loads, debug_mask="all")
block_c = sch.get_block("C")
Expand Down

0 comments on commit 5186f7c

Please sign in to comment.