Skip to content

Commit

Permalink
Breakpoint, expose the transformed axes for use in TE scheduling.
Browse files Browse the repository at this point in the history
Final step, exposing the axes generated in .transform_layout for use
in TE scheduling.
  • Loading branch information
Lunderberg committed Dec 13, 2021
2 parents 9868fd5 + 5e18278 commit b961ad1
Show file tree
Hide file tree
Showing 6 changed files with 455 additions and 13 deletions.
1 change: 1 addition & 0 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class ComputeOp : public Operation {
Array<IterVar> axis, Array<PrimExpr> body);

TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode);
};

/*!
Expand Down
58 changes: 55 additions & 3 deletions include/tvm/te/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,14 @@ class Stage : public ObjectRef {
* Expressions should be in terms of the variables given in
* initial_indices.
*
* \param out_iter_vars An optional output location for the updated
* loop iteration variables.
*
* \return reference to self
*/
TVM_DLL Stage& transform_layout(const Array<Var>& initial_indices,
const Array<PrimExpr>& final_indices);
const Array<PrimExpr>& final_indices,
Array<IterVar>* out_iter_vars = nullptr);
/*! \brief Defines separators between groups of axes.
*
* Used to define `BufferNode::axis_separators`, which has
Expand Down Expand Up @@ -494,9 +498,27 @@ class StageNode : public Object {
* while origin_op remains fixed.
*/
Operation origin_op;
/*! \brief All the nodes in the iter var */
/*! \brief All the nodes in the iter var
*
* Each element of all_iter_vars represents an iteration variable
* that may appear within this stage's computation. Any element
* of `all_iter_vars` that is in `leaf_iter_vars` represents a
* variable that is directly defined and usable within the stage's
* computation. All other elements of `all_iter_vars` represent
* variables whose value must be computed from the variables in
* `leaf_iter_vars`. (e.g. Support index k has been split by
* ``ko, ki = s.split(k, factor=4)``. ko and ki will appear in
* `leaf_iter_vars`, while k will not, and must be computed as
* `4*ko + ki`.
*/
Array<IterVar> all_iter_vars;
/*! \brief The current active leaf iter vars in the stage. */
/*! \brief The current active leaf iter vars in the stage.
*
* Each element of leaf_iter_vars will either be replaced with the
* bound index (e.g. threadIdx.x), or will be expanded into a loop
* over the variable's extent. `leaf_iter_vars` is a subset of
* `all_iter_vars`.
*/
Array<IterVar> leaf_iter_vars;
/*!
* \brief Specify threads to be launched at the stage.
Expand Down Expand Up @@ -809,6 +831,36 @@ class Singleton : public IterVarRelation {
TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
};

/*!
* \brief Transform iterator according to some arbitrary expression.
*/
class TransformNode : public IterVarRelationNode {
public:
Array<IterVar> original_variables;
Array<IterVar> transformed_variables;
IndexMap forward_transformation;
IndexMap inverse_transformation;

void VisitAttrs(AttrVisitor* v) {
v->Visit("original_variables", &original_variables);
v->Visit("transformed_variables", &transformed_variables);
v->Visit("forward_transformation", &forward_transformation);
v->Visit("inverse_transformation", &inverse_transformation);
}

static constexpr const char* _type_key = "Transform";
TVM_DECLARE_FINAL_OBJECT_INFO(TransformNode, IterVarRelationNode);
};

class Transform : public IterVarRelation {
public:
TVM_DLL explicit Transform(Array<IterVar> original_variables,
Array<IterVar> transformed_variables, IndexMap forward_transformation,
IndexMap inverse_transformation);

TVM_DEFINE_OBJECT_REF_METHODS(Transform, IterVarRelation, TransformNode);
};

/*! \brief Container for specialization conditions. */
class SpecializedConditionNode : public Object {
public:
Expand Down
46 changes: 39 additions & 7 deletions python/tvm/te/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,15 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
"""Defines the layout transformation for the current stage's tensor.
The map from initial_indices to final_indices must be an
invertible affine transformation.
invertible affine transformation. This method may be called
more than once for a given tensor, in which case each
transformation is applied sequentially.
This method may be called more than once for a given tensor, in which case each
If the stage is a ComputeOp, then the iteration order of the
compute stage is rewritten to be a row-major traversal of the
tensor, and the new loop iteration variables are returned.
For all other stages, the loop iteration order is unmodified,
and the return value is None.
Parameters
----------
Expand All @@ -543,6 +549,17 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
the current stage's tensor, using the post-transformation
layout.
Returns
-------
new_iter_vars : Optional[List[tvm.tir.IterVar]]
If the stage is a ComputeOp, then the return will be the
updated loop iteration variables over the data array, in
the same order as the output values from the
`mapping_function`.
Otherwise, the return value is None.
Examples
--------
.. code-block:: python
Expand All @@ -557,15 +574,29 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
.. code-block:: python
# ``A`` is a tensor whose compute definition is in format,
# and should be transformed such that the last index is
# split, with the slower-chan index of the split placed at the
# slowest changing dimension.
# ``A`` is a tensor whose compute definition is in an
# arbitrary format, and should be transformed such that
# the last index is split, with the slower-changing index
# of the split placed at the slowest changing dimension.
s[A].transform_layout(
lambda *indices, i: [i//4, *indices, i%4]
)
.. code-block:: python
# ``B`` is a tensor defined by te.compute to be a copy of
# ``A`, and should be transformed such that ``B``'s layout
# is a transpose of ``A``'s layout. The loop iteration
# that computes ``B`` will correspond to ``B``'s memory
# layout.
A = te.placeholder([n,m])
B = te.compute(A.shape, lambda i,j: A[i,j])
s = te.create_schedule(B.op)
s[B].transform_layout(lambda i,j: [j,i])
"""

args = []
Expand Down Expand Up @@ -626,9 +657,10 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
"Instead received {val} of type {type(val)}."
)

_ffi_api.StageTransformLayout(self, initial_indices, final_indices)
new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, final_indices)
_ffi_api.StageSetAxisSeparators(self, axis_separators)

return new_iter_vars or None


@tvm._ffi.register_object
Expand Down
140 changes: 140 additions & 0 deletions src/te/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>*
} else if (const RebaseNode* s = rel.as<RebaseNode>()) {
state[s->parent] = state[s->rebased];
} else if (rel.as<SingletonNode>()) {
} else if (const TransformNode* s = rel.as<TransformNode>()) {
// Currently, this marks all original iter vars as deriving from
// a thread bind if any of the transformed variables are bound,
// even if the inverse expression for that iter var doesn't
// depend on the bound variable.

// TODO(Lunderberg): For each of original variable, check
// whether any variable in the inverse expression for it has a
// thread binding.
bool is_thread_binding = false;
for (const auto& iter_var : s->transformed_variables) {
is_thread_binding = is_thread_binding || state[iter_var];
}
for (const auto& iter_var : s->original_variables) {
state[iter_var] = is_thread_binding;
}
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -157,6 +173,29 @@ void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_st
Update(p_state, r->rebased, Range::FromMinExtent(0, state.at(r->parent)->extent), actx);
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
Update(p_state, s->iter, Range::FromMinExtent(0, 1), actx);
} else if (const TransformNode* s = rel.as<TransformNode>()) {
bool missing_originals = false;
for (const auto& iter_var : s->original_variables) {
if (!state.count(iter_var)) {
ICHECK(allow_missing);
missing_originals = true;
}
}
if (missing_originals) {
continue;
}

Array<Range> original_ranges;
for (const auto& iter_var : s->original_variables) {
original_ranges.push_back(state[iter_var]);
}
Array<Range> updated_ranges = s->forward_transformation->MapRanges(original_ranges);

ICHECK_EQ(updated_ranges.size(), s->transformed_variables.size());
for (size_t i = 0; i < updated_ranges.size(); i++) {
Update(p_state, s->transformed_variables[i], updated_ranges[i], actx);
}

} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -225,6 +264,39 @@ void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
state[s->parent] = value;
}
} else if (rel.as<SingletonNode>()) {
} else if (const TransformNode* s = rel.as<TransformNode>()) {
bool missing_transformed = false;
for (const auto& iter_var : s->transformed_variables) {
if (!state.count(iter_var)) {
// for (const auto& kv : state) {
// std::cout << "Looking for " << tvm::PrettyPrint(iter_var) << std::endl;
// std::cout << "State contains " << tvm::PrettyPrint(kv.first) << " -> "
// << tvm::PrettyPrint(kv.second) << std::endl;
// }
// TODO: Decide whether to have this check, for similarity
// with other handlers. In this case, the indices may
// already be in terms of the pre-transformed variables, so
// no need to untransform them?

// ICHECK(allow_missing);
missing_transformed = true;
}
}
if (missing_transformed) {
continue;
}

Array<PrimExpr> transformed_indices;
for (const auto& iter_var : s->transformed_variables) {
transformed_indices.push_back(state[iter_var]);
}
Array<PrimExpr> original_indices = s->inverse_transformation->MapIndices(transformed_indices);

ICHECK_EQ(original_indices.size(), s->original_variables.size());
for (size_t i = 0; i < original_indices.size(); i++) {
state[s->original_variables[i]] = original_indices[i];
}

} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -270,6 +342,28 @@ void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
state[s->rebased] = value;
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
state[s->iter] = make_zero(s->iter->var.dtype());
} else if (const TransformNode* s = rel.as<TransformNode>()) {
bool missing_originals = false;
for (const auto& iter_var : s->original_variables) {
if (!state.count(iter_var)) {
ICHECK(allow_missing);
missing_originals = true;
}
}
if (missing_originals) {
continue;
}

Array<PrimExpr> original_indices;
for (const auto& iter_var : s->original_variables) {
original_indices.push_back(state[iter_var]);
}
Array<PrimExpr> transformed_indices = s->forward_transformation->MapIndices(original_indices);

ICHECK_EQ(transformed_indices.size(), s->transformed_variables.size());
for (size_t i = 0; i < transformed_indices.size(); i++) {
state[s->transformed_variables[i]] = transformed_indices[i];
}
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -351,6 +445,26 @@ void PassUpDomain(const RebaseNode* s, const std::unordered_map<IterVar, Range>&
*parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}});
}

Array<IntSet> PassUpDomain(const TransformNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const Map<IterVar, IntSet>& transformed_domains) {
Array<IntSet> output;

Array<PrimExpr> transformed_indices;
for (const auto& iter_var : s->transformed_variables) {
transformed_indices.push_back(iter_var->var);
}

Array<PrimExpr> transformed_exprs = s->inverse_transformation->MapIndices(transformed_indices);

ICHECK_EQ(transformed_exprs.size(), s->original_variables.size());
for (size_t i = 0; i < transformed_exprs.size(); i++) {
output.push_back(arith::EvalSet(transformed_exprs[i], transformed_domains));
}

return output;
}

void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) {
auto& state = *p_state;
Expand All @@ -370,6 +484,16 @@ void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>&
PassUpDomain(r, dom_map, state.at(r->rebased), &parent);
state[r->parent] = parent;
} else if (rel.as<SingletonNode>()) {
} else if (const TransformNode* r = rel.as<TransformNode>()) {
Map<IterVar, IntSet> transformed_domains;
for (const auto& var : r->transformed_variables) {
transformed_domains.Set(var, state.at(var));
}
auto original_ranges = PassUpDomain(r, dom_map, transformed_domains);
ICHECK_EQ(original_ranges.size(), r->original_variables.size());
for (size_t i = 0; i < original_ranges.size(); i++) {
state[r->original_variables[i]] = original_ranges[i];
}
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down Expand Up @@ -509,6 +633,22 @@ void PassUpBoundCheck(const Stage& s, const Map<IterVar, Range>& dom_map,
state[s->parent] = state.at(s->rebased);
} else if (rel.as<SingletonNode>()) {
// nop
} else if (const TransformNode* s = rel.as<TransformNode>()) {
// Currently, this marks all original iter vars as requiring
// bounds checks if any of the transformed variables require
// bounds checks, even if the inverse expression for that iter
// var doesn't depend on the bound variable.

// TODO(Lunderberg): For each of original variable, check
// whether any variable in the inverse expression for it
// requires bounds checking.
bool needs_bounds_check = false;
for (const auto& iter_var : s->transformed_variables) {
needs_bounds_check = needs_bounds_check || state[iter_var];
}
for (const auto& iter_var : s->original_variables) {
state[iter_var] = needs_bounds_check;
}
} else {
LOG(FATAL) << "unknown relation type";
}
Expand Down
Loading

0 comments on commit b961ad1

Please sign in to comment.