Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][TIR][API-Change] Range/IntSet API style consistency. #5953

Merged
merged 1 commit into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/dev/inferbound.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,16 @@ The Ranges of the inner and outer IterVars of the split are set based on the par

.. code:: cpp

rmap[split->inner] = Range::make_by_min_extent(0, split->factor)
rmap[split->outer] = Range::make_by_min_extent(0, DivCeil(rmap[split->parent]->extent, split->factor))
rmap[split->inner] = Range::FromMinExtent(0, split->factor)
rmap[split->outer] = Range::FromMinExtent(0, DivCeil(rmap[split->parent]->extent, split->factor))

There is an opportunity here to tighten the bounds produced by InferBound, when ``split->factor`` does not evenly divide the parent's extent. Suppose the parent's extent is 20, and the split factor is 16. Then on the second iteration of the outer loop, the inner loop only needs to perform 4 iterations, not 16. If PassDownDomain could set the extent of ``split->inner`` to ``min(split->factor, rmap[split->parent]->extent - (split->outer * split->factor))``, then the extent of the inner variable would properly adapt, based on which iteration of the outer loop is being executed.

For Fuse relations, the Range of the fused IterVar is set based on the known Ranges of the inner and outer IterVars, as follows:

.. code:: cpp

rmap[fuse->fused] = Range::make_by_min_extent(0, rmap[fuse->outer]->extent * rmap[fuse->inner]->extent)
rmap[fuse->fused] = Range::FromMinExtent(0, rmap[fuse->outer]->extent * rmap[fuse->inner]->extent)


InferRootBound
Expand Down
51 changes: 22 additions & 29 deletions include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,82 +65,75 @@ class IntSetNode : public Object {
*/
class IntSet : public ObjectRef {
public:
/*! \brief constructor */
IntSet() {}
// constructor from not container.
explicit IntSet(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
const IntSetNode* operator->() const { return static_cast<const IntSetNode*>(get()); }
/*!
* \brief Find a range that covers the region.
* \param max_range The range to be covered.
* \return The covering range.
*/
Range cover_range(Range max_range) const;
Range CoverRange(Range max_range) const;
/*! \return Lower bound of the set */
PrimExpr min() const;
/*! \return upper bound of the set */
PrimExpr max() const;
/*! \return The sign of the elements in the integer set */
SignType GetSignType() const;
/*! \return Whether the set represent nothing */
bool is_nothing() const;
bool IsNothing() const;
/*! \return Whether the set represent everything */
bool is_everything() const;
bool IsEverything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
bool IsSinglePoint() const;
/*! \return Whether the set is proved to be bigger than 0 */
bool can_prove_positive() const;
bool CanProvePositive() const;
/*! \return Whether the set is proved to be smaller than 0 */
bool can_prove_negative() const;
bool CanProveNegative() const;
/*! \return Whether the set is proved to be smaller than or equal to 0 */
bool can_prove_non_positive() const;
bool CanProveNonPositive() const;
/*! \return Whether the set is proved to be larger than or equal to 0 */
bool can_prove_non_negative() const;
/*! \return The sign of the elements in the integer set */
SignType sign_type() const;
bool CanProveNonNegative() const;
/*!
* \brief The single point value, call only if is_single_point is true
* \brief The single point value, call only if IsSinglePoint is true
* \return The point value.
*/
PrimExpr point_value() const;
PrimExpr PointValue() const;
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
* \note It is guanrateed that IntSet::FromRange(r).MatchRange(r) == true
* \return true if we can prove they are the same.
*/
bool match_range(const Range& r) const;
bool MatchRange(const tvm::Range& r) const;
/*! \return The set contains nothing */
static IntSet nothing();
static IntSet Nothing();
/*! \return The set contains everything */
static IntSet everything();
static IntSet Everything();
/*!
* \brief construct a point set.
* \param point The point in the set.
* \return construct a single point set
*/
static IntSet single_point(PrimExpr point);
static IntSet SinglePoint(PrimExpr point);
/*!
* \brief construct a integer set from vector expression.
* \param vec The vector expression, can also be single point.
* \return The result set containing the indices in the vector.
*/
static IntSet vector(PrimExpr vec);
static IntSet Vector(PrimExpr vec);
/*!
* \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
*/
static IntSet range(Range r);
static IntSet FromRange(tvm::Range r);
/*!
* \brief Construct a set representing a interval.
* \param min The minimum value of the interval.
* \param max The maximum value of the interval.
* \return constructed set.
*/
static IntSet interval(PrimExpr min, PrimExpr max);
static IntSet Interval(PrimExpr min, PrimExpr max);

TVM_DEFINE_OBJECT_REF_METHODS(IntSet, ObjectRef, IntSetNode);
};

//-----------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ class Range : public ObjectRef {
* \param min The minimum range.
* \param extent The extent of the range.
*/
static Range make_by_min_extent(PrimExpr min, PrimExpr extent);
static Range FromMinExtent(PrimExpr min, PrimExpr extent);
// declare range.
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
};
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, begin, end=None):
_ffi_api.Range, begin, end)

@staticmethod
def make_by_min_extent(min_value, extent):
def from_min_extent(min_value, extent):
"""Construct a Range by min and extent.

This constructs a range in [min_value, min_value + extent)
Expand All @@ -136,4 +136,4 @@ def make_by_min_extent(min_value, extent):
rng : Range
The constructed range.
"""
return _ffi_api.range_by_min_extent(min_value, extent)
return _ffi_api.Range_from_min_extent(min_value, extent)
2 changes: 1 addition & 1 deletion python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def wrap_up_realize(self, node, body):
if _scope == 'global':
body = self.wrap_up_binds(body)

_domain = [Range.make_by_min_extent(0, i) for i in _buf.shape]
_domain = [Range.from_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype
_true = tvm.runtime.convert(True)
body = tvm.tir.ProducerRealize(_buf, _domain, _true, body)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/te/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_region(tslice):
begin = idx.var
else:
begin = idx
region.append(Range.make_by_min_extent(begin, 1))
region.append(Range.from_min_extent(begin, 1))
return region


Expand Down
8 changes: 4 additions & 4 deletions src/arith/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class BoundDeducer : public ExprVisitor {
if (operand.dtype().is_uint()) {
sign_operand = kPositive;
} else {
sign_operand = expr_map_[operand].sign_type();
sign_operand = expr_map_[operand].GetSignType();
}

if (sign_operand == SignType::kNegative) {
Expand Down Expand Up @@ -315,7 +315,7 @@ void BoundDeducer::Deduce() {
void BoundDeducer::Relax() {
IntSet a = EvalSet(expr_, relax_map_);
IntSet b = EvalSet(result_, relax_map_);
if (a.is_everything() || b.is_everything()) {
if (a.IsEverything() || b.IsEverything()) {
success_ = false;
return;
}
Expand All @@ -336,7 +336,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success_) return IntSet::nothing();
if (!d.success_) return IntSet::Nothing();
PrimExpr min = neg_inf(), max = pos_inf();
if (d.comp_op == kEqual) {
min = d.result_;
Expand All @@ -346,7 +346,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e,
} else {
max = d.result_;
}
return IntSet::interval(min, max);
return IntSet::Interval(min, max);
}

// assuming e >= 0, deduce the bound of variable from it.
Expand Down
6 changes: 3 additions & 3 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ class BufferTouchedDomain final : public StmtExprVisitor {
Region ret;
Range none;
for (size_t i = 0; i < bounds_.size(); ++i) {
ret.push_back(arith::Union(bounds_[i]).cover_range(none));
ret.push_back(arith::Union(bounds_[i]).CoverRange(none));
}
return ret;
}

void VisitStmt_(const ForNode* op) final {
const VarNode* var = op->loop_var.get();
dom_map_[var] = IntSet::range(Range::make_by_min_extent(op->min, op->extent));
dom_map_[var] = IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
}
Expand All @@ -69,7 +69,7 @@ class BufferTouchedDomain final : public StmtExprVisitor {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const VarNode* var = thread_axis->var.get();
dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
dom_map_[var] = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
} else {
Expand Down
Loading