Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Support affine expressions as indices in reverse compute inline #11317

Merged
merged 2 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 116 additions & 58 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ namespace tir {
static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of
'A[i, j, k, ...] = f(i, j, k, ...)',
where the indices on the left are distinct atomic variables,
and there should not no variables other than the index variables)";
and there should be no variables other than the index variables)";

static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of
`B[...] = g(i, j, k, A[i, j, k, ...] ...)`,
`B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
where A is the only buffer the block consumes, whose indices are distinct atomic variables,
and there should not no variables other than the index variables)";
and there should be no variables other than the index variables), and f is a bijective affine
mapping)";

class HasInitBlock : public ScheduleError {
public:
Expand Down Expand Up @@ -257,57 +258,6 @@ class BaseInliner : public StmtExprMutator {
return std::move(tgt_block);
}

/*!
* \brief Check if the indices are atomic distinct variables and the access is n-dimensional.
* If so, set `self->idx_vars_` properly.
* \param indices The indices to be extracted
* \param expected_ndim The expected ndim of the access
* \return A boolean flag indicating if the check is successful
*/
bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int expected_ndim) {
int n = indices.size();
if (n != expected_ndim) {
// Failure: dimension mismatch
return false;
}
std::vector<const VarNode*> result;
result.reserve(n);
for (const PrimExpr& i : indices) {
if (const auto* var = i.as<VarNode>()) {
result.push_back(var);
} else {
// Failure: indexing expression is not a variable
return false;
}
}
using DistinctSet = std::unordered_set<const VarNode*>;
int n_distinct = DistinctSet(result.begin(), result.end()).size();
if (n != n_distinct) {
// Failure: indexing variables are not distinct
return false;
}
if (idx_vars_.empty()) {
idx_vars_ = std::move(result);
} else if (!support::ArrayWithSameContent(idx_vars_, result)) {
// Failure: indexing variables are not consitent in different BufferLoads
return false;
}
return true;
}

/*!
* \brief Set the mapping of index substitution `self->idx_sub_`
* \param indices The expressions that the corresponding index variables are replaced to
*/
void SetIndexSubstitution(const Array<PrimExpr>& indices) {
ICHECK_EQ(indices.size(), idx_vars_.size());
int n = idx_vars_.size();
idx_sub_.reserve(n);
for (int i = 0; i < n; ++i) {
idx_sub_[idx_vars_[i]] = indices[i];
}
}

/*!
* \brief Count the number of undefined variables that are not used
* as buffer objects.
Expand Down Expand Up @@ -490,6 +440,57 @@ class ComputeInliner : public BaseInliner {
SetIndexSubstitution(load->indices);
return Substitute(inlined_store_->value, idx_sub_);
}

/*!
* \brief Check if the indices are atomic distinct variables and the access is n-dimensional.
* If so, set `self->idx_vars_` properly.
* \param indices The indices to be extracted
* \param expected_ndim The expected ndim of the access
* \return A boolean flag indicating if the check is successful
*/
bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int expected_ndim) {
int n = indices.size();
if (n != expected_ndim) {
// Failure: dimension mismatch
return false;
}
std::vector<const VarNode*> result;
result.reserve(n);
for (const PrimExpr& i : indices) {
if (const auto* var = i.as<VarNode>()) {
result.push_back(var);
} else {
// Failure: indexing expression is not a variable
return false;
}
}
using DistinctSet = std::unordered_set<const VarNode*>;
int n_distinct = DistinctSet(result.begin(), result.end()).size();
if (n != n_distinct) {
// Failure: indexing variables are not distinct
return false;
}
if (idx_vars_.empty()) {
idx_vars_ = std::move(result);
} else if (!support::ArrayWithSameContent(idx_vars_, result)) {
// Failure: indexing variables are not consitent in different BufferLoads
return false;
}
return true;
}

/*!
* \brief Set the mapping of index substitution `self->idx_sub_`
* \param indices The expressions that the corresponding index variables are replaced to
*/
void SetIndexSubstitution(const Array<PrimExpr>& indices) {
ICHECK_EQ(indices.size(), idx_vars_.size());
int n = idx_vars_.size();
idx_sub_.reserve(n);
for (int i = 0; i < n; ++i) {
idx_sub_[idx_vars_[i]] = indices[i];
}
}
};

/*!
Expand Down Expand Up @@ -534,13 +535,34 @@ class ReverseComputeInliner : public BaseInliner {
// Failure: no BufferLoad from the `inlined_buffer_`
return false;
}
int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));

// Collect block iter domains and update the substition map
Map<Var, Range> consumer_iter_doms;
for (const auto& iter_var : consumer_block->iter_vars) {
consumer_iter_doms.Set(iter_var->var, iter_var->dom);
// Set default mapping for unit iters
if (is_const_int(iter_var->dom->extent, 1) && is_const_int(iter_var->dom->min)) {
idx_sub_[iter_var->var.get()] = iter_var->dom->min;
}
}

for (const BufferLoadNode* load : loads) {
if (!UpdateAndCheckIndexVars(load->indices, n_vars)) {
// Failure: incorrect of inconsistent index vars
if (!UpdateAndCheckIndexExprs(load->indices)) {
return false;
}
}

buffer_load_iter_map_ = arith::DetectIterMap(
/*indices=*/buffer_load_indices_,
/*input_iters=*/consumer_iter_doms,
/*predicate=*/true,
/*require_bijective=*/true,
/*analyzer=*/&analyzer,
/*simplify_trivial_iterators=*/false);
if (buffer_load_iter_map_.empty()) {
// Failure: indices of BufferLoad are not bijective affine
return false;
}
return true;
}

Expand All @@ -556,8 +578,20 @@ class ReverseComputeInliner : public BaseInliner {
return ReplaceInlinedBuffer(std::move(store));
}

/*!
* \brief Apply the inverse of `buffer_load_iter_map_` to producer indices. Update `idx_sub_` with
* the result. It will be later used to transform the BufferStore indices of the producer.
* \param producer_indices The BufferStore indices of the producer.
*/
void CreateInverseMapping(const Array<PrimExpr> producer_indices) {
auto inverse_iter_map = arith::InverseAffineIterMap(buffer_load_iter_map_, producer_indices);
for (const auto& pair : inverse_iter_map) {
idx_sub_[pair.first.get()] = pair.second;
}
}

Stmt ReplaceInlinedBuffer(BufferStore producer) {
SetIndexSubstitution(producer->indices);
CreateInverseMapping(producer->indices);
producer_rhs_ = producer->value;
return Substituter(this)(GetRef<BufferStore>(inlined_store_));
}
Expand Down Expand Up @@ -588,8 +622,32 @@ class ReverseComputeInliner : public BaseInliner {
return std::move(extractor.result);
}

/*!
* \brief Update `buffer_load_indices_` with the given indices. If `buffer_load_indices_` is
* already non-empty, check it is consistent with the given indices.
* \param indices The indices
* \param expected_ndim The expected ndim of the access
* \return A boolean flag indicating if the check is successful
*/
bool UpdateAndCheckIndexExprs(const Array<PrimExpr>& indices) {
if (buffer_load_indices_.empty()) {
buffer_load_indices_ = indices;
} else if (!std::equal(buffer_load_indices_.begin(), buffer_load_indices_.end(),
indices.begin(), indices.end(), ExprDeepEqual())) {
// Failure: indices are not consistent in different BufferLoads
return false;
}
return true;
}

/*! \brief The RHS value of the producer's BufferStore statement */
PrimExpr producer_rhs_{nullptr};
/*! \brief The indices of the consumer's BufferLoad */
Array<PrimExpr> buffer_load_indices_;
/*! \brief The IterMap representing the indices of the consumer's BufferLoad */
Array<arith::IterSumExpr> buffer_load_iter_map_{nullptr};
/*! \brief The arithmetic analyzer */
arith::Analyzer analyzer;
};

void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
Expand Down
140 changes: 140 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,112 @@ def elementwise_multi_reverse_loads_inlined(a: T.handle, c: T.handle) -> None:
C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0


@T.prim_func
def elementwise_reverse_affine_load(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 32, 8, 8), "float32"]
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j, k, l in T.grid(8, 32, 8, 8):
with T.block("C"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
C[vi, vj, vk, vl] = B[
((((vi * 32) + vj) * 8 + vk) * 8 + vl) // 128,
((((vi * 32) + vj) * 8 + vk) * 8 + vl) % 128,
]


@T.prim_func
def elementwise_reverse_affine_load_inlined(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 32, 8, 8), "float32"]
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
C[
(vj + vi * 128) // 2048,
(vj + vi * 128) // 64 % 32,
((vj + vi * 128) // 8) % 8,
(vj + vi * 128) % 8,
] = (
A[vi, vj] * 2.0
)


@T.prim_func
def elementwise_reverse_affine_load_unit_iter(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(8, 16, 1), "float32"],
D: T.Buffer[(1, 8, 16, 128), "float32"],
) -> None:
C = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] * 2.0
for i, j, k in T.grid(8, 16, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
D[0, vi, vj, vk] = C[vi * 16 + vj, vk] + B[vi, vj, 0]


@T.prim_func
def elementwise_reverse_affine_load_unit_iter_inlined(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(8, 16, 1), "float32"],
D: T.Buffer[(1, 8, 16, 128), "float32"],
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
D[0, vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + B[vi // 16, vi % 16, 0]


@T.prim_func
def elementwise_multi_reverse_affine_load(
A: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(8, 16, 128), "float32"],
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j, k in T.grid(8, 16, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[vi * 16 + vj, vk] + B[vi * 16 + vj, vk]


@T.prim_func
def elementwise_multi_reverse_affine_load_inlined(
A: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(8, 16, 128), "float32"],
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + A[vi, vj] * 2.0


@T.prim_func
def elementwise_reverse_non_affine_load(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 16, 128), "float32"]
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j, k in T.grid(8, 16, 128):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[vi * 16 + vj, vi * 16 + vj]


@T.prim_func
def opaque_access_load(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
Expand Down Expand Up @@ -520,6 +626,40 @@ def test_reverse_compute_multi_reverse_loads():
verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_loads)


def test_reverse_compute_inline_affine_load():
sch = tir.Schedule(elementwise_reverse_affine_load, debug_mask="all")
block_c = sch.get_block("C")
sch.reverse_compute_inline(block_c)
tvm.ir.assert_structural_equal(elementwise_reverse_affine_load_inlined, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load)


def test_reverse_compute_inline_multi_affine_load():
sch = tir.Schedule(elementwise_multi_reverse_affine_load, debug_mask="all")
block_c = sch.get_block("C")
sch.reverse_compute_inline(block_c)
tvm.ir.assert_structural_equal(elementwise_multi_reverse_affine_load_inlined, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_affine_load)


def test_reverse_compute_inline_affine_load_unit_iter():
sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter, debug_mask="all")
block_c = sch.get_block("C")
sch.reverse_compute_inline(block_c)
print(sch.mod.script())
tvm.ir.assert_structural_equal(
elementwise_reverse_affine_load_unit_iter_inlined, sch.mod["main"]
)
verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter)


def test_reverse_compute_fail_non_affine_load():
sch = tir.Schedule(elementwise_reverse_non_affine_load, debug_mask="all")
block_c = sch.get_block("C")
with pytest.raises(tvm.tir.ScheduleError):
sch.reverse_compute_inline(block_c)


def test_reverse_compute_fail_multi_reverse_loads():
sch = tir.Schedule(elementwise_multi_loads, debug_mask="all")
block_c = sch.get_block("C")
Expand Down